diff options
author | Determinant <[email protected]> | 2016-02-16 17:04:44 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2016-02-16 17:04:44 +0800 |
commit | 7dee5871f8f67a78ee704c9efd5d4708e8a27740 (patch) | |
tree | 19d672ba48d4805f3fe7e5c8ba47936dd8a2225c /nerv | |
parent | 9e7171e2da3e4edba303f5c2bdaef416fb62e81a (diff) |
improve parameter auto-detection
Diffstat (limited to 'nerv')
-rw-r--r-- | nerv/layer/affine.lua | 8 | ||||
-rw-r--r-- | nerv/layer/init.lua | 34 | ||||
-rw-r--r-- | nerv/nn/layer_repo.lua | 3 |
3 files changed, 30 insertions, 15 deletions
diff --git a/nerv/layer/affine.lua b/nerv/layer/affine.lua index ec13519..d83b5f2 100644 --- a/nerv/layer/affine.lua +++ b/nerv/layer/affine.lua @@ -84,14 +84,16 @@ function AffineLayer:__init(id, global_conf, layer_conf) layer_conf.ltp1 = layer_conf.ltp end for i = 1, #self.dim_in do - self["ltp" .. i] = self:find_param("ltp" .. i, layer_conf, global_conf, + local pid = "ltp" .. i + local pid_list = i == 1 and {"ltp", pid} or pid + self["ltp" .. i] = self:find_param(pid_list, layer_conf, global_conf, nerv.LinearTransParam, - {self.dim_in[i], self.dim_out[1]}) + {self.dim_in[i], self.dim_out[1]}, pid) end self.ltp = self.ltp1 -- alias of ltp1 self.bp = self:find_param("bp", layer_conf, global_conf, nerv.BiasParam, - {1, self.dim_out[1]}) + {1, self.dim_out[1]}, "bp") self.gconf = global_conf self:check_dim_len(-1, 1) -- exactly one output, allow multiple inputs end diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua index 23606e1..86ea9cf 100644 --- a/nerv/layer/init.lua +++ b/nerv/layer/init.lua @@ -70,22 +70,32 @@ function Layer:get_dim() return self.dim_in, self.dim_out end -function Layer:find_param(pid, l_conf, gconf, p_type, p_dim) - if l_conf[pid] ~= nil then - nerv.info("Param [%s] of layer [%s] found in layer_conf.", pid, self.id) - return l_conf[pid] +function Layer:find_param(pid_list, lconf, gconf, p_type, p_dim, target_pid) + if type(pid_list) == "string" then + pid_list = {pid_list} end - local pid_g = self.id .. '_' .. pid --global identifier - local pr = l_conf.pr - local p - if pr ~= nil and pr:has_param(pid_g) == true then - nerv.info("Param [%s] of layer [%s] found in layer_conf.paramRepo.", pid, self.id) - p = pr:get_param(pid_g) - return p + pid_list_str = table.tostring(pid_list) + for i, pid in ipairs(pid_list) do + if lconf[pid] ~= nil then + nerv.info("param [%s] of layer [%s] found in `layer_conf`.", pid, self.id) + return lconf[pid] + end + local pid_g = self.id .. '_' .. pid --global identifier + local pr = lconf.pr + local p + if pr ~= nil and pr:has_param(pid_g) == true then + nerv.info("param [%s] of layer [%s] found in `layer_conf.pr`.", pid_list_str, self.id) + p = pr:get_param(pid_g) + return p + end end - nerv.info("Param [%s] of layer [%s] is not found in layer_conf or layer_conf.paramRepo, switch to auto-generate.", pid, self.id) + nerv.info("param [%s] of layer [%s] is not found in `layer_conf` or `layer_conf.pr`, " .. + "switch to auto-generate.", pid_list_str, self.id) p = p_type(pid_g, gconf) p.trans = gconf.cumat_type(unpack(p_dim)) + if type(gconf.param_random) ~= "function" then + nerv.error("a param generate function is needed") + end p.trans:generate(gconf.param_random) return p end diff --git a/nerv/nn/layer_repo.lua b/nerv/nn/layer_repo.lua index ef333a7..2f8de08 100644 --- a/nerv/nn/layer_repo.lua +++ b/nerv/nn/layer_repo.lua @@ -24,6 +24,9 @@ function LayerRepo:add_layers(layer_spec, param_repo, global_conf) for pname, pid in pairs(spec[1]) do layer_config[pname] = param_repo:get_param(pid) end + if layer_config.pr == nil then + layer_config.pr = param_repo + end layers[id] = layer_type(id, global_conf, layer_config) end end |