aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/mse.lua
diff options
context:
space:
mode:
authorYimmon Zhuang <yimmon.zhuang@gmail.com>2015-09-18 22:17:25 +0800
committerYimmon Zhuang <yimmon.zhuang@gmail.com>2015-09-18 22:17:25 +0800
commit37286a08b40f68b544983d8dde4a77ac0b488397 (patch)
treecc5512ef1c5e9eab3a2f1ba7c6d064a92079dafc /nerv/layer/mse.lua
parent5b99c28961ca223cc35e77a4482eb789d5bef06d (diff)
kaldi mpe training support
Diffstat (limited to 'nerv/layer/mse.lua')
-rw-r--r--nerv/layer/mse.lua8
1 files changed, 8 insertions, 0 deletions
diff --git a/nerv/layer/mse.lua b/nerv/layer/mse.lua
index 2516998..0ee3080 100644
--- a/nerv/layer/mse.lua
+++ b/nerv/layer/mse.lua
@@ -20,6 +20,14 @@ function MSELayer:init(batch_size)
self.diff = self.mse:create()
end
+function MSELayer:batch_resize(batch_size)
+ if self.mse:nrow() ~= batch_resize then
+ self.mse = self.gconf.cumat_type(batch_size, self.dim_in[1])
+ self.mse_sum = self.gconf.cumat_type(batch_size, 1)
+ self.diff = self.mse:create()
+ end
+end
+
function MSELayer:update(bp_err, input, output)
-- no params, therefore do nothing
end