aboutsummaryrefslogtreecommitdiff
path: root/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'matrix')
-rw-r--r--matrix/cuda_helper.h75
-rw-r--r--matrix/cukernel.cu17
-rw-r--r--matrix/cukernel.h20
-rw-r--r--matrix/cumatrix.c87
-rw-r--r--matrix/generic/cukernel.cu571
-rw-r--r--matrix/generic/cumatrix.c493
-rw-r--r--matrix/generic/elem_type.h22
-rw-r--r--matrix/generic/matrix.c155
-rw-r--r--matrix/generic/matrix.h19
-rw-r--r--matrix/generic/mmatrix.c122
-rw-r--r--matrix/init.c35
-rw-r--r--matrix/init.lua77
-rw-r--r--matrix/mmatrix.c77
13 files changed, 0 insertions, 1770 deletions
diff --git a/matrix/cuda_helper.h b/matrix/cuda_helper.h
deleted file mode 100644
index fde6f18..0000000
--- a/matrix/cuda_helper.h
+++ /dev/null
@@ -1,75 +0,0 @@
-#ifndef NERV_CUDA_HELPER_H
-#define NERV_CUDA_HELPER_H
-#include "cuda.h"
-#include "cuda_runtime.h"
-#include "driver_types.h"
-#include "cublas_v2.h"
-#define CUBLAS_SAFE_SYNC_CALL(call) \
- do { \
- cublasStatus_t err = (call); \
- if (err != CUBLAS_STATUS_SUCCESS) \
- nerv_error(L, "cumatrix cublas error: %s at %s:%d", \
- cublasGetErrorString(err), __FILE__, __LINE__); \
- cudaDeviceSynchronize(); \
- } while (0)
-
-#define CUDA_SAFE_CALL(call) \
- do { \
- cudaError_t err = (call); \
- if (err != cudaSuccess) \
- nerv_error(L, "cumatrix CUDA error: %s at %s:%d", \
- cudaGetErrorString(err), __FILE__, __LINE__); \
- } while (0)
-
-#define CUDA_SAFE_SYNC_CALL(call) \
- do { \
- CUDA_SAFE_CALL(call); \
- cudaDeviceSynchronize(); \
- } while (0)
-
-#define CHECK_SAME_DIMENSION(a, b) \
- do { \
- if (!(a->nrow == b->nrow && a->ncol == b->ncol)) \
- nerv_error(L, "matrices should be of the same dimension"); \
- } while (0)
-
-static const char *cublasGetErrorString(cublasStatus_t err) {
- switch (err)
- {
- case CUBLAS_STATUS_SUCCESS:
- return "CUBLAS_STATUS_SUCCESS";
- case CUBLAS_STATUS_NOT_INITIALIZED:
- return "CUBLAS_STATUS_NOT_INITIALIZED";
- case CUBLAS_STATUS_ALLOC_FAILED:
- return "CUBLAS_STATUS_ALLOC_FAILED";
- case CUBLAS_STATUS_INVALID_VALUE:
- return "CUBLAS_STATUS_INVALID_VALUE";
- case CUBLAS_STATUS_ARCH_MISMATCH:
- return "CUBLAS_STATUS_ARCH_MISMATCH";
- case CUBLAS_STATUS_MAPPING_ERROR:
- return "CUBLAS_STATUS_MAPPING_ERROR";
- case CUBLAS_STATUS_EXECUTION_FAILED:
- return "CUBLAS_STATUS_EXECUTION_FAILED";
- case CUBLAS_STATUS_INTERNAL_ERROR:
- return "CUBLAS_STATUS_INTERNAL_ERROR";
-/* case CUBLAS_STATUS_NOT_SUPPORTED:
- return "CUBLAS_STATUS_NOT_SUPPORTED";
- case CUBLAS_STATUS_LICENSE_ERROR:
- return "CUBLAS_STATUS_LICENSE_ERROR"; */
- }
- return "<unknown>";
-}
-
-#define PROFILE_START \
- do { \
- cudaEventRecord(profile_start, 0);
-#define PROFILE_STOP \
- cudaEventRecord(profile_stop, 0); \
- cudaEventSynchronize(profile_stop); \
- float milliseconds = 0; \
- cudaEventElapsedTime(&milliseconds, profile_start, profile_stop); \
- accu_profile(__func__, milliseconds / 1000); \
- } while (0);
-
-#define PROFILE_END
-#endif
diff --git a/matrix/cukernel.cu b/matrix/cukernel.cu
deleted file mode 100644
index a19030a..0000000
--- a/matrix/cukernel.cu
+++ /dev/null
@@ -1,17 +0,0 @@
-#define NERV_GENERIC_CUKERNEL
-
-#define cudak_(NAME) cudak_float_ ## NAME
-#define MATRIX_USE_FLOAT
-#include "generic/elem_type.h"
-#include "generic/cukernel.cu"
-#undef cudak_
-#undef MATRIX_USE_FLOAT
-#undef MATRIX_ELEM
-#undef MATRIX_ELEM_PTR
-#undef MATRIX_ELEM_FMT
-#undef MATRIX_ELEM_WRITE_FMT
-
-#define cudak_(NAME) cudak_double_ ## NAME
-#define MATRIX_USE_DOUBLE
-#include "generic/elem_type.h"
-#include "generic/cukernel.cu"
diff --git a/matrix/cukernel.h b/matrix/cukernel.h
deleted file mode 100644
index 8a1494f..0000000
--- a/matrix/cukernel.h
+++ /dev/null
@@ -1,20 +0,0 @@
-#ifdef NERV_GENERIC_CUKERNEL
-void cudak_(cuda_mul_elem)(const Matrix *a, const Matrix *b, Matrix *c);
-void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b);
-void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b);
-void cudak_(cuda_sigmoid_grad)(const Matrix *output, const Matrix *err, Matrix *nerr);
-void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b);
-void cudak_(cuda_rowmax)(const Matrix *a, Matrix *b);
-void cudak_(cuda_rowmax_idx)(const Matrix *a, Matrix *b, Matrix *idx);
-void cudak_(cuda_colsum)(const Matrix *a, Matrix *b);
-void cudak_(cuda_colsame)(const Matrix *a, const Matrix *ref, Matrix *b);
-void cudak_(cuda_softmax_denominator)(const Matrix *a, const Matrix *max, Matrix *b);
-void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max, const Matrix *deno, Matrix *b);
-void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta);
-void cudak_(cuda_fill)(Matrix *a, double val);
-void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context);
-void cudak_(cuda_rearrange_frm)(const Matrix *a, Matrix *b, int step);
-void cudak_(cuda_scale_rows_by_row)(const Matrix *a, Matrix *b);
-void cudak_(cuda_scale_rows_by_col)(const Matrix *a, Matrix *b);
-void cudak_(cuda_decompress)(const Matrix *a, Matrix *b);
-#endif
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
deleted file mode 100644
index af34fb4..0000000
--- a/matrix/cumatrix.c
+++ /dev/null
@@ -1,87 +0,0 @@
-#define NERV_GENERIC_CUMATRIX
-#include "../common.h"
-#include "cuda_helper.h"
-#include <string.h>
-#define PROFILE_HASHMAP_SIZE 123457
-static cublasHandle_t cublas_handle;
-static cudaEvent_t profile_start, profile_stop;
-static HashMap *profile;
-
-static int print_profile(lua_State *L) {
- (void)L;
- size_t i;
- fprintf(stderr, "*** [nerv cumatrix profile] **\n");
- for (i = 0; i < profile->size; i++)
- {
- HashNode *ptr;
- for (ptr = profile->bucket[i]; ptr; ptr = ptr->next)
- {
- fprintf(stderr, "%s:\t%.6f\n", ptr->key, *(float *)ptr->val);
- }
- }
- return 0;
-}
-
-static int clear_profile(lua_State *L) {
- (void)L;
- hashmap_clear(profile);
- return 0;
-}
-
-void accu_profile(const char *name, float delta) {
- float *val = hashmap_getval(profile, name);
- if (!val)
- {
- val = malloc(sizeof(float));
- *val = 0;
- hashmap_setval(profile, name, val);
- }
- *val += delta;
-}
-
-static const luaL_Reg cumatrix_methods[] = {
- {"print_profile", print_profile},
- {"clear_profile", clear_profile},
- {NULL, NULL}
-};
-
-extern void nerv_matrix_cuda_float_init(lua_State *L);
-extern void nerv_matrix_cuda_double_init(lua_State *L);
-
-void nerv_cumatrix_init(lua_State *L) {
- luaL_register(L, NULL, cumatrix_methods);
- cublasCreate(&cublas_handle);
- cudaEventCreate(&profile_start);
- cudaEventCreate(&profile_stop);
- profile = hashmap_create(PROFILE_HASHMAP_SIZE, bkdr_hash, strcmp);
- nerv_matrix_cuda_float_init(L);
- nerv_matrix_cuda_double_init(L);
-}
-
-#define MATRIX_USE_FLOAT
-#define cuda_matrix_(NAME) cuda_matrix_float_##NAME
-#define nerv_matrix_(NAME) nerv_matrix_cuda_float_##NAME
-#define cudak_(NAME) cudak_float_ ## NAME
-#define NERV_CUBLAS_(NAME) cublasS##NAME
-#define MATRIX_CUMATRIX_HOST_TNAME nerv_matrix_host_float_tname
-const char *nerv_matrix_(tname) = "nerv.CuMatrixFloat";
-#include "generic/cumatrix.c"
-#undef NERV_CUBLAS_
-#undef cudak_
-#undef nerv_matrix_
-#undef cuda_matrix_
-#undef MATRIX_USE_FLOAT
-#undef MATRIX_ELEM
-#undef MATRIX_ELEM_PTR
-#undef MATRIX_ELEM_FMT
-#undef MATRIX_ELEM_WRITE_FMT
-#undef MATRIX_CUMATRIX_HOST_TNAME
-
-#define MATRIX_USE_DOUBLE
-#define cuda_matrix_(NAME) cuda_matrix_double_##NAME
-#define nerv_matrix_(NAME) nerv_matrix_cuda_double_##NAME
-#define cudak_(NAME) cudak_double_ ## NAME
-#define NERV_CUBLAS_(NAME) cublasD##NAME
-#define MATRIX_CUMATRIX_HOST_TNAME nerv_matrix_host_double_tname
-const char *nerv_matrix_(tname) = "nerv.CuMatrixDouble";
-#include "generic/cumatrix.c"
diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu
deleted file mode 100644
index d6c8adc..0000000
--- a/matrix/generic/cukernel.cu
+++ /dev/null
@@ -1,571 +0,0 @@
-#ifdef NERV_GENERIC_CUKERNEL
-#include <assert.h>
-#include <stdio.h>
-#include "matrix.h"
-#include "cuda.h"
-#include "float.h"
-#define CUDA_THREADS_N 16
-#define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N))
-#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
-__global__ void cudak_(log_elem)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- MATRIX_ELEM tmp;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- tmp = a[idx];
- if(tmp < FLT_MIN) tmp = FLT_MIN;
- b[idx] = log(tmp);
-}
-
-__global__ void cudak_(mul_elem)(const MATRIX_ELEM *a, const MATRIX_ELEM *b,
- MATRIX_ELEM *c,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- c[idx] = a[idx] * b[idx];
-}
-
-__global__ void cudak_(sigmoid)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- b[idx] = 1.0 / (1.0 + exp(-a[idx]));
-}
-
-__global__ void cudak_(sigmoid_grad)(const MATRIX_ELEM *output,
- const MATRIX_ELEM *err,
- MATRIX_ELEM *nerr,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- nerr[idx] = output[idx] * (1.0 - output[idx]) * err[idx];
-}
-
-__global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- const MATRIX_ELEM *max, const MATRIX_ELEM *deno,
- int nrow, int ncol, int stride, int mstride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- b[idx] = exp(a[idx] - max[0 + i * mstride]) / deno[0 + i * mstride];
-}
-
-__global__ void cudak_(block_reduce_rowsum)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset];
- __syncthreads();
- }
- if (threadIdx.x == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_colsum)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- cudak_(arr)[threadIdx.y] = i < n ? input[blockIdx.x + istride * i] : 0;
- __syncthreads();
- for (int offset = blockDim.y >> 1; offset; offset >>= 1)
- {
- if (threadIdx.y < offset)
- cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset];
- __syncthreads();
- }
- if (threadIdx.y == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_colsame)(const MATRIX_ELEM *input,
- const MATRIX_ELEM *ref_input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- cudak_(arr)[threadIdx.y] = (i < n && input[blockIdx.x + istride * i] == \
- ref_input[blockIdx.x + istride * i]) ? 1.0 : 0;
- __syncthreads();
- for (int offset = blockDim.y >> 1; offset; offset >>= 1)
- {
- if (threadIdx.y < offset)
- cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset];
- __syncthreads();
- }
- if (threadIdx.y == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_softmax_rowsum)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const MATRIX_ELEM *max,
- const int istride, const int ostride,
- const int mstride, const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- cudak_(arr)[threadIdx.x] = j < n ? exp(input[j + istride * blockIdx.y] - \
- max[0 + mstride * blockIdx.y]) : 0;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset];
- __syncthreads();
- }
- if (threadIdx.x == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : -FLT_MAX;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- {
- MATRIX_ELEM l = cudak_(arr)[threadIdx.x],
- r = cudak_(arr)[threadIdx.x + offset];
- if (r > l)
- cudak_(arr)[threadIdx.x] = r;
- }
- __syncthreads();
- }
- if (threadIdx.x == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_rowmax_idx)(const MATRIX_ELEM *input,
- const MATRIX_ELEM *idx_input,
- MATRIX_ELEM *output,
- MATRIX_ELEM *idx_output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- MATRIX_ELEM *arr_val = cudak_(arr);
- MATRIX_ELEM *arr_idx = arr_val + blockDim.x;
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- arr_val[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : -FLT_MAX;
- arr_idx[threadIdx.x] = j < n ? idx_input[j + istride * blockIdx.y] : 0;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- {
- MATRIX_ELEM l = arr_val[threadIdx.x],
- r = arr_val[threadIdx.x + offset];
- if (r > l)
- {
- arr_val[threadIdx.x] = r;
- arr_idx[threadIdx.x] = arr_idx[threadIdx.x + offset];
- }
- }
- __syncthreads();
- }
- if (threadIdx.x == 0)
- {
- output[blockIdx.x + ostride * blockIdx.y] = arr_val[0];
- idx_output[blockIdx.x + ostride * blockIdx.y] = arr_idx[0];
- }
-}
-
-__global__ void cudak_(add_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol, int stride, double beta) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * stride] += beta * a[j];
-}
-
-__global__ void cudak_(fill)(MATRIX_ELEM *a,
- int nrow, int ncol, int stride, double val) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- a[j + i * stride] = val;
-}
-
-__global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int enrow, int encol,
- int stride, int estride,
- int context) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- int ridx;
- if (i >= enrow || j >= encol) return;
- ridx = i + j / ncol - context;
- if (ridx < 0) ridx = 0;
- else if (ridx >= nrow) ridx = nrow - 1;
- b[j + i * estride] = a[j % ncol + ridx * stride];
-}
-
-__global__ void cudak_(rearrange_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int stride, int step, int orig_dim) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * stride] = a[j / step + (j % step) * orig_dim + i * stride];
-}
-
-__global__ void cudak_(scale_rows_by_col)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int astride, int bstride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * bstride] *= a[i * astride];
-}
-
-__global__ void cudak_(scale_rows_by_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * stride] *= a[j];
-}
-
-__global__ void cudak_(decompress)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int stride_a, int stride_b) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[lrintf(a[j + i * stride_a]) + i * stride_b] = 1.0;
-}
-
-__global__ void cudak_(gen_col_idx)(MATRIX_ELEM *b,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * stride] = j;
-}
-
-extern "C" {
-#include "../cukernel.h"
- void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(log_elem)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_mul_elem)(const Matrix *a, const Matrix *b,
- Matrix *c) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(mul_elem)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- MATRIX_ELEM_PTR(c),
- b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(sigmoid)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
- b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_sigmoid_grad)(const Matrix *output,
- const Matrix *err, Matrix *nerr) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x),
- CEIL_DIV(nerr->nrow, threadsPerBlock.y));
- cudak_(sigmoid_grad)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err),
- MATRIX_ELEM_PTR(nerr),
- nerr->nrow, nerr->ncol,
- nerr->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- MATRIX_ELEM *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudak_(block_reduce_rowsum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), res,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowsum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- ncol);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_colsame)(const Matrix *a, const Matrix *ref, Matrix *b) {
- dim3 block(1, CUDA_THREADS_NN);
- int nrow = a->nrow;
- int blocks_per_col = CEIL_DIV(nrow, block.y);
- dim3 grid(a->ncol, blocks_per_col);
- MATRIX_ELEM *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col);
- cudak_(block_reduce_colsame)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(ref), res,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- nrow);
- nrow = blocks_per_col;
- assert((unsigned long)nrow <= block.y);
- grid.y = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- nrow);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_colsum)(const Matrix *a, Matrix *b) {
- dim3 block(1, CUDA_THREADS_NN);
- int nrow = a->nrow;
- int blocks_per_col = CEIL_DIV(nrow, block.y);
- dim3 grid(a->ncol, blocks_per_col);
- MATRIX_ELEM *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col);
- cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), res,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- nrow);
- nrow = blocks_per_col;
- assert((unsigned long)nrow <= block.y);
- grid.y = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- nrow);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max,
- const Matrix *deno, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(softmax_final)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- MATRIX_ELEM_PTR(max), MATRIX_ELEM_PTR(deno),
- b->nrow, b->ncol,
- b->stride / sizeof(MATRIX_ELEM),
- max->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_softmax_denominator)(const Matrix *a, const Matrix *max, Matrix *b) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- MATRIX_ELEM *res;
- size_t stride;
- assert(max->ncol == 1);
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudak_(block_reduce_softmax_rowsum) \
- <<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), res, MATRIX_ELEM_PTR(max),
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- max->stride / sizeof(MATRIX_ELEM),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowsum) \
- <<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- ncol);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_rowmax)(const Matrix *a, Matrix *b) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- MATRIX_ELEM *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudak_(block_reduce_rowmax)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), res,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowmax)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- ncol);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_rowmax_idx)(const Matrix *a, Matrix *b, Matrix *b_idx) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- MATRIX_ELEM *a_idx, *res, *res_idx;
- size_t stride;
- cudaMallocPitch(&a_idx, &stride, a->stride, a->nrow);
- cudak_(gen_col_idx)<<<grid, block>>>(a_idx, a->nrow, ncol, stride / sizeof(MATRIX_ELEM));
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudaMallocPitch(&res_idx, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowmax_idx)<<<grid, block,
- 2 * block.x * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), a_idx, res, res_idx,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowmax_idx)<<<grid, block,
- 2 * block.x * sizeof(MATRIX_ELEM)>>> \
- (res, res_idx, MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(b_idx),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- ncol);
- cudaStreamSynchronize(0);
- cudaFree(a_idx);
- cudaFree(res);
- cudaFree(res_idx);
- }
-
- /* in-place calc */
- void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(add_row)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
- b->stride / sizeof(MATRIX_ELEM), beta);
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_fill)(Matrix *a, double val) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
- CEIL_DIV(a->nrow, threadsPerBlock.y));
- cudak_(fill)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), a->nrow, a->ncol,
- a->stride / sizeof(MATRIX_ELEM), val);
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(expand_frm)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- a->nrow, a->ncol,
- b->nrow, b->ncol,
- a->stride / sizeof(MATRIX_ELEM),
- b->stride / sizeof(MATRIX_ELEM),
- context);
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_rearrange_frm)(const Matrix *a, Matrix *b, int step) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(rearrange_frm)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM),
- step, b->ncol / step);
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_scale_rows_by_col)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(scale_rows_by_col)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- b->nrow, b->ncol,
- a->stride / sizeof(MATRIX_ELEM),
- b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_scale_rows_by_row)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(scale_rows_by_row)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_decompress)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(1, CUDA_THREADS_NN);
- dim3 numBlocks(1, CEIL_DIV(a->nrow, threadsPerBlock.y));
- cudak_(decompress)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- a->nrow, a->ncol,
- a->stride / sizeof(MATRIX_ELEM),
- b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-}
-#endif
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
deleted file mode 100644
index b5d1a35..0000000
--- a/matrix/generic/cumatrix.c
+++ /dev/null
@@ -1,493 +0,0 @@
-#ifdef NERV_GENERIC_CUMATRIX
-#include "matrix.h"
-#include "elem_type.h"
-
-#define MATRIX_DATA_FREE(L, ptr) cuda_matrix_(free)(L, ptr)
-#define MATRIX_DATA_ALLOC(L, dptr, stride, width, height) \
- cuda_matrix_(alloc)(L, dptr, stride, width, height)
-#define MATRIX_DATA_WRITE(L, data, idx, val) cuda_matrix_(write)(L, data, idx, val)
-#define MATRIX_DATA_READ(L, data, idx) cuda_matrix_(read)(L, data, idx)
-#define MATRIX_INIT(L) cuda_matrix_(init)(L)
-#define MATRIX_BASE_TNAME nerv_matrix_cuda_tname
-#define NERV_GENERIC_MATRIX
-#define NERV_GENERIC_CUKERNEL
-#include "../../common.h"
-#include "../cukernel.h"
-#include "../cuda_helper.h"
-
-Matrix *nerv_matrix_(new_)(lua_State *L, long nrow, long ncol);
-void nerv_matrix_(data_free)(lua_State *L, Matrix *self);
-
-static void nerv_matrix_(add_)(lua_State *L, const Matrix *a, const Matrix *b,
- const Matrix *c,
- MATRIX_ELEM alpha, MATRIX_ELEM beta) {
- PROFILE_START
- CUBLAS_SAFE_SYNC_CALL(
- NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
- a->ncol, a->nrow,
- &alpha,
-