aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua71
1 files changed, 42 insertions, 29 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 6bdf57c..645f1ef 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -22,9 +22,9 @@ local function build_trainer(ifname)
local input_order = get_input_order()
network = nerv.Network("nt", gconf, {network = network})
- network:init(gconf.batch_size, 1)
+ network:init(gconf.batch_size, gconf.chunk_size)
global_transf = nerv.Network("gt", gconf, {network = global_transf})
- global_transf:init(gconf.batch_size, 1)
+ global_transf:init(gconf.batch_size, gconf.chunk_size)
local iterative_trainer = function (prefix, scp_file, bp, rebind_param_repo)
-- rebind the params if necessary
@@ -39,11 +39,17 @@ local function build_trainer(ifname)
local buffer = make_buffer(make_readers(scp_file, layer_repo))
-- initialize the network
gconf.cnt = 0
- err_input = {mat_type(gconf.batch_size, 1)}
- err_input[1]:fill(1)
+ local err_input = {{}}
+ local output = {{}}
+ for i = 1, gconf.chunk_size do
+ local mini_batch = mat_type(gconf.batch_size, 1)
+ mini_batch:fill(1)
+ table.insert(err_input[1], mini_batch)
+ table.insert(output[1], mat_type(gconf.batch_size, 1))
+ end
network:epoch_init()
global_transf:epoch_init()
- for data in buffer.get_data, buffer do
+ for d in buffer.get_data, buffer do
-- prine stat periodically
gconf.cnt = gconf.cnt + 1
if gconf.cnt == 1000 then
@@ -54,35 +60,39 @@ local function build_trainer(ifname)
-- break
end
local input = {}
+ local err_output = {}
-- if gconf.cnt == 1000 then break end
for i, e in ipairs(input_order) do
local id = e.id
- if data[id] == nil then
+ if d.data[id] == nil then
nerv.error("input data %s not found", id)
end
- local transformed
+ local transformed = {}
+ local err_output_i = {}
if e.global_transf then
- transformed = nerv.speech_utils.global_transf(data[id],
- global_transf,
- gconf.frm_ext or 0, 0,
- gconf)
+ for _, mini_batch in ipairs(d.data[id]) do
+ table.insert(transformed,
+ nerv.speech_utils.global_transf(mini_batch,
+ global_transf,
+ gconf.frm_ext or 0, 0,
+ gconf))
+ end
else
- transformed = data[id]
+ transformed = d.data[id]
+ end
+ for _, mini_batch in ipairs(transformed) do
+ table.insert(err_output_i, mini_batch:create())
end
+ table.insert(err_output, err_output_i)
table.insert(input, transformed)
end
- local output = {mat_type(gconf.batch_size, 1)}
- err_output = {}
- for i = 1, #input do
- table.insert(err_output, input[i]:create())
- end
- network:mini_batch_init({seq_length = table.vector(gconf.batch_size, 1),
- new_seq = {},
+ network:mini_batch_init({seq_length = d.seq_length,
+ new_seq = d.new_seq,
do_train = bp,
- input = {input},
- output = {output},
- err_input = {err_input},
- err_output = {err_output}})
+ input = input,
+ output = output,
+ err_input = err_input,
+ err_output = err_output})
network:propagate()
if bp then
network:back_propagate()
@@ -111,19 +121,21 @@ end
local function check_and_add_defaults(spec, opts)
local function get_opt_val(k)
- return opts[string.gsub(k, '_', '-')].val
+ local k = string.gsub(k, '_', '-')
+ return opts[k].val, opts[k].specified
end
local opt_v = get_opt_val("resume_from")
if opt_v then
+ nerv.info("resuming from previous training state")
gconf = dofile(opt_v)
else
for k, v in pairs(spec) do
- local opt_v = get_opt_val(k)
- if opt_v ~= nil then
+ local opt_v, specified = get_opt_val(k)
+ if (not specified) and gconf[k] ~= nil then
+ nerv.info("using setting in network config file: %s = %s", k, gconf[k])
+ elseif opt_v ~= nil then
+ nerv.info("using setting in options: %s = %s", k, opt_v)
gconf[k] = opt_v
- elseif gconf[k] ~= nil then
- elseif v ~= nil then
- gconf[k] = v
end
end
end
@@ -168,6 +180,7 @@ end
local trainer_defaults = {
lrate = 0.8,
batch_size = 256,
+ chunk_size = 1,
buffer_size = 81920,
wcost = 1e-6,
momentum = 0.9,