"""
Distributed LSRN

Xiangrui Meng <mengxr@stanford.edu>
iCME, Stanford University
"""

__all__ = [ 'lsrn' ]

from math import sqrt, log
from time import time

from exceptions import NotImplementedError

import numpy as np
from numpy.linalg import norm, svd, lstsq

from scipy.sparse.linalg import aslinearoperator, LinearOperator

from mpi4py       import MPI

from ziggurat     import seed, randn

from lsqr         import lsqr
from ls_chebyshev import ls_chebyshev
from barrier      import barrier
from _gen_prob    import _gen_prob

def lsrn( A, b, over = True, gamma = 2.0, tol = 1e-14, rcond = 1e-10, blk_sz = -1, solver = lsqr, 
          comm = MPI.COMM_WORLD ):
    """
    LSRN (for over-determined system) computes the min-length solution of linear least
    squares via LSQR with randomized preconditioning

    Parameters
    ----------
    
    A : m-by-n {ndarray, matrix, sparse}. If over is True, A contains a subset of rows
        of the full matrix. Otherwise, A contains a subset of columns.

    b : (m,) ndarray. If over is True, b is partitioned the same way as A. Otherwise, b
        is the full RHS vector.

    over : bool, over-determined or under-determined.

    gamma : float (>1), oversampling factor

    tol : float, tolerance such that norm(A*x-A*x_opt)<tol*norm(A*x_opt)

    rcond : float, reciprocal condition number

    blk_sz : block size for G

    solver : iterative solver

    comm : mpi communicator

    Returns
    -------

    x : (n,) ndarray, the min-length solution

    r : int, the rank of A

    flag : int, 0 means solved, 1 means cgls reached iteration limit.

    itn : int, cgls iteration number

    timing : dict, timing info
    """

    tic_all = time()
    
    timing = { 'randn': 0.0, 'mult': 0.0, 
               'svd': 0.0, 'iter': 0.0,
               'comm': 0.0, 'all': 0.0 }

    rank = comm.Get_rank()
    size = comm.Get_size()
    
    m, n = A.shape
    if rcond < 0:
        rcond = np.min([m,n]) * np.finfo(float).eps
    seed(7984579*(rank+1))

    if over is True:            # over-determined

        s  = np.ceil(gamma*n)

        As = np.zeros([n,s])

        if (blk_sz < 0) or (blk_sz > s):
            blk_sz = s

        for i in xrange(int(np.ceil(1.0*s/blk_sz))):

            blk_begin = i*blk_sz
            blk_end   = np.min([(i+1)*blk_sz,s])
            blk_len   = blk_end-blk_begin

            tic_randn        = time()
            G                = randn(m,blk_len)
            timing['randn'] += time() - tic_randn

            tic_mult                = time()
            As[:,blk_begin:blk_end] = A.T.dot(G)
            timing['mult']         += time() - tic_mult

        tic_comm        = time()
        As              = comm.reduce(As)
        timing['comm'] += time() - tic_comm

        if rank == 0:
            tic_svd        = time()
            V, S           = svd(As,False)[:2]
            r_tol          = S[0]*rcond
            r              = np.sum(S>r_tol)
            N              = V[:,:r]/S[:r]
            timing['svd'] += time() - tic_svd
        else:
            N = None

        barrier(comm)

        tic_comm        = time()
        N               = comm.bcast(N)
        timing['comm'] += time() - tic_comm

        tic_iter = time()

        r        = N.shape[1]
        gamma    = 1.0*s/r
        condest  = (sqrt(gamma)+1.0)/(sqrt(gamma)-1.0) # the estimate of cond(AN)

        def AN_op_matvec(v):
            return A.dot(N.dot(v))
        def AN_op_rmatvec(v):
            return N.T.dot(A.T.dot(v))
        AN_op = LinearOperator((m, r),
                               matvec  = AN_op_matvec, 
                               rmatvec = AN_op_rmatvec)

        if solver == lsqr:
            iter_lim     = np.ceil(-2*(log(tol)-log(2)) / log(gamma))
            y, flag, itn = lsqr( AN_op, b, True, tol/condest, iter_lim, comm=comm )[:3]
        elif solver == ls_chebyshev:
            p_fail       = 0.1  # failure rate
            t            = sqrt(-2.0*log(p_fail/2.0))
            s_max        = 1.0/(sqrt(s)-sqrt(r)-t)
            s_min        = 1.0/(sqrt(s)+sqrt(r)+t)
            y, flag, itn = ls_chebyshev( AN_op, b, True, s_max, s_min, tol, None, comm )
        else:
            raise NotImplementedError

        x = N.dot(y)

        timing['iter'] += time() - tic_iter

    else:                       # under-determined

        s  = np.ceil(gamma*m)

        As = np.zeros([m,s])

        if (blk_sz < 0) or (blk_sz > s):
            blk_sz = s

        for i in xrange(int(np.ceil(1.0*s/blk_sz))):

            blk_begin = i*blk_sz
            blk_end   = np.min( [ (i+1)*blk_sz, s ] )
            blk_len   = blk_end - blk_begin
            
            tic_randn        = time()
            G                = randn(n,blk_len)
            timing['randn'] += time() - tic_randn

            tic_mult                = time()
            As[:,blk_begin:blk_end] = A.dot(G)
            timing['mult']         += time() - tic_mult
        
        tic_comm        = time()
        As              = comm.reduce(As)
        timing['comm'] += time() - tic_comm

        if rank == 0:
            tic_svd        = time()
            V, S           = svd(As,False)[:2]
            r_tol          = S[0]*rcond
            r              = np.sum(S>r_tol)
            M              = (V[:,:r]/S[:r]).T
            timing['svd'] += time() - tic_svd
        else:
            M     = None

        barrier(comm)

        tic_comm        = time()
        M               = comm.bcast(M)
        timing['comm'] += time() - tic_comm

        tic_iter = time()

        r        = M.shape[0]
        Mb       = M.dot(b)
        gamma    = 1.0*s/r
        condest  = (sqrt(gamma)+1.0)/(sqrt(gamma)-1.0) # the estimate of cond(AN)
        
        def MA_op_matvec(v):
            return M.dot(A.dot(v))
        def MA_op_rmatvec(v):
            return A.T.dot(M.T.dot(v))
        MA_op = LinearOperator((r,n),
                               matvec = MA_op_matvec,
                               rmatvec = MA_op_rmatvec)

        if solver == lsqr: 
            iter_lim     = np.ceil(-2*(log(tol)-log(2)) / log(gamma))
            x, flag, itn = lsqr( MA_op, Mb, False, tol, iter_lim, comm )[:3]
        elif solver == ls_chebyshev:
            p_fail       = 0.1                   # failure rate
            t            = sqrt(-2.0*log(p_fail/2.0))
            s_max        = 1.0/(sqrt(s)-sqrt(r)-t)
            s_min        = 1.0/(sqrt(s)+sqrt(r)+t)
            x, flag, itn = ls_chebyshev( MA_op, Mb, False, s_max, s_min, tol, None, comm )

        timing['iter'] += time() - tic_iter

    timing['all'] = time() - tic_all
        
    return x, r, flag, itn, timing

def _test():

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    m     = 1e5
    n     = 1e3
    cond  = 1e6
    r     = n
    tol   = 1e-14
    gamma = 2.0

    A, b, x_opt  = _gen_prob(m,n,cond,r,comm)
    x, r, flag, itn, timing = lsrn( A, b, True, gamma=gamma, tol=tol, solver=lsqr, comm=comm )

    if rank == 0:
        print 'LSRN(over) time: %f.' % (timing['all'],)
        print timing
        if flag >= 0:
            print "LSQR converged in %d iterations." % (itn,)
        else:
            print "LSQR didn't converge in %d iterations." % (itn,)

    if rank == 0:
        relerr = norm(x-x_opt)/norm(x_opt)
        if relerr < tol*cond:
            print "Over-determined test passed with relerr %G." % (relerr,)
        else:
            print "Over-determined test failed with relerr %G," % (relerr,)

    A, b, x_opt  = _gen_prob(n,m,cond,r,comm)
    x, r, flag, itn, timing = lsrn( A, b, False, gamma=gamma, tol=tol, solver=lsqr, comm=comm )

    if rank == 0:
        print 'LSRN(under) time: %f.' % (timing['all'],)
        print timing
        if flag >= 0:
            print "LSQR converged in %d iterations." % (itn,)
        else:
            print "LSQR didn't converge in %d iterations." % (itn,)

    sq_err = comm.reduce(norm(x-x_opt)**2)
    sq_x_opt = comm.reduce(norm(x_opt)**2)
    if rank == 0:
        relerr    = sqrt(sq_err)/sqrt(sq_x_opt)
        if relerr < tol*cond:
            print "Under-determined test passed with relerr %G." % (relerr,)
        else:
            print "Under-determined test failed with relerr %G." % (relerr,)    

if __name__ == '__main__':
    _test()
