aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn')
-rw-r--r--nerv/nn/layer_repo.lua3
-rw-r--r--nerv/nn/param_repo.lua7
2 files changed, 5 insertions, 5 deletions
diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua
index ef333a7..ec0f80a 100644
--- a/nerv/nn/layer_repo.lua
+++ b/nerv/nn/layer_repo.lua
@@ -23,6 +23,9 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf)
end
for pname, pid in pairs(spec[1]) do
layer_config[pname] = param_repo:get_param(pid)
+ if layer_config[pname] == nil then
+ nerv.error("did not find parameter in paramRepo")
+ end
end
layers[id] = layer_type(id, global_conf, layer_config)
end
diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua
index ab971ba..7fc0498 100644
--- a/nerv/nn/param_repo.lua
+++ b/nerv/nn/param_repo.lua
@@ -68,9 +68,6 @@ function ParamRepo:export(param_file, pids)
end
function ParamRepo:get_param(pid)
- local p = self.params[pid]
- if p == nil then
- nerv.error("param with id %s not found", pid)
- end
- return p
+ --if pid does not exist, return nil
+ return self.params[pid]
end