diff options
Diffstat (limited to 'nerv/lib')
-rw-r--r-- | nerv/lib/cblas.h | 596 | ||||
-rw-r--r-- | nerv/lib/matrix/cumatrix.c | 52 | ||||
-rw-r--r-- | nerv/lib/matrix/cumatrix.h | 3 | ||||
-rw-r--r-- | nerv/lib/matrix/generic/mmatrix.c | 2 |
4 files changed, 645 insertions, 8 deletions
diff --git a/nerv/lib/cblas.h b/nerv/lib/cblas.h new file mode 100644 index 0000000..4087ffb --- /dev/null +++ b/nerv/lib/cblas.h @@ -0,0 +1,596 @@ +#ifndef CBLAS_H + +#ifndef CBLAS_ENUM_DEFINED_H + #define CBLAS_ENUM_DEFINED_H + enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 }; + enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113, + AtlasConj=114}; + enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; + enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; + enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; +#endif + +#ifndef CBLAS_ENUM_ONLY +#define CBLAS_H +#define CBLAS_INDEX int + +int cblas_errprn(int ierr, int info, char *form, ...); + +/* + * =========================================================================== + * Prototypes for level 1 BLAS functions (complex are recast as routines) + * =========================================================================== + */ +float cblas_sdsdot(const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY); +double cblas_dsdot(const int N, const float *X, const int incX, const float *Y, + const int incY); +float cblas_sdot(const int N, const float *X, const int incX, + const float *Y, const int incY); +double cblas_ddot(const int N, const double *X, const int incX, + const double *Y, const int incY); +/* + * Functions having prefixes Z and C only + */ +void cblas_cdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); +void cblas_cdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + +void cblas_zdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); +void cblas_zdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + + +/* + * Functions having prefixes S D SC DZ + */ +float cblas_snrm2(const int N, const float *X, const int incX); +float cblas_sasum(const int N, const float *X, const int incX); + +double cblas_dnrm2(const int N, const double *X, const int incX); +double cblas_dasum(const int N, const double *X, const int incX); + +float cblas_scnrm2(const int N, const void *X, const int incX); +float cblas_scasum(const int N, const void *X, const int incX); + +double cblas_dznrm2(const int N, const void *X, const int incX); +double cblas_dzasum(const int N, const void *X, const int incX); + + +/* + * Functions having standard 4 prefixes (S D C Z) + */ +CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX); +CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX); +CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX); +CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX); + +/* + * =========================================================================== + * Prototypes for level 1 BLAS routines + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (s, d, c, z) + */ +void cblas_sswap(const int N, float *X, const int incX, + float *Y, const int incY); +void cblas_scopy(const int N, const float *X, const int incX, + float *Y, const int incY); +void cblas_saxpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY); +void catlas_saxpby(const int N, const float alpha, const float *X, + const int incX, const float beta, float *Y, const int incY); +void catlas_sset + (const int N, const float alpha, float *X, const int incX); + +void cblas_dswap(const int N, double *X, const int incX, + double *Y, const int incY); +void cblas_dcopy(const int N, const double *X, const int incX, + double *Y, const int incY); +void cblas_daxpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY); +void catlas_daxpby(const int N, const double alpha, const double *X, + const int incX, const double beta, double *Y, const int incY); +void catlas_dset + (const int N, const double alpha, double *X, const int incX); + +void cblas_cswap(const int N, void *X, const int incX, + void *Y, const int incY); +void cblas_ccopy(const int N, const void *X, const int incX, + void *Y, const int incY); +void cblas_caxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); +void catlas_caxpby(const int N, const void *alpha, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void catlas_cset + (const int N, const void *alpha, void *X, const int incX); + +void cblas_zswap(const int N, void *X, const int incX, + void *Y, const int incY); +void cblas_zcopy(const int N, const void *X, const int incX, + void *Y, const int incY); +void cblas_zaxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); +void catlas_zaxpby(const int N, const void *alpha, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void catlas_zset + (const int N, const void *alpha, void *X, const int incX); + + +/* + * Routines with S and D prefix only + */ +void cblas_srotg(float *a, float *b, float *c, float *s); +void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); +void cblas_srot(const int N, float *X, const int incX, + float *Y, const int incY, const float c, const float s); +void cblas_srotm(const int N, float *X, const int incX, + float *Y, const int incY, const float *P); + +void cblas_drotg(double *a, double *b, double *c, double *s); +void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); +void cblas_drot(const int N, double *X, const int incX, + double *Y, const int incY, const double c, const double s); +void cblas_drotm(const int N, double *X, const int incX, + double *Y, const int incY, const double *P); + + +/* + * Routines with S D C Z CS and ZD prefixes + */ +void cblas_sscal(const int N, const float alpha, float *X, const int incX); +void cblas_dscal(const int N, const double alpha, double *X, const int incX); +void cblas_cscal(const int N, const void *alpha, void *X, const int incX); +void cblas_zscal(const int N, const void *alpha, void *X, const int incX); +void cblas_csscal(const int N, const float alpha, void *X, const int incX); +void cblas_zdscal(const int N, const double alpha, void *X, const int incX); + +/* + * Extra reference routines provided by ATLAS, but not mandated by the standard + */ +void cblas_crotg(void *a, void *b, void *c, void *s); +void cblas_zrotg(void *a, void *b, void *c, void *s); +void cblas_csrot(const int N, void *X, const int incX, void *Y, const int incY, + const float c, const float s); +void cblas_zdrot(const int N, void *X, const int incX, void *Y, const int incY, + const double c, const double s); + +/* + * =========================================================================== + * Prototypes for level 2 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void cblas_sgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY); +void cblas_sgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const float alpha, + const float *A, const int lda, const float *X, + const int incX, const float beta, float *Y, const int incY); +void cblas_strmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, + float *X, const int incX); +void cblas_stbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX); +void cblas_stpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX); +void cblas_strsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, float *X, + const int incX); +void cblas_stbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX); +void cblas_stpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX); + +void cblas_dgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY); +void cblas_dgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const double alpha, + const double *A, const int lda, const double *X, + const int incX, const double beta, double *Y, const int incY); +void cblas_dtrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, + double *X, const int incX); +void cblas_dtbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX); +void cblas_dtpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX); +void cblas_dtrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, double *X, + const int incX); +void cblas_dtbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX); +void cblas_dtpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX); + +void cblas_cgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); +void cblas_cgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const void *alpha, + const void *A, const int lda, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void cblas_ctrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX); +void cblas_ctbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ctpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); +void cblas_ctrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX); +void cblas_ctbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ctpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); + +void cblas_zgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); +void cblas_zgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const void *alpha, + const void *A, const int lda, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void cblas_ztrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX); +void cblas_ztbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ztpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); +void cblas_ztrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX); +void cblas_ztbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ztpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); + + +/* + * Routines with S and D prefixes only + */ +void cblas_ssymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_ssbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const float alpha, const float *A, + const int lda, const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_sspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *Ap, + const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_sger(const enum CBLAS_ORDER Order, const int M, const int N, + const float alpha, const float *X, const int incX, + const float *Y, const int incY, float *A, const int lda); +void cblas_ssyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *A, const int lda); +void cblas_sspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *Ap); +void cblas_ssyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A, + const int lda); +void cblas_sspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A); + +void cblas_dsymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dsbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const double alpha, const double *A, + const int lda, const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *Ap, + const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dger(const enum CBLAS_ORDER Order, const int M, const int N, + const double alpha, const double *X, const int incX, + const double *Y, const int incY, double *A, const int lda); +void cblas_dsyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *A, const int lda); +void cblas_dspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *Ap); +void cblas_dsyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A, + const int lda); +void cblas_dspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A); + + +/* + * Routines with C and Z prefixes only + */ +void cblas_chemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_chbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_chpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *Ap, + const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_cgeru(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_cgerc(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_cher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, const int incX, + void *A, const int lda); +void cblas_chpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, + const int incX, void *A); +void cblas_cher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_chpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *Ap); + +void cblas_zhemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zhbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zhpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *Ap, + const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zgeru(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zgerc(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, const int incX, + void *A, const int lda); +void cblas_zhpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, + const int incX, void *A); +void cblas_zher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zhpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *Ap); + +/* + * =========================================================================== + * Prototypes for level 3 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const float alpha, const float *A, + const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc); +void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc); +void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float beta, float *C, const int ldc); +void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc); +void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb); +void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb); + +void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const double alpha, const double *A, + const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); +void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc); +void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double beta, double *C, const int ldc); +void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc); +void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb); +void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb); + +void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); +void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc); +void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); +void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); + +void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); +void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc); +void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); +void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); + + +/* + * Routines with prefixes C and Z only + */ +void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const void *A, const int lda, + const float beta, void *C, const int ldc); +void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const float beta, + void *C, const int ldc); +void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const void *A, const int lda, + const double beta, void *C, const int ldc); +void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const double beta, + void *C, const int ldc); + +int cblas_errprn(int ierr, int info, char *form, ...); + +#endif /* end #ifdef CBLAS_ENUM_ONLY */ +#endif diff --git a/nerv/lib/matrix/cumatrix.c b/nerv/lib/matrix/cumatrix.c index ff2ea22..aec4d60 100644 --- a/nerv/lib/matrix/cumatrix.c +++ b/nerv/lib/matrix/cumatrix.c @@ -37,7 +37,9 @@ void nerv_cuda_context_accu_profile(CuContext *context, *val += delta; } -static void new_cuda_handles(CuContext *context, Status *status) { +static void new_cuda_handles(CuContext *context, int dev, Status *status) { + if (context->has_handle) return; + CUDA_SAFE_SYNC_CALL(cudaSetDevice(dev), status); CUBLAS_SAFE_SYNC_CALL(cublasCreate(&(context->cublas_handle)), status); CURAND_SAFE_SYNC_CALL(curandCreateGenerator(&(context->curand_gen), CURAND_RNG_PSEUDO_DEFAULT), status); @@ -47,9 +49,12 @@ static void new_cuda_handles(CuContext *context, Status *status) { CUDA_SAFE_SYNC_CALL(cudaEventCreate(&(context->profile_start)), status); CUDA_SAFE_SYNC_CALL(cudaEventCreate(&(context->profile_stop)), status); NERV_SET_STATUS(status, NERV_NORMAL, 0); + context->has_handle = 1; } static void free_cuda_handles(CuContext *context, Status *status) { + if (!context->has_handle) return; + context->has_handle = 0; CUBLAS_SAFE_SYNC_CALL(cublasDestroy(context->cublas_handle), status); CURAND_SAFE_SYNC_CALL(curandDestroyGenerator(context->curand_gen), status); CUDA_SAFE_SYNC_CALL(cudaEventDestroy(context->profile_start), status); @@ -57,9 +62,41 @@ static void free_cuda_handles(CuContext *context, Status *status) { NERV_SET_STATUS(status, NERV_NORMAL, 0); } -CuContext *nerv_cuda_context_create(Status *status) { +static int choose_best_gpu(Status *status) { + int i, n, dev = 0; + float best_ratio = 0; + fprintf(stderr, "*** select a GPU based on available space\n"); + CUDA_SAFE_CALL_RET(cudaGetDeviceCount(&n), status); + for (i = 0; i < n; i++) + { + size_t avail, total; + float ratio; + CUDA_SAFE_SYNC_CALL_RET(cudaSetDevice(i), status); + CUDA_SAFE_SYNC_CALL_RET(cuMemGetInfo(&avail, &total), status); + ratio = (float)avail/total * 100; + fprintf(stderr, "* card %d: %.2f%%\n", i, ratio); + if (ratio > best_ratio) + { + best_ratio = ratio; + dev = i; + } + CUDA_SAFE_SYNC_CALL_RET(cudaDeviceReset(), status); + } + fprintf(stderr, "*** final decision: GPU %d\n", dev); + NERV_SET_STATUS(status, NERV_NORMAL, 0); + return dev; +} + +CuContext *nerv_cuda_context_create(int dev, Status *status) { CuContext *context = (CuContext *)malloc(sizeof(CuContext)); - new_cuda_handles(context, status); + context->has_handle = 0; /* this line must come first */ + if (dev == -1) + { + dev = choose_best_gpu(status); + if (status->err_code != NERV_NORMAL) + return NULL; + } + new_cuda_handles(context, dev, status); if (status->err_code != NERV_NORMAL) return NULL; context->profile = nerv_hashmap_create(PROFILE_HASHMAP_SIZE, bkdr_hash, strcmp); @@ -78,11 +115,14 @@ void nerv_cuda_context_destroy(CuContext *context, Status *status) { void nerv_cuda_context_select_gpu(CuContext *context, int dev, Status *status) { - free_cuda_handles(context, status); + /* free_cuda_handles(context, status); if (status->err_code != NERV_NORMAL) return; - CUDA_SAFE_SYNC_CALL(cudaSetDevice(dev), status); - new_cuda_handles(context, status); + */ + /* because of cudaDeviceReset */ + context->has_handle = 0; + CUDA_SAFE_SYNC_CALL(cudaDeviceReset(), status); + new_cuda_handles(context, dev, status); if (status->err_code != NERV_NORMAL) return; NERV_SET_STATUS(status, NERV_NORMAL, 0); diff --git a/nerv/lib/matrix/cumatrix.h b/nerv/lib/matrix/cumatrix.h index 280035b..fd2a5ce 100644 --- a/nerv/lib/matrix/cumatrix.h +++ b/nerv/lib/matrix/cumatrix.h @@ -5,6 +5,7 @@ #include "cuda_helper.h" typedef struct CuContext { + int has_handle; cublasHandle_t cublas_handle; cudaEvent_t profile_start, profile_stop; curandGenerator_t curand_gen; @@ -15,6 +16,6 @@ void nerv_cuda_context_print_profile(CuContext *context); void nerv_cuda_context_clear_profile(CuContext *context); void nerv_cuda_context_accu_profile(CuContext *context, const char *name, float delta); void nerv_cuda_context_select_gpu(CuContext *context, int dev, Status *status); -CuContext *nerv_cuda_context_create(Status *status); +CuContext *nerv_cuda_context_create(int dev, Status *status); void nerv_cuda_context_destroy(CuContext *contex, Status *status); #endif diff --git a/nerv/lib/matrix/generic/mmatrix.c b/nerv/lib/matrix/generic/mmatrix.c index 485d778..fb99b53 100644 --- a/nerv/lib/matrix/generic/mmatrix.c +++ b/nerv/lib/matrix/generic/mmatrix.c @@ -8,10 +8,10 @@ context, status) #define NERV_GENERIC_MATRIX #include "../../common.h" +#include "../../cblas.h" #include "../../io/chunk_file.h" #include <string.h> #include <math.h> -#include <cblas.h> |