aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQi Liu <[email protected]>2016-06-21 15:49:07 +0800
committerQi Liu <[email protected]>2016-06-21 15:49:07 +0800
commit3622d8315aad9f8438b1cfcb734165de459725a9 (patch)
treeed20795f95854c72137bd7537a4097582e997fb3
parentbc49910f6f55620a4fb4e7038e751bab52fdafa6 (diff)
parent3856e63dab1b28aaec4133b6b0ec2a44ebf8cf46 (diff)
Merge branch 'master' into 'master' HEADmaster
Master add back propagation function for softmax See merge request !6
-rw-r--r--nerv/layer/softmax.lua7
-rw-r--r--nerv/lib/matrix/generic/cukernel.cu19
-rw-r--r--nerv/lib/matrix/generic/cumatrix.c12
-rw-r--r--nerv/lib/matrix/generic/cumatrix.h2
-rw-r--r--nerv/lib/matrix/generic/mmatrix.c20
-rw-r--r--nerv/lib/matrix/generic/mmatrix.h2
-rw-r--r--nerv/matrix/generic/cumatrix.c1
-rw-r--r--nerv/matrix/generic/matrix.c12
-rw-r--r--nerv/matrix/generic/mmatrix.c1
9 files changed, 75 insertions, 1 deletions
diff --git a/nerv/layer/softmax.lua b/nerv/layer/softmax.lua
index f7a5163..6789ccc 100644
--- a/nerv/layer/softmax.lua
+++ b/nerv/layer/softmax.lua
@@ -28,7 +28,12 @@ function SoftmaxLayer:propagate(input, output)
end
function SoftmaxLayer:back_propagate(bp_err, next_bp_err, input, output)
- nerv.error_method_not_implemented()
+ local nbe = next_bp_err[1]
+ nbe:mul_elem(bp_err[1], output[1])
+ local offset = nbe:rowsum()
+ nbe:copy_from(bp_err[1])
+ nbe:add_col(offset, -1.0)
+ nbe:mul_elem(nbe, output[1])
end
function SoftmaxLayer:get_params()
diff --git a/nerv/lib/matrix/generic/cukernel.cu b/nerv/lib/matrix/generic/cukernel.cu
index 82bea14..4447f7c 100644
--- a/nerv/lib/matrix/generic/cukernel.cu
+++ b/nerv/lib/matrix/generic/cukernel.cu
@@ -263,6 +263,14 @@ __global__ void cudak_(add_row)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
b[j + i * stride] += beta * a[j];
}
+__global__ void cudak_(add_col)(const MATRIX_ELEM *a, MATRIX_ELEM *b,
+ int nrow, int ncol, int astride, int bstride, double beta) {
+ int j = blockIdx.x * blockDim.x + threadIdx.x;
+ int i = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i >= nrow || j >= ncol) return;
+ b[j + i * bstride] += beta * a[i * astride];
+}
+
__global__ void cudak_(fill)(MATRIX_ELEM *a,
int nrow, int ncol, int stride, double val) {
int j = blockIdx.x * blockDim.x + threadIdx.x;
@@ -735,6 +743,17 @@ extern "C" {
cudaStreamSynchronize(0);
}
+ void cudak_(cuda_add_col)(const Matrix *a, Matrix *b, double beta) {
+ dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
+ dim3 numBlocks(CEIL_DIV(b->ncol, threadsPerBlock.x),
+ CEIL_DIV(b->nrow, threadsPerBlock.y));
+ cudak_(add_col)<<<numBlocks, threadsPerBlock>>> \
+ (MATRIX_ELEM_PTR(a), MATRIX_ELEM_PTR(b), b->nrow, b->ncol,
+ a->stride / sizeof(MATRIX_ELEM),
+ b->stride / sizeof(MATRIX_ELEM), beta);
+ cudaStreamSynchronize(0);
+ }
+
void cudak_(cuda_fill)(Matrix *a, double val) {
dim3 threadsPerBlock(CUDA_THREADS_N, CUDA_THREADS_N);
dim3 numBlocks(CEIL_DIV(a->ncol, threadsPerBlock.x),
diff --git a/nerv/lib/matrix/generic/cumatrix.c b/nerv/lib/matrix/generic/cumatrix.c
index 432222a..52cfe50 100644
--- a/nerv/lib/matrix/generic/cumatrix.c
+++ b/nerv/lib/matrix/generic/cumatrix.c
@@ -244,6 +244,18 @@ void nerv_matrix_(add_row)(Matrix *b, const Matrix *a, double beta,
NERV_SET_STATUS(status, NERV_NORMAL, 0);
}
+void nerv_matrix_(add_col)(Matrix *b, const Matrix *a, double beta,
+ CuContext *context, Status *status) {
+ if (a->nrow != b->nrow)
+ NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0);
+ if (a->ncol != 1)
+ NERV_EXIT_STATUS(status, MAT_COL_VECTOR_EXP, 0);
+ PROFILE_START
+ cudak_(cuda_add_col)(a, b, beta);
+ PROFILE_STOP
+ NERV_SET_STATUS(status, NERV_NORMAL, 0);
+}
+
void nerv_matrix_(fill)(Matrix *self, double val,
CuContext *context, Status *status) {
PROFILE_START
diff --git a/nerv/lib/matrix/generic/cumatrix.h b/nerv/lib/matrix/generic/cumatrix.h
index 459513b..9addae3 100644
--- a/nerv/lib/matrix/generic/cumatrix.h
+++ b/nerv/lib/matrix/generic/cumatrix.h
@@ -34,6 +34,8 @@ void nerv_matrix_(rowmax_idx)(Matrix *a, Matrix **b, Matrix **idx,
CuContext *context, Status *status);
void nerv_matrix_(add_row)(Matrix *b, const Matrix *a, double beta,
CuContext *context, Status *status);
+void nerv_matrix_(add_col)(Matrix *b, const Matrix *a, double beta,
+ CuContext *context, Status *status);
void nerv_matrix_(clip)(Matrix *self, double val1, double val2,
CuContext *context, Status *status);
void nerv_matrix_(fill)(Matrix *self, double val,
diff --git a/nerv/lib/matrix/generic/mmatrix.c b/nerv/lib/matrix/generic/mmatrix.c
index e76d4fb..54360d5 100644
--- a/nerv/lib/matrix/generic/mmatrix.c
+++ b/nerv/lib/matrix/generic/mmatrix.c
@@ -243,6 +243,26 @@ void nerv_matrix_(add_row)(Matrix *b, const Matrix *a, double beta,
NERV_SET_STATUS(status, NERV_NORMAL, 0);
}
+void nerv_matrix_(add_col)(Matrix *b, const Matrix *a, double beta,
+ MContext *context, Status *status) {
+ if (a->nrow != b->nrow)
+ NERV_EXIT_STATUS(status, MAT_MISMATCH_DIM, 0);
+ if (a->ncol != 1)
+ NERV_EXIT_STATUS(status, MAT_COL_VECTOR_EXP, 0);
+ MATRIX_ELEM *arow = MATRIX_ELEM_PTR(a);
+ MATRIX_ELEM *brow = MATRIX_ELEM_PTR(b);
+ int i, j;
+ size_t astride = a->stride, bstride = b->stride;
+ for (i = 0; i < b->nrow; i++)
+ {
+ for (j = 0; j < b->ncol; j++)
+ brow[j] += arow[0] * beta;
+ arow = MATRIX_NEXT_ROW_PTR(arow, astride);
+ brow = MATRIX_NEXT_ROW_PTR(brow, bstride);
+ }
+ NERV_SET_STATUS(status, NERV_NORMAL, 0);
+}
+
void nerv_matrix_(clip)(Matrix *self, double val1, double val2,
MContext *context, Status *status) {
int i, j;
diff --git a/nerv/lib/matrix/generic/mmatrix.h b/nerv/lib/matrix/generic/mmatrix.h
index 7f494d6..2a3fea8 100644
--- a/nerv/lib/matrix/generic/mmatrix.h
+++ b/nerv/lib/matrix/generic/mmatrix.h
@@ -35,6 +35,8 @@ void nerv_matrix_(rowmax_idx)(Matrix *a, Matrix **b, Matrix **idx,
MContext *context, Status *status);
void nerv_matrix_(add_row)(Matrix *b, const Matrix *a, double beta,
MContext *context, Status *status);
+void nerv_matrix_(add_col)(Matrix *b, const Matrix *a, double beta,
+ MContext *context, Status *status);
void nerv_matrix_(clip)(Matrix *self, double val1, double val2,
MContext *context, Status *status);
void nerv_matrix_(diagonalize)(Matrix *self, MContext *context, Status *status);
diff --git a/nerv/matrix/generic/cumatrix.c b/nerv/matrix/generic/cumatrix.c
index 3481ede..ccd07ad 100644
--- a/nerv/matrix/generic/cumatrix.c
+++ b/nerv/matrix/generic/cumatrix.c
@@ -224,6 +224,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"add", nerv_matrix_(lua_add)},
{"mul", nerv_matrix_(lua_mul)},
{"add_row", nerv_matrix_(lua_add_row)},
+ {"add_col", nerv_matrix_(lua_add_col)},
{"clip", nerv_matrix_(lua_clip)},
{"fill", nerv_matrix_(lua_fill)},
{"sigmoid", nerv_matrix_(lua_sigmoid)},
diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c
index 9f31b4b..c679731 100644
--- a/nerv/matrix/generic/matrix.c
+++ b/nerv/matrix/generic/matrix.c
@@ -270,6 +270,18 @@ static int nerv_matrix_(lua_add_row)(lua_State *L) {
return 0;
}
+static int nerv_matrix_(lua_add_col)(lua_State *L) {
+ Status status;
+ MATRIX_CONTEXT *context;
+ MATRIX_GET_CONTEXT(L, 4);
+ const Matrix *a = luaT_checkudata(L, 2, nerv_matrix_(tname));
+ Matrix *b = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ double beta = luaL_checknumber(L, 3);
+ nerv_matrix_(add_col)(b, a, beta, context, &status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ return 0;
+}
+
static int nerv_matrix_(lua_fill)(lua_State *L) {
Status status;
MATRIX_CONTEXT *context;
diff --git a/nerv/matrix/generic/mmatrix.c b/nerv/matrix/generic/mmatrix.c
index 530888b..ca8d73f 100644
--- a/nerv/matrix/generic/mmatrix.c
+++ b/nerv/matrix/generic/mmatrix.c
@@ -114,6 +114,7 @@ static const luaL_Reg nerv_matrix_(extra_methods)[] = {
{"add", nerv_matrix_(lua_add)},
{"mul", nerv_matrix_(lua_mul)},
{"add_row", nerv_matrix_(lua_add_row)},
+ {"add_col", nerv_matrix_(lua_add_col)},
{"clip", nerv_matrix_(lua_clip)},
{"fill", nerv_matrix_(lua_fill)},
{"diagonalize", nerv_matrix_(lua_diagonalize)},