aboutsummaryrefslogtreecommitdiff
path: root/matrix/generic
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/generic')
-rw-r--r--matrix/generic/cukernel.cu571
-rw-r--r--matrix/generic/cumatrix.c493
-rw-r--r--matrix/generic/elem_type.h22
-rw-r--r--matrix/generic/matrix.c155
-rw-r--r--matrix/generic/matrix.h19
-rw-r--r--matrix/generic/mmatrix.c122
6 files changed, 0 insertions, 1382 deletions
diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu
deleted file mode 100644
index d6c8adc..0000000
--- a/matrix/generic/cukernel.cu
+++ /dev/null
@@ -1,571 +0,0 @@
-#ifdef NERV_GENERIC_CUKERNEL
-#include <assert.h>
-#include <stdio.h>
-#include "matrix.h"
-#include "cuda.h"
-#include "float.h"
-#define CUDA_THREADS_N 16
-#define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N))
-#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
-__global__ void cudak_(log_elem)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- MATRIX_ELEM tmp;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- tmp = a[idx];
- if(tmp < FLT_MIN) tmp = FLT_MIN;
- b[idx] = log(tmp);
-}
-
-__global__ void cudak_(mul_elem)(const MATRIX_ELEM *a, const MATRIX_ELEM *b,
- MATRIX_ELEM *c,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- c[idx] = a[idx] * b[idx];
-}
-
-__global__ void cudak_(sigmoid)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- b[idx] = 1.0 / (1.0 + exp(-a[idx]));
-}
-
-__global__ void cudak_(sigmoid_grad)(const MATRIX_ELEM *output,
- const MATRIX_ELEM *err,
- MATRIX_ELEM *nerr,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- nerr[idx] = output[idx] * (1.0 - output[idx]) * err[idx];
-}
-
-__global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- const MATRIX_ELEM *max, const MATRIX_ELEM *deno,
- int nrow, int ncol, int stride, int mstride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- long idx;
- if (i >= nrow || j >= ncol) return;
- idx = j + i * stride;
- b[idx] = exp(a[idx] - max[0 + i * mstride]) / deno[0 + i * mstride];
-}
-
-__global__ void cudak_(block_reduce_rowsum)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset];
- __syncthreads();
- }
- if (threadIdx.x == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_colsum)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- cudak_(arr)[threadIdx.y] = i < n ? input[blockIdx.x + istride * i] : 0;
- __syncthreads();
- for (int offset = blockDim.y >> 1; offset; offset >>= 1)
- {
- if (threadIdx.y < offset)
- cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset];
- __syncthreads();
- }
- if (threadIdx.y == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_colsame)(const MATRIX_ELEM *input,
- const MATRIX_ELEM *ref_input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- cudak_(arr)[threadIdx.y] = (i < n && input[blockIdx.x + istride * i] == \
- ref_input[blockIdx.x + istride * i]) ? 1.0 : 0;
- __syncthreads();
- for (int offset = blockDim.y >> 1; offset; offset >>= 1)
- {
- if (threadIdx.y < offset)
- cudak_(arr)[threadIdx.y] += cudak_(arr)[threadIdx.y + offset];
- __syncthreads();
- }
- if (threadIdx.y == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_softmax_rowsum)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const MATRIX_ELEM *max,
- const int istride, const int ostride,
- const int mstride, const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- cudak_(arr)[threadIdx.x] = j < n ? exp(input[j + istride * blockIdx.y] - \
- max[0 + mstride * blockIdx.y]) : 0;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- cudak_(arr)[threadIdx.x] += cudak_(arr)[threadIdx.x + offset];
- __syncthreads();
- }
- if (threadIdx.x == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_rowmax)(const MATRIX_ELEM *input,
- MATRIX_ELEM *output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- cudak_(arr)[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : -FLT_MAX;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- {
- MATRIX_ELEM l = cudak_(arr)[threadIdx.x],
- r = cudak_(arr)[threadIdx.x + offset];
- if (r > l)
- cudak_(arr)[threadIdx.x] = r;
- }
- __syncthreads();
- }
- if (threadIdx.x == 0)
- output[blockIdx.x + ostride * blockIdx.y] = cudak_(arr)[0];
-}
-
-__global__ void cudak_(block_reduce_rowmax_idx)(const MATRIX_ELEM *input,
- const MATRIX_ELEM *idx_input,
- MATRIX_ELEM *output,
- MATRIX_ELEM *idx_output,
- const int istride, const int ostride,
- const int n) {
- extern __shared__ MATRIX_ELEM cudak_(arr)[];
- MATRIX_ELEM *arr_val = cudak_(arr);
- MATRIX_ELEM *arr_idx = arr_val + blockDim.x;
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- arr_val[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : -FLT_MAX;
- arr_idx[threadIdx.x] = j < n ? idx_input[j + istride * blockIdx.y] : 0;
- __syncthreads();
- for (int offset = blockDim.x >> 1; offset; offset >>= 1)
- {
- if (threadIdx.x < offset)
- {
- MATRIX_ELEM l = arr_val[threadIdx.x],
- r = arr_val[threadIdx.x + offset];
- if (r > l)
- {
- arr_val[threadIdx.x] = r;
- arr_idx[threadIdx.x] = arr_idx[threadIdx.x + offset];
- }
- }
- __syncthreads();
- }
- if (threadIdx.x == 0)
- {
- output[blockIdx.x + ostride * blockIdx.y] = arr_val[0];
- idx_output[blockIdx.x + ostride * blockIdx.y] = arr_idx[0];
- }
-}
-
-__global__ void cudak_(add_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol, int stride, double beta) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * stride] += beta * a[j];
-}
-
-__global__ void cudak_(fill)(MATRIX_ELEM *a,
- int nrow, int ncol, int stride, double val) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- a[j + i * stride] = val;
-}
-
-__global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int enrow, int encol,
- int stride, int estride,
- int context) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- int ridx;
- if (i >= enrow || j >= encol) return;
- ridx = i + j / ncol - context;
- if (ridx < 0) ridx = 0;
- else if (ridx >= nrow) ridx = nrow - 1;
- b[j + i * estride] = a[j % ncol + ridx * stride];
-}
-
-__global__ void cudak_(rearrange_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int stride, int step, int orig_dim) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * stride] = a[j / step + (j % step) * orig_dim + i * stride];
-}
-
-__global__ void cudak_(scale_rows_by_col)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int astride, int bstride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * bstride] *= a[i * astride];
-}
-
-__global__ void cudak_(scale_rows_by_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * stride] *= a[j];
-}
-
-__global__ void cudak_(decompress)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int stride_a, int stride_b) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[lrintf(a[j + i * stride_a]) + i * stride_b] = 1.0;
-}
-
-__global__ void cudak_(gen_col_idx)(MATRIX_ELEM *b,
- int nrow, int ncol, int stride) {
- int j = blockIdx.x * blockDim.x + threadIdx.x;
- int i = blockIdx.y * blockDim.y + threadIdx.y;
- if (i >= nrow || j >= ncol) return;
- b[j + i * stride] = j;
-}
-
-extern "C" {
-#include "../cukernel.h"
- void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(log_elem)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_mul_elem)(const Matrix *a, const Matrix *b,
- Matrix *c) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(mul_elem)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- MATRIX_ELEM_PTR(c),
- b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(sigmoid)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
- b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_sigmoid_grad)(const Matrix *output,
- const Matrix *err, Matrix *nerr) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x),
- CEIL_DIV(nerr->nrow, threadsPerBlock.y));
- cudak_(sigmoid_grad)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err),
- MATRIX_ELEM_PTR(nerr),
- nerr->nrow, nerr->ncol,
- nerr->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- MATRIX_ELEM *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudak_(block_reduce_rowsum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), res,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowsum)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- ncol);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_colsame)(const Matrix *a, const Matrix *ref, Matrix *b) {
- dim3 block(1, CUDA_THREADS_NN);
- int nrow = a->nrow;
- int blocks_per_col = CEIL_DIV(nrow, block.y);
- dim3 grid(a->ncol, blocks_per_col);
- MATRIX_ELEM *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col);
- cudak_(block_reduce_colsame)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(ref), res,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- nrow);
- nrow = blocks_per_col;
- assert((unsigned long)nrow <= block.y);
- grid.y = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- nrow);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_colsum)(const Matrix *a, Matrix *b) {
- dim3 block(1, CUDA_THREADS_NN);
- int nrow = a->nrow;
- int blocks_per_col = CEIL_DIV(nrow, block.y);
- dim3 grid(a->ncol, blocks_per_col);
- MATRIX_ELEM *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, a->ncol * sizeof(MATRIX_ELEM), blocks_per_col);
- cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), res,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- nrow);
- nrow = blocks_per_col;
- assert((unsigned long)nrow <= block.y);
- grid.y = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_colsum)<<<grid, block, block.y * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- nrow);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max,
- const Matrix *deno, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(softmax_final)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- MATRIX_ELEM_PTR(max), MATRIX_ELEM_PTR(deno),
- b->nrow, b->ncol,
- b->stride / sizeof(MATRIX_ELEM),
- max->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_softmax_denominator)(const Matrix *a, const Matrix *max, Matrix *b) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- MATRIX_ELEM *res;
- size_t stride;
- assert(max->ncol == 1);
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudak_(block_reduce_softmax_rowsum) \
- <<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), res, MATRIX_ELEM_PTR(max),
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- max->stride / sizeof(MATRIX_ELEM),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowsum) \
- <<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- ncol);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_rowmax)(const Matrix *a, Matrix *b) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- MATRIX_ELEM *res;
- size_t stride;
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudak_(block_reduce_rowmax)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), res,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowmax)<<<grid, block, block.x * sizeof(MATRIX_ELEM)>>> \
- (res, MATRIX_ELEM_PTR(b),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- ncol);
- cudaStreamSynchronize(0);
- cudaFree(res);
- }
-
- void cudak_(cuda_rowmax_idx)(const Matrix *a, Matrix *b, Matrix *b_idx) {
- dim3 block(CUDA_THREADS_NN, 1);
- int ncol = a->ncol;
- int blocks_per_row = CEIL_DIV(ncol, block.x);
- dim3 grid(blocks_per_row, a->nrow);
- MATRIX_ELEM *a_idx, *res, *res_idx;
- size_t stride;
- cudaMallocPitch(&a_idx, &stride, a->stride, a->nrow);
- cudak_(gen_col_idx)<<<grid, block>>>(a_idx, a->nrow, ncol, stride / sizeof(MATRIX_ELEM));
- cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudaMallocPitch(&res_idx, &stride, blocks_per_row * sizeof(MATRIX_ELEM), a->nrow);
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowmax_idx)<<<grid, block,
- 2 * block.x * sizeof(MATRIX_ELEM)>>> \
- (MATRIX_ELEM_PTR(a), a_idx, res, res_idx,
- a->stride / sizeof(MATRIX_ELEM), stride / sizeof(MATRIX_ELEM),
- ncol);
- ncol = blocks_per_row;
- assert((unsigned long)ncol <= block.x);
- grid.x = 1;
- cudaStreamSynchronize(0);
- cudak_(block_reduce_rowmax_idx)<<<grid, block,
- 2 * block.x * sizeof(MATRIX_ELEM)>>> \
- (res, res_idx, MATRIX_ELEM_PTR(b), MATRIX_ELEM_PTR(b_idx),
- stride / sizeof(MATRIX_ELEM), b->stride / sizeof(MATRIX_ELEM),
- ncol);
- cudaStreamSynchronize(0);
- cudaFree(a_idx);
- cudaFree(res);
- cudaFree(res_idx);
- }
-
- /* in-place calc */
- void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(add_row)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
- b->stride / sizeof(MATRIX_ELEM), beta);
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_fill)(Matrix *a, double val) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
- CEIL_DIV(a->nrow, threadsPerBlock.y));
- cudak_(fill)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), a->nrow, a->ncol,
- a->stride / sizeof(MATRIX_ELEM), val);
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(expand_frm)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- a->nrow, a->ncol,
- b->nrow, b->ncol,
- a->stride / sizeof(MATRIX_ELEM),
- b->stride / sizeof(MATRIX_ELEM),
- context);
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_rearrange_frm)(const Matrix *a, Matrix *b, int step) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(rearrange_frm)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM),
- step, b->ncol / step);
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_scale_rows_by_col)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(scale_rows_by_col)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- b->nrow, b->ncol,
- a->stride / sizeof(MATRIX_ELEM),
- b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_scale_rows_by_row)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
- dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
- CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(scale_rows_by_row)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-
- void cudak_(cuda_decompress)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(1, CUDA_THREADS_NN);
- dim3 numBlocks(1, CEIL_DIV(a->nrow, threadsPerBlock.y));
- cudak_(decompress)<<<numBlocks, threadsPerBlock>>> \
- (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
- a->nrow, a->ncol,
- a->stride / sizeof(MATRIX_ELEM),
- b->stride / sizeof(MATRIX_ELEM));
- cudaStreamSynchronize(0);
- }
-}
-#endif
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
deleted file mode 100644
index b5d1a35..0000000
--- a/matrix/generic/cumatrix.c
+++ /dev/null
@@ -1,493 +0,0 @@
-#ifdef NERV_GENERIC_CUMATRIX
-#include "matrix.h"
-#include "elem_type.h"
-
-#define MATRIX_DATA_FREE(L, ptr) cuda_matrix_(free)(L, ptr)
-#define MATRIX_DATA_ALLOC(L, dptr, stride, width, height) \
- cuda_matrix_(alloc)(L, dptr, stride, width, height)
-#define MATRIX_DATA_WRITE(L, data, idx, val) cuda_matrix_(write)(L, data, idx, val)
-#define MATRIX_DATA_READ(L, data, idx) cuda_matrix_(read)(L, data, idx)
-#define MATRIX_INIT(L) cuda_matrix_(init)(L)
-#define MATRIX_BASE_TNAME nerv_matrix_cuda_tname
-#define NERV_GENERIC_MATRIX
-#define NERV_GENERIC_CUKERNEL
-#include "../../common.h"
-#include "../cukernel.h"
-#include "../cuda_helper.h"
-
-Matrix *nerv_matrix_(new_)(lua_State *L, long nrow, long ncol);
-void nerv_matrix_(data_free)(lua_State *L, Matrix *self);
-
-static void nerv_matrix_(add_)(lua_State *L, const Matrix *a, const Matrix *b,
- const Matrix *c,
- MATRIX_ELEM alpha, MATRIX_ELEM beta) {
- PROFILE_START
- CUBLAS_SAFE_SYNC_CALL(
- NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
- a->ncol, a->nrow,
- &alpha,
- MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
- &beta,
- MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM),
- MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM)));
- PROFILE_STOP
-}
-
-static int nerv_matrix_(add)(lua_State *L) {
- Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
- MATRIX_ELEM alpha = luaL_checknumber(L, 4);
- MATRIX_ELEM beta = luaL_checknumber(L, 5);
- CHECK_SAME_DIMENSION(a, b);
- CHECK_SAME_DIMENSION(a, c);
- nerv_matrix_(add_)(L, a, b, c, alpha, beta);
- return 0;
-}
-
-static int nerv_matrix_(get_cublas_op)(char ch) {
- return (ch == 'T' || ch == 't') ? CUBLAS_OP_T : CUBLAS_OP_N;
-}
-
-static int nerv_matrix_(mul)(lua_State *L) {
-#define SWAP(a, b) \
- do { int t = (a); (a) = (b); (b) = t; } while (0)
-
- Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
- MATRIX_ELEM alpha = luaL_checknumber(L, 4);
- MATRIX_ELEM beta = luaL_checknumber(L, 5);
- int nargs = lua_gettop(L);
- int ta = nargs > 5 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 6)) \
- : CUBLAS_OP_N;
- int tb = nargs > 6 ? nerv_matrix_(get_cublas_op)(*luaL_checkstring(L, 7)) \
- : CUBLAS_OP_N;
- int am = a->nrow, an = a->ncol;
- int bm = b->nrow, bn = b->ncol;
- if (ta == CUBLAS_OP_T) SWAP(am, an);
- if (tb == CUBLAS_OP_T) SWAP(bm, bn);
- if (an != bm)
- nerv_error(L, "Wrong dimension of multipliers");
-/* MATRIX_ELEM alpha = 1.0f, beta = 0.0f; */
- /* Because matrix in Nerv is row-major, here b comes first */
- PROFILE_START
- CUBLAS_SAFE_SYNC_CALL(
- NERV_CUBLAS_(gemm)(cublas_handle, tb, ta,
- bn, am, bm,
- &alpha,
- MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM),
- MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
- &beta,
- MATRIX_ELEM_PTR(c), c->stride / sizeof(MATRIX_ELEM)));
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(create)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(new_)(L, a->nrow, a->ncol);
- luaT_pushudata(L, b, nerv_matrix_(tname));
- return 1;
-}
-
-static int nerv_matrix_(sigmoid)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- CHECK_SAME_DIMENSION(a, b);
- PROFILE_START
- cudak_(cuda_sigmoid)(b, a);
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(sigmoid_grad)(lua_State *L) {
- Matrix *nerr = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *err = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *output = luaT_checkudata(L, 3, nerv_matrix_(tname));
- CHECK_SAME_DIMENSION(nerr, err);
- CHECK_SAME_DIMENSION(nerr, output);
- PROFILE_START
- cudak_(cuda_sigmoid_grad)(output, err, nerr);
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(softmax)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *max, *max_idx;
- Matrix *dno;
- CHECK_SAME_DIMENSION(a, b);
- max = nerv_matrix_(new_)(L, a->nrow, 1);
- max_idx = nerv_matrix_(new_)(L, a->nrow, 1);
- dno = nerv_matrix_(new_)(L, a->nrow, 1);
- PROFILE_START
- cudak_(cuda_rowmax_idx)(a, max, max_idx);
- cudak_(cuda_softmax_denominator)(a, max, dno);
- cudak_(cuda_softmax_final)(a, max, dno, b);
- PROFILE_STOP
- nerv_matrix_(data_free)(L, max);
- nerv_matrix_(data_free)(L, dno);
- luaT_pushudata(L, max_idx, nerv_matrix_(tname));
- return 1;
-}
-
-static int nerv_matrix_(rowsum)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(new_)(L, a->nrow, 1);
- PROFILE_START
- cudak_(cuda_rowsum)(a, b);
- PROFILE_STOP
- luaT_pushudata(L, b, nerv_matrix_(tname));
- return 1;
-}
-
-static int nerv_matrix_(colsum)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(new_)(L, 1, a->ncol);
- PROFILE_START
- cudak_(cuda_colsum)(a, b);
- PROFILE_STOP
- luaT_pushudata(L, b, nerv_matrix_(tname));
- return 1;
-}
-
-static int nerv_matrix_(colsame)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *ref = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(new_)(L, 1, a->ncol);
- CHECK_SAME_DIMENSION(a, ref);
- PROFILE_START
- cudak_(cuda_colsame)(a, ref, b);
- PROFILE_STOP
- luaT_pushudata(L, b, nerv_matrix_(tname));
- return 1;
-}
-
-static int nerv_matrix_(rowmax)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(new_)(L, a->nrow, 1);
- PROFILE_START
- cudak_(cuda_rowmax)(a, b);
- PROFILE_STOP
- luaT_pushudata(L, b, nerv_matrix_(tname));
- return 1;
-}
-
-static int nerv_matrix_(rowmax_idx)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(new_)(L, a->nrow, 1);
- Matrix *idx = nerv_matrix_(new_)(L, a->nrow, 1);
- PROFILE_START
- cudak_(cuda_rowmax_idx)(a, b, idx);
- PROFILE_STOP
- luaT_pushudata(L, b, nerv_matrix_(tname));
- luaT_pushudata(L, idx, nerv_matrix_(tname));
- return 2;
-}
-
-static int nerv_matrix_(add_row)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
- double beta = luaL_checknumber(L, 3);
- if (a->ncol != b->ncol)
- nerv_error(L, "the number of columns is not the same");
- if (a->nrow != 1)
- nerv_error(L, "a row vector is expected");
- PROFILE_START
- cudak_(cuda_add_row)(a, b, beta);
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(fill)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- double val = luaL_checknumber(L, 2);
- PROFILE_START
- cudak_(cuda_fill)(self, val);
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(copy_fromd)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- int nargs = lua_gettop(L);
- int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
- int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
- int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- if (!(0 <= b_begin && b_begin < b_end && b_end <= b->nrow &&
- a_begin + b_end - b_begin <= a->nrow))
- nerv_error(L, "invalid copy interval");
- if (a->ncol != b->ncol)
- nerv_error(L, "matrices should be of the same dimension");
- PROFILE_START
- CUDA_SAFE_SYNC_CALL(
- cudaMemcpy2D(MATRIX_ROW_PTR(a, a_begin), a->stride,
- MATRIX_ROW_PTR(b, b_begin), b->stride,
- sizeof(MATRIX_ELEM) * b->ncol, b_end - b_begin,
- cudaMemcpyDeviceToDevice));
- PROFILE_STOP
- return 0;
-}
-
-extern const char *MATRIX_CUMATRIX_HOST_TNAME;
-static int nerv_matrix_(copy_fromh)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
- int nargs = lua_gettop(L);
- int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
- int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
- int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- if (!(0 <= b_begin && b_begin < b_end && b_end <= b->nrow &&
- a_begin + b_end - b_begin <= a->nrow))
- nerv_error(L, "invalid copy interval");
- if (a->ncol != b->ncol)
- nerv_error(L, "matrices should be of the same dimension");
- PROFILE_START
- CUDA_SAFE_SYNC_CALL(
- cudaMemcpy2D(MATRIX_ROW_PTR(a, a_begin), a->stride,
- MATRIX_ROW_PTR(b, b_begin), b->stride,
- sizeof(MATRIX_ELEM) * b->ncol, b_end - b_begin,
- cudaMemcpyHostToDevice));
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(copy_toh)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
- int nargs = lua_gettop(L);
- int a_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
- int a_end = nargs > 3 ? luaL_checkinteger(L, 4) : a->nrow;
- int b_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- if (!(0 <= a_begin && a_begin < a_end && a_end <= a->nrow &&
- b_begin + a_end - a_begin <= b->nrow))
- nerv_error(L, "invalid copy interval");
- if (b->ncol != a->ncol)
- nerv_error(L, "matrices should be of the same dimension");
- PROFILE_START
- CUDA_SAFE_SYNC_CALL(
- cudaMemcpy2D(MATRIX_ROW_PTR(b, b_begin), b->stride,
- MATRIX_ROW_PTR(a, a_begin), a->stride,
- sizeof(MATRIX_ELEM) * a->ncol, a_end - a_begin,
- cudaMemcpyDeviceToHost));
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(trans)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = nerv_matrix_(new_)(L, a->ncol, a->nrow);
- MATRIX_ELEM alpha = 1, beta = 0;
- /* FIXME: possible memory leak when lua error is raised */
- PROFILE_START
- CUBLAS_SAFE_SYNC_CALL(
- NERV_CUBLAS_(geam)(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T,
- a->nrow, a->ncol,
- &alpha,
- MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
- &beta,
- MATRIX_ELEM_PTR(a), a->stride / sizeof(MATRIX_ELEM),
- MATRIX_ELEM_PTR(b), b->stride / sizeof(MATRIX_ELEM)));
- PROFILE_STOP
- luaT_pushudata(L, b, nerv_matrix_(tname));
- return 1;
-}
-
-static int nerv_matrix_(mul_elem)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
- Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
- CHECK_SAME_DIMENSION(a, b);
- CHECK_SAME_DIMENSION(a, c);
- PROFILE_START
- cudak_(cuda_mul_elem)(a, b, c);
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(log_elem)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
- CHECK_SAME_DIMENSION(a, b);
- PROFILE_START
- cudak_(cuda_log_elem)(a, b);
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(decompress)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b;
- int orig_col = luaL_checkinteger(L, 2);
- if (a->ncol != 1)
- nerv_error(L, "the compressed matrix must be a column vector");
- b = nerv_matrix_(new_)(L, a->nrow, orig_col);
- PROFILE_START
- cudak_(cuda_fill)(b, 0.0);
- cudak_(cuda_decompress)(a, b);
- PROFILE_STOP
- luaT_pushudata(L, b, nerv_matrix_(tname));
- return 1;
-}
-
-extern const char *nerv_matrix_host_int_tname;
-static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, MATRIX_CUMATRIX_HOST_TNAME);
- Matrix *idx = luaT_checkudata(L, 3, nerv_matrix_host_int_tname);
- long nrow = a->nrow;
- int b_begin = lua_gettop(L) > 3 ? luaL_checkinteger(L, 4) : 0;
- if (!(0 <= b_begin && b_begin + nrow <= idx->ncol))
- nerv_error(L, "invalid copy interval");
- long *idx_ptr = idx->data.i;
- int i;
- if (idx->nrow != 1)
- nerv_error(L, "index should be a vector");
- if (a->ncol != b->ncol)
- nerv_error(L, "source/destination dimension mismatch");
- cudaStream_t *streams = (cudaStream_t*)malloc(sizeof(cudaStream_t) * nrow);
- for (i = 0; i < nrow; i++)
- {
- int src_row = idx_ptr[b_begin + i];
- if (!(0 <= src_row && src_row < b->nrow))
- nerv_error(L, "invalid index");
- CUDA_SAFE_CALL(cudaStreamCreate(streams + i));
- CUDA_SAFE_CALL(cudaMemcpyAsync(MATRIX_ROW_PTR(a, i),
- MATRIX_ROW_PTR(b, src_row),
- b->stride,
- cudaMemcpyHostToDevice, streams[i]));
- }
- for (i = 0; i < nrow; i++)
- {
- CUDA_SAFE_CALL(cudaStreamSynchronize(streams[i]));
- CUDA_SAFE_CALL(cudaStreamDestroy(streams[i]));
- }
- free(streams);
- return 0;
-}
-
-static int nerv_matrix_(expand_frm)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- int context = luaL_checkinteger(L, 3);
- if (a->nrow != b->nrow)
- nerv_error(L, "mismatching number of frames");
- if (a->ncol != b->ncol * (context * 2 + 1))
- nerv_error(L, "the width should be 2 * context + 1");
- PROFILE_START
- cudak_(cuda_expand_frm)(b, a, context);
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(rearrange_frm)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- int step = luaL_checkinteger(L, 3);
- CHECK_SAME_DIMENSION(a, b);
- if (b->ncol % step)
- nerv_error(L, "the dimension of columns is not divisible by step");
- PROFILE_START
- cudak_(cuda_rearrange_frm)(b, a, step);
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(scale_rows_by_col)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- if (a->nrow != b->nrow)
- nerv_error(L, "the number of rows is not the same");
- if (b->ncol != 1)
- nerv_error(L, "a column vector is expected");
- PROFILE_START
- cudak_(cuda_scale_rows_by_col)(b, a);
- PROFILE_STOP
- return 0;
-}
-
-static int nerv_matrix_(scale_rows_by_row)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- if (a->ncol != b->ncol)
- nerv_error(L, "the number of columns is not the same");
- if (b->nrow != 1)
- nerv_error(L, "a row vector is expected");
- PROFILE_START
- cudak_(cuda_scale_rows_by_row)(b, a);
- PROFILE_STOP
- return 0;
-}
-
-static const luaL_Reg nerv_matrix_(extra_methods)[] = {
- {"create", nerv_matrix_(create)},
- {"colsum", nerv_matrix_(colsum)},
- {"colsame", nerv_matrix_(colsame)},
- {"rowsum", nerv_matrix_(rowsum)},
- {"rowmax", nerv_matrix_(rowmax)},
- {"rowmax_idx", nerv_matrix_(rowmax_idx)},
- {"trans", nerv_matrix_(trans)},
- {"decompress", nerv_matrix_(decompress)},
- /* in-place calc */
- {"copy_fromh", nerv_matrix_(copy_fromh)},
- {"copy_fromd", nerv_matrix_(copy_fromd)},
- {"copy_toh", nerv_matrix_(copy_toh)},
- {"add", nerv_matrix_(add)},
- {"mul", nerv_matrix_(mul)},
- {"add_row", nerv_matrix_(add_row)},
- {"fill", nerv_matrix_(fill)},
- {"sigmoid", nerv_matrix_(sigmoid)},
- {"sigmoid_grad", nerv_matrix_(sigmoid_grad)},
- {"softmax", nerv_matrix_(softmax)},
- {"mul_elem", nerv_matrix_(mul_elem)},
- {"log_elem", nerv_matrix_(log_elem)},
- {"copy_rows_fromh_by_idx", nerv_matrix_(copy_rows_fromh_by_idx)},
- {"expand_frm", nerv_matrix_(expand_frm)},
- {"rearrange_frm", nerv_matrix_(rearrange_frm)},
- {"scale_rows_by_row", nerv_matrix_(scale_rows_by_row)},
- {"scale_rows_by_col", nerv_matrix_(scale_rows_by_col)},
- {NULL, NULL}
-};
-
-static void cuda_matrix_(init)(lua_State *L) {
- luaN_append_methods(L, nerv_matrix_(extra_methods));
-}
-
-static void cuda_matrix_(free)(lua_State *L, MATRIX_ELEM *ptr) {
- CUDA_SAFE_SYNC_CALL(cudaFree(ptr));
-}
-
-static void cuda_matrix_(alloc)(lua_State *L, MATRIX_ELEM **dptr,
- size_t *stride, long width, long height) {
- PROFILE_START
- CUDA_SAFE_SYNC_CALL(cudaMallocPitch((void **)dptr, stride, width, height));
- PROFILE_STOP
-}
-
-static MATRIX_ELEM cuda_matrix_(read)(lua_State *L, MATRIX_ELEM *data,
- int idx) {
- MATRIX_ELEM res;
- CUDA_SAFE_SYNC_CALL(cudaMemcpy(&res, data + idx,
- sizeof(MATRIX_ELEM), cudaMemcpyDeviceToHost));
- return res;
-}
-
-static void cuda_matrix_(write)(lua_State *L, MATRIX_ELEM *data,
- int idx, MATRIX_ELEM val) {
- CUDA_SAFE_SYNC_CALL(cudaMemcpy(data + idx, &val,
- sizeof(MATRIX_ELEM), cudaMemcpyHostToDevice));
-}
-
-int nerv_matrix_(get_elem)(lua_State *L) {
- return nerv_error_method_not_implemented(L);
-}
-
-int nerv_matrix_(set_elem)(lua_State *L) {
- return nerv_error_method_not_implemented(L);
-}
-
-#include "matrix.c"
-#endif
diff --git a/matrix/generic/elem_type.h b/matrix/generic/elem_type.h
deleted file mode 100644
index bffe940..0000000
--- a/matrix/generic/elem_type.h
+++ /dev/null
@@ -1,22 +0,0 @@
-#ifdef MATRIX_USE_FLOAT
-
-#define MATRIX_ELEM float
-#define MATRIX_ELEM_FMT "%f"
-#define MATRIX_ELEM_WRITE_FMT "%.8f"
-#define MATRIX_ELEM_PTR(self) ((self)->data.f)
-
-#elif defined(MATRIX_USE_DOUBLE)
-
-#define MATRIX_ELEM double
-#define MATRIX_ELEM_FMT "%lf"
-#define MATRIX_ELEM_WRITE_FMT "%.8lf"
-#define MATRIX_ELEM_PTR(self) ((self)->data.d)
-
-#elif defined(MATRIX_USE_INT)
-
-#define MATRIX_ELEM long
-#define MATRIX_ELEM_FMT "%ld"
-#define MATRIX_ELEM_WRITE_FMT "%ld"
-#define MATRIX_ELEM_PTR(self) ((self)->data.i)
-
-#endif
diff --git a/matrix/generic/matrix.c b/matrix/generic/matrix.c
deleted file mode 100644
index e17fb42..0000000
--- a/matrix/generic/matrix.c
+++ /dev/null
@@ -1,155 +0,0 @@
-#ifdef NERV_GENERIC_MATRIX
-#include "../../common.h"
-#include "matrix.h"
-
-extern const char *nerv_matrix_(tname);
-extern const char *MATRIX_BASE_TNAME;
-
-void nerv_matrix_(data_free)(lua_State *L, Matrix *self) {
- (void)L;
- assert(*self->data_ref > 0);
- if (--(*self->data_ref) == 0)
- {
- /* free matrix data */
- MATRIX_DATA_FREE(L, MATRIX_ELEM_PTR(self));
- free(self->data_ref);
- free(self);
- }
-}
-
-void nerv_matrix_(data_retain)(Matrix *self) {
- (*self->data_ref)++;
-}
-
-Matrix *nerv_matrix_(new_)(lua_State *L, long nrow, long ncol) {
- Matrix *self = (Matrix *)malloc(sizeof(Matrix));
- self->nrow = nrow;
- self->ncol = ncol;
- self->nmax = self->nrow * self->ncol;
- MATRIX_DATA_ALLOC(L, &MATRIX_ELEM_PTR(self), &self->stride,
- sizeof(MATRIX_ELEM) * self->ncol, self->nrow);
- self->data_ref = (long *)malloc(sizeof(long));
- *self->data_ref = 0;
- nerv_matrix_(data_retain)(self);
- return self;
-}
-
-int nerv_matrix_(new)(lua_State *L) {
- luaT_pushudata(L, nerv_matrix_(new_)(L, luaL_checkinteger(L, 1),
- luaL_checkinteger(L, 2)),
- nerv_matrix_(tname));
- return 1;
-}
-
-int nerv_matrix_(destroy)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- nerv_matrix_(data_free)(L, self);
- return 1;
-}
-
-int nerv_matrix_(get_elem)(lua_State *L);
-int nerv_matrix_(set_elem)(lua_State *L);
-
-static Matrix *nerv_matrix_(getrow)(Matrix *self, int row) {
- Matrix *prow = (Matrix *)malloc(sizeof(Matrix));
- prow->ncol = self->ncol;
- prow->nrow = 1;
- prow->stride = self->stride;
- prow->nmax = prow->ncol;
- MATRIX_ELEM_PTR(prow) = MATRIX_ROW_PTR(self, row);
- prow->data_ref = self->data_ref;
- nerv_matrix_(data_retain)(prow);
- return prow;
-}
-
-static int nerv_matrix_(newindex)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- if (lua_isnumber(L, 2))
- {
- int idx = luaL_checkinteger(L, 2);
- if (self->nrow == 1)
- {
- if (idx < 0 || idx >= self->ncol)
- nerv_error(L, "index must be within range [0, %d)", self->ncol);
- MATRIX_DATA_WRITE(L, MATRIX_ELEM_PTR(self), idx,
- luaL_checknumber(L, 3));
- }
- else
- nerv_error(L, "cannot assign to row vector");
- lua_pushboolean(L, 1);
- return 1;
- }
- else
- {
- lua_pushboolean(L, 0);
- return 1;
- }
-}
-
-
-static int nerv_matrix_(index)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- if (lua_isnumber(L, 2))
- {
- int idx = luaL_checkinteger(L, 2);
- if (self->nrow == 1)
- {
- if (idx < 0 || idx >= self->ncol)
- nerv_error(L, "index must be within range [0, %d)", self->ncol);
- lua_pushnumber(L, MATRIX_DATA_READ(L, MATRIX_ELEM_PTR(self), idx));
- }
- else
- {
- if (idx < 0 || idx >= self->nrow)
- nerv_error(L, "index must be within range [0, %d)", self->nrow);
- luaT_pushudata(L, nerv_matrix_(getrow)(self, idx), nerv_matrix_(tname));
- }
- lua_pushboolean(L, 1);
- return 2;
- }
- else
- {
- lua_pushboolean(L, 0);
- return 1;
- }
-}
-
-static int nerv_matrix_(ncol)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- lua_pushinteger(L, self->ncol);
- return 1;
-}
-
-static int nerv_matrix_(nrow)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- lua_pushinteger(L, self->nrow);
- return 1;
-}
-
-static int nerv_matrix_(get_dataref_value)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- lua_pushinteger(L, *(self->data_ref));
- return 1;
-}
-
-static const luaL_Reg nerv_matrix_(methods)[] = {
- {"get_elem", nerv_matrix_(get_elem)},
- {"set_elem", nerv_matrix_(set_elem)},
- {"ncol", nerv_matrix_(ncol)},
- {"nrow", nerv_matrix_(nrow)},
- {"get_dataref_value", nerv_matrix_(get_dataref_value)},
- {"__index__", nerv_matrix_(index)},
- {"__newindex__", nerv_matrix_(newindex)},
- {NULL, NULL}
-};
-
-void nerv_matrix_(init)(lua_State *L) {
- luaT_newmetatable(L, nerv_matrix_(tname), MATRIX_BASE_TNAME,
- nerv_matrix_(new), nerv_matrix_(destroy), NULL);
- luaL_register(L, NULL, nerv_matrix_(methods));
-#ifdef MATRIX_INIT
- MATRIX_INIT(L);
-#endif
- lua_pop(L, 1);
-}
-#endif
diff --git a/matrix/generic/matrix.h b/matrix/generic/matrix.h
deleted file mode 100644
index 833724b..0000000
--- a/matrix/generic/matrix.h
+++ /dev/null
@@ -1,19 +0,0 @@
-#ifndef NERV_GENERIC_MATRIX_H
-#define NERV_GENERIC_MATRIX_H
-
-#include <stddef.h>
-typedef struct Matrix {
- size_t stride; /* size of a row */
- long ncol, nrow, nmax; /* dimension of the matrix */
- union {
- float *f;
- double *d;
- long *i;
- } data; /* pointer to actual storage */
- long *data_ref;
-} Matrix;
-
-#define MATRIX_ROW_PTR(self, row) \
- (MATRIX_ELEM *)((char *)MATRIX_ELEM_PTR(self) + (row) * (self)->stride)
-
-#endif
diff --git a/matrix/generic/mmatrix.c b/matrix/generic/mmatrix.c
deleted file mode 100644
index b0f0791..0000000
--- a/matrix/generic/mmatrix.c
+++ /dev/null
@@ -1,122 +0,0 @@
-#ifdef NERV_GENERIC_MMATRIX
-#include "matrix.h"
-#include "elem_type.h"
-#define MATRIX_DATA_FREE(L, ptr) free(ptr)
-#define MATRIX_DATA_ALLOC(L, dptr, stride, width, height) \
- host_matrix_(alloc)(L, dptr, stride, width, height)
-#define MATRIX_DATA_WRITE(L, data, idx, val) (data[idx] = val)
-#define MATRIX_DATA_READ(L, data, idx) (data[idx])
-#define MATRIX_INIT(L) host_matrix_(init)(L)
-#define MATRIX_BASE_TNAME nerv_matrix_host_tname
-#define NERV_GENERIC_MATRIX
-#include "../../common.h"
-#include "../../io/chunk_file.h"
-#include "string.h"
-
-static void host_matrix_(alloc)(lua_State *L,
- MATRIX_ELEM **dptr, size_t *stride,
- long width, long height) {
- if ((*dptr = (MATRIX_ELEM *)malloc(width * height)) == NULL)
- nerv_error(L, "mmatrix insufficient memory");
- *stride = width;
-}
-
-int nerv_matrix_(get_elem)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- int idx = luaL_checkinteger(L, 2);
- if (idx < 0 || idx >= self->nmax)
- nerv_error(L, "index must be within range [0, %d)", self->nmax);
- lua_pushnumber(L, MATRIX_ELEM_PTR(self)[idx]);
- return 1;
-}
-
-int nerv_matrix_(set_elem)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- int idx = luaL_checkinteger(L, 2);
- MATRIX_ELEM v = luaL_checknumber(L, 3);
- if (idx < 0 || idx >= self->nmax)
- nerv_error(L, "index must be within range [0, %d)", self->nmax);
- MATRIX_ELEM_PTR(self)[idx] = v;
- return 0;
-}
-
-static const luaL_Reg nerv_matrix_(extra_methods)[];
-static void host_matrix_(init)(lua_State *L) {
- luaN_append_methods(L, nerv_matrix_(extra_methods));
-#ifdef MMATRIX_INIT
- MMATRIX_INIT(L);
-#endif
-}
-
-#include "matrix.c"
-
-int nerv_matrix_(load)(lua_State *L) {
- ChunkData *chunk = luaT_checkudata(L, 1, nerv_chunk_data_tname);
- Matrix *self;
- int i, j;
- long nrow, ncol;
- FILE *fp = chunk->fp;
- if (fscanf(fp, "%ld %ld", &nrow, &ncol) != 2)
- return 0;
- self = nerv_matrix_(new_)(L, nrow, ncol);
- for (i = 0; i < nrow; i++)
- {
- MATRIX_ELEM *row = MATRIX_ROW_PTR(self, i);
- for (j = 0; j < ncol; j++)
- if (fscanf(fp, MATRIX_ELEM_FMT, row + j) != 1)
- {
- free(self);
- return 0;
- }
- }
- luaT_pushudata(L, self, nerv_matrix_(tname));
- return 1;
-}
-
-int nerv_matrix_(save)(lua_State *L) {
- ChunkFileHandle *chunk = luaT_checkudata(L, 2,
- nerv_chunk_file_handle_tname);
- Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
- int i, j;
- long nrow = self->nrow, ncol = self->ncol;
- FILE *fp = chunk->fp;
- if (fprintf(fp, "%ld %ld\n", nrow, ncol) < 0)
- return 0;
- for (i = 0; i < nrow; i++)
- {
- MATRIX_ELEM *row = MATRIX_ROW_PTR(self, i);
- for (j = 0; j < ncol; j++)
- if (fprintf(fp, MATRIX_ELEM_WRITE_FMT " ", row[j]) < 0)
- return 0;
- if (fprintf(fp, "\n") < 0)
- return 0;
- }
- return 0;
-}
-
-static int nerv_matrix_(copy_from)(lua_State *L) {
- Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
- Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
- int nargs = lua_gettop(L);
- int b_begin = nargs > 2 ? luaL_checkinteger(L, 3) : 0;
- int b_end = nargs > 3 ? luaL_checkinteger(L, 4) : b->nrow;
- int a_begin = nargs > 4 ? luaL_checkinteger(L, 5) : 0;
- if (!(0 <= b_begin && b_begin < b_end && b_end <= b->nrow &&
- a_begin + b_end - b_begin <= a->nrow))
- nerv_error(L, "invalid copy interval");
- if (a->ncol != b->ncol)
- nerv_error(L, "matrices should be of the same dimension");
- memmove(MATRIX_ROW_PTR(a, a_begin),
- MATRIX_ROW_PTR(b, b_begin),
- sizeof(MATRIX_ELEM) * b->ncol * (b_end - b_begin));
- return 0;
-}
-
-static const luaL_Reg nerv_matrix_(extra_methods)[] = {
- {"load", nerv_matrix_(load)},
- {"save", nerv_matrix_(save)},
- {"copy_from", nerv_matrix_(copy_from)},
- {NULL, NULL}
-};
-
-#endif