#ifndef TNet_Matrix_h
#define TNet_Matrix_h
#include <stddef.h>
#include <stdlib.h>
#include <stdexcept>
#include <iostream>
#ifdef HAVE_ATLAS
extern "C"{
#include <cblas.h>
#include <clapack.h>
}
#endif
#include "Common.h"
#include "MathAux.h"
#include "Types.h"
#include "Error.h"
//#define TRACE_MATRIX_OPERATIONS
#define CHECKSIZE
namespace TNet
{
// class matrix_error : public std::logic_error {};
// class matrix_sizes_error : public matrix_error {};
// declare the class so the header knows about it
template<typename _ElemT> class Vector;
template<typename _ElemT> class SubVector;
template<typename _ElemT> class Matrix;
template<typename _ElemT> class SubMatrix;
// we need to declare the friend << operator here
template<typename _ElemT>
std::ostream & operator << (std::ostream & rOut, const Matrix<_ElemT> & rM);
// we need to declare the friend << operator here
template<typename _ElemT>
std::istream & operator >> (std::istream & rIn, Matrix<_ElemT> & rM);
// we need to declare this friend function here
template<typename _ElemT>
_ElemT TraceOfProduct(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B)
// we need to declare this friend function here
template<typename _ElemT>
_ElemT TraceOfProductT(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B^T)==tr(A^T B)
/** **************************************************************************
** **************************************************************************
* @brief Provides a matrix class
*
* This class provides a way to work with matrices in TNet.
* It encapsulates basic operations and memory optimizations.
*
*/
template<typename _ElemT>
class Matrix
{
public:
/// defines a transpose type
struct HtkHeader
{
INT_32 mNSamples;
INT_32 mSamplePeriod;
INT_16 mSampleSize;
UINT_16 mSampleKind;
};
/**
* @brief Extension of the HTK header
*/
struct HtkHeaderExt
{
INT_32 mHeaderSize;
INT_32 mVersion;
INT_32 mSampSize;
};
/// defines a type of this
typedef Matrix<_ElemT> ThisType;
// Constructors
/// Empty constructor
Matrix<_ElemT> ():
mpData(NULL), mMCols(0), mMRows(0), mStride(0)
#ifdef STK_MEMALIGN_MANUAL
, mpFreeData(NULL)
#endif
{}
/// Copy constructor
Matrix<_ElemT> (const Matrix<_ElemT> & rM, MatrixTrasposeType trans=NO_TRANS):
mpData(NULL)
{ if(trans==NO_TRANS){ Init(rM.mMRows, rM.mMCols); Copy(rM); } else { Init(rM.mMCols,rM.mMRows); Copy(rM,TRANS); } }
/// Copy constructor from another type.
template<typename _ElemU>
explicit Matrix<_ElemT> (const Matrix<_ElemU> & rM, MatrixTrasposeType trans=NO_TRANS):
mpData(NULL)
{ if(trans==NO_TRANS){ Init(rM.Rows(), rM.Cols()); Copy(rM); } else { Init(rM.Cols(),rM.Rows()); Copy(rM,TRANS); } }
/// Basic constructor
Matrix(const size_t r, const size_t c, bool clear=true)
{ mpData=NULL; Init(r, c, clear); }
Matrix<_ElemT> &operator = (const Matrix <_ElemT> &other) { Init(other.Rows(), other.Cols()); Copy(other); return *this; } // Needed for inclusion in std::vector
/// Destructor
~Matrix()
{ Destroy(); }
/// Initializes matrix (if not done by constructor)
ThisType &
Init(const size_t r,
const size_t c, bool clear=true);
/**
* @brief Dealocates the matrix from memory and resets the dimensions to (0, 0)
*/
void
Destroy();
ThisType &
Zero();
ThisType &
Unit(); // set to unit.
/**
* @brief Copies the contents of a matrix
* @param rM Source data matrix
* @return Returns reference to this
*/
template<typename _ElemU> ThisType &
Copy(const Matrix<_ElemU> & rM, MatrixTrasposeType Trans=NO_TRANS);
/**
* @brief Copies the elements of a vector row-by-row into a matrix
* @param rV Source vector
* @param nRows Number of rows of returned matrix
* @param nCols Number of columns of returned matrix
*
* Note that rV.Dim() must equal nRows*nCols
*/
ThisType &
CopyVectorSplicedRows(const Vector<_ElemT> &rV, const size_t nRows, const size_t nCols);
/**
* @brief Returns @c true if matrix is initialized
*/
bool
IsInitialized() const
{ return mpData != NULL; }
/// Returns number of rows in the matrix
inline size_t
Rows() const
{
return mMRows;
}
/// Returns number of columns in the matrix
inline size_t
Cols() const
{
return mMCols;
}
/// Returns number of columns in the matrix memory
inline size_t
Stride() const
{
return mStride;
}
/**
* @brief Gives access to a specified matrix row without range check
* @return Pointer to the const array
*/
inline const _ElemT* __attribute__((aligned(16)))
pData () const
{
return mpData;
}
/**
* @brief Gives access to a specified matrix row without range check
* @return Pointer to the non-const data array
*/
inline _ElemT* __attribute__((aligned(16)))
pData ()
{
return mpData;
}
/**
* @brief pData_workaround is a workaround that allows SubMatrix to get a
* @return pointer to non-const data even though the Matrix is const...
*/
protected:
inline _ElemT* __attribute__((aligned(16)))
pData_workaround () const
{
return mpData;
}
public:
/// Returns size of matrix in memory
size_t
MSize() const
{
return mMRows * mStride * sizeof(_ElemT);
}
/// Checks the content of the matrix for nan and inf values
void
CheckData(const std::string file = "") const
{
for(size_t row=0; row<Rows(); row++) {
for(size_t col=0; col<Cols(); col++) {
if(isnan((*this)(row,col)) || isinf((*this)(row,col))) {
std::ostringstream os;
os << "Invalid value: " << (*this)(row,col)
<< " in matrix row: " << row
<< " col: " << col
<< " file: " << file;
Error(os.str());
}
}
}
}
/**
* **********************************************************************
* **************************************************************