diff options
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r-- | nerv/examples/asr_trainer.lua | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index 5bf28bd..6bdf57c 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -20,6 +20,12 @@ local function build_trainer(ifname) local network = get_network(layer_repo) local global_transf = get_global_transf(layer_repo) local input_order = get_input_order() + + network = nerv.Network("nt", gconf, {network = network}) + network:init(gconf.batch_size, 1) + global_transf = nerv.Network("gt", gconf, {network = global_transf}) + global_transf:init(gconf.batch_size, 1) + local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo) -- rebind the params if necessary if rebind_param_repo then @@ -32,10 +38,11 @@ local function build_trainer(ifname) -- build buffer local buffer = make_buffer(make_readers(scp_file, layer_repo)) -- initialize the network - network:init(gconf.batch_size) gconf.cnt = 0 err_input = {mat_type(gconf.batch_size, 1)} err_input[1]:fill(1) + network:epoch_init() + global_transf:epoch_init() for data in buffer.get_data, buffer do -- prine stat periodically gconf.cnt = gconf.cnt + 1 @@ -69,10 +76,17 @@ local function build_trainer(ifname) for i = 1, #input do table.insert(err_output, input[i]:create()) end - network:propagate(input, output) + network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1), + new_seq = {}, + do_train = bp, + input = {input}, + output = {output}, + err_input = {err_input}, + err_output = {err_output}}) + network:propagate() if bp then - network:back_propagate(err_input, err_output, input, output) - network:update(err_input, input, output) + network:back_propagate() + network:update() end -- collect garbage in-time to save GPU memory collectgarbage("collect") |