"""
MPI/CGLS

Xiangrui Meng <mengxr@stanford.edu>
iCME, Stanford University
"""

from math import sqrt, log

import numpy as np
from numpy.random import seed, randn
from numpy.linalg import norm, svd

import scipy as sp
import scipy.sparse as sps
from scipy.sparse.linalg import aslinearoperator

from mpi4py import MPI

from _gen_prob import _gen_prob

def cgls( A, b, over = True, tol = 1e-15, iter_lim = None, comm = MPI.COMM_WORLD ):
    """
    MPI/CGLS

    Parameters
    ----------
    
    A : m-by-n {ndarray, matrix, sparse, LinearOperator}. If over is True, A contains a
        subset of rows of the full matrix. Otherwise, A contains a subset of columns.

    b : (m,) ndarray. If over is True, b is partitioned the same way as A. If over is False,
        b is the RHS.

    over : over-determined or under-determined.

    tol : float, tolerance such that norm(A*x-A*x_opt)<tol*norm(A*x_opt)

    iter_lim : max number of iterations

    comm : mpi communicator

    Returns
    -------

    x : (n,) ndarray, solution

    flag : int, 0 means solved, 1 means cgls didn't converge

    itn : int, number of iterations
    """

    rank = comm.Get_rank()

    A     = aslinearoperator(A)
    m, n  = A.shape

    if over is True:                    # over-determined

        x       = np.zeros(n)
        r       = b.squeeze().copy()
        s       = comm.allreduce(A.rmatvec(b))
        sq_s    = np.dot(s,s)
        nrm_s_0 = sqrt(sq_s)
        gamma   = sq_s
        p       = s.copy()

        itn       = 0
        converged = False
        if iter_lim is None:
            iter_lim = 2*n

        while (not converged) and (itn < iter_lim):

            itn    += 1

            q       = A.matvec(p)
            sq_q    = comm.allreduce(np.dot(q,q))
            alpha   = gamma / sq_q
            x      += alpha*p
            r      -= alpha*q
            s       = comm.allreduce(A.rmatvec(r))
            sq_s    = np.dot(s,s)
            nrm_s   = sqrt(sq_s)
            gamma_0 = gamma
            gamma   = sq_s
            beta    = gamma / gamma_0
            p       = s + beta*p

            if nrm_s < tol*nrm_s_0:
                converged = True

        flag = (itn==iter_lim) + 0

        return x, flag, itn

    else:                               # under-determined

        x       = np.zeros(n)
        r       = b.copy()
        s       = A.rmatvec(b)
        sq_s    = comm.allreduce(np.dot(s,s))
        nrm_s_0 = sqrt(sq_s)
        gamma   = sq_s
        p       = s.copy()

        itn       = 0
        converged = False
        if iter_lim is None:
            iter_lim = 2*m

        while (not converged) and (itn < iter_lim):

            itn    += 1

            q       = comm.allreduce(A.matvec(p))
            sos_q   = norm(q)**2
            alpha   = gamma / sos_q
            x      += alpha*p
            r      -= alpha*q
            s       = A.rmatvec(r)
            sq_s    = comm.allreduce(np.dot(s,s))
            nrm_s   = sqrt(sq_s)
            gamma_0 = gamma
            gamma   = sq_s
            beta    = gamma / gamma_0
            p       = s + beta*p

            if nrm_s < tol*nrm_s_0:
                converged = True

        flag = (itn==iter_lim) + 0

        return x, flag, itn

def _test():

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    
    m    = 1e5
    n    = 1e2
    r    = n
    cond = 10.0

    tol      = 1e-14
    iter_lim = np.ceil( (log(tol)-log(2.0))/log((cond-1.0)/(cond+1.0)) )

    # an over-determined case

    A, b, x_opt  = _gen_prob(m,n,cond,r,comm)
    x, flag, itn = cgls(A,b,True,tol/cond,iter_lim,comm)
    if rank == 0:
        if flag == 0:
            print "CGLS converged in %d iterations." % (itn,)
        else:
            print "CGLS didn't converge in %d iterations." % (itn,)

    sq_A_x_opt = comm.reduce(norm(A.dot(x_opt))**2)
    sq_A_err   = comm.reduce(norm(A.dot(x-x_opt))**2)

    if rank == 0:
        relerr = sqrt(sq_A_err)/sqrt(sq_A_x_opt)
        if relerr < tol:
            print "Over-determined test passed with relerr (AtA-norm) %G" % (relerr,)
        else:
            print "Over-determined test failed with relerr (AtA-norm) %G" % (relerr,)

    # an under-determined case

    A, b, x_opt  = _gen_prob(n,m,cond,r,comm)
    x, flag, itn = cgls(A,b,False,tol/cond,iter_lim,comm)
    if rank == 0:
        if flag == 0:
            print "CGLS converged in %d iterations." % (itn,)
        else:
            print "CGLS didn't converge in %d iterations." % (itn,)

    A_x_opt = comm.reduce(A.dot(x_opt))
    A_err   = comm.reduce(A.dot(x-x_opt))
    if rank == 0:
        relerr = norm(A_err)/norm(A_x_opt)
        if relerr < tol:
            print "Under-determined test passed with relerr (AtA-norm) %G" % (relerr,)
        else:
            print "Under-determined test failed with relerr (AtA-norm) %G" % (relerr,)

if __name__ == '__main__':
    _test()
