diff options
author | txh18 <[email protected]> | 2015-11-25 18:42:26 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-11-25 18:42:26 +0800 |
commit | ca3500f01ea7ce695a4dbf70d2be8244827097c9 (patch) | |
tree | 5fa85c778c16a40279cd2bb331f8511aae2b5dca /nerv/lib/matrix/generic/cukernel.cu | |
parent | 8e590ba284bfee414659f1845e175b41cac05d45 (diff) |
added tanh operation for matrix
Diffstat (limited to 'nerv/lib/matrix/generic/cukernel.cu')
-rw-r--r-- | nerv/lib/matrix/generic/cukernel.cu | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu index e58c488..c82041f 100644 --- a/nerv/lib/matrix/generic/cukernel.cu +++ b/nerv/lib/matrix/generic/cukernel.cu @@ -53,6 +53,28 @@ __global__ void cudak_(sigmoid_grad)(const MATRIX_ELEM *output, nerr[idx] = output[idx] * (1.0 - output[idx]) * err[idx]; } +__global__ void cudak_(tanh)(const MATRIX_ELEM *a, MATRIX_ELEM *b, + int nrow, int ncol, int stride) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + long idx; + if (i >= nrow || j >= ncol) return; + idx = j + i * stride; + b[idx] = (exp(a[idx]) - exp(-a[idx])) / (exp(a[idx]) + exp(-a[idx])); +} + +__global__ void cudak_(tanh_grad)(const MATRIX_ELEM *output, + const MATRIX_ELEM *err, + MATRIX_ELEM *nerr, + int nrow, int ncol, int stride) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + long idx; + if (i >= nrow || j >= ncol) return; + idx = j + i * stride; + nerr[idx] = (1.0 - output[idx] * output[idx]) * err[idx]; +} + __global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b, const MATRIX_ELEM *max, const MATRIX_ELEM *deno, int nrow, int ncol, int stride, int mstride) { @@ -353,6 +375,29 @@ extern "C" { cudaStreamSynchronize(0); } + void cudak_(cuda_tanh)(const Matrix *a, Matrix *b) { + dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); + dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x), + CEIL_DIV(b->nrow, threadsPerBlock.y)); + cudak_(tanh)<<<numBlocks, threadsPerBlock>>> \ + (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol, + b->stride / sizeof(MATRIX_ELEM)); + cudaStreamSynchronize(0); + } + + void cudak_(cuda_tanh_grad)(const Matrix *output, + const Matrix *err, Matrix *nerr) { + dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N); + dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x), + CEIL_DIV(nerr->nrow, threadsPerBlock.y)); + cudak_(tanh_grad)<<<numBlocks, threadsPerBlock>>> \ + (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err), + MATRIX_ELEM_PTR(nerr), + nerr->nrow, nerr->ncol, + nerr->stride / sizeof(MATRIX_ELEM)); + cudaStreamSynchronize(0); + } + void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b) { dim3 block(CUDA_THREADS_NN, 1); int ncol = a->ncol; |