aboutsummaryrefslogtreecommitdiff
path: root/matrix/generic/cukernel.cu
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2015-06-10 20:42:10 +0800
committerDeterminant <ted.sybil@gmail.com>2015-06-10 20:42:10 +0800
commitb818c2562d07a69083377cbc34f2add108e9fa66 (patch)
treea595ce4f269035951715334d2942d91d42ae236e /matrix/generic/cukernel.cu
parentc20af45d0756d5d3004105da10e51d42a382ad66 (diff)
add CombinerLayer to support branches in NN; add MSELayer
Diffstat (limited to 'matrix/generic/cukernel.cu')
-rw-r--r--matrix/generic/cukernel.cu31
1 files changed, 26 insertions, 5 deletions
diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu
index ffae5ed..d6c8adc 100644
--- a/matrix/generic/cukernel.cu
+++ b/matrix/generic/cukernel.cu
@@ -237,9 +237,18 @@ __global__ void cudak_(rearrange_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
b[j + i * stride] = a[j / step + (j % step) * orig_dim + i * stride];
}
-__global__ void cudak_(scale_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int stride) {
+__global__ void cudak_(scale_rows_by_col)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol,
+ int astride, int bstride) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i >= nrow || j >= ncol) return;
+ b[j + i * bstride] *= a[i * astride];
+}
+
+__global__ void cudak_(scale_rows_by_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol,
+ int stride) {
int j = blockIdx.x * blockDim.x + threadIdx.x;
int i = blockIdx.y * blockDim.y + threadIdx.y;
if (i >= nrow || j >= ncol) return;
@@ -526,11 +535,23 @@ extern "C" {
cudaStreamSynchronize(0);
}
- void cudak_(cuda_scale_row)(const Matrix *a, Matrix *b) {
+ void cudak_(cuda_scale_rows_by_col)(const Matrix *a, Matrix *b) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ cudak_(scale_rows_by_col)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
+ b->nrow, b->ncol,
+ a->stride / sizeof(MATRIX_ELEM),
+ b->stride / sizeof(MATRIX_ELEM));
+ cudaStreamSynchronize(0);
+ }
+
+ void cudak_(cuda_scale_rows_by_row)(const Matrix *a, Matrix *b) {
dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(scale_row)<<<numBlocks, threadsPerBlock>>> \
+ cudak_(scale_rows_by_row)<<<numBlocks, threadsPerBlock>>> \
(MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
cudaStreamSynchronize(0);