aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua')
-rw-r--r--nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua16
1 files changed, 14 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua b/nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua
index dddb05a..a9ce975 100644
--- a/nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua
+++ b/nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua
@@ -16,6 +16,9 @@ function SoftmaxCELayer:init(batch_size, chunk_size)
if not self.compressed and (self.dim_in[1] ~= self.dim_in[2]) then
nerv.error("mismatching dimensions of previous network output and labels")
end
+ if chunk_size == nil then
+ chunk_size = 1
+ end
self.total_ce = 0.0
self.total_correct = 0
self.total_frames = 0
@@ -27,9 +30,12 @@ function SoftmaxCELayer:init(batch_size, chunk_size)
end
end
-function SoftmaxCELayer:batch_resize(batch_size)
+function SoftmaxCELayer:batch_resize(batch_size, chunk_size)
+ if chunk_size == nil then
+ chunk_size = 1
+ end
for t = 1, chunk_size do
- if self.softmax_t[t]:nrow() ~= batch_resize then
+ if self.softmax_t[t]:nrow() ~= batch_size then
self.softmax_t[t] = self.gconf.cumat_type(batch_size, self.dim_in[1])
self.ce_t[t] = self.gconf.cumat_type(batch_size, self.dim_in[1])
end
@@ -41,6 +47,9 @@ function SoftmaxCELayer:update(bp_err, input, output, t)
end
function SoftmaxCELayer:propagate(input, output, t)
+ if t == nil then
+ t = 1
+ end
local softmax = self.softmax_t[t]
local ce = self.ce_t[t]
local classified = softmax:softmax(input[1])
@@ -65,6 +74,9 @@ end
function SoftmaxCELayer:back_propagate(bp_err, next_bp_err, input, output, t)
-- softmax output - label
+ if t == nil then
+ t = 1
+ end
local label = input[2]
if self.compressed then
label = label:decompress(input[1]:ncol())