aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/test_dnn_layers.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/test_dnn_layers.lua')
-rw-r--r--nerv/examples/test_dnn_layers.lua78
1 files changed, 78 insertions, 0 deletions
diff --git a/nerv/examples/test_dnn_layers.lua b/nerv/examples/test_dnn_layers.lua
new file mode 100644
index 0000000..64c0dec
--- /dev/null
+++ b/nerv/examples/test_dnn_layers.lua
@@ -0,0 +1,78 @@
+require 'layer.affine'
+require 'layer.sigmoid'
+require 'layer.softmax_ce'
+
+global_conf = {lrate = 0.8, wcost = 1e-6,
+ momentum = 0.9, cumat_type = nerv.CuMatrixFloat}
+
+pf = nerv.ChunkFile("affine.param", "r")
+ltp = pf:read_chunk("a", global_conf)
+bp = pf:read_chunk("b", global_conf)
+
+-- print(bp.trans)
+
+af = nerv.AffineLayer("test", global_conf, {["ltp"] = ltp,
+ ["bp"] = bp,
+ dim_in = {429},
+ dim_out = {2048}})
+sg = nerv.SigmoidLayer("test2", global_conf, {dim_in = {2048},
+ dim_out = {2048}})
+sm = nerv.SoftmaxCELayer("test3", global_conf, {dim_in = {2048, 2048},
+ dim_out = {}})
+af:init()
+sg:init()
+sm:init()
+
+df = nerv.ChunkFile("input.param", "r")
+
+label = nerv.CuMatrixFloat(10, 2048)
+label:fill(0)
+for i = 0, 9 do
+ label[i][i] = 1.0
+end
+
+input1 = {df:read_chunk("input", global_conf).trans}
+output1 = {nerv.CuMatrixFloat(10, 2048)}
+input2 = output1
+output2 = {nerv.CuMatrixFloat(10, 2048)}
+input3 = {output2[1], label}
+output3 = {}
+err_input1 = {}
+err_output1 = {nerv.CuMatrixFloat(10, 2048)}
+err_input2 = err_output1
+err_output2 = {nerv.CuMatrixFloat(10, 2048)}
+err_input3 = err_output2
+err_output3 = {input1[1]:create()}
+
+for i = 0, 3 do
+ -- propagate
+ af:propagate(input1, output1)
+ sg:propagate(input2, output2)
+ sm:propagate(input3, output3)
+
+ -- back_propagate
+ sm:back_propagate(err_output1, err_input1, input3, output3)
+ sg:back_propagate(err_output2, err_input2, input2, output2)
+ af:back_propagate(err_output3, err_input3, input1, output1)
+
+ -- update
+ sm:update(err_input1, input3, output3)
+ sg:update(err_input2, input2, output2)
+ af:update(err_input3, input1, output1)
+
+
+ print("output1")
+ print(output1[1])
+ print("output2")
+ print(output2[1])
+ print("err_output1")
+ print(err_output1[1])
+ print("err_output2")
+ print(err_output2[1])
+ nerv.printf("cross entropy: %.8f\n", sm.total_ce)
+ nerv.printf("frames: %.8f\n", sm.total_frames)
+end
+print("linear")
+print(af.ltp.trans)
+print("linear2")
+print(af.bp.trans)