From 620c1971c3c821337cd16cca20cddd27f7bc6085 Mon Sep 17 00:00:00 2001 From: Determinant Date: Thu, 18 Feb 2016 18:04:06 +0800 Subject: generalize softmax_ce.lua (to softmax_ce_t.lua) --- nerv/layer/softmax_ce.lua | 44 +++++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/nerv/layer/softmax_ce.lua b/nerv/layer/softmax_ce.lua index 31a2ad7..d7d650e 100644 --- a/nerv/layer/softmax_ce.lua +++ b/nerv/layer/softmax_ce.lua @@ -17,31 +17,46 @@ function SoftmaxCELayer:__init(id, global_conf, layer_conf) self:check_dim_len(2, -1) -- two inputs: nn output and label end -function SoftmaxCELayer:init(batch_size) +function SoftmaxCELayer:init(batch_size, chunk_size) if not self.compressed and (self.dim_in[1] ~= self.dim_in[2]) then nerv.error("mismatching dimensions of previous network output and labels") end + if chunk_size == nil then + chunk_size = 1 + end self.total_ce = 0.0 self.total_correct = 0 self.total_frames = 0 - self.softmax = self.mat_type(batch_size, self.dim_in[1]) - self.ce = self.softmax:create() + self.softmax = {} + self.ce = {} + for t = 1, chunk_size do + self.softmax[t] = self.mat_type(batch_size, self.dim_in[1]) + self.ce[t] = self.softmax[t]:create() + end end -function SoftmaxCELayer:batch_resize(batch_size) - if self.softmax:nrow() ~= batch_resize then - self.softmax = self.mat_type(batch_size, self.dim_in[1]) - self.ce = self.softmax:create() +function SoftmaxCELayer:batch_resize(batch_size, chunk_size) + if chunk_size == nil then + chunk_size = 1 + end + for t = 1, chunk_size do + if self.softmax[t]:nrow() ~= batch_size then + self.softmax[t] = self.mat_type(batch_size, self.dim_in[1]) + self.ce[t] = self.softmax:create() + end end end -function SoftmaxCELayer:update(bp_err, input, output) +function SoftmaxCELayer:update(bp_err, input, output, t) -- no params, therefore do nothing end -function SoftmaxCELayer:propagate(input, output) - local softmax = self.softmax - local ce = self.ce +function SoftmaxCELayer:propagate(input, output, t) + if t == nil then + t = 1 + end + local softmax = self.softmax[t] + local ce = self.ce[t] local classified = softmax:softmax(input[1]) local label = input[2] ce:log_elem(softmax) @@ -62,14 +77,17 @@ function SoftmaxCELayer:propagate(input, output) end end -function SoftmaxCELayer:back_propagate(bp_err, next_bp_err, input, output) +function SoftmaxCELayer:back_propagate(bp_err, next_bp_err, input, output, t) -- softmax output - label + if t == nil then + t = 1 + end local label = input[2] if self.compressed then label = label:decompress(input[1]:ncol()) end local nbe = next_bp_err[1] - nbe:add(self.softmax, label, 1.0, -1.0) + nbe:add(self.softmax[t], label, 1.0, -1.0) if bp_err[1] ~= nil then nbe:scale_rows_by_col(bp_err[1]) end -- cgit v1.2.3