from math import sqrt, log

import numpy as np
from numpy.linalg import norm, lstsq

from scipy.sparse.linalg import LinearOperator, aslinearoperator

from mpi4py import MPI

from _gen_prob import _gen_prob

def ls_chebyshev( A, b, over, s_max, s_min, tol = 1e-8, iter_lim = None,
                  comm = MPI.COMM_WORLD ):
    """
    Chebyshev iteration for linear least squares problems.

    Implement Algorithm 5 in The Chebyshev iteration revisited with modification for
    solving least squares problems.

    Parameters
    ----------
    
    A : m-by-n {ndarray, matrix, sparse, LinearOperator}. If over is True, A contains a
        subset of rows of the full matrix. Otherwise, A contains a subset of columns.

    b : (m,) ndarray. If over is True, b is partitioned the same way as A. Otherwise b
        is the full RHS vector.
    """

    A     = aslinearoperator(A)
    m, n  = A.shape
    
    d     = (s_max*s_max+s_min*s_min)/2.0
    c     = (s_max*s_max-s_min*s_min)/2.0

    eps   = np.finfo(float).eps
    tol   = np.max( [tol, eps] )
    theta = (1.0-s_min/s_max)/(1.0+s_min/s_max) # convergence rate
    itn_e = np.ceil((log(tol)-log(2))/log(theta))
    if (iter_lim is None) or (iter_lim < itn_e):
        iter_lim = itn_e

    alpha = 0.0
    beta  = 0.0

    r     = b.copy()
    x     = np.zeros(n)
    v     = np.zeros(n)

    for k in xrange(int(iter_lim)):

        if k == 0:
            beta  = 0.0
            alpha = 1.0/d
        elif k == 1:
            beta  = -1.0/2.0*(c*c)/(d*d)
            alpha =  1.0*(d-c*c/(2.0*d))
        else:
            beta  = -(c*c)/4.0*(alpha*alpha)
            alpha = 1.0/(d-(c*c)/4.0*alpha)

        if over is True:
            v  = comm.allreduce(A.rmatvec(r)) - beta*v
            x += alpha*v
            r -= alpha*A.matvec(v)
        else:
            v  = A.rmatvec(r) - beta*v
            x += alpha*v
            r -= alpha*comm.allreduce(A.matvec(v))

    flag = 0

    return x, flag, k

def _test():

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    m    = 1e4
    n    = 1e2
    cond = 4.0
    r    = 1e2

    A, b, x_opt = _gen_prob(m,n,cond,r,comm)

    s_min = 1.0
    s_max = cond
    tol   = 1e-14

    x     = ls_chebyshev( A, b, True, s_max, s_min, tol, None, comm )[:1]

    if rank == 0:
        relerr = norm(x-x_opt)/norm(x_opt)
        print "relerr = %e." % (relerr,)

if __name__ == '__main__':
    _test()
