aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/init.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/layer/init.lua')
-rw-r--r--nerv/layer/init.lua12
1 files changed, 9 insertions, 3 deletions
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua
index d266773..c5b7657 100644
--- a/nerv/layer/init.lua
+++ b/nerv/layer/init.lua
@@ -27,6 +27,10 @@ function Param:update(gradient)
nerv.error_method_not_implemented()
end
+function Param:gen_zero()
+ return 0
+end
+
local Layer = nerv.class('nerv.Layer')
function Layer:__init(id, global_conf, layer_conf)
@@ -93,7 +97,7 @@ function Layer:get_sublayer(id)
nerv.error('primitive layer does not have sublayers')
end
-function Layer:find_param(plist, lconf, gconf, p_type, p_dim)
+function Layer:find_param(plist, lconf, gconf, p_type, p_dim, p_gen)
if type(plist) == "string" then
plist = {plist}
end
@@ -120,10 +124,12 @@ function Layer:find_param(plist, lconf, gconf, p_type, p_dim)
"switch to auto-generate", plist_str, self.id)
local p = p_type(pid, gconf)
p.trans = self.mat_type(unpack(p_dim))
- if type(gconf.param_random) ~= "function" then
+ p_gen = p_gen or gconf.param_gen
+ or gconf.param_random -- obsolete name
+ if type(p_gen) ~= "function" then
nerv.error("a param generate function is needed")
end
- p.trans:generate(gconf.param_random)
+ p.trans:generate(p_gen)
return p
end