diff options
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | examples/test_dnn_layers.lua | 74 | ||||
-rw-r--r-- | layer/sigmoid.lua | 3 | ||||
-rw-r--r-- | layer/softmax_ce.lua | 32 |
4 files changed, 110 insertions, 1 deletions
@@ -6,7 +6,7 @@ OBJS := nerv.o luaT.o common.o \ LIBS := libnerv.so LUA_LIBS := matrix/init.lua io/init.lua nerv.lua \ pl/utils.lua pl/compat.lua \ - layer/init.lua layer/affine.lua layer/sigmoid.lua + layer/init.lua layer/affine.lua layer/sigmoid.lua layer/softmax_ce.lua INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK CUDA_BASE := /usr/local/cuda-6.5 CUDA_INCLUDE := -I $(CUDA_BASE)/include/ diff --git a/examples/test_dnn_layers.lua b/examples/test_dnn_layers.lua new file mode 100644 index 0000000..c57de6d --- /dev/null +++ b/examples/test_dnn_layers.lua @@ -0,0 +1,74 @@ +require 'layer.affine' +require 'layer.sigmoid' +require 'layer.softmax_ce' + +global_conf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9} + +pf = nerv.ParamFile("affine.param", "r") +ltp = pf:read_param("a") +bp = pf:read_param("b") + +-- print(bp.trans) + +af = nerv.AffineLayer("test", global_conf, ltp, bp) +sg = nerv.SigmoidLayer("test2", global_conf) +sm = nerv.SoftmaxCELayer("test3", global_conf) + +af:init() +sg:init() +sm:init() + +df = nerv.ParamFile("input.param", "r") + +label = nerv.CuMatrixFloat(10, 2048) +label:fill(0) +for i = 0, 9 do + label[i][i] = 1.0 +end + +input1 = {[0] = df:read_param("input").trans} +output1 = {[0] = nerv.CuMatrixFloat(10, 2048)} +input2 = output1 +output2 = {[0] = nerv.CuMatrixFloat(10, 2048)} +input3 = {[0] = output2[0], [1] = label} +output3 = nil +err_input1 = nil +err_output1 = {[0] = nerv.CuMatrixFloat(10, 2048)} +err_input2 = err_output1 +err_output2 = {[0] = nerv.CuMatrixFloat(10, 2048)} +err_input3 = err_output2 +err_output3 = {[0] = input1[0]:create()} + +for i = 0, 3 do + -- propagate + af:propagate(input1, output1) + sg:propagate(input2, output2) + sm:propagate(input3, output3) + + + -- back_propagate + sm:back_propagate(err_output1, err_input1, input3, output3) + sm:update(err_input1, input3, output3) + + sg:back_propagate(err_output2, err_input2, input2, output2) + sg:update(err_input2, input2, output2) + + af:back_propagate(err_output3, err_input3, input1, output1) + af:update(err_input3, input1, output1) + + + print("output1") + print(output1[0]) + print("output2") + print(output2[0]) + print("err_output1") + print(err_output1[0]) + print("err_output2") + print(err_output2[0]) + nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) + nerv.utils.printf("frames: %.8f\n", sm.total_frames) +end +print("linear") +print(af.ltp.trans) +print("linear2") +print(af.bp.trans) diff --git a/layer/sigmoid.lua b/layer/sigmoid.lua index 41a6ef7..ca34419 100644 --- a/layer/sigmoid.lua +++ b/layer/sigmoid.lua @@ -5,6 +5,9 @@ function SigmoidLayer:__init(id, global_conf) self.gconf = global_conf end +function SigmoidLayer:init() +end + function SigmoidLayer:update(bp_err, input, output) -- no params, therefore do nothing end diff --git a/layer/softmax_ce.lua b/layer/softmax_ce.lua new file mode 100644 index 0000000..37d2864 --- /dev/null +++ b/layer/softmax_ce.lua @@ -0,0 +1,32 @@ +local SoftmaxCELayer = nerv.class("nerv.SoftmaxCELayer", "nerv.Layer") + +function SoftmaxCELayer:__init(id, global_conf) + self.id = id + self.gconf = global_conf +end + +function SoftmaxCELayer:init() + self.total_ce = 0.0 + self.total_frames = 0 +end + +function SoftmaxCELayer:update(bp_err, input, output) + -- no params, therefore do nothing +end + +function SoftmaxCELayer:propagate(input, output) + local soutput = input[0]:create() -- temporary value for calc softmax + self.soutput = soutput + soutput:softmax(input[0]) + local ce = soutput:create() + ce:log_elem(soutput) + ce:mul_elem(ce, input[1]) + -- add total ce + self.total_ce = self.total_ce - ce:rowsum():colsum()[0] + self.total_frames = self.total_frames + soutput:nrow() +end + +function SoftmaxCELayer:back_propagate(next_bp_err, bp_err, input, output) + -- softmax output - label + next_bp_err[0]:add(self.soutput, input[1], 1.0, -1.0) +end |