aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/asr_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/asr_trainer.lua')
-rw-r--r--nerv/examples/asr_trainer.lua12
1 files changed, 8 insertions, 4 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 645f1ef..52cb754 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -5,7 +5,7 @@ local function build_trainer(ifname)
local mat_type
local src_loc_type
local train_loc_type
- host_param_repo:import(ifname, nil, gconf)
+ host_param_repo:import(ifname, gconf)
if gconf.use_cpu then
mat_type = gconf.mmat_type
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
@@ -15,7 +15,7 @@ local function build_trainer(ifname)
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
end
- local param_repo = host_param_repo:copy(train_loc_type)
+ local param_repo = host_param_repo:copy(train_loc_type, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
@@ -30,7 +30,7 @@ local function build_trainer(ifname)
-- rebind the params if necessary
if rebind_param_repo then
host_param_repo = rebind_param_repo
- param_repo = host_param_repo:copy(train_loc_type)
+ param_repo = host_param_repo:copy(train_loc_type, gconf)
layer_repo:rebind(param_repo)
rebind_param_repo = nil
end
@@ -106,7 +106,11 @@ local function build_trainer(ifname)
mat_type.clear_profile()
local fname
if (not bp) then
- host_param_repo = param_repo:copy(src_loc_type)
+-- host_param_repo = param_repo:copy(src_loc_type)
+ host_param_repo = nerv.ParamRepo.merge({network:get_params(),
+ global_transf:get_params()},
+ train_loc_type)
+ :copy(src_loc_type, gconf)
if prefix ~= nil then
nerv.info("writing back...")
fname = string.format("%s_cv%.3f.nerv",