diff options
author | Determinant <ted.sybil@gmail.com> | 2015-08-26 14:26:54 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2015-08-26 14:26:54 +0800 |
commit | e81e9832ec4f2ad031fd42b5018cea134e8cda7e (patch) | |
tree | ed49289619399a99c80f47398ccc4de9ae4cedf6 /nerv/examples/asr_trainer.lua | |
parent | ed2a4148dbb9c18f428571b3e2970d7b2adfb058 (diff) |
move global_transf to asr_trainer.lua
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r-- | nerv/examples/asr_trainer.lua | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index dcadfa3..5a50542 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -3,6 +3,7 @@ function build_trainer(ifname) param_repo:import(ifname, nil, gconf) local layer_repo = make_layer_repo(param_repo) local network = get_network(layer_repo) + local global_transf = get_global_transf(layer_repo) local input_order = get_input_order() local iterative_trainer = function (prefix, scp_file, bp) gconf.randomize = bp @@ -24,15 +25,29 @@ function build_trainer(ifname) -- break end local input = {} --- if gconf.cnt == 100 then break end - for i, id in ipairs(input_order) do +-- if gconf.cnt == 1000 then break end + for i, e in ipairs(input_order) do + local id = e.id if data[id] == nil then nerv.error("input data %s not found", id) end - table.insert(input, data[id]) + local transformed + if e.global_transf then + transformed = nerv.speech_utils.global_transf(data[id], + global_transf, + gconf.frm_ext or 0, + gconf.frm_trim or 0, + gconf) + else + transformed = data[id] + end + table.insert(input, transformed) end local output = {nerv.CuMatrixFloat(gconf.batch_size, 1)} - err_output = {input[1]:create()} + err_output = {} + for i = 1, #input do + table.insert(err_output, input[i]:create()) + end network:propagate(input, output) if bp then network:back_propagate(err_input, err_output, input, output) |