diff options
author | Determinant <[email protected]> | 2015-05-27 12:02:44 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-05-27 12:02:44 +0800 |
commit | 44752a8175c54effd0c901333a7eb4cc71dd6810 (patch) | |
tree | 9794cfff9382e510298c729bb70b2cdd1ece23f7 /matrix/generic/cukernel.cu | |
parent | f8543464c13584e39bfacee694ee1ed80ac121f4 (diff) |
add implementation for sigmoid layer (not tested)
Diffstat (limited to 'matrix/generic/cukernel.cu')
-rw-r--r-- | matrix/generic/cukernel.cu | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu index 8b929e4..517393e 100644 --- a/matrix/generic/cukernel.cu +++ b/matrix/generic/cukernel.cu @@ -16,6 +16,18 @@ __global__ void cudak_(sigmoid)(const MATRIX_ELEM *a, MATRIX_ELEM *b, b[idx] = 1.0 / (1.0 + exp(-a[idx])); } +__global__ void cudak_(sigmoid_grad)(const MATRIX_ELEM *output, + const MATRIX_ELEM *err, + MATRIX_ELEM *nerr, + int nrow, int ncol, int stride) { + int j = blockIdx.x * blockDim.x + threadIdx.x; + int i = blockIdx.y * blockDim.y + threadIdx.y; + long idx; + if (i >= nrow || j >= ncol) return; + idx = j + i * stride; + nerr[idx] = output[idx] * (1.0 - output[idx]) * err[idx]; +} + __global__ void cudak_(softmax_final)(const MATRIX_ELEM *a, MATRIX_ELEM *b, const MATRIX_ELEM *max, const MATRIX_ELEM *deno, int nrow, int ncol, int stride, int mstride) { @@ -134,6 +146,19 @@ extern "C" { b->stride / sizeof(MATRIX_ELEM)); } + void cudak_(cuda_sigmoid_grad)(const Matrix *output, + const Matrix *err, Matrix *nerr) { + dim3 threadsPerBlock(CUDA_THREADS_N, + CUDA_THREADS_N); + dim3 numBlocks(CEIL_DIV(nerr->ncol, threadsPerBlock.x), + CEIL_DIV(nerr->nrow, threadsPerBlock.y)); + cudak_(sigmoid_grad)<<<numBlocks, threadsPerBlock>>> \ + (MATRIX_ELEM_PTR(output), MATRIX_ELEM_PTR(err), + MATRIX_ELEM_PTR(nerr), + nerr->nrow, nerr->ncol, + nerr->stride / sizeof(MATRIX_ELEM)); + } + void cudak_(cuda_rowsum)(const Matrix *a, Matrix *b) { dim3 block(CUDA_THREADS_NN, 1); int ncol = a->ncol; |