diff options
Diffstat (limited to 'examples')
-rw-r--r-- | examples/test_dnn_layers.lua | 2 | ||||
-rw-r--r-- | examples/test_nn_lib.lua | 91 |
2 files changed, 71 insertions, 22 deletions
diff --git a/examples/test_dnn_layers.lua b/examples/test_dnn_layers.lua index 6e4d98d..f306807 100644 --- a/examples/test_dnn_layers.lua +++ b/examples/test_dnn_layers.lua @@ -3,7 +3,7 @@ require 'layer.sigmoid' require 'layer.softmax_ce' global_conf = {lrate = 0.8, wcost = 1e-6, - momentum = 0.9, mat_type = nerv.CuMatrixFloat} + momentum = 0.9, cumat_type = nerv.CuMatrixFloat} pf = nerv.ChunkFile("affine.param", "r") ltp = pf:read_chunk("a", global_conf) diff --git a/examples/test_nn_lib.lua b/examples/test_nn_lib.lua index ec338fe..9600917 100644 --- a/examples/test_nn_lib.lua +++ b/examples/test_nn_lib.lua @@ -1,14 +1,24 @@ --- require 'layer.affine' --- require 'layer.sigmoid' --- require 'layer.softmax_ce' - +require 'speech.init' gconf = {lrate = 0.8, wcost = 1e-6, momentum = 0.9, - mat_type = nerv.CuMatrixFloat, - batch_size = 10} + cumat_type = nerv.CuMatrixFloat, + mmat_type = nerv.MMatrixFloat, + batch_size = 256} -param_repo = nerv.ParamRepo({"converted.nerv"}) +param_repo = nerv.ParamRepo({"converted.nerv", "global_transf.nerv"}) sublayer_repo = nerv.LayerRepo( { + -- global transf + ["nerv.BiasLayer"] = + { + blayer1 = {{bias = "bias1"}, {dim_in = {429}, dim_out = {429}}}, + blayer2 = {{bias = "bias2"}, {dim_in = {429}, dim_out = {429}}} + }, + ["nerv.WindowLayer"] = + { + wlayer1 = {{window = "window1"}, {dim_in = {429}, dim_out = {429}}}, + wlayer2 = {{window = "window2"}, {dim_in = {429}, dim_out = {429}}} + }, + -- biased linearity ["nerv.AffineLayer"] = { affine0 = {{ltp = "affine0_ltp", bp = "affine0_bp"}, @@ -40,7 +50,7 @@ sublayer_repo = nerv.LayerRepo( }, ["nerv.SoftmaxCELayer"] = { - softmax_ce0 = {{}, {dim_in = {3001, 3001}, dim_out = {}}} + softmax_ce0 = {{}, {dim_in = {3001, 1}, dim_out = {}, compressed = true}} } }, param_repo, gconf) @@ -48,8 +58,19 @@ layer_repo = nerv.LayerRepo( { ["nerv.DAGLayer"] = { + global_transf = {{}, { + dim_in = {429}, dim_out = {429}, + sub_layers = sublayer_repo, + connections = { + ["<input>[1]"] = "blayer1[1]", + ["blayer1[1]"] = "wlayer1[1]", + ["wlayer1[1]"] = "blayer2[1]", + ["blayer2[1]"] = "wlayer2[1]", + ["wlayer2[1]"] = "<output>[1]" + } + }}, main = {{}, { - dim_in = {429, 3001}, dim_out = {}, + dim_in = {429, 1}, dim_out = {}, sub_layers = sublayer_repo, connections = { ["<input>[1]"] = "affine0[1]", @@ -74,24 +95,52 @@ layer_repo = nerv.LayerRepo( } }, param_repo, gconf) -df = nerv.ChunkFile("input.param", "r") -label = nerv.CuMatrixFloat(10, 3001) -label:fill(0) -for i = 0, 9 do - label[i][i] = 1.0 -end +tnet_reader = nerv.TNetReader(gconf, + { + id = "main_scp", + scp_file = "/slfs1/users/mfy43/swb_ivec/train_bp.scp", +-- scp_file = "t.scp", + conf_file = "/slfs1/users/mfy43/swb_ivec/plp_0_d_a.conf", + frm_ext = 5, + mlfs = { + ref = { + file = "/slfs1/users/mfy43/swb_ivec/ref.mlf", + format = "map", + format_arg = "/slfs1/users/mfy43/swb_ivec/dict", + dir = "*/", + ext = "lab" + } + }, + global_transf = layer_repo:get_layer("global_transf") + }) + +buffer = nerv.SGDBuffer(gconf, + { + buffer_size = 8192, + readers = { + { reader = tnet_reader, + data = {main_scp = 429, ref = 1}} + } + }) -input = {df:read_chunk("input", gconf).trans, label} -output = {} -err_input = {} -err_output = {input[1]:create()} sm = sublayer_repo:get_layer("softmax_ce0") main = layer_repo:get_layer("main") -main:init() -for i = 0, 3 do +main:init(gconf.batch_size) +cnt = 0 +for data in buffer.get_data, buffer do + if cnt == 1000 then break end + cnt = cnt + 1 + input = {data.main_scp, data.ref} + output = {} + err_input = {} + err_output = {input[1]:create()} + main:propagate(input, output) main:back_propagate(err_output, err_input, input, output) main:update(err_input, input, output) + nerv.utils.printf("cross entropy: %.8f\n", sm.total_ce) nerv.utils.printf("frames: %.8f\n", sm.total_frames) + nerv.utils.printf("err/frm: %.8f\n", sm.total_ce / sm.total_frames) + collectgarbage("collect") end |