diff options
author | Determinant <[email protected]> | 2015-05-29 23:06:58 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-05-29 23:06:58 +0800 |
commit | 74b9f7cb88cd21cfac3c2e50c8efb802485df0c5 (patch) | |
tree | bd6e583088a086144acc2d8af3eaca59691194ff /tnet_io/KaldiLib/Matrix.h |
init
Diffstat (limited to 'tnet_io/KaldiLib/Matrix.h')
-rw-r--r-- | tnet_io/KaldiLib/Matrix.h | 677 |
1 files changed, 677 insertions, 0 deletions
diff --git a/tnet_io/KaldiLib/Matrix.h b/tnet_io/KaldiLib/Matrix.h new file mode 100644 index 0000000..d33cb0c --- /dev/null +++ b/tnet_io/KaldiLib/Matrix.h @@ -0,0 +1,677 @@ +#ifndef TNet_Matrix_h +#define TNet_Matrix_h + +#include <stddef.h> +#include <stdlib.h> +#include <stdexcept> +#include <iostream> + +#ifdef HAVE_ATLAS +extern "C"{ + #include <cblas.h> + #include <clapack.h> +} +#endif + +#include "Common.h" +#include "MathAux.h" +#include "Types.h" +#include "Error.h" + +//#define TRACE_MATRIX_OPERATIONS +#define CHECKSIZE + +namespace TNet +{ + + + // class matrix_error : public std::logic_error {}; + // class matrix_sizes_error : public matrix_error {}; + + // declare the class so the header knows about it + template<typename _ElemT> class Vector; + template<typename _ElemT> class SubVector; + template<typename _ElemT> class Matrix; + template<typename _ElemT> class SubMatrix; + + // we need to declare the friend << operator here + template<typename _ElemT> + std::ostream & operator << (std::ostream & rOut, const Matrix<_ElemT> & rM); + + // we need to declare the friend << operator here + template<typename _ElemT> + std::istream & operator >> (std::istream & rIn, Matrix<_ElemT> & rM); + + // we need to declare this friend function here + template<typename _ElemT> + _ElemT TraceOfProduct(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B) + + // we need to declare this friend function here + template<typename _ElemT> + _ElemT TraceOfProductT(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B^T)==tr(A^T B) + + + /** ************************************************************************** + ** ************************************************************************** + * @brief Provides a matrix class + * + * This class provides a way to work with matrices in TNet. + * It encapsulates basic operations and memory optimizations. + * + */ + template<typename _ElemT> + class Matrix + { + public: + /// defines a transpose type + + struct HtkHeader + { + INT_32 mNSamples; + INT_32 mSamplePeriod; + INT_16 mSampleSize; + UINT_16 mSampleKind; + }; + + + /** + * @brief Extension of the HTK header + */ + struct HtkHeaderExt + { + INT_32 mHeaderSize; + INT_32 mVersion; + INT_32 mSampSize; + }; + + + + + /// defines a type of this + typedef Matrix<_ElemT> ThisType; + + // Constructors + + /// Empty constructor + Matrix<_ElemT> (): + mpData(NULL), mMCols(0), mMRows(0), mStride(0) +#ifdef STK_MEMALIGN_MANUAL + , mpFreeData(NULL) +#endif + {} + + /// Copy constructor + Matrix<_ElemT> (const Matrix<_ElemT> & rM, MatrixTrasposeType trans=NO_TRANS): + mpData(NULL) + { if(trans==NO_TRANS){ Init(rM.mMRows, rM.mMCols); Copy(rM); } else { Init(rM.mMCols,rM.mMRows); Copy(rM,TRANS); } } + + /// Copy constructor from another type. + template<typename _ElemU> + explicit Matrix<_ElemT> (const Matrix<_ElemU> & rM, MatrixTrasposeType trans=NO_TRANS): + mpData(NULL) + { if(trans==NO_TRANS){ Init(rM.Rows(), rM.Cols()); Copy(rM); } else { Init(rM.Cols(),rM.Rows()); Copy(rM,TRANS); } } + + /// Basic constructor + Matrix(const size_t r, const size_t c, bool clear=true) + { mpData=NULL; Init(r, c, clear); } + + + Matrix<_ElemT> &operator = (const Matrix <_ElemT> &other) { Init(other.Rows(), other.Cols()); Copy(other); return *this; } // Needed for inclusion in std::vector + + /// Destructor + ~Matrix() + { Destroy(); } + + + /// Initializes matrix (if not done by constructor) + ThisType & + Init(const size_t r, + const size_t c, bool clear=true); + + /** + * @brief Dealocates the matrix from memory and resets the dimensions to (0, 0) + */ + void + Destroy(); + + + ThisType & + Zero(); + + ThisType & + Unit(); // set to unit. + + /** + * @brief Copies the contents of a matrix + * @param rM Source data matrix + * @return Returns reference to this + */ + template<typename _ElemU> ThisType & + Copy(const Matrix<_ElemU> & rM, MatrixTrasposeType Trans=NO_TRANS); + + + + /** + * @brief Copies the elements of a vector row-by-row into a matrix + * @param rV Source vector + * @param nRows Number of rows of returned matrix + * @param nCols Number of columns of returned matrix + * + * Note that rV.Dim() must equal nRows*nCols + */ + ThisType & + CopyVectorSplicedRows(const Vector<_ElemT> &rV, const size_t nRows, const size_t nCols); + + /** + * @brief Returns @c true if matrix is initialized + */ + bool + IsInitialized() const + { return mpData != NULL; } + + /// Returns number of rows in the matrix + inline size_t + Rows() const + { + return mMRows; + } + + /// Returns number of columns in the matrix + inline size_t + Cols() const + { + return mMCols; + } + + /// Returns number of columns in the matrix memory + inline size_t + Stride() const + { + return mStride; + } + + + /** + * @brief Gives access to a specified matrix row without range check + * @return Pointer to the const array + */ + inline const _ElemT* __attribute__((aligned(16))) + pData () const + { + return mpData; + } + + + /** + * @brief Gives access to a specified matrix row without range check + * @return Pointer to the non-const data array + */ + inline _ElemT* __attribute__((aligned(16))) + pData () + { + return mpData; + } + + + /** + * @brief pData_workaround is a workaround that allows SubMatrix to get a + * @return pointer to non-const data even though the Matrix is const... + */ + protected: + inline _ElemT* __attribute__((aligned(16))) + pData_workaround () const + { + return mpData; + } + public: + + + /// Returns size of matrix in memory + size_t + MSize() const + { + return mMRows * mStride * sizeof(_ElemT); + } + + /// Checks the content of the matrix for nan and inf values + void + CheckData(const std::string file = "") const + { + for(size_t row=0; row<Rows(); row++) { + for(size_t col=0; col<Cols(); col++) { + if(isnan((*this)(row,col)) || isinf((*this)(row,col))) { + std::ostringstream os; + os << "Invalid value: " << (*this)(row,col) + << " in matrix row: " << row + << " col: " << col + << " file: " << file; + Error(os.str()); + } + } + } + } + + /** + * ********************************************************************** + * ********************************************************************** + * @defgroup RESHAPE Matrix reshaping rutines + * ********************************************************************** + * ********************************************************************** + * @{ + */ + + /** + * @brief Removes one row from the matrix. The memory is not reallocated. + */ + ThisType & + RemoveRow(size_t i); + + /** @} */ + + /** + * ********************************************************************** + * ********************************************************************** + * @defgroup ACCESS Access functions and operators + * ********************************************************************** + * ********************************************************************** + * @{ + */ + + /** + * @brief Gives access to a specified matrix row without range check + * @return Subvector object representing the row + */ + inline const SubVector<_ElemT> + operator [] (size_t i) const + { + assert(i < mMRows); + return SubVector<_ElemT>(mpData + (i * mStride), Cols()); + } + + inline SubVector<_ElemT> + operator [] (size_t i) + { + assert(i < mMRows); + return SubVector<_ElemT>(mpData + (i * mStride), Cols()); + } + + /** + * @brief Gives access to a specified matrix row without range check + * @return pointer to the first field of the row + */ + inline _ElemT* + pRowData(size_t i) + { + assert(i < mMRows); + return mpData + i * mStride; + } + + /** + * @brief Gives access to a specified matrix row without range check + * @return pointer to the first field of the row (const version) + */ + inline const _ElemT* + pRowData(size_t i) const + { + assert(i < mMRows); + return mpData + i * mStride; + } + + /** + * @brief Gives access to matrix elements (row, col) + * @return reference to the desired field + */ + inline _ElemT& + operator () (size_t r, size_t c) + { +#ifdef PARANOID + assert(r < mMRows && c < mMCols); +#endif + return *(mpData + r * mStride + c); + } + + /** + * @brief Gives access to matrix elements (row, col) + * @return pointer to the desired field (const version) + */ + inline const _ElemT + operator () (size_t r, size_t c) const + { +#ifdef PARANOID + assert(r < mMRows && c < mMCols); +#endif + return *(mpData + r * mStride + c); + } + + /** + * @brief Returns a matrix sub-range + * @param ro Row offset + * @param r Rows in range + * @param co Column offset + * @param c Coluns in range + * See @c SubMatrix class for details + */ + SubMatrix<_ElemT> + Range(const size_t ro, const size_t r, + const size_t co, const size_t c) + { return SubMatrix<_ElemT>(*this, ro, r, co, c); } + + const SubMatrix<_ElemT> + Range(const size_t ro, const size_t r, + const size_t co, const size_t c) const + { return SubMatrix<_ElemT>(*this, ro, r, co, c); } + /** @} */ + + + /** + * ********************************************************************** + * ********************************************************************** + * @defgroup MATH ROUTINES + * ********************************************************************** + * ********************************************************************** + * @{ + **/ + + /** + * @brief Returns sum of all elements + */ + _ElemT& + Sum() const; + + ThisType & + DotMul(const ThisType& a); + + ThisType & + Scale(_ElemT alpha); + + ThisType & + ScaleCols(const Vector<_ElemT> &scale); // Equivalent to (*this) = (*this) * diag(scale). + + ThisType & + ScaleRows(const Vector<_ElemT> &scale); // Equivalent to (*this) = diag(scale) * (*this); + + /// Sum another matrix rMatrix with this matrix + ThisType& + Add(const Matrix<_ElemT>& rMatrix); + + + /// Sum scaled matrix rMatrix with this matrix + ThisType& + AddScaled(_ElemT alpha, const Matrix<_ElemT>& rMatrix); + + /// Apply log to all items of the matrix + ThisType& + ApplyLog(); + + /** + * @brief Computes the determinant of this matrix + * @return Returns the determinant of a matrix + * @ingroup MATH + * + */ + _ElemT LogAbsDeterminant(_ElemT *DetSign=NULL); + + + /** + * @brief Performs matrix inplace inversion + */ + ThisType & + Invert(_ElemT *LogDet=NULL, _ElemT *DetSign=NULL, bool inverse_needed=true); + + /** + * @brief Performs matrix inplace inversion in double precision, even if this object is not double precision. + */ + ThisType & + InvertDouble(_ElemT *LogDet=NULL, _ElemT *DetSign=NULL, bool inverse_needed=true){ + double LogDet_tmp, DetSign_tmp; + Matrix<double> dmat(*this); dmat.Invert(&LogDet_tmp, &DetSign_tmp, inverse_needed); if(inverse_needed) (*this).Copy(dmat); + if(LogDet) *LogDet = LogDet_tmp; if(DetSign) *DetSign = DetSign_tmp; + return *this; + } + + + /** + * @brief Inplace matrix transposition. Applicable only to square matrices + */ + ThisType & + Transpose() + { + assert(Rows()==Cols()); + size_t M=Rows(); + for(size_t i=0;i<M;i++) + for(size_t j=0;j<i;j++){ + _ElemT &a = (*this)(i,j), &b = (*this)(j,i); + std::swap(a,b); + } + return *this; + } + + + + + + bool IsSymmetric(_ElemT cutoff = 1.0e-05) const; + + bool IsDiagonal(_ElemT cutoff = 1.0e-05) const; + + bool IsUnit(_ElemT cutoff = 1.0e-05) const; + + bool IsZero(_ElemT cutoff = 1.0e-05) const; + + _ElemT FrobeniusNorm() const; // sqrt of sum of square elements. + + _ElemT LargestAbsElem() const; // largest absolute value. + + + friend _ElemT TNet::TraceOfProduct<_ElemT>(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B) + friend _ElemT TNet::TraceOfProductT<_ElemT>(const Matrix<_ElemT> &A, const Matrix<_ElemT> &B); // tr(A B^T)==tr(A^T B) + friend class SubMatrix<_ElemT>; // so it can get around const restrictions on the pointer to mpData. + + /** ********************************************************************** + * ********************************************************************** + * @defgroup BLAS_ROUTINES BLAS ROUTINES + * @ingroup MATH + * ********************************************************************** + * ********************************************************************** + **/ + + ThisType & + BlasGer(const _ElemT alpha, const Vector<_ElemT>& rA, const Vector<_ElemT>& rB); + + ThisType & + Axpy(const _ElemT alpha, const Matrix<_ElemT> &rM, MatrixTrasposeType transA=NO_TRANS); + + ThisType & + BlasGemm(const _ElemT alpha, + const ThisType& rA, MatrixTrasposeType transA, + const ThisType& rB, MatrixTrasposeType transB, + const _ElemT beta = 0.0); + + + /** @} */ + + + /** ********************************************************************** + * ********************************************************************** + * @defgroup IO Input/Output ROUTINES + * ********************************************************************** + * ********************************************************************** + * @{ + **/ + + friend std::ostream & + operator << <> (std::ostream & out, const ThisType & m); + + void PrintOut(char *file); + void ReadIn(char *file); + + + bool + LoadHTK(const char* pFileName); + + /** @} */ + + + protected: +// inline void swap4b(void *a); +// inline void swap2b(void *a); + + + protected: + /// data memory area + _ElemT* mpData; + + /// these atributes store the real matrix size as it is stored in memory + /// including memalignment + size_t mMCols; ///< Number of columns + size_t mMRows; ///< Number of rows + size_t mStride; ///< true number of columns for the internal matrix. + ///< This number may differ from M_cols as memory + ///< alignment might be used + +#ifdef STK_MEMALIGN_MANUAL + /// data to be freed (in case of manual memalignment use, see Common.h) + _ElemT* mpFreeData; +#endif + }; // class Matrix + + template<> Matrix<float> & Matrix<float>::Invert(float *LogDet, float *DetSign, bool inverse_needed); // state that we will implement separately for float and double. + template<> Matrix<double> & Matrix<double>::Invert(double *LogDet, double *DetSign, bool inverse_needed); + + + + /** ************************************************************************** + ** ************************************************************************** + * @brief Sub-matrix representation + * + * This class provides a way to work with matrix cutouts in STK. + * + * + */ + template<typename _ElemT> + class SubMatrix : public Matrix<_ElemT> + { + typedef SubMatrix<_ElemT> ThisType; + + public: + /// Constructor + SubMatrix(const Matrix<_ElemT>& rT, // Input matrix cannot be const because SubMatrix can change its contents. + const size_t ro, + const size_t r, + const size_t co, + const size_t c); + + + /// The destructor + ~SubMatrix<_ElemT>() + { +#ifndef STK_MEMALIGN_MANUAL + Matrix<_ElemT>::mpData = NULL; +#else + Matrix<_ElemT>::mpFreeData = NULL; +#endif + } + + /// Assign operator + ThisType& operator=(const ThisType& rSrc) + { + //std::cout << "[PERFORMing operator= SubMatrix&^2]" << std::flush; + this->mpData = rSrc.mpData; + this->mMCols = rSrc.mMCols; + this->mMRows = rSrc.mMRows; + this->mStride = rSrc.mStride; + this->mpFreeData = rSrc.mpFreeData; + return *this; + } + + + + /// Initializes matrix (if not done by constructor) + ThisType & + Init(const size_t r, + const size_t c, bool clear=true) + { Error("Submatrix cannot do Init"); return *this; } + + /** + * @brief Dealocates the matrix from memory and resets the dimensions to (0, 0) + */ + void + Destroy() + { Error("Submatrix cannot do Destroy"); } + + + + }; + + + + //Create useful shortcuts + typedef Matrix<BaseFloat> BfMatrix; + typedef SubMatrix<BaseFloat> BfSubMatrix; + + /** + * Function for summing matrices of different types + */ + template<typename _ElemT, typename _ElemU> + void Add(Matrix<_ElemT>& rDst, const Matrix<_ElemU>& rSrc) { + assert(rDst.Cols() == rSrc.Cols()); + assert(rDst.Rows() == rSrc.Rows()); + + for(size_t i=0; i<rDst.Rows(); i++) { + const _ElemU* p_src = rSrc.pRowData(i); + _ElemT* p_dst = rDst.pRowData(i); + for(size_t j=0; j<rDst.Cols(); j++) { + *p_dst++ += (_ElemT)*p_src++; + } + } + } + + /** + * Function for summing matrices of different types + */ + template<typename _ElemT, typename _ElemU> + void AddScaled(Matrix<_ElemT>& rDst, const Matrix<_ElemU>& rSrc, _ElemT scale) { + assert(rDst.Cols() == rSrc.Cols()); + assert(rDst.Rows() == rSrc.Rows()); + + Vector<_ElemT> tmp(rDst[0]); + + for(size_t i=0; i<rDst.Rows(); i++) { + tmp.Copy(rSrc[i]); + rDst[i].BlasAxpy(scale, tmp); + + /* + const _ElemU* p_src = rSrc.pRowData(i); + _ElemT* p_dst = rDst.pRowData(i); + for(size_t j=0; j<rDst.Cols(); j++) { + *p_dst++ += (_ElemT)(*p_src++) * scale; + } + */ + } + } + + + + + +} // namespace STK + + + +//***************************************************************************** +//***************************************************************************** +// we need to include the implementation +#include "Matrix.tcc" +//***************************************************************************** +//***************************************************************************** + + +/****************************************************************************** + ****************************************************************************** + * The following section contains specialized template definitions + * whose implementation is in Matrix.cc + */ + + +//#ifndef TNet_Matrix_h +#endif |