aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua22
1 files changed, 18 insertions, 4 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 5bf28bd..6bdf57c 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -20,6 +20,12 @@ local function build_trainer(ifname)
local network = get_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
local input_order = get_input_order()
+
+ network = nerv.Network("nt", gconf, {network = network})
+ network:init(gconf.batch_size, 1)
+ global_transf = nerv.Network("gt", gconf, {network = global_transf})
+ global_transf:init(gconf.batch_size, 1)
+
local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo)
-- rebind the params if necessary
if rebind_param_repo then
@@ -32,10 +38,11 @@ local function build_trainer(ifname)
-- build buffer
local buffer = make_buffer(make_readers(scp_file, layer_repo))
-- initialize the network
- network:init(gconf.batch_size)
gconf.cnt = 0
err_input = {mat_type(gconf.batch_size, 1)}
err_input[1]:fill(1)
+ network:epoch_init()
+ global_transf:epoch_init()
for data in buffer.get_data, buffer do
-- prine stat periodically
gconf.cnt = gconf.cnt + 1
@@ -69,10 +76,17 @@ local function build_trainer(ifname)
for i = 1, #input do
table.insert(err_output, input[i]:create())
end
- network:propagate(input, output)
+ network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
+ new_seq = {},
+ do_train = bp,
+ input = {input},
+ output = {output},
+ err_input = {err_input},
+ err_output = {err_output}})
+ network:propagate()
if bp then
- network:back_propagate(err_input, err_output, input, output)
- network:update(err_input, input, output)
+ network:back_propagate()
+ network:update()
end
-- collect garbage in-time to save GPU memory
collectgarbage("collect")