aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix/generic
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/matrix/generic')
-rw-r--r--nerv/matrix/generic/cukernel.cu21
-rw-r--r--nerv/matrix/generic/cumatrix.c11
2 files changed, 32 insertions, 0 deletions
diff --git a/nerv/matrix/generic/cukernel.cu b/nerv/matrix/generic/cukernel.cu
index d6c8adc..2ae5e62 100644
--- a/nerv/matrix/generic/cukernel.cu
+++ b/nerv/matrix/generic/cukernel.cu
@@ -213,6 +213,17 @@ __global__ void cudak_(fill)(MATRIX_ELEM *a,
a[j + i * stride] = val;
}
+__global__ void cudak_(clip)(MATRIX_ELEM *a,
+ int nrow, int ncol, int stride, double val_1, double val_2) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i >= nrow || j >= ncol) return;
+ if (a[j + i * stride] > val_2)
+ a[j + i * stride] = val_2;
+ else if (a[j + i * stride] < val_1)
+ a[j + i * stride] = val_1;
+}
+
__global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
int nrow, int ncol,
int enrow, int encol,
@@ -510,6 +521,16 @@ extern "C" {
cudaStreamSynchronize(0);
}
+ void cudak_(cuda_clip)(Matrix *a, double val_1, double val_2) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
+ CEIL_DIV(a->nrow, threadsPerBlock.y));
+ cudak_(clip)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), a->nrow, a->ncol,
+ a->stride / sizeof(MATRIX_ELEM), val_1, val_2);
+ cudaStreamSynchronize(0);
+ }
+
void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context) {
dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c
index 311b503..4bdf5f0 100644
--- a/nerv/matrix/generic/cumatrix.c
+++ b/nerv/matrix/generic/cumatrix.c
@@ -149,6 +149,16 @@ static int nerv_matrix_(lua_fill)(lua_State *L) {
return 0;
}
+static int nerv_matrix_(lua_clip)(lua_State *L) {
+ Status status;
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ double val_1 = luaL_checknumber(L, 2);
+ double val_2 = luaL_checknumber(L, 3);
+ nerv_matrix_(clip)(self, val_1, val_2, &status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ return 0;
+}
+
static int nerv_matrix_(lua_copy_fromd)(lua_State *L) {
Status status;
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
@@ -294,6 +304,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"add", nerv_matrix_(lua_add)},
{"mul", nerv_matrix_(lua_mul)},
{"add_row", nerv_matrix_(lua_add_row)},
+ {"clip", nerv_matrix_(lua_clip)},
{"fill", nerv_matrix_(lua_fill)},
{"sigmoid", nerv_matrix_(lua_sigmoid)},
{"sigmoid_grad", nerv_matrix_(lua_sigmoid_grad)},