// matrix/cblas-wrappers.h
// Copyright 2012 Johns Hopkins University (author: Daniel Povey);
// Haihua Xu; Wei Shi
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_MATRIX_CBLAS_WRAPPERS_H_
#define KALDI_MATRIX_CBLAS_WRAPPERS_H_ 1
#include <limits>
#include "matrix/sp-matrix.h"
#include "matrix/kaldi-vector.h"
#include "matrix/kaldi-matrix.h"
#include "matrix/matrix-functions.h"
// Do not include this file directly. It is to be included
// by .cc files in this directory.
namespace kaldi {
inline void cblas_Xcopy(const int N, const float *X, const int incX, float *Y,
const int incY) {
cblas_scopy(N, X, incX, Y, incY);
}
inline void cblas_Xcopy(const int N, const double *X, const int incX, double *Y,
const int incY) {
cblas_dcopy(N, X, incX, Y, incY);
}
inline float cblas_Xasum(const int N, const float *X, const int incX) {
return cblas_sasum(N, X, incX);
}
inline double cblas_Xasum(const int N, const double *X, const int incX) {
return cblas_dasum(N, X, incX);
}
inline void cblas_Xrot(const int N, float *X, const int incX, float *Y,
const int incY, const float c, const float s) {
cblas_srot(N, X, incX, Y, incY, c, s);
}
inline void cblas_Xrot(const int N, double *X, const int incX, double *Y,
const int incY, const double c, const double s) {
cblas_drot(N, X, incX, Y, incY, c, s);
}
inline float cblas_Xdot(const int N, const float *const X,
const int incX, const float *const Y,
const int incY) {
return cblas_sdot(N, X, incX, Y, incY);
}
inline double cblas_Xdot(const int N, const double *const X,
const int incX, const double *const Y,
const int incY) {
return cblas_ddot(N, X, incX, Y, incY);
}
inline void cblas_Xaxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY) {
cblas_saxpy(N, alpha, X, incX, Y, incY);
}
inline void cblas_Xaxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY) {
cblas_daxpy(N, alpha, X, incX, Y, incY);
}
inline void cblas_Xscal(const int N, const float alpha, float *data,
const int inc) {
cblas_sscal(N, alpha, data, inc);
}
inline void cblas_Xscal(const int N, const double alpha, double *data,
const int inc) {
cblas_dscal(N, alpha, data, inc);
}
inline void cblas_Xspmv(const float alpha, const int num_rows, const float *Mdata,
const float *v, const int v_inc,
const float beta, float *y, const int y_inc) {
cblas_sspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc);
}
inline void cblas_Xspmv(const double alpha, const int num_rows, const double *Mdata,
const double *v, const int v_inc,
const double beta, double *y, const int y_inc) {
cblas_dspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc);
}
inline void cblas_Xtpmv(MatrixTransposeType trans, const float *Mdata,
const int num_rows, float *y, const int y_inc) {
cblas_stpmv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
CblasNonUnit, num_rows, Mdata, y, y_inc);
}
inline void cblas_Xtpmv(MatrixTransposeType trans, const double *Mdata,
const int num_rows, double *y, const int y_inc) {
cblas_dtpmv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
CblasNonUnit, num_rows, Mdata, y, y_inc);
}
inline void cblas_Xtpsv(MatrixTransposeType trans, const float *Mdata,
const int num_rows, float *y, const int y_inc) {
cblas_stpsv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
CblasNonUnit, num_rows, Mdata, y, y_inc);
}
inline void cblas_Xtpsv(MatrixTransposeType trans, const double *Mdata,
const int num_rows, double *y, const int y_inc) {
cblas_dtpsv(CblasRowMajor, CblasLower, static_cast<CBLAS_TRANSPOSE>(trans),
CblasNonUnit, num_rows, Mdata, y, y_inc);
}
// x = alpha * M * y + beta * x
inline void cblas_Xspmv(MatrixIndexT dim, float alpha, const float *Mdata,
const float *ydata, MatrixIndexT ystride,
float beta, float *xdata, MatrixIndexT xstride) {
cblas_sspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata,
ydata, ystride, beta, xdata, xstride);
}
inline void cblas_Xspmv(MatrixIndexT dim, double alpha, const double *Mdata,
const double *ydata, MatrixIndexT ystride,
double beta, double *xdata, MatrixIndexT xstride) {
cblas_dspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata,
ydata, ystride, beta, xdata, xstride);
}
// Implements A += alpha * (x y' + y x'); A is symmetric matrix.
inline void cblas_Xspr2(MatrixIndexT dim, float alpha, const float *Xdata,
MatrixIndexT incX, const float *Ydata, MatrixIndexT incY,
float *Adata) {
cblas_sspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata,
incX, Ydata, incY, Adata);
}
inline void cblas_Xspr2