require 'libkaldiseq' local MPELayer = nerv.class("nerv.MPELayer", "nerv.Layer") function MPELayer:__init(id, global_conf, layer_conf) self.id = id self.gconf = global_conf self.dim_in = layer_conf.dim_in self.dim_out = layer_conf.dim_out self.arg = layer_conf.cmd.arg self.mdl = layer_conf.cmd.mdl self.lat = layer_conf.cmd.lat self.ali = layer_conf.cmd.ali self:check_dim_len(2, -1) -- two inputs: nn output and utt key end function MPELayer:init(batch_size) self.total_correct = 0 self.total_frames = 0 self.kaldi_mpe = nerv.KaldiMPE(self.arg, self.mdl, self.lat, self.ali) if self.kaldi_mpe == nil then nerv.error("kaldi arguments is expected: %s %s %s %s", self.arg, self.mdl, self.lat, self.ali) end end function MPELayer:batch_resize(batch_size) -- do nothing end function MPELayer:update(bp_err, input, output) -- no params, therefore do nothing end function MPELayer:propagate(input, output) self.valid = false self.valid = self.kaldi_mpe:check(input[1], input[2]) return self.valid end function MPELayer:back_propagate(bp_err, next_bp_err, input, output) if self.valid ~= true then nerv.error("kaldi sequence training back_propagate fail") end local mmat = input[1]:new_to_host() next_bp_err[1]:copy_fromh(self.kaldi_mpe:calc_diff(mmat, input[2])) self.total_frames = self.total_frames + self.kaldi_mpe:get_num_frames() self.total_correct = self.total_correct + self.kaldi_mpe:get_utt_frame_acc() end function MPELayer:get_params() return nerv.ParamRepo({}) end