diff options
author | Determinant <ted.sybil@gmail.com> | 2016-03-10 13:40:11 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2016-03-10 13:40:11 +0800 |
commit | a32195e3e2ae9ca0f0c7a82e73e6bddb64568c05 (patch) | |
tree | a19f21f8cbadecff7357f9a102f160f5fe699b65 /nerv/layer/init.lua | |
parent | 4a6872601f05e9ecc059f83fb64a0a4887992b99 (diff) |
major change: clearer param binding semantics; permit rebinding; enable
resuming from previous training
Diffstat (limited to 'nerv/layer/init.lua')
-rw-r--r-- | nerv/layer/init.lua | 60 |
1 files changed, 39 insertions, 21 deletions
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua index 54f33ae..146ad8c 100644 --- a/nerv/layer/init.lua +++ b/nerv/layer/init.lua @@ -30,7 +30,18 @@ end local Layer = nerv.class('nerv.Layer') function Layer:__init(id, global_conf, layer_conf) - nerv.error_method_not_implemented() + self.id = id + self.gconf = global_conf + self.lconf = layer_conf + if self.gconf.use_cpu then + self.mat_type = self.gconf.mmat_type + self.loc_type = nerv.ParamRepo.LOC_TYPES.ON_HOST + else + self.mat_type = self.gconf.cumat_type + self.loc_type = nerv.ParamRepo.LOC_TYPES.ON_DEVICE + end + self.dim_in = layer_conf.dim_in + self.dim_out = layer_conf.dim_out end function Layer:init(batch_size) @@ -66,34 +77,41 @@ function Layer:get_params() nerv.error_method_not_implemented() end +function Layer:bind_params() + nerv.error_method_not_implemented() +end + function Layer:get_dim() return self.dim_in, self.dim_out end -function Layer:find_param(pid_list, lconf, gconf, p_type, p_dim) - if type(pid_list) == "string" then - pid_list = {pid_list} +function Layer:find_param(plist, lconf, gconf, p_type, p_dim) + if type(plist) == "string" then + plist = {plist} end - pid_list_str = table.tostring(pid_list) - for i, pid in ipairs(pid_list) do - if lconf[pid] ~= nil then - nerv.info("param [%s] of layer [%s] found in `layer_conf`.", pid, self.id) - return lconf[pid] + if lconf.params == nil then + lconf.params = {} + end + plist_str = table.tostring(plist) + local pid + for i, pname in ipairs(plist) do + if lconf.params[pname] ~= nil then + nerv.info("param id for [%s] of layer [%s] specified in `layer_conf.params`.", pname, self.id) + pid = lconf.params[pname] end - local pid_g = self.id .. '_' .. pid --global identifier - local pr = lconf.pr - local p - if pr ~= nil and pr:has_param(pid_g) == true then - nerv.info("param [%s] of layer [%s] found in `layer_conf.pr`.", pid_list_str, self.id) - p = pr:get_param(pid_g) - return p + if lconf.pr:has_param(pid) then + return lconf.pr:get_param(pid) end end - nerv.info("param [%s] of layer [%s] is not found in `layer_conf` or `layer_conf.pr`, " .. - "switch to auto-generate", pid_list_str, self.id) - local pid_g = self.id .. '_' .. pid_list[1] - p = p_type(pid_g, gconf) - p.trans = gconf.cumat_type(unpack(p_dim)) + pid = self.id .. '_' .. plist[1] + if lconf.pr:has_param(pid) then + nerv.info("param id for [%s] of layer [%s] is generated automatically.", pname, self.id) + return lconf.pr:get_param(pid) + end + nerv.info("param id for [%s] of layer [%s] is not found in the specified param repo, " .. + "switch to auto-generate", plist_str, self.id) + local p = p_type(pid, gconf) + p.trans = self.mat_type(unpack(p_dim)) if type(gconf.param_random) ~= "function" then nerv.error("a param generate function is needed") end |