aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2016-04-21 12:50:24 +0800
committerDeterminant <ted.sybil@gmail.com>2016-04-21 12:50:24 +0800
commite4c6f4a6b7c537969fdea4abcebdfda884fa3bbc (patch)
tree169125d936cc2e7c24ecc58cfb411cb988140f30
parent8fecbb8e488569cd8e2f930075120e5f1b1b54fb (diff)
remove redundant parameters in nerv.ParamRepo:add
-rw-r--r--nerv/init.lua4
-rw-r--r--nerv/nn/param_repo.lua28
2 files changed, 16 insertions, 16 deletions
diff --git a/nerv/init.lua b/nerv/init.lua
index 320987e..ba6a08d 100644
--- a/nerv/init.lua
+++ b/nerv/init.lua
@@ -195,9 +195,9 @@ end
-- value and description of the option.
--
-- An example of specification:
--- {{"aaa", "a", "boolean", default = false, desc = "an option called aaa"},
+-- ```{{"aaa", "a", "boolean", default = false, desc = "an option called aaa"},
-- {"bbb", "b", "boolean", default = true, desc = "bbb is set to be true if --bbb=no does not present"},
--- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}}`
+-- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}}```
--
-- @return args, opts The non-option arguments and parsed options. `opts` is
-- again a list of tables, each of which corresponds to one table in parameter
diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua
index 1e7a366..932ed2a 100644
--- a/nerv/nn/param_repo.lua
+++ b/nerv/nn/param_repo.lua
@@ -37,19 +37,19 @@ function ParamRepo:__init(plist, loc_type)
end
end
-function ParamRepo:add(pid, p)
- if self.params[pid] ~= nil then
- nerv.error("duplicate params with the same id: %s", pid)
+function ParamRepo:add(p)
+ if self.params[p.id] ~= nil then
+ nerv.error("duplicate params with the same id: %s", p.id)
end
p:check(self.checker)
- self.params[pid] = p
+ self.params[p.id] = p
end
-function ParamRepo:remove(pid, p)
+function ParamRepo:remove(pid)
if self.params[pid] == nil then
nerv.error("param %s does not exit", pid)
end
- table.remove(self.params, pid)
+ self.params[pid] = nil
end
function ParamRepo.merge(repos, loc_type)
@@ -59,7 +59,7 @@ function ParamRepo.merge(repos, loc_type)
nerv.error("nerv.ParamRepo objects expected, got %s", repo)
end
for pid, p in pairs(repo.params) do
- self:add(pid, p)
+ self:add(p)
end
end
return self
@@ -77,7 +77,7 @@ function ParamRepo:import(param_files, gconf, pids)
if not nerv.is_type(p, "nerv.Param") then
nerv.error("param chunk is expected")
end
- self:add(cid, p)
+ self:add(p)
end
end
end
@@ -85,12 +85,12 @@ end
function ParamRepo:export(param_file, pids)
cf = nerv.ChunkFile(param_file, "w")
- if pids == nil then
- for id, p in pairs(self.params) do
- cf:write_chunk(p)
- end
- else
- for i, pid in ipairs(pids) do
+if pids == nil then
+ for id, p in pairs(self.params) do
+ cf:write_chunk(p)
+ end
+else
+ for i, pid in ipairs(pids) do
cf:write_chunk(self:get_param(pid))
end
end