From df737041e4a9f3f55978cc74db9a9cea27fa9fa0 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 5 Jun 2015 10:58:57 +0800 Subject: add profiling; add ce accurarcy; several other changes --- matrix/cuda_helper.h | 26 ++++++++- matrix/cukernel.h | 2 + matrix/cumatrix.c | 34 +++++++++++ matrix/generic/cukernel.cu | 138 +++++++++++++++++++++++++++++++++++++++++---- matrix/generic/cumatrix.c | 97 ++++++++++++++++++++++++++----- matrix/generic/mmatrix.c | 4 ++ matrix/init.c | 5 ++ matrix/mmatrix.c | 29 ++++++++++ 8 files changed, 310 insertions(+), 25 deletions(-) (limited to 'matrix') diff --git a/matrix/cuda_helper.h b/matrix/cuda_helper.h index cedc643..5e5f2ad 100644 --- a/matrix/cuda_helper.h +++ b/matrix/cuda_helper.h @@ -1,17 +1,23 @@ #ifndef NERV_CUDA_HELPER_H #define NERV_CUDA_HELPER_H +#include "cuda.h" +#include "cuda_runtime.h" +#include "driver_types.h" +#include "cublas_v2.h" #define CUBLAS_SAFE_CALL(call) \ do { \ cublasStatus_t err = (call); \ if (err != CUBLAS_STATUS_SUCCESS) \ - nerv_error(L, "cumatrix cublas error: %s", cublasGetErrorString(err)); \ + nerv_error(L, "cumatrix cublas error: %s at %s:%d", \ + cublasGetErrorString(err), __FILE__, __LINE__); \ } while (0) #define CUDA_SAFE_CALL(call) \ do { \ cudaError_t err = (call); \ if (err != cudaSuccess) \ - nerv_error(L, "cumatrix CUDA error: %s", cudaGetErrorString(err)); \ + nerv_error(L, "cumatrix CUDA error: %s at %s:%d", \ + cudaGetErrorString(err), __FILE__, __LINE__); \ } while (0) #define CUDA_SAFE_SYNC_CALL(call) \ @@ -52,4 +58,20 @@ static const char *cublasGetErrorString(cublasStatus_t err) { } return ""; } + +#define PROFILE_START \ + do { \ + cudaEvent_t start, stop; \ + cudaEventCreate(&start); \ + cudaEventCreate(&stop); \ + cudaEventRecord(start, 0); +#define PROFILE_STOP \ + cudaEventRecord(stop, 0); \ + cudaEventSynchronize(stop); \ + float milliseconds = 0; \ + cudaEventElapsedTime(&milliseconds, start, stop); \ + accu_profile(__func__, milliseconds / 1000); \ + } while (0); + +#define PROFILE_END #endif diff --git a/matrix/cukernel.h b/matrix/cukernel.h index 7d2168e..23398c8 100644 --- a/matrix/cukernel.h +++ b/matrix/cukernel.h @@ -5,7 +5,9 @@ void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b); void cudak_(cuda_sigmoid_grad)(const Matrix *output, const Matrix *err, Matrix *nerr); void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b); void cudak_(cuda_rowmax)(const Matrix *a, Matrix *b); +void cudak_(cuda_rowmax_idx)(const Matrix *a, Matrix *b, Matrix *idx); void cudak_(cuda_colsum)(const Matrix *a, Matrix *b); +void cudak_(cuda_colsame)(const Matrix *a, const Matrix *ref, Matrix *b); void cudak_(cuda_softmax_denominator)(const Matrix *a, const Matrix *max, Matrix *b); void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max, const Matrix *deno, Matrix *b); void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta); diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c index 51a3681..4ebc5ff 100644 --- a/matrix/cumatrix.c +++ b/matrix/cumatrix.c @@ -1,4 +1,38 @@ #define NERV_GENERIC_CUMATRIX +#include "../common.h" +#include "cuda_helper.h" +static cublasHandle_t cublas_handle; +static HashMap *profile; + +int print_profile(lua_State *L) { + size_t i; + fprintf(stderr, "*** [nerv cumatrix profile] **\n"); + for (i = 0; i < profile->size; i++) + { + HashNode *ptr; + for (ptr = profile->bucket[i]; ptr; ptr = ptr->next) + { + fprintf(stderr, "%s:\t%.6f\n", ptr->key, *(float *)ptr->val); + } + } + return 0; +} + +int clear_profile(lua_State *L) { + hashmap_clear(profile); + return 0; +} + +void accu_profile(const char *name, float delta) { + float *val = hashmap_getval(profile, name); + if (!val) + { + val = malloc(sizeof(float)); + *val = 0; + hashmap_setval(profile, name, val); + } + *val += delta; +} #define MATRIX_USE_FLOAT #define cuda_matrix_(NAME) cuda_matrix_float_##NAME diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu index 05a1e78..fdab356 100644 --- a/matrix/generic/cukernel.cu +++ b/matrix/generic/cukernel.cu @@ -3,6 +3,7 @@ #include #include "matrix.h" #include "cuda.h" +#include "float.h" #define CUDA_THREADS_N 16 #define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N)) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) @@ -11,9 +12,12 @@ __global__ void cudak_(log_elem)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int j = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.y * blockDim.y + threadIdx.y; long idx; + MATRIX_ELEM tmp; if (i >= nrow || j >= ncol) return; idx = j + i * stride; - b[idx] = log(a[idx]); + tmp = a[idx]; + if(tmp < FLT_MIN) tmp = FLT_MIN; + b[idx] = log(tmp); } __global__ void cudak_(mul_elem)(const MATRIX_ELEM *a, const MATRIX_ELEM *b, @@ -61,9 +65,9 @@ __global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b, } __global__ void cudak_(block_reduce_rowsum)(const MATRIX_ELEM *input, - MATRIX_ELEM *output, - const int istride, const int ostride, - const int n) { + MATRIX_ELEM *output, + const int istride, const int ostride, + const int n) { extern __shared__ MATRIX_ELEM cudak_(arr)[]; int j = blockIdx.x * blockDim.x + threadIdx.x; cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0; @@ -96,6 +100,26 @@ __global__ void cudak_(block_reduce_colsum)(const MATRIX_ELEM *input, output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0]; } +__global__ void cudak_(block_reduce_colsame)(const MATRIX_ELEM *input, + const MATRIX_ELEM *ref_input, + MATRIX_ELEM *output, + const int istride, const int ostride, + const int n) { + extern __shared__ MATRIX_ELEM cudak_(arr)[]; + int i = blockIdx.y * blockDim.y + threadIdx.y; + cudak_(arr)[threadIdx.y] = (i < n && input[blockIdx.x + istride * i] == \ + ref_input[blockIdx.x + istride * i]) ? 1.0 : 0; + __syncthreads(); + for (int offset = blockDim.y >> 1; offset; offset >>= 1) + { + if (threadIdx.y < offset) + cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset]; + __syncthreads(); + } + if (threadIdx.y == 0) + output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0]; +} + __global__ void cudak_(block_reduce_softmax_rowsum)(const MATRIX_ELEM *input, MATRIX_ELEM *output, const MATRIX_ELEM *max, @@ -117,9 +141,9 @@ __global__ void cudak_(block_reduce_softmax_rowsum)(const MATRIX_ELEM *input, } __global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input, - MATRIX_ELEM *output, - const int istride, const int ostride, - const int n) { + MATRIX_ELEM *output, + const int istride, const int ostride, + const int n) { extern __shared__ MATRIX_ELEM cudak_(arr)[]; int j = blockIdx.x * blockDim.x + threadIdx.x; cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0; @@ -129,8 +153,9 @@ __global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input, if (threadIdx.x < offset) { MATRIX_ELEM l = cudak_(arr)[threadIdx.x], - r = cudak_(arr)[threadIdx.x + offset]; - if (r > l) cudak_(arr)[threadIdx.x] = r; + r = cudak_(arr)[threadIdx.x + offset]; + if (r > l) + cudak_(arr)[threadIdx.x] = r; } __syncthreads(); } @@ -138,6 +163,40 @@ __global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input, output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0]; } +__global__ void cudak_(block_reduce_rowmax_idx)(const MATRIX_ELEM *input, + const MATRIX_ELEM *idx_input, + MATRIX_ELEM *output, + MATRIX_ELEM *idx_output, + const int istride, const int ostride, + const int n) { + extern __shared__ MATRIX_ELEM cudak_(arr)[]; + MATRIX_ELEM *arr_val = cudak_(arr); + MATRIX_ELEM *arr_idx = arr_val + blockDim.x; + int j = blockIdx.x * blockDim.x + threadIdx.x; + arr_val[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0; + arr_idx[threadIdx.x] = j < n ? idx_input[j + istride * blockIdx.y] : 0; + __syncthreads(); + for (int offset = blockDim.x >> 1; offset; offset >>= 1) + { + if (threadIdx.x < offset) + { + MATRIX_ELEM l = arr_val[threadIdx.x], + r = arr_val[threadIdx.x + offset]; + if (r > l) + { + arr_val[threadIdx.x] = r; + arr_idx[threadIdx.x] = arr_idx[threadIdx.x + offset]; + } + } + __syncthreads(); + } + if (threadIdx.x == 0) + { + output[blockIdx.x + ostride * blockIdx.y] = arr_val[0]; + idx_output[blockIdx.x + ostride * blockIdx.y] = arr_idx[0]; + } +} + __global__ void cudak_(add_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int stride, double beta) { int j = blockIdx.x * blockDim.x + threadIdx.x; @@ -196,6 +255,14 @@ __global__ void cudak_(decompress)(const MATRIX_ELEM *a, MATRIX_ELEM *b, b[lrintf(a[j + i * stride_a]) + i * stride_b] = 1.0; } +__global__ void cudak_(gen_col_idx)(MATRIX_ELEM *b, + int nrow, int ncol, int stride) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= nrow || j >= ncol) return; + b[j + i * stride] = j; +} + extern "C" { #include "../cukernel.h" void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) { @@ -261,10 +328,32 @@ extern "C" { cudaFree(res); } + void cudak_(cuda_colsame)(const Matrix *a, const Matrix *ref, Matrix *b) { + dim3 block(1, CUDA_THREADS_NN); + int nrow = a->nrow; + int blocks_per_col = CEIL_DIV(nrow, block.y); + dim3 grid(a->ncol, blocks_per_col); + MATRIX_ELEM *res; + size_t stride; + cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col); + cudak_(block_reduce_colsame)<<>> \ + (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(ref), res, + a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM), + nrow); + nrow = blocks_per_col; + assert((unsigned long)nrow <= block.y); + grid.y = 1; + cudak_(block_reduce_colsum)<<>> \ + (res, MATRIX_ELEM_PTR(b), + stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), + nrow); + cudaFree(res); + } + void cudak_(cuda_colsum)(const Matrix *a, Matrix *b) { dim3 block(1, CUDA_THREADS_NN); int nrow = a->nrow; - int blocks_per_col = CEIL_DIV(nrow, block.x); + int blocks_per_col = CEIL_DIV(nrow, block.y); dim3 grid(a->ncol, blocks_per_col); MATRIX_ELEM *res; size_t stride; @@ -344,6 +433,35 @@ extern "C" { cudaFree(res); } + void cudak_(cuda_rowmax_idx)(const Matrix *a, Matrix *b, Matrix *b_idx) { + dim3 block(CUDA_THREADS_NN, 1); + int ncol = a->ncol; + int blocks_per_row = CEIL_DIV(ncol, block.x); + dim3 grid(blocks_per_row, a->nrow); + MATRIX_ELEM *a_idx, *res, *res_idx; + size_t stride; + cudaMallocPitch(&a_idx, &stride, a->stride, a->nrow); + cudak_(gen_col_idx)<<>>(a_idx, a->nrow, ncol, stride / sizeof(MATRIX_ELEM)); + cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow); + cudaMallocPitch(&res_idx, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow); + cudak_(block_reduce_rowmax_idx)<<>> \ + (MATRIX_ELEM_PTR(a), a_idx, res, res_idx, + a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM), + ncol); + cudaFree(a_idx); + ncol = blocks_per_row; + assert((unsigned long)ncol <= block.x); + grid.x = 1; + cudak_(block_reduce_rowmax_idx)<<>> \ + (res, res_idx, MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(b_idx), + stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM), + ncol); + cudaFree(res); + cudaFree(res_idx); + } + /* in-place calc */ void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c index 373fc42..8e7d34f 100644 --- a/matrix/generic/cumatrix.c +++ b/matrix/generic/cumatrix.c @@ -11,15 +11,11 @@ #define MATRIX_BASE_TNAME nerv_matrix_cuda_tname #define NERV_GENERIC_MATRIX #define NERV_GENERIC_CUKERNEL +#define PROFILE_HASHMAP_SIZE 123457 #include "../../common.h" #include "../cukernel.h" -#include "cuda.h" -#include "cuda_runtime.h" -#include "driver_types.h" -#include "cublas_v2.h" #include "../cuda_helper.h" - -static cublasHandle_t cublas_handle; +#include Matrix *nerv_matrix_(new_)(lua_State *L, long nrow, long ncol); void nerv_matrix_(data_free)(lua_State *L, Matrix *self); @@ -27,6 +23,7 @@ void nerv_matrix_(data_free)(lua_State *L, Matrix *self); static void nerv_matrix_(add_)(lua_State *L, const Matrix *a, const Matrix *b, const Matrix *c, MATRIX_ELEM alpha, MATRIX_ELEM beta) { + PROFILE_START CUBLAS_SAFE_CALL( NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, a->ncol, a->nrow, @@ -35,6 +32,7 @@ static void nerv_matrix_(add_)(lua_State *L, const Matrix *a, const Matrix *b, &beta, MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM), MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM))); + PROFILE_STOP } static int nerv_matrix_(add)(lua_State *L) { @@ -75,6 +73,7 @@ static int nerv_matrix_(mul)(lua_State *L) { nerv_error(L, "Wrong dimension of multipliers"); /* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */ /* Because matrix in Nerv is row-major, here b comes first */ + PROFILE_START CUBLAS_SAFE_CALL( NERV_CUBLAS_(gemm)(cublas_handle, tb, ta, bn, am, bm, @@ -83,6 +82,7 @@ static int nerv_matrix_(mul)(lua_State *L) { MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM), &beta, MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM))); + PROFILE_STOP return 0; } @@ -97,7 +97,9 @@ static int nerv_matrix_(sigmoid)(lua_State *L) { Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname)); CHECK_SAME_DIMENSION(a, b); + PROFILE_START cudak_(cuda_sigmoid)(b, a); + PROFILE_STOP return 0; } @@ -107,30 +109,38 @@ static int nerv_matrix_(sigmoid_grad)(lua_State *L) { Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname)); CHECK_SAME_DIMENSION(nerr, err); CHECK_SAME_DIMENSION(nerr, output); + PROFILE_START cudak_(cuda_sigmoid_grad)(output, err, nerr); + PROFILE_STOP return 0; } static int nerv_matrix_(softmax)(lua_State *L) { Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); - Matrix *max; + Matrix *max, *max_idx; Matrix *dno; CHECK_SAME_DIMENSION(a, b); max = nerv_matrix_(new_)(L, a->nrow, 1); + max_idx = nerv_matrix_(new_)(L, a->nrow, 1); dno = nerv_matrix_(new_)(L, a->nrow, 1); - cudak_(cuda_rowmax)(a, max); + PROFILE_START + cudak_(cuda_rowmax_idx)(a, max, max_idx); cudak_(cuda_softmax_denominator)(a, max, dno); cudak_(cuda_softmax_final)(a, max, dno, b); + PROFILE_STOP nerv_matrix_(data_free)(L, max); nerv_matrix_(data_free)(L, dno); - return 0; + luaT_pushudata(L, max_idx, nerv_matrix_(tname)); + return 1; } static int nerv_matrix_(rowsum)(lua_State *L) { Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = nerv_matrix_(new_)(L, a->nrow, 1); + PROFILE_START cudak_(cuda_rowsum)(a, b); + PROFILE_STOP luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } @@ -138,7 +148,21 @@ static int nerv_matrix_(rowsum)(lua_State *L) { static int nerv_matrix_(colsum)(lua_State *L) { Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = nerv_matrix_(new_)(L, 1, a->ncol); + PROFILE_START cudak_(cuda_colsum)(a, b); + PROFILE_STOP + luaT_pushudata(L, b, nerv_matrix_(tname)); + return 1; +} + +static int nerv_matrix_(colsame)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *ref = luaT_checkudata(L, 2, nerv_matrix_(tname)); + Matrix *b = nerv_matrix_(new_)(L, 1, a->ncol); + CHECK_SAME_DIMENSION(a, ref); + PROFILE_START + cudak_(cuda_colsame)(a, ref, b); + PROFILE_STOP luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } @@ -146,11 +170,24 @@ static int nerv_matrix_(colsum)(lua_State *L) { static int nerv_matrix_(rowmax)(lua_State *L) { Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = nerv_matrix_(new_)(L, a->nrow, 1); + PROFILE_START cudak_(cuda_rowmax)(a, b); + PROFILE_STOP luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } +static int nerv_matrix_(rowmax_idx)(lua_State *L) { + Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); + Matrix *b = nerv_matrix_(new_)(L, a->nrow, 1); + Matrix *idx = nerv_matrix_(new_)(L, a->nrow, 1); + PROFILE_START + cudak_(cuda_rowmax_idx)(a, b, idx); + PROFILE_STOP + luaT_pushudata(L, b, nerv_matrix_(tname)); + luaT_pushudata(L, idx, nerv_matrix_(tname)); + return 2; +} static int nerv_matrix_(add_row)(lua_State *L) { Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); @@ -160,14 +197,18 @@ static int nerv_matrix_(add_row)(lua_State *L) { nerv_error(L, "the number of columns is not the same"); if (a->nrow != 1) nerv_error(L, "a row vector is expected"); + PROFILE_START cudak_(cuda_add_row)(a, b, beta); + PROFILE_STOP return 0; } static int nerv_matrix_(fill)(lua_State *L) { Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname)); double val = luaL_checknumber(L, 2); + PROFILE_START cudak_(cuda_fill)(self, val); + PROFILE_STOP return 0; } @@ -183,11 +224,13 @@ static int nerv_matrix_(copy_fromd)(lua_State *L) { nerv_error(L, "invalid copy interval"); if (a->ncol != b->ncol) nerv_error(L, "matrices should be of the same dimension"); + PROFILE_START CUDA_SAFE_SYNC_CALL( cudaMemcpy2D(MATRIX_ROW_PTR(a, a_begin), a->stride, MATRIX_ROW_PTR(b, b_begin), b->stride, sizeof(MATRIX_ELEM) * b->ncol, b_end - b_begin, cudaMemcpyDeviceToDevice)); + PROFILE_STOP return 0; } @@ -204,11 +247,13 @@ static int nerv_matrix_(copy_fromh)(lua_State *L) { nerv_error(L, "invalid copy interval"); if (a->ncol != b->ncol) nerv_error(L, "matrices should be of the same dimension"); + PROFILE_START CUDA_SAFE_SYNC_CALL( cudaMemcpy2D(MATRIX_ROW_PTR(a, a_begin), a->stride, MATRIX_ROW_PTR(b, b_begin), b->stride, sizeof(MATRIX_ELEM) * b->ncol, b_end - b_begin, cudaMemcpyHostToDevice)); + PROFILE_STOP return 0; } @@ -224,11 +269,13 @@ static int nerv_matrix_(copy_toh)(lua_State *L) { nerv_error(L, "invalid copy interval"); if (b->ncol != a->ncol) nerv_error(L, "matrices should be of the same dimension"); + PROFILE_START CUDA_SAFE_SYNC_CALL( cudaMemcpy2D(MATRIX_ROW_PTR(b, b_begin), b->stride, MATRIX_ROW_PTR(a, a_begin), a->stride, sizeof(MATRIX_ELEM) * a->ncol, a_end - a_begin, cudaMemcpyDeviceToHost)); + PROFILE_STOP return 0; } @@ -237,6 +284,7 @@ static int nerv_matrix_(trans)(lua_State *L) { Matrix *b = nerv_matrix_(new_)(L, a->ncol, a->nrow); MATRIX_ELEM alpha = 1, beta = 0; /* FIXME: possible memory leak when lua error is raised */ + PROFILE_START CUBLAS_SAFE_CALL( NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, a->nrow, a->ncol, @@ -245,6 +293,7 @@ static int nerv_matrix_(trans)(lua_State *L) { &beta, MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM), MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM))); + PROFILE_STOP luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } @@ -255,7 +304,9 @@ static int nerv_matrix_(mul_elem)(lua_State *L) { Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname)); CHECK_SAME_DIMENSION(a, b); CHECK_SAME_DIMENSION(a, c); + PROFILE_START cudak_(cuda_mul_elem)(a, b, c); + PROFILE_STOP return 0; } @@ -263,7 +314,9 @@ static int nerv_matrix_(log_elem)(lua_State *L) { Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname)); CHECK_SAME_DIMENSION(a, b); + PROFILE_START cudak_(cuda_log_elem)(a, b); + PROFILE_STOP return 0; } @@ -274,8 +327,10 @@ static int nerv_matrix_(decompress)(lua_State *L) { if (a->ncol != 1) nerv_error(L, "the compressed matrix must be a column vector"); b = nerv_matrix_(new_)(L, a->nrow, orig_col); + PROFILE_START cudak_(cuda_fill)(b, 0.0); cudak_(cuda_decompress)(a, b); + PROFILE_STOP luaT_pushudata(L, b, nerv_matrix_(tname)); return 1; } @@ -285,21 +340,25 @@ static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) { Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname)); Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME); Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_int_tname); + long nrow = a->nrow; + int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0; + if (!(0 <= b_begin && b_begin + nrow <= idx->ncol)) + nerv_error(L, "invalid copy interval"); long *idx_ptr = idx->data.i; int i; - long nrow = a->nrow; if (idx->nrow != 1) nerv_error(L, "index should be a vector"); - if (idx->ncol != nrow) - nerv_error(L, "index dimension mismatch"); if (a->ncol != b->ncol) nerv_error(L, "source/destination dimension mismatch"); cudaStream_t *streams = (cudaStream_t*)malloc(sizeof(cudaStream_t) * nrow); for (i = 0; i < nrow; i++) { + int src_row = idx_ptr[b_begin + i]; + if (!(0 <= src_row && src_row < b->nrow)) + nerv_error(L, "invalid index"); CUDA_SAFE_CALL(cudaStreamCreate(streams + i)); CUDA_SAFE_CALL(cudaMemcpyAsync(MATRIX_ROW_PTR(a, i), - MATRIX_ROW_PTR(b, idx_ptr[i]), + MATRIX_ROW_PTR(b, src_row), b->stride, cudaMemcpyHostToDevice, streams[i])); } @@ -308,6 +367,7 @@ static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) { CUDA_SAFE_CALL(cudaStreamSynchronize(streams[i])); CUDA_SAFE_CALL(cudaStreamDestroy(streams[i])); } + free(streams); return 0; } @@ -319,7 +379,9 @@ static int nerv_matrix_(expand_frm)(lua_State *L) { nerv_error(L, "mismatching number of frames"); if (a->ncol != b->ncol * (context * 2 + 1)) nerv_error(L, "the width should be 2 * context + 1"); + PROFILE_START cudak_(cuda_expand_frm)(b, a, context); + PROFILE_STOP return 0; } @@ -330,7 +392,9 @@ static int nerv_matrix_(rearrange_frm)(lua_State *L) { CHECK_SAME_DIMENSION(a, b); if (b->ncol % step) nerv_error(L, "the dimension of columns is not divisible by step"); + PROFILE_START cudak_(cuda_rearrange_frm)(b, a, step); + PROFILE_STOP return 0; } @@ -341,15 +405,19 @@ static int nerv_matrix_(scale_row)(lua_State *L) { nerv_error(L, "the number of columns is not the same"); if (b->nrow != 1) nerv_error(L, "a row vector is expected"); + PROFILE_START cudak_(cuda_scale_row)(b, a); + PROFILE_STOP return 0; } static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"create", nerv_matrix_(create)}, {"colsum", nerv_matrix_(colsum)}, + {"colsame", nerv_matrix_(colsame)}, {"rowsum", nerv_matrix_(rowsum)}, {"rowmax", nerv_matrix_(rowmax)}, + {"rowmax_idx", nerv_matrix_(rowmax_idx)}, {"trans", nerv_matrix_(trans)}, {"decompress", nerv_matrix_(decompress)}, /* in-place calc */ @@ -375,6 +443,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = { static void cuda_matrix_(init)(lua_State *L) { luaN_append_methods(L, nerv_matrix_(extra_methods)); cublasCreate(&cublas_handle); + profile = hashmap_create(PROFILE_HASHMAP_SIZE, bkdr_hash, strcmp); } static void cuda_matrix_(free)(lua_State *L, MATRIX_ELEM *ptr) { @@ -383,7 +452,9 @@ static void cuda_matrix_(free)(lua_State *L, MATRIX_ELEM *ptr) { static void cuda_matrix_(alloc)(lua_State *L, MATRIX_ELEM **dptr, size_t *stride, long width, long height) { + PROFILE_START CUDA_SAFE_SYNC_CALL(cudaMallocPitch((void **)dptr, stride, width, height)); + PROFILE_STOP } static MATRIX_ELEM cuda_matrix_(read)(lua_State *L, MATRIX_ELEM *data, diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c index 4b722f3..75d1eb1 100644 --- a/matrix/generic/mmatrix.c +++ b/matrix/generic/mmatrix.c @@ -43,6 +43,9 @@ int nerv_matrix_(set_elem)(lua_State *L) { static const luaL_Reg nerv_matrix_(extra_methods)[]; static void host_matrix_(init)(lua_State *L) { luaN_append_methods(L, nerv_matrix_(extra_methods)); +#ifdef MMATRIX_INIT + MMATRIX_INIT(L); +#endif } #include "matrix.c" @@ -114,6 +117,7 @@ static int nerv_matrix_(copy_from)(lua_State *L) { sizeof(MATRIX_ELEM) * b->ncol * (b_end - b_begin)); return 0; } + static const luaL_Reg nerv_matrix_(extra_methods)[] = { {"load", nerv_matrix_(load)}, {"save", nerv_matrix_(save)}, diff --git a/matrix/init.c b/matrix/init.c index b54cd12..7b7f478 100644 --- a/matrix/init.c +++ b/matrix/init.c @@ -9,12 +9,17 @@ void nerv_matrix_host_float_init(lua_State *L); void nerv_matrix_cuda_float_init(lua_State *L); void nerv_matrix_host_double_init(lua_State *L); void nerv_matrix_cuda_double_init(lua_State *L); +void nerv_matrix_host_int_init(lua_State *L); +int print_profile(lua_State *L); +int clear_profile(lua_State *L); static const luaL_Reg matrix_methods[] = { {"__tostring__", nerv_error_method_not_implemented }, {"__add__", nerv_error_method_not_implemented }, {"__sub__", nerv_error_method_not_implemented }, {"__mul__", nerv_error_method_not_implemented }, + {"print_profile", print_profile}, + {"clear_profile", clear_profile}, {NULL, NULL} }; diff --git a/matrix/mmatrix.c b/matrix/mmatrix.c index ab15197..81f8dfc 100644 --- a/matrix/mmatrix.c +++ b/matrix/mmatrix.c @@ -29,5 +29,34 @@ const char *nerv_matrix_(tname) = "nerv.MMatrixDouble"; #define host_matrix_(NAME) host_matrix_int_##NAME #define nerv_matrix_(NAME) nerv_matrix_host_int_##NAME const char *nerv_matrix_(tname) = "nerv.MMatrixInt"; +#define MMATRIX_INIT(L) host_matrix_(init_extra)(L) + +static const luaL_Reg nerv_matrix_(extra_methods_int)[]; +static void host_matrix_(init_extra)(lua_State *L) { + luaN_append_methods(L, nerv_matrix_(extra_methods_int)); +} + #include "generic/mmatrix.c" +static int nerv_matrix_(perm_gen)(lua_State *L) { + int i, ncol = luaL_checkinteger(L, 1); + Matrix *self = nerv_matrix_(new_)(L, 1, ncol); + long *prow = self->data.i; + for (i = 0; i < ncol; i++) + prow[i] = i; + for (i = ncol - 1; i >= 0; i--) + { + size_t j = rand() % (i + 1); + long tmp = prow[i]; + prow[i] = prow[j]; + prow[j] = tmp; + } + luaT_pushudata(L, self, nerv_matrix_(tname)); + return 1; +} + +static const luaL_Reg nerv_matrix_(extra_methods_int)[] = { + {"perm_gen", nerv_matrix_(perm_gen)}, + {NULL, NULL} +}; + -- cgit v1.2.3