diff options
-rw-r--r-- | nerv/examples/asr_trainer.lua | 12 | ||||
-rw-r--r-- | nerv/examples/seq_trainer.lua | 2 | ||||
-rw-r--r-- | nerv/init.lua | 2 | ||||
-rw-r--r-- | nerv/nn/param_repo.lua | 4 |
4 files changed, 12 insertions, 8 deletions
diff --git a/nerv/examples/asr_trainer.lua b/nerv/examples/asr_trainer.lua index 645f1ef..52cb754 100644 --- a/nerv/examples/asr_trainer.lua +++ b/nerv/examples/asr_trainer.lua @@ -5,7 +5,7 @@ local function build_trainer(ifname) local mat_type local src_loc_type local train_loc_type - host_param_repo:import(ifname, nil, gconf) + host_param_repo:import(ifname, gconf) if gconf.use_cpu then mat_type = gconf.mmat_type src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST @@ -15,7 +15,7 @@ local function build_trainer(ifname) src_loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST train_loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE end - local param_repo = host_param_repo:copy(train_loc_type) + local param_repo = host_param_repo:copy(train_loc_type, gconf) local layer_repo = make_layer_repo(param_repo) local network = get_network(layer_repo) local global_transf = get_global_transf(layer_repo) @@ -30,7 +30,7 @@ local function build_trainer(ifname) -- rebind the params if necessary if rebind_param_repo then host_param_repo = rebind_param_repo - param_repo = host_param_repo:copy(train_loc_type) + param_repo = host_param_repo:copy(train_loc_type, gconf) layer_repo:rebind(param_repo) rebind_param_repo = nil end @@ -106,7 +106,11 @@ local function build_trainer(ifname) mat_type.clear_profile() local fname if (not bp) then - host_param_repo = param_repo:copy(src_loc_type) +-- host_param_repo = param_repo:copy(src_loc_type) + host_param_repo = nerv.ParamRepo.merge({network:get_params(), + global_transf:get_params()}, + train_loc_type) + :copy(src_loc_type, gconf) if prefix ~= nil then nerv.info("writing back...") fname = string.format("%s_cv%.3f.nerv", diff --git a/nerv/examples/seq_trainer.lua b/nerv/examples/seq_trainer.lua index b8ed3eb..a8411bd 100644 --- a/nerv/examples/seq_trainer.lua +++ b/nerv/examples/seq_trainer.lua @@ -1,6 +1,6 @@ function build_trainer(ifname) local param_repo = nerv.ParamRepo() - param_repo:import(ifname, nil, gconf) + param_repo:import(ifname, gconf) local layer_repo = make_layer_repo(param_repo) local network = get_network(layer_repo) local global_transf = get_global_transf(layer_repo) diff --git a/nerv/init.lua b/nerv/init.lua index d017f82..320987e 100644 --- a/nerv/init.lua +++ b/nerv/init.lua @@ -109,7 +109,7 @@ function table.val_to_str(v) (("number" == type(v) or "string" == type(v) or "boolean" == type(v)) and tostring(v)) or - "" -- failed to serialize + "nil" -- failed to serialize end end diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua index aba7765..1e7a366 100644 --- a/nerv/nn/param_repo.lua +++ b/nerv/nn/param_repo.lua @@ -65,7 +65,7 @@ function ParamRepo.merge(repos, loc_type) return self end -function ParamRepo:import(param_files, pids, gconf) +function ParamRepo:import(param_files, gconf, pids) if type(param_files) ~= "table" then nerv.error("param file table is need") end @@ -109,7 +109,7 @@ function ParamRepo:get_param(pid) return p end -function ParamRepo:copy(loc_type, pids) +function ParamRepo:copy(loc_type, gconf, pids) local copier local target = nerv.ParamRepo(nil, loc_type) if loc_type == nil then |