aboutsummaryrefslogtreecommitdiff
path: root/matrix/generic
diff options
context:
space:
mode:
Diffstat (limited to 'matrix/generic')
-rw-r--r--matrix/generic/cukernel.cu31
-rw-r--r--matrix/generic/cumatrix.c24
2 files changed, 45 insertions, 10 deletions
diff --git a/matrix/generic/cukernel.cu b/matrix/generic/cukernel.cu
index ffae5ed..d6c8adc 100644
--- a/matrix/generic/cukernel.cu
+++ b/matrix/generic/cukernel.cu
@@ -237,9 +237,18 @@ __global__ void cudak_(rearrange_frm)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
b[j + i * stride] = a[j / step + (j % step) * orig_dim + i * stride];
}
-__global__ void cudak_(scale_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
- int nrow, int ncol,
- int stride) {
+__global__ void cudak_(scale_rows_by_col)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol,
+ int astride, int bstride) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i >= nrow || j >= ncol) return;
+ b[j + i * bstride] *= a[i * astride];
+}
+
+__global__ void cudak_(scale_rows_by_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol,
+ int stride) {
int j = blockIdx.x * blockDim.x + threadIdx.x;
int i = blockIdx.y * blockDim.y + threadIdx.y;
if (i >= nrow || j >= ncol) return;
@@ -526,11 +535,23 @@ extern "C" {
cudaStreamSynchronize(0);
}
- void cudak_(cuda_scale_row)(const Matrix *a, Matrix *b) {
+ void cudak_(cuda_scale_rows_by_col)(const Matrix *a, Matrix *b) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ cudak_(scale_rows_by_col)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
+ b->nrow, b->ncol,
+ a->stride / sizeof(MATRIX_ELEM),
+ b->stride / sizeof(MATRIX_ELEM));
+ cudaStreamSynchronize(0);
+ }
+
+ void cudak_(cuda_scale_rows_by_row)(const Matrix *a, Matrix *b) {
dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
CEIL_DIV(b->nrow, threadsPerBlock.y));
- cudak_(scale_row)<<<numBlocks, threadsPerBlock>>> \
+ cudak_(scale_rows_by_row)<<<numBlocks, threadsPerBlock>>> \
(MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b),
b->nrow, b->ncol, b->stride / sizeof(MATRIX_ELEM));
cudaStreamSynchronize(0);
diff --git a/matrix/generic/cumatrix.c b/matrix/generic/cumatrix.c
index a8e18e0..b5d1a35 100644
--- a/matrix/generic/cumatrix.c
+++ b/matrix/generic/cumatrix.c
@@ -37,8 +37,8 @@ static int nerv_matrix_(add)(lua_State *L) {
Matrix *c = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 3, nerv_matrix_(tname));
- MATRIX_ELEM alpha = luaL_checknumber(L, 4); /* alpha */
- MATRIX_ELEM beta = luaL_checknumber(L, 5); /* alpha */
+ MATRIX_ELEM alpha = luaL_checknumber(L, 4);
+ MATRIX_ELEM beta = luaL_checknumber(L, 5);
CHECK_SAME_DIMENSION(a, b);
CHECK_SAME_DIMENSION(a, c);
nerv_matrix_(add_)(L, a, b, c, alpha, beta);
@@ -396,7 +396,20 @@ static int nerv_matrix_(rearrange_frm)(lua_State *L) {
return 0;
}
-static int nerv_matrix_(scale_row)(lua_State *L) {
+static int nerv_matrix_(scale_rows_by_col)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ if (a->nrow != b->nrow)
+ nerv_error(L, "the number of rows is not the same");
+ if (b->ncol != 1)
+ nerv_error(L, "a column vector is expected");
+ PROFILE_START
+ cudak_(cuda_scale_rows_by_col)(b, a);
+ PROFILE_STOP
+ return 0;
+}
+
+static int nerv_matrix_(scale_rows_by_row)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_matrix_(tname));
Matrix *b = luaT_checkudata(L, 2, nerv_matrix_(tname));
if (a->ncol != b->ncol)
@@ -404,7 +417,7 @@ static int nerv_matrix_(scale_row)(lua_State *L) {
if (b->nrow != 1)
nerv_error(L, "a row vector is expected");
PROFILE_START
- cudak_(cuda_scale_row)(b, a);
+ cudak_(cuda_scale_rows_by_row)(b, a);
PROFILE_STOP
return 0;
}
@@ -434,7 +447,8 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"copy_rows_fromh_by_idx", nerv_matrix_(copy_rows_fromh_by_idx)},
{"expand_frm", nerv_matrix_(expand_frm)},
{"rearrange_frm", nerv_matrix_(rearrange_frm)},
- {"scale_row", nerv_matrix_(scale_row)},
+ {"scale_rows_by_row", nerv_matrix_(scale_rows_by_row)},
+ {"scale_rows_by_col", nerv_matrix_(scale_rows_by_col)},
{NULL, NULL}
};