aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/softmax_ce.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/layer/softmax_ce.lua')
-rw-r--r--nerv/layer/softmax_ce.lua7
1 files changed, 7 insertions, 0 deletions
diff --git a/nerv/layer/softmax_ce.lua b/nerv/layer/softmax_ce.lua
index f878a2f..9071e86 100644
--- a/nerv/layer/softmax_ce.lua
+++ b/nerv/layer/softmax_ce.lua
@@ -23,6 +23,13 @@ function SoftmaxCELayer:init(batch_size)
self.ce = self.softmax:create()
end
+function SoftmaxCELayer:batch_resize(batch_size)
+ if self.softmax:nrow() ~= batch_resize then
+ self.softmax = self.gconf.cumat_type(batch_size, self.dim_in[1])
+ self.ce = self.softmax:create()
+ end
+end
+
function SoftmaxCELayer:update(bp_err, input, output)
-- no params, therefore do nothing
end