aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer
diff options
context:
space:
mode:
authorYimmon Zhuang <yimmon.zhuang@gmail.com>2015-10-08 22:27:58 +0800
committerYimmon Zhuang <yimmon.zhuang@gmail.com>2015-10-08 22:27:58 +0800
commit7975592b94d65b6f356093694a76201de62a7a6a (patch)
treecf1eb9e8726cb016166129c51a3b8078cd9c78fd /nerv/layer
parent37286a08b40f68b544983d8dde4a77ac0b488397 (diff)
MMI support
Diffstat (limited to 'nerv/layer')
-rw-r--r--nerv/layer/init.lua1
-rw-r--r--nerv/layer/mmi.lua50
2 files changed, 51 insertions, 0 deletions
diff --git a/nerv/layer/init.lua b/nerv/layer/init.lua
index b74422f..25dfebb 100644
--- a/nerv/layer/init.lua
+++ b/nerv/layer/init.lua
@@ -80,3 +80,4 @@ nerv.include('combiner.lua')
nerv.include('affine_recurrent.lua')
nerv.include('softmax.lua')
nerv.include('mpe.lua')
+nerv.include('mmi.lua')
diff --git a/nerv/layer/mmi.lua b/nerv/layer/mmi.lua
new file mode 100644
index 0000000..ecc7f48
--- /dev/null
+++ b/nerv/layer/mmi.lua
@@ -0,0 +1,50 @@
+require 'libkaldiseq'
+local MMILayer = nerv.class("nerv.MMILayer", "nerv.Layer")
+
+function MMILayer:__init(id, global_conf, layer_conf)
+ self.id = id
+ self.gconf = global_conf
+ self.dim_in = layer_conf.dim_in
+ self.dim_out = layer_conf.dim_out
+ self.arg = layer_conf.cmd.arg
+ self.mdl = layer_conf.cmd.mdl
+ self.lat = layer_conf.cmd.lat
+ self.ali = layer_conf.cmd.ali
+ self:check_dim_len(2, -1) -- two inputs: nn output and utt key
+end
+
+function MMILayer:init(batch_size)
+ self.total_frames = 0
+ self.kaldi_mmi = nerv.KaldiMMI(self.arg, self.mdl, self.lat, self.ali)
+ if self.kaldi_mmi == nil then
+ nerv.error("kaldi arguments is expected: %s %s %s %s", self.arg,
+ self.mdl, self.lat, self.ali)
+ end
+end
+
+function MMILayer:batch_resize(batch_size)
+ -- do nothing
+end
+
+function MMILayer:update(bp_err, input, output)
+ -- no params, therefore do nothing
+end
+
+function MMILayer:propagate(input, output)
+ self.valid = false
+ self.valid = self.kaldi_mmi:check(input[1], input[2])
+ return self.valid
+end
+
+function MMILayer:back_propagate(bp_err, next_bp_err, input, output)
+ if self.valid ~= true then
+ nerv.error("kaldi sequence training back_propagate fail")
+ end
+ local mmat = input[1]:new_to_host()
+ next_bp_err[1]:copy_fromh(self.kaldi_mmi:calc_diff(mmat, input[2]))
+ self.total_frames = self.total_frames + self.kaldi_mmi:get_num_frames()
+end
+
+function MMILayer:get_params()
+ return nerv.ParamRepo({})
+end