summaryrefslogblamecommitdiff
path: root/kaldi_seq/layer/mmi.lua
blob: ecc7f48ec1c597fc8b2c77c61c9822a6380cd14c (plain) (tree)

















































                                                                           
require 'libkaldiseq'
local MMILayer = nerv.class("nerv.MMILayer", "nerv.Layer")

function MMILayer:__init(id, global_conf, layer_conf)
    self.id = id
    self.gconf = global_conf
    self.dim_in = layer_conf.dim_in
    self.dim_out = layer_conf.dim_out
    self.arg = layer_conf.cmd.arg
    self.mdl = layer_conf.cmd.mdl
    self.lat = layer_conf.cmd.lat
    self.ali = layer_conf.cmd.ali
    self:check_dim_len(2, -1) -- two inputs: nn output and utt key
end

function MMILayer:init(batch_size)
    self.total_frames = 0
    self.kaldi_mmi = nerv.KaldiMMI(self.arg, self.mdl, self.lat, self.ali)
    if self.kaldi_mmi == nil then
        nerv.error("kaldi arguments is expected: %s %s %s %s", self.arg,
        self.mdl, self.lat, self.ali)
    end
end

function MMILayer:batch_resize(batch_size)
    -- do nothing
end

function MMILayer:update(bp_err, input, output)
    -- no params, therefore do nothing
end

function MMILayer:propagate(input, output)
    self.valid = false
    self.valid = self.kaldi_mmi:check(input[1], input[2])
    return self.valid
end

function MMILayer:back_propagate(bp_err, next_bp_err, input, output)
    if self.valid ~= true then
        nerv.error("kaldi sequence training back_propagate fail")
    end
    local mmat = input[1]:new_to_host()
    next_bp_err[1]:copy_fromh(self.kaldi_mmi:calc_diff(mmat, input[2]))
    self.total_frames = self.total_frames + self.kaldi_mmi:get_num_frames()
end

function MMILayer:get_params()
    return nerv.ParamRepo({})
end