///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
///
/// =========================================================================
//
// CSR: Compressed Sparse Row format
//
// author: Pierre.Saramito@imag.fr
//
// date: 21 january 1997
//

# include "rheolef/csr.h"
# include "rheolef/dns.h"
# include "rheolef/diag.h"
# include "rheolef/ssk.h"
# include "rheolef/ssk-algorithm.h"
# include "rheolef/avec.h"
# include "rheolef/asr.h"
# include "rheolef/vec.h"
# include "rheolef/permutation.h"

# include "num-algorithm.h"
# include "csr-algorithm.h"
# include "csr-algo-trans.h"
# include "csr-algo-cat.h"
# include "csr-algo-trans-mult.h"
# include "csr-algo-dmula.h"
# include "csr-algo-amuld.h"

# include "blas-algorithm.h"
# include "blas3.h"

# include "csr-algo-aplb.h"
# include "csr-algo-amub.h"

using namespace std;

namespace rheolef {
// ===============================[ CONSTRUCTORS ]===============================
template<class T>
csr<T>::csr (const dns<T>& b)
: smart_pointer<csrrep<T> >(new_macro(csrrep<T>(b.nrow(),b.ncol(),b.nnz())))
{
    size_type nr = nrow();
    size_type nc = ncol();
    typename Array<T>::iterator         iter_aa = a().begin();
    typename Array<size_type>::iterator iter_ja = ja().begin();
    typename Array<size_type>::iterator iter_ia = ia().begin();
    typename Array<size_type>::iterator prec_ia = iter_ia;
    typename Array<T>::const_iterator   val_bb  = b.begin();
    *iter_ia = 0;
    ++iter_ia;
    for (size_type i = 0; i < nr; i++) {
	*iter_ia = *prec_ia;
    	for (size_type j = 0; j < nc; j++) {
	    T val = val_bb [nr*j + i];
	    if (val != T()) {
		*iter_ja++ = j;
		*iter_aa++ = val;
		(*iter_ia)++;
	    }
	}
	prec_ia = iter_ia;
	iter_ia++;
    }
}
template<class T>
csr<T>::csr (const basic_diag<T>& b)
: smart_pointer<csrrep<T> >(new_macro(csrrep<T>(b.nrow(),b.ncol(),std::min(b.nrow(),b.ncol()))))
{
    typename Array<T>::iterator         a_aa = a().begin();
    typename Array<size_type>::iterator a_ja = ja().begin();
    typename Array<size_type>::iterator a_ia = ia().begin();
    typename Array<T>::const_iterator   val_bb  = b.begin();
    a_ia[0] = 0;
    size_type nz = nnz();
    for (size_type i = 0; i < nz; i++) {
	a_aa[i] = val_bb[i];
	a_ja[i] = i;
	a_ia[i+1] = i+1;
    }
}
template<class T>
csr<T>::csr (const asr<T>& b)
: smart_pointer<csrrep<T> >(new_macro(csrrep<T>(b.nrow(),b.ncol(),b.nnz())))
{
    typename Array<size_type>::iterator iter_ia = ia().begin();
    typename Array<size_type>::iterator iter_ja = ja().begin();
    typename Array<T>::iterator iter_a = a().begin();
    typename asr<T>::const_iterator iter_ib = b.begin();
    typename asr<T>::const_iterator last_ib = b.end();
    size_type last = 0;
    while (iter_ib != last_ib) {

	*iter_ia = last;
        ++iter_ia;
        typename asr<T>::row::const_iterator iter_jb_b = (*iter_ib).begin();
        typename asr<T>::row::const_iterator last_jb_b = (*iter_ib).end();
        while (iter_jb_b != last_jb_b) {

	    *iter_ja  = (*iter_jb_b).first;
	    *iter_a   = (*iter_jb_b).second;
	    
	    ++iter_ja;
	    ++iter_a;
	    ++iter_jb_b;
 	}
	last += (*iter_ib).size();
	iter_ib++;
    }
    *iter_ia = last;
}
#if !defined(_RHEOLEF_HAVE_SPOOLES) && !defined(_RHEOLEF_HAVE_TAUCS) && !defined(_RHEOLEF_HAVE_UMFPACK)
template<class T>
csrrep<T>::csrrep (const ssk<T>& b)
: IA(b.nrow()+1), JA(0), A(0), NCOL(b.ncol())
{
    typename csr<T>::size_type nnzsky, _nnz;

    _ssk_to_csr_nnz_and_set_indexes (

        b.begin_isky(),
        b.end_isky(),
        b.begin_asky(),
        IA.begin(),
	_nnz,
	T());
    
    resize (b.nrow(), b.ncol(), _nnz);

    _ssk_to_csr_set_values (

        b.begin_isky(),
        b.end_isky(),
        b.begin_asky(),
        IA.begin(),
        JA.begin(),
        A.begin(),
	nnzsky,
	int(0),
	T(0));
}
template<class T>
csr<T>::csr (const ssk<T>& b)
: smart_pointer<csrrep<T> >(new_macro(csrrep<T>(b)))
{
}
#endif // ! _RHEOLEF_HAVE_SPOOLES

template <class T>
void
csr<T>::clear()
{
    fill (ia().begin(), ia().end(), size_t(0));
    ja().resize(0);
    a().resize(0);
}
template<class T>
csr<T>::csr (const char* basename)
: smart_pointer<csrrep<T> >(new_macro(csrrep<T>()))
{
    fatal_macro("not yet implemented");
}
// ===============================[ ASSIGNMENTS ]================================
template<class T>
csr<T>
csr<T>::operator = (const T& lambda)
{
    fill_n (data().a().begin(), data().a().size(), lambda);
    return *this ;
}
template<class T>
csr<T>
operator *= (csr<T>& a, const T& lambda)
{
    typename Array<T>::iterator iter_a = a.a().begin();
    typename Array<T>::iterator last_a = a.a().end();
    while (iter_a != last_a) (*iter_a++) *= lambda;
    return a;
}
// ===============================[ T-MAT-VEC PRODUCT ]==========================

template<class T>
vec<T>
csr<T>::trans_mult (const vec<T>& x) const
{
    check_macro (x.n() == nrow(), "trans_mult: incompatible vec(" 
	      << x.n() << ")^T*csr(" 
	      << nrow() << "," << ncol() << ")");
    
    vec<T> y (ncol());
    y.reset();
    csr_cumul_trans_mult (
    	ia().begin(),
    	ia().end(),
    	ja().begin(),
    	a().begin(),
    	x.begin(),
    	y.begin(),
    	T(0));
    return y;
}
// ===============================[ UNARY OPERATORS ]============================

// NOTE: transposed matrix has always rows sorted by increasing column indexes
//       even if original matrix has not
template<class T>
csr<T>
trans (const csr<T>& b)
{
    csr<T> a(b.ncol(), b.nrow(), b.nnz());
    
    // first pass: reset ia
    fill (a.ia().begin(), a.ia().end(),  size_t(0));

    // second pass: compute lengths of row(i) of a^T in ia(i+1)
    column_length (
        b.ia().begin(), b.ia().end(), b.ja().begin(),
        a.ia().begin(), 
        a.nrow());

    // third pass: compute pointers from lengths
    pointer_from_length (a.ia().begin(), a.ia().end());

    // fourth pass: store values
   trans_copy_values(
     b.ia().begin(), b.ia().end(), b.ja().begin(), b.a().begin(),
     a.ia().begin(),               a.ja().begin(), a.a().begin(), 
     size_t(0));

    // fiveth: shift pointers
    right_shift (a.ia().rbegin(), a.ia().rend(), size_t(0));

    return a;
}
template <class Difference>
class less_equal_k : binary_function<Difference, Difference, bool> {
    Difference k;
  public:
    less_equal_k (Difference k1 = 0) : k(k1) {}
    bool operator()(const Difference& i, const Difference& j) const
		       { return i-j <= k; }
};
template<class T>
csr<T>
tril (const csr<T>& a, int k)
{
    typename csr<T>::size_type nnzl = nnz_trig (a.ia().begin(), a.ia().end(), a.ja().begin(), 
	      less_equal_k<int>(k));

    csr<T> l (a.nrow(), a.ncol(), nnzl);

    trig (a.ia().begin(), a.ia().end(), a.ja().begin(), a.a().begin(),
	  l.ia().begin(),               l.ja().begin(), l.a().begin(), 
	  less_equal_k<int>(k));
    return l;
}
template <class Difference>
class greater_equal_k : binary_function<Difference, Difference, bool> {
    Difference k;
  public:
    greater_equal_k (Difference k1 = 0) : k(k1) {}
    bool operator()(const Difference& i, const Difference& j) const
		       { return i-j >= k; }
};
template<class T>
csr<T>
triu (const csr<T>& a, int k)
{
    typename csr<T>::size_type nnzu = nnz_trig (a.ia().begin(), a.ia().end(), a.ja().begin(), 
	      greater_equal_k<int>(-k));

    csr<T> u (a.nrow(), a.ncol(), nnzu);

    trig (a.ia().begin(), a.ia().end(), a.ja().begin(), a.a().begin(),
	  u.ia().begin(),               u.ja().begin(), u.a().begin(), 
	  greater_equal_k<int>(-k));

    return u;
}
template<class T>
csr<T>
operator - (const csr<T>& a)
{
    csr<T> b(a);
    typename Array<T>::iterator end_b  = b.a().end();
    typename Array<T>::iterator iter_b = b.a().begin();
    for (; iter_b < end_b; iter_b++)
	(*iter_b) = - (*iter_b);
    return b;
}
// =========================[ MAT AND Float ]===================================
template <class T>
csr<T>
operator * (const csr<T>& a, const Float& lambda)
{
  csr<T> b(a.nrow(), a.ncol(), a.nnz()) ;
  copy(a.ia().begin(), a.ia().end(), b.ia().begin());
  copy(a.ja().begin(), a.ja().end(), b.ja().begin());
  zassignopx (b.a().begin(), b.a().end(), my_bind2nd(mul_op<T, Float, T>(), lambda), a.a().begin());
  return b ;
}
template <class T>
csr<T>
operator * (const Float& lambda, const csr<T>& a)
{
  return a*lambda;
}
template <class T>
csr<T>
operator / (const csr<T>& a, const Float& lambda)
{
  return (1/lambda)*a;
}
// ==========================[ MAT AND DIAG ]====================================
template <class T>
csr<T>
operator * (const basic_diag<T>& d, const csr<T>& b)
{
    check_diag_mat_length(d,b);
    csr<T> c (b.nrow(), b.ncol(), b.nnz());
    if (d.size() == 0) {
	// dmula algo does not work with d.size == 0
        return c;
    }
    copy(b.ia().begin(), b.ia().end(), c.ia().begin());
    copy(b.ja().begin(), b.ja().end(), c.ja().begin());
    dmula (d.begin(), b.ia().begin(), b.ia().end(), b.a().begin(), c.a().begin(), T());
    return c;
}
template <class T>
csr<T>
csr<T>::left_mult (const basic_diag<T>& d)
{
    check_diag_mat_length(d, *this);
    if (d.size() == 0) {
	// dmula algo does not work with d.size == 0
        return *this;
    }
    dmula (d.begin(), ia().begin(), ia().end(), a().begin(), a().begin(), T());
    return *this;
}
template <class T>
csr<T>
operator * (const csr<T>& a, const basic_diag<T>& d)
{
    check_mat_diag_length(a, d);
    csr<T> c (a.nrow(), a.ncol(), a.nnz());
    copy(a.ia().begin(), a.ia().end(), c.ia().begin());
    copy(a.ja().begin(), a.ja().end(), c.ja().begin());
    amuld (a.ia().begin(), a.ia().end(), a.ja().begin(), a.a().begin(), 
           d.begin(), c.a().begin());
    return c;
}
template <class T>
csr<T>
operator *= (csr<T>& a, const basic_diag<T>& d)
{
    check_mat_diag_length(a, d);
    amuld (a.ia().begin(), a.ia().end(), a.ja().begin(), a.a().begin(), 
           d.begin(), a.a().begin());
    return a;
}
// ==============================[ BINARY OPERATORS ]============================

template<class T>
csr<T>
operator+ (const csr<T>& a, const csr<T>& b)
{
  csr<T> c(a.nrow(), b.ncol());
  
  typename csr<T>::size_type nnzc = aplb_size (
    a.ia().begin(),  
    a.ia().end(),
    a.ja().begin(),  
    b.ia().begin(),  
    b.ja().begin(),  
    (typename csr<T>::size_type)0);

  c.resize (a.nrow(), b.ncol(), nnzc);
 
  aplb (
    plus<T>(),
    a.ia().begin(),  
    a.ia().end(),
    a.ja().begin(),  
    a.a().begin(),  
    b.ia().begin(),  
    b.ja().begin(),  
    b.a().begin(),  
    c.ia().begin(),  
    c.ja().begin(),  
    c.a().begin(),  
    (typename csr<T>::size_type)0);

  return c;
}
template<class T>
csr<T>
operator- (const csr<T>& a, const csr<T>& b)
{
  csr<T> c(a.nrow(), b.ncol());
  
  typename csr<T>::size_type nnzc = aplb_size (
    a.ia().begin(),  
    a.ia().end(),
    a.ja().begin(),  
    b.ia().begin(),  
    b.ja().begin(),  
    (typename csr<T>::size_type)0);

  c.resize (a.nrow(), b.ncol(), nnzc);
 
  aplb (
    minus<T>(),
    a.ia().begin(),  
    a.ia().end(),
    a.ja().begin(),  
    a.a().begin(),  
    b.ia().begin(),  
    b.ja().begin(),  
    b.a().begin(),  
    c.ia().begin(),  
    c.ja().begin(),  
    c.a().begin(),  
    (typename csr<T>::size_type)0);

  return c;
}
template<class T>
csr<T>
operator* (const csr<T>& a, const csr<T>& b)
{
  csr<T> c(a.nrow(), b.ncol());
  
  typename csr<T>::size_type nnzc = amub_size (
    a.ia().begin(),  
    a.ia().end(),
    a.ja().begin(),  
    b.ia().begin(),  
    b.ja().begin(),  
    (typename csr<T>::size_type)0);

  c.resize (a.nrow(), b.ncol(), nnzc);
  
  amub (
    a.ia().begin(),  
    a.ia().end(),
    a.ja().begin(),  
    a.a().begin(),  
    b.ia().begin(),  
    b.ja().begin(),  
    b.a().begin(),  
    c.ia().begin(),  
    c.ja().begin(),  
    c.a().begin(),  
    (typename csr<T>::size_type)0,
    (T)0);

  return c;
}
// =============================[ CONCATENATIONS ]===============================
template<class T>
csr<T>
vcat (const csr<T>& a1, const csr<T>& a2)
{
    if (a1.ncol() != a2.ncol()) {
        fatal_macro ("vcat: csr("<< a1.nrow() << "," << a1.ncol()
		<< ") and csr("<<a2.nrow()<<","<<a2.ncol()<<") are incompatible.");
    }
    csr<Float> a (a1.nrow() + a2.nrow(), a1.ncol(), a1.nnz()+a2.nnz());
    
    vcat (a1.ia().begin(), a1.ia().end(),
          a1.ja().begin(), a1.ja().end(), 
          a1.a().begin(),  a1.a().end(),
          a2.ia().begin(), a2.ia().end(),
          a2.ja().begin(), a2.ja().end(), 
          a2.a().begin(),  a2.a().end(),
          a.ia().begin(),  a.ja().begin(),  a.a().begin(),
          a1.nnz());
    return a;
}
template<class T>
csr<T>
hcat (const csr<T>& a1, const csr<T>& a2)
{
    if (a1.nrow() != a2.nrow()) {
        fatal_macro ("hcat: csr("<< a1.nrow() << "," << a1.ncol()
		<< ") and csr("<<a2.nrow()<<","<<a2.ncol()<<") are incompatible.");
    }
    csr<Float> a (a1.nrow(), a1.ncol() + a2.ncol(), a1.nnz()+a2.nnz());
    
    hcat (a1.ia().begin(),                a1.ja().begin(), a1.a().begin(),  
          a2.ia().begin(),                a2.ja().begin(), a2.a().begin(), 
          a.ia().begin(),  a.ia().end(),  a.ja().begin(),  a.a().begin(),
          a1.ncol());
    return a;
}
// ==================================[ SORTED ]==================================

template<class T>
bool
csr<T>::is_sorted () const
{
    return csr_is_sorted (ia().begin(), ia().end(), ja().begin());
}
template<class T>
csr<T>
csr<T>::sort ()
{
    trace_macro ("**** SORT CSR MATRIX ****");
    iterator iter = begin();
    iterator last = end();
    typename Array<T>::iterator iter_a = a().begin();
    typename Array<size_type>::iterator iter_ja = ja().begin();
    while (iter != last) {
	// use red-black sorted sparse vector
	avecrep<T> x = *iter;
	// copy sparse vector into row **TODO: csr<T>::const_row(const avec<T>&)
	typename avecrep<T>::const_iterator iter_x = x.begin();
	typename avecrep<T>::const_iterator last_x = x.end();
	// copy back in csr data structure
	while (iter_x != last_x) {
	    *iter_ja = (*iter_x).first;
	    *iter_a  = (*iter_x).second;
	    ++iter_a;
	    ++iter_ja;
	    ++iter_x;
	}
	++iter;
    }
    return *this;
}
// ================[ SLOW ACCESS TO COMPONENT ]==================================

// NOTE: requires rows sorted by increasing column indexes
template <class T>
T
csr<T>::operator() (size_type i, size_type j) const
{
    typename csr<T>::const_row i_row = operator()(i);
    pair<size_type,T> p = i_row.operator()(j);
    return p.second;
}
// =====================[ GIBBS REORDERING  ]====================================
template<class T>
csr<T>
perm (const csr<T>& a, const permutation& p, const permutation& q) 
{
    check_macro (a.nrow() == a.ncol(), 
	"csr(gibbs): square matrix is required while csr("
        << a.nrow() << "," << a.ncol() << ") found.");

    csr<T> b (a.nrow(), a.ncol(), a.nnz());

    perm (p.begin(), q.begin(),
          a.ia().begin(), a.ia().end(), a.ja().begin(), a.a().begin(),
          b.ia().begin(),               b.ja().begin(), b.a().begin());
    
    // rows may be sorted by increasing column number
    b.sort();
    
    return b;
}
// =====================[ INSTANCIATION IN LIBRARY ]=============================
template class csr<Float>;

template csr<Float> trans (const csr<Float>&);
template csr<Float> tril  (const csr<Float>&, int);
template csr<Float> triu  (const csr<Float>&, int);
template csr<Float> perm (const csr<Float>&, const permutation&, 
                                               const permutation&);


template csr<Float> operator - (const csr<Float>&);
template csr<Float> operator *= (csr<Float>&, const Float&);

template csr<Float> operator * (const csr<Float>& a, const Float& lambda);
template csr<Float> operator * (const Float&, const csr<Float>&);
template csr<Float> operator / (const csr<Float>&, const Float&);

template csr<Float> operator * (const basic_diag<Float>&, const csr<Float>&);
template csr<Float> operator * (const csr<Float>&, const basic_diag<Float>&);
template csr<Float> operator *= (csr<Float>&, const basic_diag<Float>&);

template csr<Float> operator + (const csr<Float>&, const csr<Float>&);
template csr<Float> operator - (const csr<Float>&, const csr<Float>&);
template csr<Float> operator * (const csr<Float>&, const csr<Float>&);

template csr<Float> hcat (const csr<Float>&, const csr<Float>&);
template csr<Float> vcat (const csr<Float>&, const csr<Float>&);
}// namespace rheolef
