import sys

from math import sqrt

import numpy as np
from numpy.random import seed, randn
from numpy.linalg import norm, svd, lstsq

from mpi4py import MPI

def _gen_prob( m, n, cond, r = None, comm = MPI.COMM_WORLD ):
    """
    Generate a random linear least squares problem
    """
    
    rank = comm.Get_rank()
    size = comm.Get_size()

    theta = 0.25

    if r is None:
        r = np.min([m,n])

    seed(rank)

    if (m >= r*size) and (m >= n): # strongly over-determined

        mi    = np.floor(1.0*m/size) + (rank<np.mod(m,size)-0.5)

        U,    = svd(randn(mi,r),False)[:1]
        U    /= sqrt(size)
        s     = np.linspace(1,cond,r)
        if rank==0:
            V, = svd(randn(n,r),False)[:1]
        else:
            V = None
        V     = comm.bcast(V)
        A     = (U*s).dot(V.T)
        b     = A.dot(randn(n))
        res   = randn(mi)
        b    += theta*norm(b) * res/norm(res)

        UTb   = comm.allreduce(U.T.dot(b))
        x_opt = V.dot(UTb/s)

    elif (m < r*size) and (m >= n): # over-determined but near-squared
        
        if rank == 0:           # only use one thread to generate the problem
            
            U, = svd(randn(m,r),False)[:1]
            s  = np.linspace(1,cond,r)
            V, = svd(randn(n,r),False)[:1]
            
            A    = (U*s).dot(V.T)
            b    = A.dot(randn(n))
            res  = randn(m)
            b   += theta*norm(b) * res/norm(res)
            
            x_opt = V.dot(U.T.dot(b)/s)
            
            A_pile = []
            b_pile = []
            mi_cum = 0
            for i in range(size):
                mi = np.floor(1.0*m/size) + (i<np.mod(m,size)-0.5)
                A_pile.append( A[mi_cum:(mi_cum+mi),:] )
                b_pile.append( b[mi_cum:(mi_cum+mi)] )
                mi_cum += mi

        else:
            
            A     = None
            b     = None
            x_opt = None

            A_pile = None
            b_pile = None

        A     = comm.scatter( A_pile )
        b     = comm.scatter( b_pile )
        x_opt = comm.bcast( x_opt )

    elif (n < r*size) and (n > m): # under-determined but near-squared

        if rank == 0:
            
            U, = svd(randn(m,r),False)[:1]
            s  = np.linspace(1,cond,r)
            V, = svd(randn(n,r),False)[:1]
            
            A    = (U*s).dot(V.T)
            b    = A.dot(randn(n))
            res  = randn(m)
            b   += theta*norm(b) * res/norm(res)
            
            x_opt = V.dot(U.T.dot(b)/s)
            
            A_pile     = []
            x_opt_pile = []
            ni_cum = 0
            for i in range(size):
                ni = np.floor(1.0*n/size) + (i<np.mod(n,size)-0.5)
                A_pile.append( A[:,ni_cum:(ni_cum+ni)] )
                x_opt_pile.append( x_opt[ni_cum:(ni_cum+ni)] )
                ni_cum += ni

        else:
            
            A     = None
            b     = None
            x_opt = None

            A_pile     = None
            x_opt_pile = None

        A     = comm.scatter( A_pile )
        b     = comm.bcast( b )
        x_opt = comm.scatter( x_opt_pile )

    elif (n >= size*r) and (n > m):           # under-determined

        ni    = np.floor(1.0*n/size) + (rank<np.mod(n,size)-0.5)

        V,    = svd(randn(ni,r),False)[:1]
        V    /= sqrt(size)
        s     = np.linspace(1,cond,r)
        if rank==0:
            U, = svd(randn(m,r),False)[:1]
        else:
            U = None
        U     = comm.bcast(U)
        A     = (U*s).dot(V.T)
        b     = comm.allreduce(A.dot(randn(ni)))
        if rank==0:
            theta = 0.25;
            res   = randn(m)
            b    += theta*norm(b) * res/norm(res)
        b     = comm.bcast(b)

        x_opt = V.dot(U.T.dot(b)/s)

    return A, b, x_opt

def _test():

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    
    m     = 5e2
    n     = 1e2
    r     = np.ceil(n/2)
    cond  = 10.0
    rcond = 1.0/(10.0*cond);

    A, b, x_opt = _gen_prob(m,n,cond,r,comm)

    A_pile = comm.gather( A )
    b_pile = comm.gather( b )

    if rank == 0:
        A = np.vstack(A_pile)
        b = np.hstack(b_pile)
        x_lstsq, res, r_lstsq = lstsq(A,b,rcond)[:3]
        if not np.allclose(x_lstsq, x_opt):
            sys.exit( "Over-determined test failed." )

    A, b, x_opt = _gen_prob(n,m,cond,r,comm)
    
    A_pile = comm.gather( A )
    x_opt_pile = comm.gather( x_opt )

    if rank == 0:
        A     = np.hstack(A_pile)
        x_opt = np.hstack(x_opt_pile)
        x_lstsq, res, r_lstsq = lstsq(A,b,rcond)[:3]
        if not np.allclose(x_lstsq, x_opt):
            sys.exit( "Under-determined test failed." )

    if rank == 0:
        print "All tests passed."

if __name__ == "__main__":
    # should test with -np 2 and -np 12
    _test()
