diff options
Diffstat (limited to 'examples/test_dnn_layers.lua')
-rw-r--r-- | examples/test_dnn_layers.lua | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/examples/test_dnn_layers.lua b/examples/test_dnn_layers.lua index 866e685..9be9d71 100644 --- a/examples/test_dnn_layers.lua +++ b/examples/test_dnn_layers.lua @@ -11,10 +11,14 @@ bp = pf:read_chunk("b", global_conf) -- print(bp.trans) -af = nerv.AffineLayer("test", global_conf, ltp, bp) -sg = nerv.SigmoidLayer("test2", global_conf) -sm = nerv.SoftmaxCELayer("test3", global_conf) - +af = nerv.AffineLayer("test", global_conf, {["ltp"] = ltp, + ["bp"] = bp, + dim_in = {429}, + dim_out = {2048}}) +sg = nerv.SigmoidLayer("test2", global_conf, {dim_in = {2048}, + dim_out = {2048}}) +sm = nerv.SoftmaxCELayer("test3", global_conf, {dim_in = {2048, 2048}, + dim_out = {}}) af:init() sg:init() sm:init() @@ -27,18 +31,18 @@ for i = 0, 9 do label[i][i] = 1.0 end -input1 = {[0] = df:read_chunk("input", global_conf).trans} -output1 = {[0] = nerv.CuMatrixFloat(10, 2048)} +input1 = {df:read_chunk("input", global_conf).trans} +output1 = {nerv.CuMatrixFloat(10, 2048)} input2 = output1 -output2 = {[0] = nerv.CuMatrixFloat(10, 2048)} -input3 = {[0] = output2[0], [1] = label} +output2 = {nerv.CuMatrixFloat(10, 2048)} +input3 = {output2[1], label} output3 = nil err_input1 = nil -err_output1 = {[0] = nerv.CuMatrixFloat(10, 2048)} +err_output1 = {nerv.CuMatrixFloat(10, 2048)} err_input2 = err_output1 -err_output2 = {[0] = nerv.CuMatrixFloat(10, 2048)} +err_output2 = {nerv.CuMatrixFloat(10, 2048)} err_input3 = err_output2 -err_output3 = {[0] = input1[0]:create()} +err_output3 = {input1[1]:create()} for i = 0, 3 do -- propagate @@ -59,13 +63,13 @@ for i = 0, 3 do print("output1") - print(output1[0]) + print(output1[1]) print("output2") - print(output2[0]) + print(output2[1]) print("err_output1") - print(err_output1[0]) + print(err_output1[1]) print("err_output2") - print(err_output2[0]) + print(err_output2[1]) nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) nerv.utils.printf("frames: %.8f\n", sm.total_frames) end |