1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
require 'libkaldiseq'
local MMILayer = nerv.class("nerv.MMILayer", "nerv.Layer")
function MMILayer:__init(id, global_conf, layer_conf)
self.id = id
self.gconf = global_conf
self.dim_in = layer_conf.dim_in
self.dim_out = layer_conf.dim_out
self.arg = layer_conf.cmd.arg
self.mdl = layer_conf.cmd.mdl
self.lat = layer_conf.cmd.lat
self.ali = layer_conf.cmd.ali
self:check_dim_len(2, -1) -- two inputs: nn output and utt key
end
function MMILayer:init(batch_size)
self.total_frames = 0
self.kaldi_mmi = nerv.KaldiMMI(self.arg, self.mdl, self.lat, self.ali)
if self.kaldi_mmi == nil then
nerv.error("kaldi arguments is expected: %s %s %s %s", self.arg,
self.mdl, self.lat, self.ali)
end
end
function MMILayer:batch_resize(batch_size)
-- do nothing
end
function MMILayer:update(bp_err, input, output)
-- no params, therefore do nothing
end
function MMILayer:propagate(input, output)
self.valid = false
self.valid = self.kaldi_mmi:check(input[1], input[2])
return self.valid
end
function MMILayer:back_propagate(bp_err, next_bp_err, input, output)
if self.valid ~= true then
nerv.error("kaldi sequence training back_propagate fail")
end
local mmat = input[1]:new_to_host()
next_bp_err[1]:copy_fromh(self.kaldi_mmi:calc_diff(mmat, input[2]))
self.total_frames = self.total_frames + self.kaldi_mmi:get_num_frames()
end
function MMILayer:get_params()
return nerv.ParamRepo({})
end
|