summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/layer/softmax_ce.lua44
1 files changed, 31 insertions, 13 deletions
diff --git a/nerv/layer/softmax_ce.lua b/nerv/layer/softmax_ce.lua
index 31a2ad7..d7d650e 100644
--- a/nerv/layer/softmax_ce.lua
+++ b/nerv/layer/softmax_ce.lua
@@ -17,31 +17,46 @@ function SoftmaxCELayer:__init(id, global_conf, layer_conf)
self:check_dim_len(2, -1) -- two inputs: nn output and label
end
-function SoftmaxCELayer:init(batch_size)
+function SoftmaxCELayer:init(batch_size, chunk_size)
if not self.compressed and (self.dim_in[1] ~= self.dim_in[2]) then
nerv.error("mismatching dimensions of previous network output and labels")
end
+ if chunk_size == nil then
+ chunk_size = 1
+ end
self.total_ce = 0.0
self.total_correct = 0
self.total_frames = 0
- self.softmax = self.mat_type(batch_size, self.dim_in[1])
- self.ce = self.softmax:create()
+ self.softmax = {}
+ self.ce = {}
+ for t = 1, chunk_size do
+ self.softmax[t] = self.mat_type(batch_size, self.dim_in[1])
+ self.ce[t] = self.softmax[t]:create()
+ end
end
-function SoftmaxCELayer:batch_resize(batch_size)
- if self.softmax:nrow() ~= batch_resize then
- self.softmax = self.mat_type(batch_size, self.dim_in[1])
- self.ce = self.softmax:create()
+function SoftmaxCELayer:batch_resize(batch_size, chunk_size)
+ if chunk_size == nil then
+ chunk_size = 1
+ end
+ for t = 1, chunk_size do
+ if self.softmax[t]:nrow() ~= batch_size then
+ self.softmax[t] = self.mat_type(batch_size, self.dim_in[1])
+ self.ce[t] = self.softmax:create()
+ end
end
end
-function SoftmaxCELayer:update(bp_err, input, output)
+function SoftmaxCELayer:update(bp_err, input, output, t)
-- no params, therefore do nothing
end
-function SoftmaxCELayer:propagate(input, output)
- local softmax = self.softmax
- local ce = self.ce
+function SoftmaxCELayer:propagate(input, output, t)
+ if t == nil then
+ t = 1
+ end
+ local softmax = self.softmax[t]
+ local ce = self.ce[t]
local classified = softmax:softmax(input[1])
local label = input[2]
ce:log_elem(softmax)
@@ -62,14 +77,17 @@ function SoftmaxCELayer:propagate(input, output)
end
end
-function SoftmaxCELayer:back_propagate(bp_err, next_bp_err, input, output)
+function SoftmaxCELayer:back_propagate(bp_err, next_bp_err, input, output, t)
-- softmax output - label
+ if t == nil then
+ t = 1
+ end
local label = input[2]
if self.compressed then
label = label:decompress(input[1]:ncol())
end
local nbe = next_bp_err[1]
- nbe:add(self.softmax, label, 1.0, -1.0)
+ nbe:add(self.softmax[t], label, 1.0, -1.0)
if bp_err[1] ~= nil then
nbe:scale_rows_by_col(bp_err[1])
end