"""
LSRN computes the min-length solution of linear least squares via LSQR with randomized
preconditioning. LSRN works best when m >> n or m << n.

Xiangrui Meng <mengxr@stanford.edu>
iCME, Stanford University
"""

from exceptions import NotImplementedError

from time import time, clock
from math import log, sqrt

import numpy as np
from numpy.linalg import norm, svd, lstsq

from scipy.sparse.linalg import aslinearoperator, LinearOperator, lsqr

from ziggurat import seed, randn

from cgls import cgls
from lsqr import lsqr
from nesterov import ls_nesterov

from _gen_prob import _gen_prob

def lsrn( A, b, gamma=2.0, tol=np.finfo(float).eps, rcond=-1, ls_solver=lsqr ):
    """
    LSRN computes the min-length solution of linear least squares via LSQR with
    randomized preconditioning

    Parameters
    ----------
    
    A       : {matrix, sparse matrix, ndarray, LinearOperator} of size m-by-n

    b       : (m,) ndarray

    gamma : float (>1), oversampling factor

    tol : float, tolerance such that norm(A*x-A*x_opt)<tol*norm(A*x_opt)

    rcond : float, reciprocal condition number

    Returns
    -------
    x      : (n,) ndarray, the min-length solution
    
    r      : int, the rank of A

    flag : int,

    itn : int, iteration number

    timing : dict, 
    """

    m, n = A.shape

    if rcond < 0:
        rcond = np.min([m,n])*np.finfo(float).eps
    
    timing = { 'rand': 0.0, 'mult': 0.0, 'svd': 0.0, 'cg': 0.0 }

    if m > n:                           # over-determined

        s      = np.ceil(gamma*n)
        As     = np.zeros([s,n])
        blk_sz = 128
        for i in xrange(int(np.ceil(1.0*s/blk_sz))):
            blk_begin = i*blk_sz;
            blk_end   = np.min([(i+1)*blk_sz,s])
            blk_len   = blk_end-blk_begin
            t = time()
            G = randn(blk_len,m);
            timing['rand'] += time() - t
            t = time()
            As[blk_begin:blk_end,:] = G.dot(A)
            timing['mult'] += time() - t

        t = time()
        U, S, V = svd(As, False)
        # determine the rank
        r_tol = S[0]*rcond
        r     = np.sum(S>r_tol)
        timing['svd'] += time() - t

        
        t = time()
        N = V[:r,:].T/S[:r]
        def AN_op_matvec(v):
            return A.dot(N.dot(v))
        def AN_op_rmatvec(v):
            return N.T.dot(A.T.dot(v))
        AN_op = LinearOperator((m, r),
                               matvec = AN_op_matvec, 
                               rmatvec = AN_op_rmatvec)
        gamma    = 1.0*s/r                             # re-estimate gamma
        condest  = (sqrt(gamma)+1.0)/(sqrt(gamma)-1.0) # condition number of AN
        iter_lim = np.ceil(-2*(log(tol)-log(2))/log(gamma))
        if (ls_solver == lsqr) or (ls_solver == cgls):
            y, flag, itn = lsqr( AN_op, b, tol=tol/condest, iter_lim=iter_lim )
        elif ls_solver == ls_nesterov:
            y, flag, itn = ls_nesterov( AN_op, b, 1.0/(sqrt(s)-sqrt(r)), 1.0/(sqrt(s)+sqrt(r)), tol=tol )
        x              = N.dot(y)
        timing['cg'] += time() - t

    else:

        raise NotImplementedError
        
    return x, r, flag, itn, timing

def _test():

    m     = 1e4
    n     = 1e2
    r     = 50
    c     = 1e6
    gamma = 3
    tol   = 1e-14

    A, b, x_opt = _gen_prob( m, n, c, r )

    x, r, flag, itn, timing = lsrn( A, b, gamma=gamma, tol=tol, ls_solver=lsqr )

    print 'rank: %d' % (r,)
    print 'iter: %d' % (itn,)
    print 'flag: %d' % (flag,)

    relerr = norm(x-x_opt)/norm(x_opt)
    print 'relerr: %G' % (relerr,)
    relerr_AtA = norm(A.dot(x-x_opt))/norm(A.dot(x_opt))
    print 'relerr_AtA: %G' % (relerr_AtA,)

if __name__ == '__main__':
    _test()
