aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/layer')
-rw-r--r--nerv/layer/affine.lua8
-rw-r--r--nerv/layer/init.lua34
2 files changed, 27 insertions, 15 deletions
diff --git a/nerv/layer/affine.lua b/nerv/layer/affine.lua
index ec13519..d83b5f2 100644
--- a/nerv/layer/affine.lua
+++ b/nerv/layer/affine.lua
@@ -84,14 +84,16 @@ function AffineLayer:__init(id, global_conf, layer_conf)
layer_conf.ltp1 = layer_conf.ltp
end
for i = 1, #self.dim_in do
- self["ltp" .. i] = self:find_param("ltp" .. i, layer_conf, global_conf,
+ local pid = "ltp" .. i
+ local pid_list = i == 1 and {"ltp", pid} or pid
+ self["ltp" .. i] = self:find_param(pid_list, layer_conf, global_conf,
nerv.LinearTransParam,
- {self.dim_in[i], self.dim_out[1]})
+ {self.dim_in[i], self.dim_out[1]}, pid)
end
self.ltp = self.ltp1 -- alias of ltp1
self.bp = self:find_param("bp", layer_conf, global_conf,
nerv.BiasParam,
- {1, self.dim_out[1]})
+ {1, self.dim_out[1]}, "bp")
self.gconf = global_conf
self:check_dim_len(-1, 1) -- exactly one output, allow multiple inputs
end
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua
index 23606e1..86ea9cf 100644
--- a/nerv/layer/init.lua
+++ b/nerv/layer/init.lua
@@ -70,22 +70,32 @@ function Layer:get_dim()
return self.dim_in, self.dim_out
end
-function Layer:find_param(pid, l_conf, gconf, p_type, p_dim)
- if l_conf[pid] ~= nil then
- nerv.info("Param [%s] of layer [%s] found in layer_conf.", pid, self.id)
- return l_conf[pid]
+function Layer:find_param(pid_list, lconf, gconf, p_type, p_dim, target_pid)
+ if type(pid_list) == "string" then
+ pid_list = {pid_list}
end
- local pid_g = self.id .. '_' .. pid --global identifier
- local pr = l_conf.pr
- local p
- if pr ~= nil and pr:has_param(pid_g) == true then
- nerv.info("Param [%s] of layer [%s] found in layer_conf.paramRepo.", pid, self.id)
- p = pr:get_param(pid_g)
- return p
+ pid_list_str = table.tostring(pid_list)
+ for i, pid in ipairs(pid_list) do
+ if lconf[pid] ~= nil then
+ nerv.info("param [%s] of layer [%s] found in `layer_conf`.", pid, self.id)
+ return lconf[pid]
+ end
+ local pid_g = self.id .. '_' .. pid --global identifier
+ local pr = lconf.pr
+ local p
+ if pr ~= nil and pr:has_param(pid_g) == true then
+ nerv.info("param [%s] of layer [%s] found in `layer_conf.pr`.", pid_list_str, self.id)
+ p = pr:get_param(pid_g)
+ return p
+ end
end
- nerv.info("Param [%s] of layer [%s] is not found in layer_conf or layer_conf.paramRepo, switch to auto-generate.", pid, self.id)
+ nerv.info("param [%s] of layer [%s] is not found in `layer_conf` or `layer_conf.pr`, " ..
+ "switch to auto-generate.", pid_list_str, self.id)
p = p_type(pid_g, gconf)
p.trans = gconf.cumat_type(unpack(p_dim))
+ if type(gconf.param_random) ~= "function" then
+ nerv.error("a param generate function is needed")
+ end
p.trans:generate(gconf.param_random)
return p
end