/**
 * @file   RandSolver.cpp
 * @author Xiangrui Meng <mengxr@stanford.edu>
 * @date   Tue Oct  4 23:43:30 2011
 * 
 * @brief  
 * 
 * 
 */

#include <cstddef>
#include <cassert>
#include <cmath>

#include <stdexcept>
#include <algorithm>

#include "Utils.hpp"
#include "LinOp.hpp"
#include "LinAlg.hpp"
#include "Random.hpp"
#include "IterSolver.hpp"
#include "RandSolver.hpp"

void lsne( LinOp<double>& A, Vec_d& b, double rcond,
           Vec_d& x, ptrdiff_t& r )
{
  assert( A.m() == b.n() );
  
  typedef ptrdiff_t idx_t;

  idx_t m = A.m();
  idx_t n = A.n();

  if( m >= n )
  {
    idx_t s = n+4;
    Mat_d N = rnpre( A, s, rcond );
    r = N.n();
    Mat_d AN_t_AN(r,r);
    idx_t blk_sz = 512;
    blk_sz = std::min( blk_sz, r );
    Mat_d AN;
    while( AN.empty() )
    {
      try
      {
        AN = Mat_d( m, blk_sz );
      }
      catch( std::bad_alloc& )
      {
        blk_sz /= 2;
        if( blk_sz == 0 )
          throw;
      }
    }
    Mat_d A_t_AN(n,blk_sz);
    for( idx_t j=0; j<r; j+=blk_sz )
    {
      idx_t len = std::min( j+blk_sz, r ) - j;

      Mat_d N_j( n, len, N, 0, j );
      Mat_d AN_j( m, len, AN );
      Mat_d A_t_AN_j( n, len, A_t_AN );
      Mat_d AN_t_AN_j( r, len, AN_t_AN, 0, j );

      // N'*(A'*(A*N_j))
      A.mm( 'n', 'n', 1.0, N_j, 0.0, AN_j );
      A.mm( 't', 'n', 1.0, AN_j, 0.0, A_t_AN_j );
      N.mm( 't', 'n', 1.0, A_t_AN_j, 0.0, AN_t_AN_j );
    }
    Vec_d A_t_b(n);
    Vec_d AN_t_b(r);
    // N*AN_t_AN\(A'*N'*b)
    A.mv( 't', 1.0, b, 0.0, A_t_b );
    N.mv( 't', 1.0, A_t_b, 0.0, AN_t_b );
    chol( AN_t_AN );
    chol_sol( AN_t_AN, AN_t_b );
    N.mv( 'n', 1.0, AN_t_b, 0.0, x );
  }
  else
  {
    idx_t s = m+4;
    Mat_d M = rnpre( A, s, rcond );
    r = M.n();
    Mat_d AtM_t_AtM(r,r);
    idx_t blk_sz = 512;
    blk_sz = std::min( blk_sz, r );
    Mat_d AtM;
    while( AtM.empty() )
    {
      try
      {
        AtM = Mat_d( n, blk_sz );
      }
      catch( std::bad_alloc& )
      {
        blk_sz /= 2;
        if( blk_sz == 0 )
          throw;
      }
    }
    Mat_d A_AtM(m,blk_sz);
    for( idx_t j=0; j<r; j+=blk_sz )
    {
      idx_t len = std::min( j+blk_sz, r ) - j;

      Mat_d M_j( m, len, M, 0, j );
      Mat_d AtM_j( n, len, AtM );
      Mat_d A_AtM_j( m, len, A_AtM );
      Mat_d AtM_t_AtM_j( r, len, AtM_t_AtM, 0, j );

      // M'*(A*(A'*M_j))
      A.mm( 't', 'n', 1.0, M_j, 0.0, AtM_j );
      A.mm( 'n', 'n', 1.0, AtM_j, 0.0, A_AtM_j );
      M.mm( 't', 'n', 1.0, A_AtM_j, 0.0, AtM_t_AtM_j );
    }
    Vec_d Mtb(r);
    // A'*M*(AtM_t_AtM\(M'*b))
    M.mv( 't', 1.0, b, 0.0, Mtb );
    chol( AtM_t_AtM );
    chol_sol( AtM_t_AtM, Mtb );
    Vec_d tmp(m);
    M.mv( 'n', 1.0, Mtb, 0.0, tmp );
    A.mv( 't', 1.0, tmp, 0.0, x );
  }
}

void lsrn( LinOp<double>& A, Vec_d& b, double rcond, Vec_d& x, long& rank )
{
  assert( A.m() == b.n() );
  assert( A.n() == x.n() );

  typedef ptrdiff_t idx_t;
  
  idx_t m = A.m();
  idx_t n = A.n();

  double gamma = 2.0;
  double tol   = 1e-16;
  
  if( m > gamma*n )
  {
    idx_t s = ceil(gamma*n);
    Mat_d N = rnpre( A, s, rcond );
    rank = N.n();
    Composite_LinOP<double> AN( '*', 'n', A, 'n', N );
    gamma = 1.0*s/rank;
    long maxit = ceil( 2.0 * ( log(2.0) - log(tol) ) / log(gamma) );
    Vec_d y(rank);
    lsqr( AN, b, tol, maxit, y );
    mv( 'n', 1.0, N, y, 0.0, x );
  }
  else if( n > gamma*m )
  {
    idx_t s = ceil(gamma*m);
    Mat_d M = rnpre( A, s, rcond );
    rank = M.n();
    Composite_LinOP<double> MtA( '*', 't', M, 'n', A );
    Vec_d Mtb(rank);
    M.mv( 't', 1.0, b, 0.0, Mtb );
    gamma = 1.0*s/rank;
    long maxit = ceil( 2.0 * ( log(2.0) - log(tol) ) / log(gamma) );
    lsqr( MtA, Mtb, tol, maxit, x );
  }
  else
  {
    throw std::runtime_error( "Matrix is not thin enough or fat enough." );
  }
}

Mat_d rnpj( LinOp<double>& A, ptrdiff_t s )
{
  typedef ptrdiff_t idx_t;

  idx_t m = A.m();
  idx_t n = A.n();
  
  idx_t blk_sz = 512;

  if( m >= n )
  {
    blk_sz = std::min( blk_sz, s );
    Mat_d G;
    while( G.empty() )
    {
      try
      {
        G = Mat_d( m, blk_sz );
      }
      catch( std::bad_alloc& )
      {
        blk_sz /= 2;
        if( blk_sz == 0 )
          throw;
      }
    }
    Mat_d As_T( n, s );
    for( int j=0; j<s; j+=blk_sz )
    {
      idx_t len = std::min( j+blk_sz, s ) - j;
      randn( m*len, G.data() );
      
      Mat_d sub_As_T( n, len, As_T, 0, j );
      Mat_d sub_G( m, len, G, 0, 0 );
      
      A.mm( 't', 'n', 1.0, sub_G, 0.0, sub_As_T );
    }

    return As_T;
  }
  else if( m < n )
  {
    blk_sz = std::min( blk_sz, s );
    Mat_d G;
    while( G.empty() )
    {
      try
      {
        G = Mat_d( n, blk_sz );
      }
      catch( std::bad_alloc& )
      {
        blk_sz /= 2;
        if( blk_sz == 0 )
          throw;
      }
    }
    Mat_d As( m, s );

    for( int j=0; j<s; j+=blk_sz )
    {
      idx_t len = std::min( j+blk_sz, s ) - j;
      randn( n*len, G.data() );

      Mat_d sub_As( m, len, As, 0, j );
      Mat_d sub_G( n, len, G, 0, 0 );

      A.mm( 'n', 'n', 1.0, sub_G, 0.0, sub_As );
    }

    return As;
  }

  return Mat_d();
}

Mat_d rnpre( LinOp<double>& A, ptrdiff_t s, double rcond )
{
  typedef ptrdiff_t idx_t;

  idx_t m = A.m();
  idx_t n = A.n();

  if( m >= n )
  {
    Mat_d As_T = rnpj( A, s );
    Vec_d sgm(n);
    Mat_d V(n,n);

    svd( As_T, sgm, V );

    double r_tol = rcond*sgm(0);
    idx_t rank = 0;
    for( idx_t i=0; i<n; ++i )
    {
      if( sgm(i) > r_tol )
        rank++;
      else
        break;
    }

    for( idx_t j=0; j<rank; ++j )
      scal(1.0/sgm(j), V.col(j));

    Mat_d N(n,rank,V);

    return N;
  }
  else
  {
    Mat_d As = rnpj( A, s );
    Vec_d sgm(m);
    Mat_d U(m,m);

    svd( As, sgm, U );

    double r_tol = rcond*sgm(0);
    idx_t rank = 0;
    for( idx_t i=0; i<m; ++i )
    {
      if( sgm(i) > r_tol )
        rank++;
      else
        break;
    }

    for( idx_t j=0; j<rank; ++j )
      scal(1.0/sgm(j), U.col(j));
    
    Mat_d M(m,rank,U);

    return M;
  }
}
