aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2016-02-16 19:42:29 +0800
committertxh18 <[email protected]>2016-02-16 19:42:29 +0800
commit490a10c2130773bd022f05513fa2905b6a6c6e91 (patch)
treea66d613c3c6cfc4054e799606f7fb56278281347
parent1721de3a5f5cd6df731484a8161b537468bea0bd (diff)
fixed some minor problem
-rw-r--r--nerv/layer/affine.lua8
-rw-r--r--nerv/layer/init.lua1
2 files changed, 5 insertions, 4 deletions
diff --git a/nerv/layer/affine.lua b/nerv/layer/affine.lua
index d83b5f2..4156dde 100644
--- a/nerv/layer/affine.lua
+++ b/nerv/layer/affine.lua
@@ -85,15 +85,15 @@ function AffineLayer:__init(id, global_conf, layer_conf)
end
for i = 1, #self.dim_in do
local pid = "ltp" .. i
- local pid_list = i == 1 and {"ltp", pid} or pid
+ local pid_list = i == 1 and {pid, "ltp"} or pid
self["ltp" .. i] = self:find_param(pid_list, layer_conf, global_conf,
nerv.LinearTransParam,
- {self.dim_in[i], self.dim_out[1]}, pid)
+ {self.dim_in[i], self.dim_out[1]})
end
self.ltp = self.ltp1 -- alias of ltp1
self.bp = self:find_param("bp", layer_conf, global_conf,
nerv.BiasParam,
- {1, self.dim_out[1]}, "bp")
+ {1, self.dim_out[1]})
self.gconf = global_conf
self:check_dim_len(-1, 1) -- exactly one output, allow multiple inputs
end
@@ -142,7 +142,7 @@ function AffineLayer:back_propagate(bp_err, next_bp_err, input, output)
end
function AffineLayer:get_params()
- local pr = nerv.ParamRepo({self.ltp, self.bp})
+ local pr = nerv.ParamRepo({self.ltp1, self.bp})
for i = 2, #self.dim_in do
pr:add(self["ltp" .. i].id, self["ltp" .. i])
end
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua
index d952022..43c2250 100644
--- a/nerv/layer/init.lua
+++ b/nerv/layer/init.lua
@@ -91,6 +91,7 @@ function Layer:find_param(pid_list, lconf, gconf, p_type, p_dim)
end
nerv.info("param [%s] of layer [%s] is not found in `layer_conf` or `layer_conf.pr`, " ..
"switch to auto-generate.", pid_list_str, self.id)
+ local pid_g = self.id .. '_' .. pid_list[1]
p = p_type(pid_g, gconf)
p.trans = gconf.cumat_type(unpack(p_dim))
if type(gconf.param_random) ~= "function" then