aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2015-11-20 23:57:09 +0800
committertxh18 <[email protected]>2015-11-20 23:57:09 +0800
commite7a45e14d75959a3d4095ac34158a8abc3e995cf (patch)
treec2d0da26005034233a80277243003ea9e5006823
parentddcb0a8f3ee045910acc618177dc5baf7adb8bf3 (diff)
added has_param api for param_repo
-rw-r--r--nerv/layer/init.lua4
-rw-r--r--nerv/nn/layer_repo.lua3
-rw-r--r--nerv/nn/param_repo.lua15
3 files changed, 15 insertions, 7 deletions
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua
index d268caa..67ebe1e 100644
--- a/nerv/layer/init.lua
+++ b/nerv/layer/init.lua
@@ -78,9 +78,9 @@ function Layer:find_param(pid, l_conf, gconf, p_type, p_dim)
local pid_g = self.id .. '_' .. pid --global identifier
local pr = gconf.paramRepo
local p
- p = pr:get_param(pid_g)
- if p ~= nil then
+ if pr:has_param(pid_g) == true then
nerv.printf("Param [%s] of layer [%s] found in paramRepo.\n", pid, self.id)
+ p = pr:get_param(pid_g)
return p
end
nerv.printf("Param [%s] of layer [%s] is not found in layer_conf or paramRepo, switch to auto-generate.\n", pid, self.id)
diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua
index ec0f80a..ef333a7 100644
--- a/nerv/nn/layer_repo.lua
+++ b/nerv/nn/layer_repo.lua
@@ -23,9 +23,6 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf)
end
for pname, pid in pairs(spec[1]) do
layer_config[pname] = param_repo:get_param(pid)
- if layer_config[pname] == nil then
- nerv.error("did not find parameter in paramRepo")
- end
end
layers[id] = layer_type(id, global_conf, layer_config)
end
diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua
index 7fc0498..6d52691 100644
--- a/nerv/nn/param_repo.lua
+++ b/nerv/nn/param_repo.lua
@@ -67,7 +67,18 @@ function ParamRepo:export(param_file, pids)
cf:close()
end
+function ParamRepo:has_param(pid)
+ if self.params[pid] ~= nil then
+ return true
+ else
+ return false
+ end
+end
+
function ParamRepo:get_param(pid)
- --if pid does not exist, return nil
- return self.params[pid]
+ local p = self.params[pid]
+ if p == nil then
+ nerv.error("param with id %s not found", pid)
+ end
+ return p
end