/**
 * @file   IterSolver.cpp
 * @author Xiangrui Meng <mengxr@stanford.edu>
 * @date   Tue Oct  4 22:51:24 2011
 * 
 * @brief  
 * 
 * 
 */

#include <cassert>
#include <cmath>

#include "LinOp.hpp"
#include "LinAlg.hpp"

#include "IterSolver.hpp"

void lsqr( LinOp<double>& A, const Vec_d b,
           const double tol, const long maxit,
           Vec_d& x )
{
  assert( A.m() == b.n() );
  assert( A.n() == x.n() );
  
  Vec_d u = b.copy();
  double beta = nrm2(u);
  scal(1.0/beta,u);

  Vec_d v( A.n() );
  A.mv( 't', u, v );
  double alpha = nrm2(v);
  scal(1.0/alpha,v);
  
  Vec_d w = v.copy();
  
  x = 0.0;

  double phi, rho;
  double cs, sn;
  double phibar = beta;
  double rhobar = alpha;
  double theta;
  
  double nrm_a    = 0.0;
  double nrm_r;
  double nrm_ar;
  // double nrm_ar_0 = alpha*beta;
  
  for( long k = 0; k < maxit; ++k )
  {
    A.mv( 'n', 1.0, v, -alpha, u );
    beta = nrm2(u);
    scal(1.0/beta,u);

    nrm_a = sqrt( nrm_a*nrm_a + alpha*alpha + beta*beta );

    A.mv( 't', 1.0, u, -beta, v );
    alpha = nrm2(v);
    scal(1.0/alpha,v);

    rho    = sqrt( rhobar*rhobar + beta*beta );
    cs     = rhobar/rho;
    sn     = beta/rho;
    theta  = sn*alpha;
    rhobar = -cs*alpha;
    phi    = cs*phibar;
    phibar = sn*phibar;

    axpy( phi/rho, w, x );

    scal( -theta/rho, w );
    axpy( 1.0, v,  w );

    nrm_r  = phibar;
    nrm_ar = phibar*alpha*fabs(cs);

    if( nrm_ar < tol*nrm_a*nrm_r )
      break;
  }
}

void ls_chebyshev( LinOp<double>& A, const Vec_d b,
                   const double s_max, const double s_min,
                   const double tol, Vec_d& x )
{
  LONG m = A.m();
  LONG n = A.n();

  assert( m = b.n() );
  assert( n = x.n() );
  
  double d = (s_max*s_max+s_min*s_min)/2.0;
  double c = (s_max*s_max-s_min*s_min)/2.0;

  double theta = (1.0-s_min/s_max)/(1.0+s_min/s_max);
  long   maxit = (long) ceil( (log(tol) - log(2)) / log(theta) );

  x = 0.0;
  Vec_d r = b.copy();

  Vec_d v(n);
  v = 0.0;

  double alpha = 0.0, beta = 0.0;
  
  for( long k=1; k<maxit; ++k )
  {
    if( k == 1 )
    {
      beta  = 0.0;
      alpha = 1.0/d;
    }
    else if( k == 2 )
    {
      beta  = 1.0/2.0*(c*c)/(d*d);
      alpha = 1.0/(d-c*c/(2.0*d));
    }
    else
    {
      beta  = (c*c)/4.0*(alpha*alpha);
      alpha = 1.0/(d-(c*c)/4.0*alpha);
    }

    A.mv( 't', 1.0, r, beta, v );
    axpy( alpha, v, x );
    A.mv( 'n', -alpha, v, 1.0, r );
  }
}
