diff options
Diffstat (limited to 'nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua')
-rw-r--r-- | nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua b/nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua index dddb05a..a9ce975 100644 --- a/nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua +++ b/nerv/examples/lmptb/tnn/layersT/softmax_ce_t.lua @@ -16,6 +16,9 @@ function SoftmaxCELayer:init(batch_size, chunk_size) if not self.compressed and (self.dim_in[1] ~= self.dim_in[2]) then nerv.error("mismatching dimensions of previous network output and labels") end + if chunk_size == nil then + chunk_size = 1 + end self.total_ce = 0.0 self.total_correct = 0 self.total_frames = 0 @@ -27,9 +30,12 @@ function SoftmaxCELayer:init(batch_size, chunk_size) end end -function SoftmaxCELayer:batch_resize(batch_size) +function SoftmaxCELayer:batch_resize(batch_size, chunk_size) + if chunk_size == nil then + chunk_size = 1 + end for t = 1, chunk_size do - if self.softmax_t[t]:nrow() ~= batch_resize then + if self.softmax_t[t]:nrow() ~= batch_size then self.softmax_t[t] = self.gconf.cumat_type(batch_size, self.dim_in[1]) self.ce_t[t] = self.gconf.cumat_type(batch_size, self.dim_in[1]) end @@ -41,6 +47,9 @@ function SoftmaxCELayer:update(bp_err, input, output, t) end function SoftmaxCELayer:propagate(input, output, t) + if t == nil then + t = 1 + end local softmax = self.softmax_t[t] local ce = self.ce_t[t] local classified = softmax:softmax(input[1]) @@ -65,6 +74,9 @@ end function SoftmaxCELayer:back_propagate(bp_err, next_bp_err, input, output, t) -- softmax output - label + if t == nil then + t = 1 + end local label = input[2] if self.compressed then label = label:decompress(input[1]:ncol()) |