From e4c6f4a6b7c537969fdea4abcebdfda884fa3bbc Mon Sep 17 00:00:00 2001 From: Determinant Date: Thu, 21 Apr 2016 12:50:24 +0800 Subject: remove redundant parameters in nerv.ParamRepo:add --- nerv/init.lua | 4 ++-- nerv/nn/param_repo.lua | 28 ++++++++++++++-------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/nerv/init.lua b/nerv/init.lua index 320987e..ba6a08d 100644 --- a/nerv/init.lua +++ b/nerv/init.lua @@ -195,9 +195,9 @@ end -- value and description of the option. -- -- An example of specification: --- {{"aaa", "a", "boolean", default = false, desc = "an option called aaa"}, +-- ```{{"aaa", "a", "boolean", default = false, desc = "an option called aaa"}, -- {"bbb", "b", "boolean", default = true, desc = "bbb is set to be true if --bbb=no does not present"}, --- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}}` +-- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}}``` -- -- @return args, opts The non-option arguments and parsed options. `opts` is -- again a list of tables, each of which corresponds to one table in parameter diff --git a/nerv/nn/param_repo.lua b/nerv/nn/param_repo.lua index 1e7a366..932ed2a 100644 --- a/nerv/nn/param_repo.lua +++ b/nerv/nn/param_repo.lua @@ -37,19 +37,19 @@ function ParamRepo:__init(plist, loc_type) end end -function ParamRepo:add(pid, p) - if self.params[pid] ~= nil then - nerv.error("duplicate params with the same id: %s", pid) +function ParamRepo:add(p) + if self.params[p.id] ~= nil then + nerv.error("duplicate params with the same id: %s", p.id) end p:check(self.checker) - self.params[pid] = p + self.params[p.id] = p end -function ParamRepo:remove(pid, p) +function ParamRepo:remove(pid) if self.params[pid] == nil then nerv.error("param %s does not exit", pid) end - table.remove(self.params, pid) + self.params[pid] = nil end function ParamRepo.merge(repos, loc_type) @@ -59,7 +59,7 @@ function ParamRepo.merge(repos, loc_type) nerv.error("nerv.ParamRepo objects expected, got %s", repo) end for pid, p in pairs(repo.params) do - self:add(pid, p) + self:add(p) end end return self @@ -77,7 +77,7 @@ function ParamRepo:import(param_files, gconf, pids) if not nerv.is_type(p, "nerv.Param") then nerv.error("param chunk is expected") end - self:add(cid, p) + self:add(p) end end end @@ -85,12 +85,12 @@ end function ParamRepo:export(param_file, pids) cf = nerv.ChunkFile(param_file, "w") - if pids == nil then - for id, p in pairs(self.params) do - cf:write_chunk(p) - end - else - for i, pid in ipairs(pids) do +if pids == nil then + for id, p in pairs(self.params) do + cf:write_chunk(p) + end +else + for i, pid in ipairs(pids) do cf:write_chunk(self:get_param(pid)) end end -- cgit v1.2.3