aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/init.lua
diff options
context:
space:
mode:
authorDeterminant <ted.sybil@gmail.com>2016-03-10 13:40:11 +0800
committerDeterminant <ted.sybil@gmail.com>2016-03-10 13:40:11 +0800
commita32195e3e2ae9ca0f0c7a82e73e6bddb64568c05 (patch)
treea19f21f8cbadecff7357f9a102f160f5fe699b65 /nerv/layer/init.lua
parent4a6872601f05e9ecc059f83fb64a0a4887992b99 (diff)
major change: clearer param binding semantics; permit rebinding; enable
resuming from previous training
Diffstat (limited to 'nerv/layer/init.lua')
-rw-r--r--nerv/layer/init.lua60
1 files changed, 39 insertions, 21 deletions
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua
index 54f33ae..146ad8c 100644
--- a/nerv/layer/init.lua
+++ b/nerv/layer/init.lua
@@ -30,7 +30,18 @@ end
local Layer = nerv.class('nerv.Layer')
function Layer:__init(id, global_conf, layer_conf)
- nerv.error_method_not_implemented()
+ self.id = id
+ self.gconf = global_conf
+ self.lconf = layer_conf
+ if self.gconf.use_cpu then
+ self.mat_type = self.gconf.mmat_type
+ self.loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST
+ else
+ self.mat_type = self.gconf.cumat_type
+ self.loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE
+ end
+ self.dim_in = layer_conf.dim_in
+ self.dim_out = layer_conf.dim_out
end
function Layer:init(batch_size)
@@ -66,34 +77,41 @@ function Layer:get_params()
nerv.error_method_not_implemented()
end
+function Layer:bind_params()
+ nerv.error_method_not_implemented()
+end
+
function Layer:get_dim()
return self.dim_in, self.dim_out
end
-function Layer:find_param(pid_list, lconf, gconf, p_type, p_dim)
- if type(pid_list) == "string" then
- pid_list = {pid_list}
+function Layer:find_param(plist, lconf, gconf, p_type, p_dim)
+ if type(plist) == "string" then
+ plist = {plist}
end
- pid_list_str = table.tostring(pid_list)
- for i, pid in ipairs(pid_list) do
- if lconf[pid] ~= nil then
- nerv.info("param [%s] of layer [%s] found in `layer_conf`.", pid, self.id)
- return lconf[pid]
+ if lconf.params == nil then
+ lconf.params = {}
+ end
+ plist_str = table.tostring(plist)
+ local pid
+ for i, pname in ipairs(plist) do
+ if lconf.params[pname] ~= nil then
+ nerv.info("param id for [%s] of layer [%s] specified in `layer_conf.params`.", pname, self.id)
+ pid = lconf.params[pname]
end
- local pid_g = self.id .. '_' .. pid --global identifier
- local pr = lconf.pr
- local p
- if pr ~= nil and pr:has_param(pid_g) == true then
- nerv.info("param [%s] of layer [%s] found in `layer_conf.pr`.", pid_list_str, self.id)
- p = pr:get_param(pid_g)
- return p
+ if lconf.pr:has_param(pid) then
+ return lconf.pr:get_param(pid)
end
end
- nerv.info("param [%s] of layer [%s] is not found in `layer_conf` or `layer_conf.pr`, " ..
- "switch to auto-generate", pid_list_str, self.id)
- local pid_g = self.id .. '_' .. pid_list[1]
- p = p_type(pid_g, gconf)
- p.trans = gconf.cumat_type(unpack(p_dim))
+ pid = self.id .. '_' .. plist[1]
+ if lconf.pr:has_param(pid) then
+ nerv.info("param id for [%s] of layer [%s] is generated automatically.", pname, self.id)
+ return lconf.pr:get_param(pid)
+ end
+ nerv.info("param id for [%s] of layer [%s] is not found in the specified param repo, " ..
+ "switch to auto-generate", plist_str, self.id)
+ local p = p_type(pid, gconf)
+ p.trans = self.mat_type(unpack(p_dim))
if type(gconf.param_random) ~= "function" then
nerv.error("a param generate function is needed")
end