From e7a45e14d75959a3d4095ac34158a8abc3e995cf Mon Sep 17 00:00:00 2001 From: txh18 Date: Fri, 20 Nov 2015 23:57:09 +0800 Subject: added has_param api for param_repo --- nerv/layer/init.lua | 4 ++-- nerv/nn/layer_repo.lua | 3 --- nerv/nn/param_repo.lua | 15 +++++++++++++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua index d268caa..67ebe1e 100644 --- a/nerv/layer/init.lua +++ b/nerv/layer/init.lua @@ -78,9 +78,9 @@ function Layer:find_param(pid, l_conf, gconf, p_type, p_dim) local pid_g = self.id .. '_' .. pid --global identifier local pr = gconf.paramRepo local p - p = pr:get_param(pid_g) - if p ~= nil then + if pr:has_param(pid_g) == true then nerv.printf("Param [%s] of layer [%s] found in paramRepo.\n", pid, self.id) + p = pr:get_param(pid_g) return p end nerv.printf("Param [%s] of layer [%s] is not found in layer_conf or paramRepo, switch to auto-generate.\n", pid, self.id) diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua index ec0f80a..ef333a7 100644 --- a/nerv/nn/layer_repo.lua +++ b/nerv/nn/layer_repo.lua @@ -23,9 +23,6 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf) end for pname, pid in pairs(spec[1]) do layer_config[pname] = param_repo:get_param(pid) - if layer_config[pname] == nil then - nerv.error("did not find parameter in paramRepo") - end end layers[id] = layer_type(id, global_conf, layer_config) end diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua index 7fc0498..6d52691 100644 --- a/nerv/nn/param_repo.lua +++ b/nerv/nn/param_repo.lua @@ -67,7 +67,18 @@ function ParamRepo:export(param_file, pids) cf:close() end +function ParamRepo:has_param(pid) + if self.params[pid] ~= nil then + return true + else + return false + end +end + function ParamRepo:get_param(pid) - --if pid does not exist, return nil - return self.params[pid] + local p = self.params[pid] + if p == nil then + nerv.error("param with id %s not found", pid) + end + return p end -- cgit v1.2.3-70-g09d2