diff options
Diffstat (limited to 'nerv/layer/softmax_ce.lua')
-rw-r--r-- | nerv/layer/softmax_ce.lua | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/nerv/layer/softmax_ce.lua b/nerv/layer/softmax_ce.lua index f878a2f..9071e86 100644 --- a/nerv/layer/softmax_ce.lua +++ b/nerv/layer/softmax_ce.lua @@ -23,6 +23,13 @@ function SoftmaxCELayer:init(batch_size) self.ce = self.softmax:create() end +function SoftmaxCELayer:batch_resize(batch_size) + if self.softmax:nrow() ~= batch_resize then + self.softmax = self.gconf.cumat_type(batch_size, self.dim_in[1]) + self.ce = self.softmax:create() + end +end + function SoftmaxCELayer:update(bp_err, input, output) -- no params, therefore do nothing end |