aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/mse.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/layer/mse.lua')
-rw-r--r--nerv/layer/mse.lua8
1 files changed, 8 insertions, 0 deletions
diff --git a/nerv/layer/mse.lua b/nerv/layer/mse.lua
index 2516998..0ee3080 100644
--- a/nerv/layer/mse.lua
+++ b/nerv/layer/mse.lua
@@ -20,6 +20,14 @@ function MSELayer:init(batch_size)
self.diff = self.mse:create()
end
+function MSELayer:batch_resize(batch_size)
+ if self.mse:nrow() ~= batch_resize then
+ self.mse = self.gconf.cumat_type(batch_size, self.dim_in[1])
+ self.mse_sum = self.gconf.cumat_type(batch_size, 1)
+ self.diff = self.mse:create()
+ end
+end
+
function MSELayer:update(bp_err, input, output)
-- no params, therefore do nothing
end