aboutsummaryrefslogtreecommitdiff
path: root/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'matrix')
-rw-r--r--matrix/cukernel.cu119
-rw-r--r--matrix/cukernel.h5
-rw-r--r--matrix/cumatrix.c28
3 files changed, 142 insertions, 10 deletions
diff --git a/matrix/cukernel.cu b/matrix/cukernel.cu
index d6d7997..dd1ebfc 100644
--- a/matrix/cukernel.cu
+++ b/matrix/cukernel.cu
@@ -1,6 +1,6 @@
#include <assert.h>
-#include "generic/matrix.h"
#include <stdio.h>
+#include "generic/matrix.h"
#include "cuda.h"
#define CUDA_THREADS_N 16
#define CUDA_THREADS_NN (16 * 16)
@@ -15,7 +15,18 @@ __global__ void sigmoid(const float *a, float *b,
b[idx] = 1.0 / (1.0 + exp(-a[idx]));
}
-__global__ void block_sum(const float *input, float *output,
+__global__ void softmax_final(const float *a, float *b,
+ const float *max, const float *deno,
+ int nrow, int ncol, int stride, int mstride) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ long idx;
+ if (i >= nrow || j >= ncol) return;
+ idx = j + i * stride;
+ b[idx] = exp(a[idx] - max[0 + i * mstride]) / deno[0 + i * mstride];
+}
+
+__global__ void block_reduce_sum(const float *input, float *output,
const int istride, const int ostride,
const int n) {
extern __shared__ float arr[];
@@ -29,10 +40,47 @@ __global__ void block_sum(const float *input, float *output,
__syncthreads();
}
if (threadIdx.x == 0)
+ output[blockIdx.x + ostride * blockIdx.y] = arr[0];
+}
+
+__global__ void block_reduce_softmax_sum(const float *input, float *output,
+ const float *max,
+ const int istride, const int ostride,
+ const int mstride, const int n) {
+ extern __shared__ float arr[];
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ arr[threadIdx.x] = j < n ? exp(input[j + istride * blockIdx.y] - \
+ max[0 + mstride * blockIdx.y]) : 0;
+ __syncthreads();
+ for (int offset = blockDim.x >> 1; offset; offset >>= 1)
{
- /* printf("bx: %d by: %d arr: %f\n", blockIdx.x, blockIdx.y, arr[0]); */
+ if (threadIdx.x < offset)
+ arr[threadIdx.x] += arr[threadIdx.x + offset];
+ __syncthreads();
+ }
+ if (threadIdx.x == 0)
output[blockIdx.x + ostride * blockIdx.y] = arr[0];
+}
+
+__global__ void block_reduce_max(const float *input, float *output,
+ const int istride, const int ostride,
+ const int n) {
+ extern __shared__ float arr[];
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ arr[threadIdx.x] = j < n ? input[j + istride * blockIdx.y] : 0;
+ __syncthreads();
+ for (int offset = blockDim.x >> 1; offset; offset >>= 1)
+ {
+ if (threadIdx.x < offset)
+ {
+ float l = arr[threadIdx.x],
+ r = arr[threadIdx.x + offset];
+ if (r > l) arr[threadIdx.x] = r;
+ }
+ __syncthreads();
}
+ if (threadIdx.x == 0)
+ output[blockIdx.x + ostride * blockIdx.y] = arr[0];
}
extern "C" {
@@ -45,7 +93,66 @@ extern "C" {
b->stride / sizeof(float));
}
- void cuda_rowsum(const Matrix *a, Matrix *b) {
+ void cuda_colsum(const Matrix *a, Matrix *b) {
+ dim3 block(CUDA_THREADS_NN, 1);
+ int ncol = a->ncol;
+ int blocks_per_row = CEIL_DIV(ncol, block.x);
+ dim3 grid(blocks_per_row, a->nrow);
+ float *res;
+ size_t stride;
+ cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(float), a->nrow);
+ block_reduce_sum<<<grid, block, block.x * sizeof(float)>>> \
+ (a->data.f, res,
+ a->stride / sizeof(float), stride / sizeof(float),
+ ncol);
+ ncol = blocks_per_row;
+ assert(ncol <= block.x);
+ grid.x = 1;
+ block_reduce_sum<<<grid, block, block.x * sizeof(float)>>> \
+ (res, b->data.f,
+ stride / sizeof(float), b->stride / sizeof(float),
+ ncol);
+ cudaFree(res);
+ }
+
+ void cuda_softmax_final(const Matrix *a, const Matrix *max,
+ const Matrix *deno, Matrix *b) {
+ dim3 threadsPerBlock(CUDA_THREADS_N,
+ CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ softmax_final<<<numBlocks, threadsPerBlock>>>(a->data.f, b->data.f,
+ max->data.f, deno->data.f,
+ b->nrow, b->ncol,
+ b->stride / sizeof(float),
+ max->stride / sizeof(float));
+ }
+
+ void cuda_softmax_denominator(const Matrix *a, const Matrix *max, Matrix *b) {
+ dim3 block(CUDA_THREADS_NN, 1);
+ int ncol = a->ncol;
+ int blocks_per_row = CEIL_DIV(ncol, block.x);
+ dim3 grid(blocks_per_row, a->nrow);
+ float *res;
+ size_t stride;
+ assert(max->ncol == 1);
+ cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(float), a->nrow);
+ block_reduce_softmax_sum<<<grid, block, block.x * sizeof(float)>>> \
+ (a->data.f, res, max->data.f,
+ a->stride / sizeof(float), stride / sizeof(float),
+ max->stride / sizeof(float),
+ ncol);
+ ncol = blocks_per_row;
+ assert(ncol <= block.x);
+ grid.x = 1;
+ block_reduce_sum<<<grid, block, block.x * sizeof(float)>>> \
+ (res, b->data.f,
+ stride / sizeof(float), b->stride / sizeof(float),
+ ncol);
+ cudaFree(res);
+ }
+
+ void cuda_colmax(const Matrix *a, Matrix *b) {
dim3 block(CUDA_THREADS_NN, 1);
int ncol = a->ncol;
int blocks_per_row = CEIL_DIV(ncol, block.x);
@@ -53,14 +160,14 @@ extern "C" {
float *res;
size_t stride;
cudaMallocPitch(&res, &stride, blocks_per_row * sizeof(float), a->nrow);
- block_sum<<<grid, block, block.x * sizeof(float)>>> \
+ block_reduce_max<<<grid, block, block.x * sizeof(float)>>> \
(a->data.f, res,
a->stride / sizeof(float), stride / sizeof(float),
ncol);
ncol = blocks_per_row;
assert(ncol <= block.x);
grid.x = 1;
- block_sum<<<grid, block, block.x * sizeof(float)>>> \
+ block_reduce_max<<<grid, block, block.x * sizeof(float)>>> \
(res, b->data.f,
stride / sizeof(float), b->stride / sizeof(float),
ncol);
diff --git a/matrix/cukernel.h b/matrix/cukernel.h
index f86a69b..9c13558 100644
--- a/matrix/cukernel.h
+++ b/matrix/cukernel.h
@@ -1,5 +1,8 @@
#ifndef NERV_CUKERNEL_H
#define NERV_CUKERNEL_H
void cuda_sigmoid(const Matrix *a, Matrix *b);
-void cuda_rowsum(const Matrix *a, Matrix *b);
+void cuda_colsum(const Matrix *a, Matrix *b);
+void cuda_colmax(const Matrix *a, Matrix *b);
+void cuda_softmax_denominator(const Matrix *a, const Matrix *max, Matrix *b);
+void cuda_softmax_final(const Matrix *a, const Matrix *max, const Matrix *deno, Matrix *b);
#endif
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
index 49b7fbf..aa10571 100644
--- a/matrix/cumatrix.c
+++ b/matrix/cumatrix.c
@@ -66,10 +66,30 @@ static int nerv_float_matrix_(sigmoid)(lua_State *L) {
return 1;
}
-static int nerv_float_matrix_(rowsum)(lua_State *L) {
+static int nerv_float_matrix_(softmax)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+ Matrix *max = nerv_float_matrix_(new_)(a->nrow, 1);
+ Matrix *dno = nerv_float_matrix_(new_)(a->nrow, 1);
+ Matrix *b = nerv_float_matrix_(new_)(a->nrow, a->ncol);
+ cuda_colmax(a, max);
+ cuda_softmax_denominator(a, max, dno);
+ cuda_softmax_final(a, max, dno, b);
+ luaT_pushudata(L, b, nerv_float_matrix_(tname));
+ return 1;
+}
+
+static int nerv_float_matrix_(colsum)(lua_State *L) {
+ Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
+ Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1);
+ cuda_colsum(a, b);
+ luaT_pushudata(L, b, nerv_float_matrix_(tname));
+ return 1;
+}
+
+static int nerv_float_matrix_(colmax)(lua_State *L) {
Matrix *a = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
Matrix *b = nerv_float_matrix_(new_)(a->nrow, 1);
- cuda_rowsum(a, b);
+ cuda_colmax(a, b);
luaT_pushudata(L, b, nerv_float_matrix_(tname));
return 1;
}
@@ -78,7 +98,9 @@ static const luaL_Reg nerv_float_matrix_(extra_methods)[] = {
{"__add__", nerv_float_matrix_(add)},
{"__mul__", nerv_float_matrix_(mul)},
{"sigmoid", nerv_float_matrix_(sigmoid)},
- {"rowsum", nerv_float_matrix_(rowsum)},
+ {"softmax", nerv_float_matrix_(softmax)},
+ {"colsum", nerv_float_matrix_(colsum)},
+ {"colmax", nerv_float_matrix_(colmax)},
{NULL, NULL}
};