diff options
author | Determinant <[email protected]> | 2015-06-10 20:42:10 +0800 |
---|---|---|
committer | Determinant <[email protected]> | 2015-06-10 20:42:10 +0800 |
commit | b818c2562d07a69083377cbc34f2add108e9fa66 (patch) | |
tree | a595ce4f269035951715334d2942d91d42ae236e /layer/combiner.lua | |
parent | c20af45d0756d5d3004105da10e51d42a382ad66 (diff) |
add CombinerLayer to support branches in NN; add MSELayer
Diffstat (limited to 'layer/combiner.lua')
-rw-r--r-- | layer/combiner.lua | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/layer/combiner.lua b/layer/combiner.lua new file mode 100644 index 0000000..2eac83c --- /dev/null +++ b/layer/combiner.lua @@ -0,0 +1,55 @@ +local CombinerLayer = nerv.class('nerv.CombinerLayer', 'nerv.Layer') + +function CombinerLayer:__init(id, global_conf, layer_conf) + self.id = id + self.lambda = layer_conf.lambda + self.dim_in = layer_conf.dim_in + self.dim_out = layer_conf.dim_out + self.gconf = global_conf + self:check_dim_len(#self.lambda, -1) +end + +function CombinerLayer:init() + local dim = self.dim_in[1] + for i = 2, #self.dim_in do + if self.dim_in[i] ~= dim then + nerv.error("mismatching dimensions of inputs") + end + end + for i = 1, #self.dim_out do + if self.dim_out[i] ~= dim then + nerv.error("mismatching dimensions of inputs/outputs") + end + end +end + +function CombinerLayer:update(bp_err, input, output) +end + +function CombinerLayer:propagate(input, output) + output[1]:fill(0) + for i = 1, #self.dim_in do + output[1]:add(output[1], input[i], 1.0, self.lambda[i]) + end + for i = 2, #self.dim_out do + output[i]:copy_fromd(output[1]) + end +end + +function CombinerLayer:back_propagate(next_bp_err, bp_err, input, output) + local sum = bp_err[1]:create() + sum:fill(0) + for i = 1, #self.dim_out do + sum:add(sum, bp_err[i], 1.0, 1.0) + end + for i = 1, #self.dim_in do + local scale = nerv.CuMatrixFloat(sum:nrow(), 1) + scale:fill(self.lambda[i]) + next_bp_err[i]:copy_fromd(sum) + next_bp_err[i]:scale_rows_by_col(scale) + end +end + +function CombinerLayer:get_params() + return {self.lambda} +end |