"""
CGLS without shift.

Xiangrui Meng <mengxr@stanford.edu>
iCME, Stanford University
"""

from math import sqrt, log

import numpy as np
from numpy.linalg import lstsq, norm, svd
from numpy.random import randn

from scipy.sparse.linalg import aslinearoperator

from _gen_prob import _gen_prob

__all__ = ['cgls']

def cgls( A, b, tol = 1e-6, iter_lim = None ):
    """
    CGLS without shift

    Parameters
    ----------
    
    A : m-by-n {ndarray, matrix, sparse, LinearOperator}

    b : (m,) ndarray

    tol      : tolerance

    iter_lim : max number of iterations

    Returns
    -------

    x : (n,) ndarray, solution

    flag : int, 0 means cgls converged, 1 means cgls didn't converge

    itn : int, number of iterations
    """
    
    A         = aslinearoperator(A)
    m, n      = A.shape

    x         = np.zeros(n)
    r         = b.squeeze().copy()
    nrm_r_0   = norm(r)
    s         = A.rmatvec(r)
    p         = s.copy()
    sq_s_0    = np.dot(s,s)
    nrm_s_0   = sqrt(sq_s_0)
    gamma     = sq_s_0

    x_best      = np.zeros(n)
    relres_best = 1.0
    itn_best    = 0

    itn       = 0
    stag      = 0
    stag_lim  = 10
    converged = False

    if iter_lim is None:
        iter_lim = 2*np.min([m,n])
    
    while (not converged) and (itn < iter_lim):

        itn    += 1

        q       = A.matvec(p)
        sq_q    = np.dot(q,q)
        alpha   = gamma / sq_q
        x      += alpha*p
        r      -= alpha*q

        # nrm_r   = norm(r)
        # if (nrm_r_0-nrm_r)/nrm_r_0 < np.finfo(float).eps:
        #     break
        # nrm_r_0  = nrm_r

        s       = A.rmatvec(r)
        sq_s    = np.dot(s,s)
        nrm_s   = sqrt(sq_s)
        gamma_0 = gamma
        gamma   = sq_s
        beta    = gamma / gamma_0
        if beta > 1:
            break
        print itn, beta, nrm_s
        p       = s + beta*p

        relres  = nrm_s/nrm_s_0

        if relres < relres_best:
            relres_best = relres
            x_best      = x.copy()
            itn_best    = itn
            stag        = 0
        else:
            stag += 1

        if stag > stag_lim:
            break
        
        if relres < tol:
            converged = True

    flag = 1-converged

    return x_best, flag, itn

def _test():

    m = 3e2
    n = 1e2

    A = randn(m,n)
    b = randn(m)

    x_opt, = lstsq(A,b)[:1]

    tol      = 1e-14
    iter_lim = 2*np.ceil(log(tol)/log(n/m));

    x, flag, itn = cgls(A,b,0,10*iter_lim)
    relerr       = norm(x-x_opt)/norm(x_opt)

    if flag == 0:
        print "CGLS converged in %d/%d iterations." % (itn,iter_lim)
    else:
        print "CGLS didn't converge in %d/%d iterations." % (itn,iter_lim)

    if relerr < tol:
        print "CGLS test passed with relerr %G." % (relerr,)
    else:
        print "CGLS test failed with relerr %G." % (relerr,)

if __name__ == '__main__':
    _test()
