"""
MPI/LSQR for MPI/LSRN

Xiangrui Meng <mengxr@stanford.edu>
iCME, Stanford
"""

from exceptions import NotImplementedError

from math import sqrt, log

import numpy as np
from numpy.linalg import norm

import scipy as sp
from scipy.sparse.linalg import aslinearoperator

from mpi4py import MPI

from _gen_prob import _gen_prob

def lsqr( A, b, over=True, tol=1e-14, iter_lim=None, comm=MPI.COMM_WORLD ):
    """
    A simple version of MPI/LSQR for MPI/LSRN
    """

    rank = comm.Get_rank()

    A    = aslinearoperator(A)
    m, n = A.shape

    eps  = 32*np.finfo(float).eps        # should be larger than the real eps

    if   tol < eps:
         tol = eps
    elif tol > 1-eps:
         tol = 1-eps    

    max_n_stag = 3
    stag       = 0        

    if over is True:                    # over-determined

        u    = b.squeeze().copy()
        beta = sqrt(comm.allreduce(np.dot(u,u)))
        if beta != 0:
            u   /= beta

        v     = comm.allreduce(A.rmatvec(u))
        alpha = sqrt(np.dot(v,v))
        if alpha != 0:
            v    /= alpha
            
        w     = v.copy()
        x     = np.zeros(n)

        phibar = beta
        rhobar = alpha

        nrm_ar_0 = alpha*beta

        if nrm_ar_0 == 0:
            return x, 0, 0

        sq_d   =  0.0
        cnd_a  =  0.0
        nrm_a  =  0.0
        nrm_r  =  beta
        nrm_x  =  0.0
        sq_x   =  0.0
        z      =  0.0
        cs2    = -1.0
        sn2    =  0.0

        flag = -1
        if iter_lim is None:
            iter_lim = np.max( [20, 2*n] )

        for itn in xrange(int(iter_lim)):

            u    = A.matvec(v) - alpha*u
            beta = sqrt(comm.allreduce(np.dot(u,u)))
            u   /= beta

            # estimate of norm(A)
            nrm_a = sqrt(nrm_a**2 + alpha**2 + beta**2)

            v     = comm.allreduce(A.rmatvec(u)) - beta*v
            alpha = sqrt(np.dot(v,v))
            v    /= alpha

            rho    = sqrt(rhobar**2+beta**2)
            cs     = rhobar/rho
            sn     = beta/rho
            theta  = sn*alpha
            rhobar = -cs*alpha
            phi    = cs*phibar
            phibar = sn*phibar            

            x     += (phi/rho)*w
            w      = v-(theta/rho)*w

            # estimate of norm(r)
            nrm_r  = phibar

            # estimate of norm(A'*r)
            nrm_ar = phibar*alpha*np.abs(cs)

            # check convergence
            if nrm_ar < tol * nrm_ar_0:
                flag = 0
                break
            
            if nrm_ar < eps * nrm_a*nrm_r:
                flag = 0
                break

            # estimate of cond(A)
            sq_w   = np.dot(w,w)
            nrm_w  = sqrt(sq_w)
            sq_d  += sq_w/(rho**2)
            cnd_a  = nrm_a*sqrt(sq_d)

            # check condition number
            if cnd_a > 1.0/eps:
                flag = 1
                break

            # check stagnation
            if np.abs(phi/rho)*nrm_w < eps*nrm_x:
                stag += 1
            else:
                stag  = 0
            if stag >= max_n_stag:
                flag = 1
                break

            # estimate of norm(x)
            delta   =  sn2*rho
            gambar  = -cs2*rho
            rhs     =  phi - delta*z
            zbar    =  rhs/gambar
            nrm_x   =  sqrt(sq_x + zbar**2)
            gamma   =  sqrt(gambar**2 + theta**2)
            cs2     =  gambar/gamma
            sn2     =  theta /gamma
            z       =  rhs   /gamma
            sq_x   +=  z**2

    else:                               # under-determined

        u    = b.squeeze().copy()
        beta = sqrt(np.dot(u,u))
        if beta != 0:
            u /= beta

        v     = A.rmatvec(u)
        alpha = sqrt(comm.allreduce(np.dot(v,v)))
        if alpha != 0:
            v /= alpha

        w     = v.copy()
        x     = np.zeros(n)

        phibar = beta
        rhobar = alpha

        nrm_ar_0 = alpha*beta       

        if nrm_ar_0 == 0:
            return x, 0, 0

        sq_d  =  0.0
        cnd_a =  0.0
        nrm_a =  0.0
        nrm_r =  beta
        nrm_x =  0.0
        sq_x  =  0.0
        z     =  0.0
        cs2   = -1.0
        sn2   =  0.0

        flag = -1
        if iter_lim is None:
            iter_lim = np.max( [ 20, m ] )

        for itn in xrange(int(iter_lim)):

            u    = comm.allreduce(A.matvec(v)) - alpha*u
            beta = sqrt(np.dot(u,u))
            u   /= beta

            # estimate of norm(A)
            nrm_a = sqrt(nrm_a**2 + alpha**2 + beta**2)

            v     = A.rmatvec(u) - beta*v
            alpha = sqrt(comm.allreduce(np.dot(v,v)))
            v    /= alpha

            rho    =  sqrt(rhobar**2+beta**2)
            cs     =  rhobar/rho
            sn     =  beta/rho
            theta  =  sn*alpha
            rhobar = -cs*alpha
            phi    =  cs*phibar
            phibar =  sn*phibar

            x     += (phi/rho)*w
            w      = v-(theta/rho)*w
            
            # estimate of norm(r)
            nrm_r  = phibar

            # estimate of norm(A'*r)
            nrm_ar = phibar*alpha*np.abs(cs)

            # check convergence
            if nrm_ar < tol * nrm_ar_0:
                flag = 0
                break

            if nrm_ar < eps * nrm_a*nrm_r:
                flag = 0
                break

            # estimate of cond(A)
            sq_w   = comm.allreduce(np.dot(w,w))
            nrm_w  = sqrt(sq_w)
            sq_d  += sq_w/(rho**2)
            cnd_a  = nrm_a*sqrt(sq_d)

            # check condition number
            if cnd_a > 1.0/eps:
                flag = 1
                break

            # check stagnation
            if np.abs(phi/rho)*nrm_w < eps*nrm_x:
                stag += 1
            else:
                stag  = 0
            if stag >= max_n_stag:
                flag = 1
                break            

            # estimate of norm(x)
            delta   =  sn2*rho
            gambar  = -cs2*rho
            rhs     =  phi - delta*z
            zbar    =  rhs/gambar
            nrm_x   =  sqrt(sq_x + zbar**2)
            gamma   =  sqrt(gambar**2 + theta**2)
            cs2     =  gambar/gamma
            sn2     =  theta /gamma
            z       =  rhs   /gamma
            sq_x   +=  z**2           

    return x, flag, itn
            
def _test():

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    
    m    = 1e4
    n    = 1e2
    r    = 50                            # we only need full-rank case
    cond = 8

    tol      = 1e-14
    iter_lim = np.ceil( (log(tol)-log(2.0))/log((cond-1.0)/(cond+1.0)))

    # an over-determined case

    A, b, x_opt  = _gen_prob(m,n,cond,r,comm)
    x, flag, itn = lsqr(A,b,True,tol/cond,iter_lim,comm)
    if rank == 0:
        if flag >= 0:
            print "LSQR converged in %d/%d iterations with flag %d." % (itn,iter_lim,flag)
        else:
            print "LSQR didn't converge in %d iterations." % (itn,)

    if rank == 0:
        relerr = norm(x-x_opt)/norm(x_opt)
        if relerr < tol*cond:
            print "Over-determined test passed with relerr %G" % (relerr,)
        else:
            print "Over-determined test failed with relerr %G" % (relerr,)

    # an under-determined case

    A, b, x_opt = _gen_prob(n,m,cond,r,comm)
    x, flag, itn = lsqr(A,b,False,tol/cond,iter_lim,comm)
    if rank == 0:
        if flag >= 0:
            print "LSQR converged in %d/%d iterations with flag %d." % (itn,iter_lim,flag)
        else:
            print "LSQR didn't converge in %d iterations." % (itn,)

    if rank == 0:
        relerr = norm(x-x_opt)/norm(x_opt)
        if relerr < tol*cond:
            print "Over-determined test passed with relerr %G" % (relerr,)
        else:
            print "Over-determined test failed with relerr %G" % (relerr,)

if __name__ == '__main__':
    _test()
    
