diff options
author | Determinant <[email protected]> | 2015-05-19 15:01:38 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-05-19 15:01:38 +0800 |
commit | e9b8855c894daa4e6749acfe891f68b3ed8ed481 (patch) | |
tree | 5a3ea5e89bd475dc4312d379ffc7bf9121862dbb /matrix/cumatrix.c | |
parent | 9b6606504241f27a9d42b96f535bf5f2c2918161 (diff) |
add double precision matrix implementation
Diffstat (limited to 'matrix/cumatrix.c')
-rw-r--r-- | matrix/cumatrix.c | 163 |
1 files changed, 24 insertions, 139 deletions
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c index aa10571..90a6703 100644 --- a/matrix/cumatrix.c +++ b/matrix/cumatrix.c @@ -1,139 +1,24 @@ -#define MATRIX_DATA_FREE(ptr) cuda_float_array_free(ptr) -#define MATRIX_DATA_ALLOC(dptr, stride, width, height) cuda_float_array_alloc(dptr, stride, width, height) -#define MATRIX_DATA_WRITE(data, idx, val) cuda_float_array_write(data, idx, val) -#define MATRIX_DATA_READ(data, idx) cuda_float_array_read(data, idx) -#define MATRIX_INIT(L) cuda_float_init(L) -#define NERV_GENERIC_MATRIX -#define nerv_float_matrix_(NAME) nerv_float_matrix_cuda_ ## NAME -#include "../common.h" -#include "generic/matrix.h" -#include "cukernel.h" -#include "cuda.h" -#include "cuda_runtime.h" -#include "driver_types.h" -#include "cublas_v2.h" - -const char *nerv_float_matrix_(tname) = "nerv.FloatCuMatrix"; -static cublasHandle_t cublas_handle; - -Matrix *nerv_float_matrix_(new_)(long nrow, long ncol); -static int nerv_float_matrix_(add)(lua_State *L) { - Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname)); - Matrix *b = luaT_checkudata(L, 2, nerv_float_matrix_(tname)); - Matrix *c; - long nrow, ncol; - if (!(a->nrow == b->nrow && a->ncol == b->ncol)) - nerv_error(L, "Matrices should be of the same dimension"); - nrow = a->nrow; - ncol = a->ncol; - c = nerv_float_matrix_(new_)(nrow, ncol); - float alpha = 1.0f, beta = 1.0f; - cublasSgeam(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, - ncol, nrow, - &alpha, - a->data.f, a->stride / sizeof(float), - &beta, - b->data.f, b->stride / sizeof(float), - c->data.f, c->stride / sizeof(float)); - luaT_pushudata(L, c, nerv_float_matrix_(tname)); - return 1; -} - -static int nerv_float_matrix_(mul)(lua_State *L) { - Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname)); - Matrix *b = luaT_checkudata(L, 2, nerv_float_matrix_(tname)); - Matrix *c; - if (a->ncol != b->nrow) - nerv_error(L, "Wrong dimension of multipliers"); - c = nerv_float_matrix_(new_)(a->nrow, b->ncol); - float alpha = 1.0f, beta = 0.0f; - cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, - b->ncol, a->nrow, b->nrow, - &alpha, - b->data.f, b->stride / sizeof(float), - a->data.f, a->stride / sizeof(float), - &beta, - c->data.f, c->stride / sizeof(float)); - luaT_pushudata(L, c, nerv_float_matrix_(tname)); - return 1; -} - -static int nerv_float_matrix_(sigmoid)(lua_State *L) { - Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname)); - Matrix *b = nerv_float_matrix_(new_)(a->nrow, a->ncol); - cuda_sigmoid(a, b); - luaT_pushudata(L, b, nerv_float_matrix_(tname)); - return 1; -} - -static int nerv_float_matrix_(softmax)(lua_State *L) { - Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname)); - Matrix *max = nerv_float_matrix_(new_)(a->nrow, 1); - Matrix *dno = nerv_float_matrix_(new_)(a->nrow, 1); - Matrix *b = nerv_float_matrix_(new_)(a->nrow, a->ncol); - cuda_colmax(a, max); - cuda_softmax_denominator(a, max, dno); - cuda_softmax_final(a, max, dno, b); - luaT_pushudata(L, b, nerv_float_matrix_(tname)); - return 1; -} - -static int nerv_float_matrix_(colsum)(lua_State *L) { - Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname)); - Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1); - cuda_colsum(a, b); - luaT_pushudata(L, b, nerv_float_matrix_(tname)); - return 1; -} - -static int nerv_float_matrix_(colmax)(lua_State *L) { - Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname)); - Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1); - cuda_colmax(a, b); - luaT_pushudata(L, b, nerv_float_matrix_(tname)); - return 1; -} - -static const luaL_Reg nerv_float_matrix_(extra_methods)[] = { - {"__add__", nerv_float_matrix_(add)}, - {"__mul__", nerv_float_matrix_(mul)}, - {"sigmoid", nerv_float_matrix_(sigmoid)}, - {"softmax", nerv_float_matrix_(softmax)}, - {"colsum", nerv_float_matrix_(colsum)}, - {"colmax", nerv_float_matrix_(colmax)}, - {NULL, NULL} -}; - -static void cuda_float_init(lua_State *L) { - luaN_append_methods(L, nerv_float_matrix_(extra_methods)); - cublasCreate(&cublas_handle); -} - -static void cuda_float_array_free(float *ptr) { - cudaFree(ptr); -} - -static void cuda_float_array_alloc(float **dptr, size_t *stride, - long width, long height) { - cudaMallocPitch((void **)dptr, stride, width, height); -} - -static float cuda_float_array_read(float *data, int idx) { - float res; - cudaMemcpy(&res, data + idx, sizeof(float), cudaMemcpyDeviceToHost); - return res; -} - -static void cuda_float_array_write(float *data, int idx, float val) { - cudaMemcpy(data + idx, &val, sizeof(float), cudaMemcpyHostToDevice); -} - -int nerv_float_matrix_(get_elem)(lua_State *L) { - return nerv_error_method_not_implemented(L); -} - -int nerv_float_matrix_(set_elem)(lua_State *L) { - return nerv_error_method_not_implemented(L); -} - -#include "generic/matrix.c" +#define NERV_GENERIC_CUMATRIX + +#define MATRIX_USE_FLOAT +#define cuda_matrix_(NAME) cuda_matrix_float_ ## NAME +#define nerv_matrix_(NAME) nerv_matrix_float_cuda_ ## NAME +#define cudak_(NAME) cudak_float_ ## NAME +#define NERV_CUBLAS_(NAME) cublasS##NAME +const char *nerv_matrix_(tname) = "nerv.FloatCuMatrix"; +#include "generic/cumatrix.c" +#undef NERV_CUBLAS_ +#undef cudak_ +#undef nerv_matrix_ +#undef cuda_matrix_ +#undef MATRIX_USE_FLOAT +#undef MATRIX_ELEM +#undef MATRIX_ELEM_PTR + +#define MATRIX_USE_DOUBLE +#define cuda_matrix_(NAME) cuda_matrix_double_ ## NAME +#define nerv_matrix_(NAME) nerv_matrix_double_cuda_ ## NAME +#define cudak_(NAME) cudak_double_ ## NAME +#define NERV_CUBLAS_(NAME) cublasD##NAME +const char *nerv_matrix_(tname) = "nerv.DoubleCuMatrix"; +#include "generic/cumatrix.c" |