summaryrefslogtreecommitdiff
path: root/tnet_io/KaldiLib/Vector.cc
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-05-29 23:06:58 +0800
committerDeterminant <[email protected]>2015-05-29 23:06:58 +0800
commit74b9f7cb88cd21cfac3c2e50c8efb802485df0c5 (patch)
treebd6e583088a086144acc2d8af3eaca59691194ff /tnet_io/KaldiLib/Vector.cc
init
Diffstat (limited to 'tnet_io/KaldiLib/Vector.cc')
-rw-r--r--tnet_io/KaldiLib/Vector.cc110
1 files changed, 110 insertions, 0 deletions
diff --git a/tnet_io/KaldiLib/Vector.cc b/tnet_io/KaldiLib/Vector.cc
new file mode 100644
index 0000000..020bae2
--- /dev/null
+++ b/tnet_io/KaldiLib/Vector.cc
@@ -0,0 +1,110 @@
+#ifndef TNet_Vector_cc
+#define TNet_Vector_cc
+
+#include <cstdlib>
+#include <cmath>
+#include <cstring>
+#include <fstream>
+#include <iomanip>
+#include "Common.h"
+
+#ifdef HAVE_ATLAS
+extern "C"{
+ #include <cblas.h>
+}
+#endif
+
+#include "Common.h"
+#include "Matrix.h"
+#include "Vector.h"
+
+namespace TNet
+{
+
+#ifdef HAVE_ATLAS
+ template<>
+ float
+ BlasDot<>(const Vector<float>& rA, const Vector<float>& rB)
+ {
+ assert(rA.mDim == rB.mDim);
+ return cblas_sdot(rA.mDim, rA.pData(), 1, rB.pData(), 1);
+ }
+
+ template<>
+ double
+ BlasDot<>(const Vector<double>& rA, const Vector<double>& rB)
+ {
+ assert(rA.mDim == rB.mDim);
+ return cblas_ddot(rA.mDim, rA.pData(), 1, rB.pData(), 1);
+ }
+
+ template<>
+ Vector<float>&
+ Vector<float>::
+ BlasAxpy(const float alpha, const Vector<float>& rV)
+ {
+ assert(mDim == rV.mDim);
+ cblas_saxpy(mDim, alpha, rV.pData(), 1, mpData, 1);
+ return *this;
+ }
+
+ template<>
+ Vector<double>&
+ Vector<double>::
+ BlasAxpy(const double alpha, const Vector<double>& rV)
+ {
+ assert(mDim == rV.mDim);
+ cblas_daxpy(mDim, alpha, rV.pData(), 1, mpData, 1);
+ return *this;
+ }
+
+ template<>
+ Vector<int>&
+ Vector<int>::
+ BlasAxpy(const int alpha, const Vector<int>& rV)
+ {
+ assert(mDim == rV.mDim);
+ for(int i=0; i<Dim(); i++) {
+ (*this)[i] += rV[i];
+ }
+ return *this;
+ }
+
+
+ template<>
+ Vector<float>&
+ Vector<float>::
+ BlasGemv(const float alpha, const Matrix<float>& rM, MatrixTrasposeType trans, const Vector<float>& rV, const float beta)
+ {
+ assert((trans == NO_TRANS && rM.Cols() == rV.mDim && rM.Rows() == mDim)
+ || (trans == TRANS && rM.Rows() == rV.mDim && rM.Cols() == mDim));
+
+ cblas_sgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), rM.Rows(), rM.Cols(), alpha, rM.pData(), rM.Stride(),
+ rV.pData(), 1, beta, mpData, 1);
+ return *this;
+ }
+
+
+
+ template<>
+ Vector<double>&
+ Vector<double>::
+ BlasGemv(const double alpha, const Matrix<double>& rM, MatrixTrasposeType trans, const Vector<double>& rV, const double beta)
+ {
+ assert((trans == NO_TRANS && rM.Cols() == rV.mDim && rM.Rows() == mDim)
+ || (trans == TRANS && rM.Rows() == rV.mDim && rM.Cols() == mDim));
+
+ cblas_dgemv(CblasRowMajor, static_cast<CBLAS_TRANSPOSE>(trans), rM.Rows(), rM.Cols(), alpha, rM.pData(), rM.Stride(),
+ rV.pData(), 1, beta, mpData, 1);
+ return *this;
+ }
+
+
+#else
+ #error Routines in this section are not implemented yet without BLAS
+#endif
+
+} // namespace TNet
+
+
+#endif // TNet_Vector_tcc