aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/asr_trainer.lua12
-rw-r--r--nerv/examples/seq_trainer.lua2
-rw-r--r--nerv/init.lua2
-rw-r--r--nerv/nn/param_repo.lua4
4 files changed, 12 insertions, 8 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua
index 645f1ef..52cb754 100644
--- a/nerv/examples/asr_trainer.lua
+++ b/nerv/examples/asr_trainer.lua
@@ -5,7 +5,7 @@ local function build_trainer(ifname)
local mat_type
local src_loc_type
local train_loc_type
- host_param_repo:import(ifname, nil, gconf)
+ host_param_repo:import(ifname, gconf)
if gconf.use_cpu then
mat_type = gconf.mmat_type
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
@@ -15,7 +15,7 @@ local function build_trainer(ifname)
src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
end
- local param_repo = host_param_repo:copy(train_loc_type)
+ local param_repo = host_param_repo:copy(train_loc_type, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
@@ -30,7 +30,7 @@ local function build_trainer(ifname)
-- rebind the params if necessary
if rebind_param_repo then
host_param_repo = rebind_param_repo
- param_repo = host_param_repo:copy(train_loc_type)
+ param_repo = host_param_repo:copy(train_loc_type, gconf)
layer_repo:rebind(param_repo)
rebind_param_repo = nil
end
@@ -106,7 +106,11 @@ local function build_trainer(ifname)
mat_type.clear_profile()
local fname
if (not bp) then
- host_param_repo = param_repo:copy(src_loc_type)
+-- host_param_repo = param_repo:copy(src_loc_type)
+ host_param_repo = nerv.ParamRepo.merge({network:get_params(),
+ global_transf:get_params()},
+ train_loc_type)
+ :copy(src_loc_type, gconf)
if prefix ~= nil then
nerv.info("writing back...")
fname = string.format("%s_cv%.3f.nerv",
diff --git a/nerv/examples/seq_trainer.lua b/nerv/examples/seq_trainer.lua
index b8ed3eb..a8411bd 100644
--- a/nerv/examples/seq_trainer.lua
+++ b/nerv/examples/seq_trainer.lua
@@ -1,6 +1,6 @@
function build_trainer(ifname)
local param_repo = nerv.ParamRepo()
- param_repo:import(ifname, nil, gconf)
+ param_repo:import(ifname, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
diff --git a/nerv/init.lua b/nerv/init.lua
index d017f82..320987e 100644
--- a/nerv/init.lua
+++ b/nerv/init.lua
@@ -109,7 +109,7 @@ function table.val_to_str(v)
(("number" == type(v) or
"string" == type(v) or
"boolean" == type(v)) and tostring(v)) or
- "" -- failed to serialize
+ "nil" -- failed to serialize
end
end
diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua
index aba7765..1e7a366 100644
--- a/nerv/nn/param_repo.lua
+++ b/nerv/nn/param_repo.lua
@@ -65,7 +65,7 @@ function ParamRepo.merge(repos, loc_type)
return self
end
-function ParamRepo:import(param_files, pids, gconf)
+function ParamRepo:import(param_files, gconf, pids)
if type(param_files) ~= "table" then
nerv.error("param file table is need")
end
@@ -109,7 +109,7 @@ function ParamRepo:get_param(pid)
return p
end
-function ParamRepo:copy(loc_type, pids)
+function ParamRepo:copy(loc_type, gconf, pids)
local copier
local target = nerv.ParamRepo(nil, loc_type)
if loc_type == nil then