From 96a32415ab43377cf1575bd3f4f2980f58028209 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 14 Aug 2015 11:51:42 +0800 Subject: add implementation for kaldi io (by ymz) --- .../tools/ATLAS/include/contrib/ATL_gemv_ger_SSE.h | 188 +++++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 kaldi_io/src/tools/ATLAS/include/contrib/ATL_gemv_ger_SSE.h (limited to 'kaldi_io/src/tools/ATLAS/include/contrib/ATL_gemv_ger_SSE.h') diff --git a/kaldi_io/src/tools/ATLAS/include/contrib/ATL_gemv_ger_SSE.h b/kaldi_io/src/tools/ATLAS/include/contrib/ATL_gemv_ger_SSE.h new file mode 100644 index 0000000..118d3de --- /dev/null +++ b/kaldi_io/src/tools/ATLAS/include/contrib/ATL_gemv_ger_SSE.h @@ -0,0 +1,188 @@ +#ifdef GER +#undef NO_TRANSPOSE +#define NO_TRANSPOSE +#endif + + +#if NDPM > 4 +#error Max NDPM is 4 +#endif + +#if !defined(ATL_SSE1) && ( defined(SREAL) || defined(SCPLX) ) +#error This routine needs ATL_SSE1 defined +#endif + +#if !defined(ATL_SSE2) && ( defined(DREAL) || defined(DCPLX) ) +#error This routine needs ATL_SSE2 defined +#endif + +#include +#include + +#include "camm_util.h" + +#ifndef GER +#if defined(BETAX) || defined(BETAXI0) +#include "camm_scale.h" +#endif +#endif + +#if NDPM >= 4 +#define EXT4 Mjoin(4dp,BLC) +#undef NDP +#define NDP 4 +#undef EXT +#define EXT EXT4 +#include "camm_dpa.h" +#endif + +#if NDPM >= 3 +#define EXT3 Mjoin(3dp,BLC) +#undef NDP +#define NDP 3 +#undef EXT +#define EXT EXT3 +#include "camm_dpa.h" +#endif + +#if NDPM >= 2 +#define EXT2 Mjoin(2dp,BLC) +#undef NDP +#define NDP 2 +#undef EXT +#define EXT EXT2 +#include "camm_dpa.h" +#endif + +#define EXT1 Mjoin(1dp,BLC) +#undef NDP +#define NDP 1 +#undef EXT +#define EXT EXT1 +#include "camm_dpa.h" + +#undef NDP +#define NDP NDPM +#undef EXT +#define EXT Mjoin(Mjoin(NDP,Mjoin(dp,BLC)),m) +#include "camm_dpa.h" + +#ifdef GER +#if defined(SCPLX) || defined(DCPLX) +#ifdef Conj_ +#define IM 1c +#else +#define IM 1u +#endif +#else +#define IM 1 +#endif + + +#define FN Mjoin(Mjoin(Mjoin(ATL_,PREC),Mjoin(ger,IM)),_a1_x1_yX) + +#undef MY_FUNCTION +#define MY_FUNCTION FN + +void +MY_FUNCTION(int m,int n, const SCALAR alpha,const TYPE *c, + int cinc,const TYPE *b,int binc, + TYPE *a,int lda) { + +#else + + +#define FN Mjoin(Mjoin(Mjoin(ATL_,PREC),gemv),Mjoin(FEXT,Mjoin(_a1_x1_,Mjoin(BL,_y1)))) + +#undef MY_FUNCTION +#define MY_FUNCTION FN + +void +MY_FUNCTION(int m,int n, const SCALAR alpha,const TYPE *a, + int lda,const TYPE *b,int binc, + const SCALAR beta,TYPE *c,int cinc) { + +#endif + + int i,mm,nn; + const TYPE *ae; +#ifdef NO_TRANSPOSE + int len=m,w=n; +#define zz b +#else + int len=n,w=m; +#define zz c +#endif + +#ifdef GER +#define zzinc binc +#else +#define zzinc 1 + + +#if defined(NO_TRANSPOSE) && defined(BETA0) + memset(c,0,m*sizeof(*c)); +#endif + +#if defined(BETAX) || defined(BETAXI0) +#if defined(SCPLX) || defined(DCPLX) + SCALE(beta,c,m); +#endif +#if defined(SREAL) || defined(DREAL) + SCALE(&beta,c,m); +#endif +#endif + +#endif + + ae=a+w*lda; + nn=STRIDE*lda; + + +#if NDPM == 1 + for (;a 1 + if (((ae-a)/lda)%STRIDE) + mm++; +#endif + + if (mm == 1) + Mjoin(dp,EXT1)(a,nn,b,c,STRIDE*zzinc,len); + +#if ( NDPM == 2 && STRIDE > 1 ) || NDPM > 2 + else if (mm == 2) + Mjoin(dp,EXT2)(a,nn,b,c,STRIDE*zzinc,len); +#endif + +#if ( NDPM == 3 && STRIDE > 1 ) || NDPM > 3 + else if (mm == 3) + Mjoin(dp,EXT3)(a,nn,b,c,STRIDE*zzinc,len); +#endif + +#if ( NDPM == 4 && STRIDE > 1 ) || NDPM > 4 + else if (mm == 4) + Mjoin(dp,EXT4)(a,nn,b,c,STRIDE*zzinc,len); +#endif + + + } + +#endif + +} + -- cgit v1.2.3