diff options
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r-- | nerv/examples/asr_trainer.lua | 71 |
1 files changed, 42 insertions, 29 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index 6bdf57c..645f1ef 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -22,9 +22,9 @@ local function build_trainer(ifname) local input_order = get_input_order() network = nerv.Network("nt", gconf, {network = network}) - network:init(gconf.batch_size, 1) + network:init(gconf.batch_size, gconf.chunk_size) global_transf = nerv.Network("gt", gconf, {network = global_transf}) - global_transf:init(gconf.batch_size, 1) + global_transf:init(gconf.batch_size, gconf.chunk_size) local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo) -- rebind the params if necessary @@ -39,11 +39,17 @@ local function build_trainer(ifname) local buffer = make_buffer(make_readers(scp_file, layer_repo)) -- initialize the network gconf.cnt = 0 - err_input = {mat_type(gconf.batch_size, 1)} - err_input[1]:fill(1) + local err_input = {{}} + local output = {{}} + for i = 1, gconf.chunk_size do + local mini_batch = mat_type(gconf.batch_size, 1) + mini_batch:fill(1) + table.insert(err_input[1], mini_batch) + table.insert(output[1], mat_type(gconf.batch_size, 1)) + end network:epoch_init() global_transf:epoch_init() - for data in buffer.get_data, buffer do + for d in buffer.get_data, buffer do -- prine stat periodically gconf.cnt = gconf.cnt + 1 if gconf.cnt == 1000 then @@ -54,35 +60,39 @@ local function build_trainer(ifname) -- break end local input = {} + local err_output = {} -- if gconf.cnt == 1000 then break end for i, e in ipairs(input_order) do local id = e.id - if data[id] == nil then + if d.data[id] == nil then nerv.error("input data %s not found", id) end - local transformed + local transformed = {} + local err_output_i = {} if e.global_transf then - transformed = nerv.speech_utils.global_transf(data[id], - global_transf, - gconf.frm_ext or 0, 0, - gconf) + for _, mini_batch in ipairs(d.data[id]) do + table.insert(transformed, + nerv.speech_utils.global_transf(mini_batch, + global_transf, + gconf.frm_ext or 0, 0, + gconf)) + end else - transformed = data[id] + transformed = d.data[id] + end + for _, mini_batch in ipairs(transformed) do + table.insert(err_output_i, mini_batch:create()) end + table.insert(err_output, err_output_i) table.insert(input, transformed) end - local output = {mat_type(gconf.batch_size, 1)} - err_output = {} - for i = 1, #input do - table.insert(err_output, input[i]:create()) - end - network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1), - new_seq = {}, + network:mini_batch_init({seq_length = d.seq_length, + new_seq = d.new_seq, do_train = bp, - input = {input}, - output = {output}, - err_input = {err_input}, - err_output = {err_output}}) + input = input, + output = output, + err_input = err_input, + err_output = err_output}) network:propagate() if bp then network:back_propagate() @@ -111,19 +121,21 @@ end local function check_and_add_defaults(spec, opts) local function get_opt_val(k) - return opts[string.gsub(k, '_', '-')].val + local k = string.gsub(k, '_', '-') + return opts[k].val, opts[k].specified end local opt_v = get_opt_val("resume_from") if opt_v then + nerv.info("resuming from previous training state") gconf = dofile(opt_v) else for k, v in pairs(spec) do - local opt_v = get_opt_val(k) - if opt_v ~= nil then + local opt_v, specified = get_opt_val(k) + if (not specified) and gconf[k] ~= nil then + nerv.info("using setting in network config file: %s = %s", k, gconf[k]) + elseif opt_v ~= nil then + nerv.info("using setting in options: %s = %s", k, opt_v) gconf[k] = opt_v - elseif gconf[k] ~= nil then - elseif v ~= nil then - gconf[k] = v end end end @@ -168,6 +180,7 @@ end local trainer_defaults = { lrate = 0.8, batch_size = 256, + chunk_size = 1, buffer_size = 81920, wcost = 1e-6, momentum = 0.9, |