From c589c3aabaae7f3867bdfed994c8179a87f42675 Mon Sep 17 00:00:00 2001 From: Qi Liu Date: Tue, 29 Mar 2016 10:05:29 +0800 Subject: fix bug of momentum & update mse layer --- nerv/layer/mse.lua | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) (limited to 'nerv/layer/mse.lua') diff --git a/nerv/layer/mse.lua b/nerv/layer/mse.lua index 458d086..c1ea596 100644 --- a/nerv/layer/mse.lua +++ b/nerv/layer/mse.lua @@ -9,23 +9,28 @@ function MSELayer:bind_params() -- do nothing end -function MSELayer:init(batch_size) +function MSELayer:init(batch_size, chunk_size) if self.dim_in[1] ~= self.dim_in[2] then nerv.error("mismatching dimensions of previous network output and labels") end - self.scale = 1 / self.dim_in[1] + self.scale = 1.0 / self.dim_in[1] self.total_mse = 0.0 self.total_frames = 0 self.mse = self.mat_type(batch_size, self.dim_in[1]) self.mse_sum = self.mat_type(batch_size, 1) - self.diff = self.mse:create() + self.diff = {} + for t = 1, chunk_size do + self.diff[t] = self.mse:create() + end end function MSELayer:batch_resize(batch_size) if self.mse:nrow() ~= batch_resize then self.mse = self.mat_type(batch_size, self.dim_in[1]) self.mse_sum = self.mat_type(batch_size, 1) - self.diff = self.mse:create() + for t = 1, chunk_size do + self.diff[t] = self.mse:create() + end end end @@ -33,24 +38,32 @@ function MSELayer:update(bp_err, input, output) -- no params, therefore do nothing end -function MSELayer:propagate(input, output) +function MSELayer:propagate(input, output, t) + if t == nil then + t = 1 + end local mse = self.mse local mse_sum = self.mse_sum + local diff = self.diff[t] mse:add(input[1], input[2], 1.0, -1.0) - self.diff:copy_from(mse) + mse:set_values_by_mask(self.gconf.mask[t], 0) + diff:copy_from(mse) mse:mul_elem(mse, mse) - mse_sum:add(mse_sum, mse:rowsum(mse), 0.0, self.scale) + mse_sum:add(mse_sum, mse:rowsum(), 0.0, self.scale * 0.5) if output[1] ~= nil then output[1]:copy_from(mse_sum) end self.total_mse = self.total_mse + mse_sum:colsum()[0][0] - self.total_frames = self.total_frames + mse_sum:nrow() + self.total_frames = self.total_frames + self.gconf.mask[t]:colsum()[0][0] end -- NOTE: must call propagate before back_propagate -function MSELayer:back_propagate(bp_err, next_bp_err, input, output) +function MSELayer:back_propagate(bp_err, next_bp_err, input, output, t) + if t == nil then + t = 1 + end local nbe = next_bp_err[1] - nbe:add(nbe, self.diff, 0.0, 2 * self.scale) + nbe:add(nbe, self.diff[t], 0.0, self.scale) if bp_err[1] ~= nil then nbe:scale_rows_by_col(bp_err[1]) end -- cgit v1.2.3-70-g09d2