"""
LSNE computes the min-length solution of linear least squares via the normal equation.

Xiangrui Meng <mengxr@stanford.edu>
iCME, Stanford University
"""

from math import sqrt

import numpy as np
from numpy.linalg import LinAlgError

from scipy.linalg import cho_factor, cho_solve, eigh
from scipy.sparse import issparse

def lsne( A, b, try_chol = True ):
    """
    LSNE computes the min-length solution of linear least squares via the normal
    equation.

    Parameters 
    ----------
    A        : {matrix, sparse matrix, ndarray, LinearOperator} of size m-by-n
    b        : (m,) ndarray
    try_chol : if True, LSNE will try Cholesky factorization on the normal equation, if
               the system is rank deficient, an eigen-decomposition will be used instead.

    Returns
    -------
    x        : (m,) ndarray, the min-length solution
    r        : int, rank of A
    """

    m, n = A.shape

    if m >= n:
        
        AtA = A.T.dot(A)
        if issparse(AtA):
            AtA = AtA.todense()

        Atb = A.T.dot(b)
        
        try:
            if try_chol:
                C  = cho_factor(AtA)
                x  = cho_solve(C,Atb)
                r  = n
            else:
                raise LinAlgError
        except LinAlgError:
            s2, V = eigh(AtA)
            eps   = np.finfo(float).eps
            tol   = m*s2[-1]*eps
            ii,   = np.where(s2>tol)
            r     = ii.size
            s2    = s2[ii]
            V     = V[:,ii]
            x     = V.dot((1.0/s2)*V.T.dot(Atb))

    else:

        AAt = A.dot(A.T)
        if issparse(AAt):
            AAt = AAt.todense()
        
        try:
            if try_chol:
                C = cho_factor(AAt)
                y = cho_solve(C,b)
                x = A.T.dot(y)
                r = m
            else:
                raise LinAlgError
        except LinAlgError:
            s2, V = eigh(AAt)
            eps   = np.finfo(float).eps
            tol   = n*s2[-1]*eps
            ii,   = np.where(s2>tol)
            r     = ii.size
            s2    = s2[ii]
            V     = V[:,ii]
            y     = V.dot((1.0/s2)*V.T.dot(b))
            x     = A.T.dot(y)
    
    return x, r
