aboutsummaryrefslogtreecommitdiff
path: root/examples/test_dnn_layers.lua
diff options
context:
space:
mode:
Diffstat (limited to 'examples/test_dnn_layers.lua')
-rw-r--r--examples/test_dnn_layers.lua34
1 files changed, 19 insertions, 15 deletions
diff --git a/examples/test_dnn_layers.lua b/examples/test_dnn_layers.lua
index 866e685..9be9d71 100644
--- a/examples/test_dnn_layers.lua
+++ b/examples/test_dnn_layers.lua
@@ -11,10 +11,14 @@ bp = pf:read_chunk("b", global_conf)
-- print(bp.trans)
-af = nerv.AffineLayer("test", global_conf, ltp, bp)
-sg = nerv.SigmoidLayer("test2", global_conf)
-sm = nerv.SoftmaxCELayer("test3", global_conf)
-
+af = nerv.AffineLayer("test", global_conf, {["ltp"] = ltp,
+ ["bp"] = bp,
+ dim_in = {429},
+ dim_out = {2048}})
+sg = nerv.SigmoidLayer("test2", global_conf, {dim_in = {2048},
+ dim_out = {2048}})
+sm = nerv.SoftmaxCELayer("test3", global_conf, {dim_in = {2048, 2048},
+ dim_out = {}})
af:init()
sg:init()
sm:init()
@@ -27,18 +31,18 @@ for i = 0, 9 do
label[i][i] = 1.0
end
-input1 = {[0] = df:read_chunk("input", global_conf).trans}
-output1 = {[0] = nerv.CuMatrixFloat(10, 2048)}
+input1 = {df:read_chunk("input", global_conf).trans}
+output1 = {nerv.CuMatrixFloat(10, 2048)}
input2 = output1
-output2 = {[0] = nerv.CuMatrixFloat(10, 2048)}
-input3 = {[0] = output2[0], [1] = label}
+output2 = {nerv.CuMatrixFloat(10, 2048)}
+input3 = {output2[1], label}
output3 = nil
err_input1 = nil
-err_output1 = {[0] = nerv.CuMatrixFloat(10, 2048)}
+err_output1 = {nerv.CuMatrixFloat(10, 2048)}
err_input2 = err_output1
-err_output2 = {[0] = nerv.CuMatrixFloat(10, 2048)}
+err_output2 = {nerv.CuMatrixFloat(10, 2048)}
err_input3 = err_output2
-err_output3 = {[0] = input1[0]:create()}
+err_output3 = {input1[1]:create()}
for i = 0, 3 do
-- propagate
@@ -59,13 +63,13 @@ for i = 0, 3 do
print("output1")
- print(output1[0])
+ print(output1[1])
print("output2")
- print(output2[0])
+ print(output2[1])
print("err_output1")
- print(err_output1[0])
+ print(err_output1[1])
print("err_output2")
- print(err_output2[0])
+ print(err_output2[1])
nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce)
nerv.utils.printf("frames: %.8f\n", sm.total_frames)
end