diff options
author | Ted Yin <ted.sybil@gmail.com> | 2015-07-10 00:24:27 -0400 |
---|---|---|
committer | Ted Yin <ted.sybil@gmail.com> | 2015-07-10 00:24:27 -0400 |
commit | b385d55268b7b327534e227065907a5ea2d2b731 (patch) | |
tree | c15e4e785a49d94eb24d4d61d237eced74560446 /nerv/matrix/generic/cukernel.cu | |
parent | 375f7c6e90d30d332178f0da18700991b2a44fff (diff) | |
parent | 1972c47c4b78e26a1e57f5001fe030c37d360a49 (diff) |
Merge pull request #37 from cloudygoose/master
add matrix:clip & affine_recurrent layer
Diffstat (limited to 'nerv/matrix/generic/cukernel.cu')
-rw-r--r-- | nerv/matrix/generic/cukernel.cu | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/nerv/matrix/generic/cukernel.cu b/nerv/matrix/generic/cukernel.cu index d6c8adc..2ae5e62 100644 --- a/nerv/matrix/generic/cukernel.cu +++ b/nerv/matrix/generic/cukernel.cu @@ -213,6 +213,17 @@ __global__ void cudak_(fill)(MATRIX_ELEM *a, a[j + i * stride] = val; } +__global__ void cudak_(clip)(MATRIX_ELEM *a, + int nrow, int ncol, int stride, double val_1, double val_2) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= nrow || j >= ncol) return; + if (a[j + i * stride] > val_2) + a[j + i * stride] = val_2; + else if (a[j + i * stride] < val_1) + a[j + i * stride] = val_1; +} + __global__ void cudak_(expand_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b, int nrow, int ncol, int enrow, int encol, @@ -510,6 +521,16 @@ extern "C" { cudaStreamSynchronize(0); } + void cudak_(cuda_clip)(Matrix *a, double val_1, double val_2) { + dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); + dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x), + CEIL_DIV(a->nrow, threadsPerBlock.y)); + cudak_(clip)<<<numBlocks, threadsPerBlock>>> \ + (MATRIX_ELEM_PTR(a), a->nrow, a->ncol, + a->stride / sizeof(MATRIX_ELEM), val_1, val_2); + cudaStreamSynchronize(0); + } + void cudak_(cuda_expand_frm)(const Matrix *a, Matrix *b, int context) { dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), |