aboutsummaryrefslogtreecommitdiff
path: root/matrix/generic
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-06-01 17:37:13 +0800
committerDeterminant <[email protected]>2015-06-01 17:37:13 +0800
commita309ce5e33b22030bcac348c63576187676abee3 (patch)
treeb43714a53b8f78a12a52cb0ee88c6ed7be786cac /matrix/generic
parentab12a9583bdd39884fde9bc2444e6fd1bc5f518e (diff)
add expand_frm, rearrange_frm, scale_row
Diffstat (limited to 'matrix/generic')
-rw-r--r--matrix/generic/cukernel.cu87
-rw-r--r--matrix/generic/cumatrix.c37
2 files changed, 109 insertions, 15 deletions
diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu
index 0e3d3cf..1d8b983 100644
--- a/matrix/generic/cukernel.cu
+++ b/matrix/generic/cukernel.cu
@@ -4,7 +4,7 @@
#include "matrix.h"
#include "cuda.h"
#define CUDA_THREADS_N 16
-#define CUDA_THREADS_NN (16 * 16)
+#define CUDA_THREADS_NN ((CUDA_THREADS_N) * (CUDA_THREADS_N))
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
__global__ void cudak_(log_elem)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
int nrow, int ncol, int stride) {
@@ -154,12 +154,43 @@ __global__ void cudak_(fill)(MATRIX_ELEM *a,
a[j + i * stride] = val;
}
+__global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol,
+ int enrow, int encol,
+ int stride, int estride,
+ int context) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ int ridx;
+ if (i >= enrow || j >= encol) return;
+ ridx = i + j / ncol - context;
+ if (ridx < 0) ridx = 0;
+ else if (ridx >= nrow) ridx = nrow - 1;
+ b[j + i * estride] = a[j % ncol + ridx * stride];
+}
+
+__global__ void cudak_(rearrange_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol,
+ int stride, int step, int orig_dim) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i >= nrow || j >= ncol) return;
+ b[j + i * stride] = a[j / step + (j % step) * orig_dim + i * stride];
+}
+
+__global__ void cudak_(scale_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol,
+ int stride) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i >= nrow || j >= ncol) return;
+ b[j + i * stride] *= a[j];
+}
extern "C" {
#include "../cukernel.h"
void cudak_(cuda_log_elem)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N,
- CUDA_THREADS_N);
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
CEIL_DIV(b->nrow, threadsPerBlock.y));
cudak_(log_elem)<<<numBlocks, threadsPerBlock>>> \
@@ -169,8 +200,7 @@ extern "C" {
void cudak_(cuda_mul_elem)(const Matrix *a, const Matrix *b,
Matrix *c) {
- dim3 threadsPerBlock(CUDA_THREADS_N,
- CUDA_THREADS_N);
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
CEIL_DIV(b->nrow, threadsPerBlock.y));
cudak_(mul_elem)<<<numBlocks, threadsPerBlock>>> \
@@ -180,8 +210,7 @@ extern "C" {
}
void cudak_(cuda_sigmoid)(const Matrix *a, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N,
- CUDA_THREADS_N);
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
CEIL_DIV(b->nrow, threadsPerBlock.y));
cudak_(sigmoid)<<<numBlocks, threadsPerBlock>>> \
@@ -191,8 +220,7 @@ extern "C" {
void cudak_(cuda_sigmoid_grad)(const Matrix *output,
const Matrix *err, Matrix *nerr) {
- dim3 threadsPerBlock(CUDA_THREADS_N,
- CUDA_THREADS_N);
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x),
CEIL_DIV(nerr->nrow, threadsPerBlock.y));
cudak_(sigmoid_grad)<<<numBlocks, threadsPerBlock>>> \
@@ -248,8 +276,7 @@ extern "C" {
void cudak_(cuda_softmax_final)(const Matrix *a, const Matrix *max,
const Matrix *deno, Matrix *b) {
- dim3 threadsPerBlock(CUDA_THREADS_N,
- CUDA_THREADS_N);
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
CEIL_DIV(b->nrow, threadsPerBlock.y));
cudak_(softmax_final)<<<numBlocks, threadsPerBlock>>> \
@@ -310,8 +337,7 @@ extern "C" {
/* in-place calc */
void cudak_(cuda_add_row)(const Matrix *a, Matrix *b, double beta) {
- dim3 threadsPerBlock(CUDA_THREADS_N,
- CUDA_THREADS_N);
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
CEIL_DIV(b->nrow, threadsPerBlock.y));
cudak_(add_row)<<<numBlocks, threadsPerBlock>>> \
@@ -320,13 +346,44 @@ extern "C" {
}
void cudak_(cuda_fill)(Matrix *a, double val) {
- dim3 threadsPerBlock(CUDA_THREADS_N,
- CUDA_THREADS_N);
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
CEIL_DIV(a->nrow, threadsPerBlock.y));
cudak_(fill)<<<numBlocks, threadsPerBlock>>> \
(MATRIX_ELEM_PTR(a), a->nrow, a->ncol,
a->stride / sizeof(MATRIX_ELEM), val);
}
+
+ void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ cudak_(expand_frm)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
+ a->nrow, a->ncol,
+ b->nrow, b->ncol,
+ a->stride / sizeof(MATRIX_ELEM),
+ b->stride / sizeof(MATRIX_ELEM),
+ context);
+ }
+
+ void cudak_(cuda_rearrange_frm)(const Matrix *a, Matrix *b, int step) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ cudak_(rearrange_frm)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
+ b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM),
+ step, b->ncol / step);
+ }
+
+ void cudak_(cuda_scale_row)(const Matrix *a, Matrix *b) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ cudak_(scale_row)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
+ b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
+ }
}
#endif
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index 3bc58d7..58f3679 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -282,6 +282,40 @@ static int nerv_matrix_(copy_rows_fromh_by_idx)(lua_State *L) {
return 0;
}
+static int nerv_matrix_(expand_frm)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ int context = luaL_checkinteger(L, 3);
+ if (a->nrow != b->nrow)
+ nerv_error(L, "mismatching number of frames");
+ if (a->ncol != b->ncol * (context * 2 + 1))
+ nerv_error(L, "the width should be 2 * context + 1");
+ cudak_(cuda_expand_frm)(b, a, context);
+ return 0;
+}
+
+static int nerv_matrix_(rearrange_frm)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ int step = luaL_checkinteger(L, 3);
+ CHECK_SAME_DIMENSION(a, b);
+ if (b->ncol % step)
+ nerv_error(L, "the dimension of columns is not divisible by step");
+ cudak_(cuda_rearrange_frm)(b, a, step);
+ return 0;
+}
+
+static int nerv_matrix_(scale_row)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ if (a->ncol != b->ncol)
+ nerv_error(L, "the number of columns is not the same");
+ if (b->nrow != 1)
+ nerv_error(L, "a row vector is expected");
+ cudak_(cuda_scale_row)(b, a);
+ return 0;
+}
+
static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"create", nerv_matrix_(create)},
{"colsum", nerv_matrix_(colsum)},
@@ -303,6 +337,9 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"mul_elem", nerv_matrix_(mul_elem)},
{"log_elem", nerv_matrix_(log_elem)},
{"copy_rows_fromh_by_idx", nerv_matrix_(copy_rows_fromh_by_idx)},
+ {"expand_frm", nerv_matrix_(expand_frm)},
+ {"rearrange_frm", nerv_matrix_(rearrange_frm)},
+ {"scale_row", nerv_matrix_(scale_row)},
{NULL, NULL}
};