summaryrefslogtreecommitdiff
path: root/kaldi_seq/src
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_seq/src')
-rw-r--r--kaldi_seq/src/init.c60
-rw-r--r--kaldi_seq/src/kaldi_mmi.cpp428
-rw-r--r--kaldi_seq/src/kaldi_mmi.h20
3 files changed, 507 insertions, 1 deletions
diff --git a/kaldi_seq/src/init.c b/kaldi_seq/src/init.c
index 88d0a80..9b38056 100644
--- a/kaldi_seq/src/init.c
+++ b/kaldi_seq/src/init.c
@@ -1,8 +1,10 @@
#include "nerv/common.h"
#include "kaldi_mpe.h"
+#include "kaldi_mmi.h"
#include <stdio.h>
const char *nerv_kaldi_mpe_tname = "nerv.KaldiMPE";
+const char *nerv_kaldi_mmi_tname = "nerv.KaldiMMI";
const char *nerv_matrix_cuda_float_tname = "nerv.CuMatrixFloat";
const char *nerv_matrix_host_float_tname = "nerv.MMatrixFloat";
@@ -63,11 +65,67 @@ static const luaL_Reg mpe_methods[] = {
static void mpe_init(lua_State *L) {
luaT_newmetatable(L, nerv_kaldi_mpe_tname, NULL,
- mpe_new, mpe_destroy, NULL);
+ mpe_new, mpe_destroy, NULL);
luaL_register(L, NULL, mpe_methods);
lua_pop(L, 1);
}
+static int mmi_new(lua_State *L) {
+ const char *arg = luaL_checkstring(L, 1);
+ const char *mdl = luaL_checkstring(L, 2);
+ const char *lat = luaL_checkstring(L, 3);
+ const char *ali = luaL_checkstring(L, 4);
+ KaldiMMI *mmi = new_KaldiMMI(arg, mdl, lat, ali);
+ luaT_pushudata(L, mmi, nerv_kaldi_mmi_tname);
+ return 1;
+}
+
+static int mmi_destroy(lua_State *L) {
+ KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname);
+ destroy_KaldiMMI(mmi);
+ return 0;
+}
+
+static int mmi_check(lua_State *L) {
+ KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname);
+ const Matrix *cumat = luaT_checkudata(L, 2, nerv_matrix_cuda_float_tname);
+ const char *utt = luaL_checkstring(L, 3);
+
+ lua_pushboolean(L, check_mmi(mmi, cumat, utt));
+ return 1;
+}
+
+static int mmi_calc_diff(lua_State *L) {
+ KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname);
+ Matrix *mat = luaT_checkudata(L, 2, nerv_matrix_host_float_tname);
+ const char *utt = luaL_checkstring(L, 3);
+
+ Matrix *diff = calc_diff_mmi(mmi, mat, utt);
+ luaT_pushudata(L, diff, nerv_matrix_host_float_tname);
+ return 1;
+}
+
+static int mmi_get_num_frames(lua_State *L) {
+ KaldiMMI *mmi = luaT_checkudata(L, 1, nerv_kaldi_mmi_tname);
+ lua_pushnumber(L, get_num_frames_mmi(mmi));
+ return 1;
+}
+
+static const luaL_Reg mmi_methods[] = {
+ {"check", mmi_check},
+ {"calc_diff", mmi_calc_diff},
+ {"get_num_frames", mmi_get_num_frames},
+ {NULL, NULL}
+};
+
+static void mmi_init(lua_State *L) {
+ luaT_newmetatable(L, nerv_kaldi_mmi_tname, NULL,
+ mmi_new, mmi_destroy, NULL);
+ luaL_register(L, NULL, mmi_methods);
+ lua_pop(L, 1);
+}
+
void kaldi_seq_init(lua_State *L) {
mpe_init(L);
+ mmi_init(L);
}
diff --git a/kaldi_seq/src/kaldi_mmi.cpp b/kaldi_seq/src/kaldi_mmi.cpp
new file mode 100644
index 0000000..a64abd0
--- /dev/null
+++ b/kaldi_seq/src/kaldi_mmi.cpp
@@ -0,0 +1,428 @@
+#include <string>
+#include "base/kaldi-common.h"
+#include "util/common-utils.h"
+#include "tree/context-dep.h"
+#include "hmm/transition-model.h"
+#include "fstext/fstext-lib.h"
+#include "decoder/faster-decoder.h"
+#include "decoder/decodable-matrix.h"
+#include "lat/kaldi-lattice.h"
+#include "lat/lattice-functions.h"
+
+#include "nnet/nnet-trnopts.h"
+#include "nnet/nnet-component.h"
+#include "nnet/nnet-activation.h"
+#include "nnet/nnet-nnet.h"
+#include "nnet/nnet-pdf-prior.h"
+#include "nnet/nnet-utils.h"
+#include "base/timer.h"
+#include "cudamatrix/cu-device.h"
+
+#include <iomanip>
+
+typedef kaldi::BaseFloat BaseFloat;
+typedef struct Matrix NervMatrix;
+
+namespace kaldi{
+ namespace nnet1{
+ void LatticeAcousticRescore(const kaldi::Matrix<BaseFloat> &log_like,
+ const TransitionModel &trans_model,
+ const std::vector<int32> &state_times,
+ Lattice *lat);
+ }
+}
+
+extern "C" {
+#include "kaldi_mmi.h"
+#include "string.h"
+#include "assert.h"
+#include "nerv/common.h"
+
+ extern NervMatrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status);
+ extern void nerv_matrix_host_float_copy_fromd(NervMatrix *mat, const NervMatrix *cumat, int, int, int, Status *);
+ using namespace kaldi;
+ using namespace kaldi::nnet1;
+ typedef kaldi::int32 int32;
+
+ struct KaldiMMI {
+ TransitionModel *trans_model;
+ RandomAccessLatticeReader *den_lat_reader;
+ RandomAccessInt32VectorReader *ref_ali_reader;
+
+ Lattice den_lat;
+ vector<int32> state_times;
+
+ PdfPriorOptions *prior_opts;
+ PdfPrior *log_prior;
+
+ std::vector<int32> ref_ali;
+
+ Timer *time;
+ double time_now;
+
+ int32 num_done, num_no_ref_ali, num_no_den_lat, num_other_error;
+ int32 num_frm_drop;
+
+ kaldi::int64 total_frames;
+ double lat_like; // total likelihood of the lattice
+ double lat_ac_like; // acoustic likelihood weighted by posterior.
+ double total_mmi_obj, mmi_obj;
+ double total_post_on_ali, post_on_ali;
+
+ int32 num_frames;
+
+ bool binary;
+ BaseFloat acoustic_scale, lm_scale, old_acoustic_scale;
+ kaldi::int32 max_frames;
+ bool drop_frames;
+ std::string use_gpu;
+ };
+
+ KaldiMMI * new_KaldiMMI(const char* arg, const char* mdl, const char* lat, const char* ali)
+ {
+ KaldiMMI * mmi = new KaldiMMI;
+
+ const char *usage =
+ "Perform one iteration of DNN-MMI training by stochastic "
+ "gradient descent.\n"
+ "The network weights are updated on each utterance.\n"
+ "Usage: nnet-train-mmi-sequential [options] <model-in> <transition-model-in> "
+ "<feature-rspecifier> <den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n"
+ "e.g.: \n"
+ " nnet-train-mmi-sequential nnet.init trans.mdl scp:train.scp scp:denlats.scp ark:train.ali "
+ "nnet.iter1\n";
+
+ ParseOptions po(usage);
+
+ NnetTrainOptions trn_opts; trn_opts.learn_rate=0.00001;
+ trn_opts.Register(&po);
+
+ mmi->binary = true;
+ po.Register("binary", &(mmi->binary), "Write output in binary mode");
+
+ std::string feature_transform;
+ po.Register("feature-transform", &feature_transform,
+ "Feature transform in Nnet format");
+
+ mmi->prior_opts = new PdfPriorOptions;
+ PdfPriorOptions &prior_opts = *(mmi->prior_opts);
+ prior_opts.Register(&po);
+
+ mmi->acoustic_scale = 1.0,
+ mmi->lm_scale = 1.0,
+ mmi->old_acoustic_scale = 0.0;
+ po.Register("acoustic-scale", &(mmi->acoustic_scale),
+ "Scaling factor for acoustic likelihoods");
+ po.Register("lm-scale", &(mmi->lm_scale),
+ "Scaling factor for \"graph costs\" (including LM costs)");
+ po.Register("old-acoustic-scale", &(mmi->old_acoustic_scale),
+ "Add in the scores in the input lattices with this scale, rather "
+ "than discarding them.");
+ mmi->max_frames = 6000; // Allow segments maximum of one minute by default
+ po.Register("max-frames",&(mmi->max_frames), "Maximum number of frames a segment can have to be processed");
+
+ mmi->drop_frames = true;
+ po.Register("drop-frames", &(mmi->drop_frames),
+ "Drop frames, where is zero den-posterior under numerator path "
+ "(ie. path not in lattice)");
+
+ mmi->use_gpu=std::string("yes");
+ po.Register("use-gpu", &(mmi->use_gpu), "yes|no|optional, only has effect if compiled with CUDA");
+
+ int narg = 0;
+ char args[64][1024];
+ char *token;
+ char *saveptr = NULL;
+ char tmpstr[1024];
+
+ strcpy(tmpstr, arg);
+ strcpy(args[0], "nnet-train-mmi-sequential");
+ for(narg = 1, token = strtok_r(tmpstr, " ", &saveptr); token; token = strtok_r(NULL, " ", &saveptr))
+ strcpy(args[narg++], token);
+ strcpy(args[narg++], "0.nnet");
+ strcpy(args[narg++], mdl);
+ strcpy(args[narg++], "feat");
+ strcpy(args[narg++], lat);
+ strcpy(args[narg++], ali);
+ strcpy(args[narg++], "1.nnet");
+
+ char **argsv = new char*[narg];
+ for(int _i = 0; _i < narg; _i++)
+ argsv[_i] = args[_i];
+
+ po.Read(narg, argsv);
+ delete [] argsv;
+
+ if (po.NumArgs() != 6) {
+ po.PrintUsage();
+ exit(1);
+ }
+
+ std::string transition_model_filename = po.GetArg(2),
+ den_lat_rspecifier = po.GetArg(4),
+ ref_ali_rspecifier = po.GetArg(5);
+
+ // Select the GPU
+#if HAVE_CUDA == 1
+ CuDevice::Instantiate().SelectGpuId(use_gpu);
+#endif
+
+ // Read the class-frame-counts, compute priors
+ mmi->log_prior = new PdfPrior(prior_opts);
+ PdfPrior &log_prior = *(mmi->log_prior);
+
+ // Read transition model
+ mmi->trans_model = new TransitionModel;
+ ReadKaldiObject(transition_model_filename, mmi->trans_model);
+
+ mmi->den_lat_reader = new RandomAccessLatticeReader(den_lat_rspecifier);
+ mmi->ref_ali_reader = new RandomAccessInt32VectorReader(ref_ali_rspecifier);
+
+ if (mmi->drop_frames) {
+ KALDI_LOG << "--drop-frames=true :"
+ " we will zero gradient for frames with total den/num mismatch."
+ " The mismatch is likely to be caused by missing correct path "
+ " from den-lattice due wrong annotation or search error."
+ " Leaving such frames out stabilizes the training.";
+ }
+
+ mmi->time = new Timer;
+ mmi->time_now = 0;
+ mmi->num_done =0;
+ mmi->num_no_ref_ali = 0;
+ mmi->num_no_den_lat = 0;
+ mmi->num_other_error = 0;
+ mmi->total_frames = 0;
+ mmi->num_frm_drop = 0;
+
+ mmi->total_mmi_obj = 0.0, mmi->mmi_obj = 0.0;
+ mmi->total_post_on_ali = 0.0, mmi->post_on_ali = 0.0;
+ return mmi;
+ }
+
+ void destroy_KaldiMMI(KaldiMMI *mmi)
+ {
+ delete mmi->trans_model;
+ delete mmi->den_lat_reader;
+ delete mmi->ref_ali_reader;
+ delete mmi->time;
+ delete mmi->prior_opts;
+ delete mmi->log_prior;
+ }
+
+ int check_mmi(KaldiMMI *mmi, const NervMatrix* mat, const char *key)
+ {
+ std::string utt(key);
+ if (!mmi->den_lat_reader->HasKey(utt)) {
+ KALDI_WARN << "Utterance " << utt << ": found no lattice.";
+ mmi->num_no_den_lat++;
+ return 0;
+ }
+ if (!mmi->ref_ali_reader->HasKey(utt)) {
+ KALDI_WARN << "Utterance " << utt << ": found no reference alignment.";
+ mmi->num_no_ref_ali++;
+ return 0;
+ }
+
+ assert(sizeof(BaseFloat) == sizeof(float));
+ // 1) get the features, numerator alignment
+ mmi->ref_ali = mmi->ref_ali_reader->Value(utt);
+ long mat_nrow = mat->nrow, mat_ncol = mat->ncol;
+ // check for temporal length of numerator alignments
+ if (static_cast<MatrixIndexT>(mmi->ref_ali.size()) != mat_nrow) {
+ KALDI_WARN << "Numerator alignment has wrong length "
+ << mmi->ref_ali.size() << " vs. "<< mat_nrow;
+ mmi->num_other_error++;
+ return 0;
+ }
+ if (mat_nrow > mmi->max_frames) {
+ KALDI_WARN << "Utterance " << utt << ": Skipped because it has " << mat_nrow <<
+ " frames, which is more than " << mmi->max_frames << ".";
+ mmi->num_other_error++;
+ return 0;
+ }
+ // 2) get the denominator lattice, preprocess
+ mmi->den_lat = mmi->den_lat_reader->Value(utt);
+ Lattice &den_lat = mmi->den_lat;
+ if (den_lat.Start() == -1) {
+ KALDI_WARN << "Empty lattice for utt " << utt;
+ mmi->num_other_error++;
+ return 0;
+ }
+ if (mmi->old_acoustic_scale != 1.0) {
+ fst::ScaleLattice(fst::AcousticLatticeScale(mmi->old_acoustic_scale),
+ &den_lat);
+ }
+ // optional sort it topologically
+ kaldi::uint64 props = den_lat.Properties(fst::kFstProperties, false);
+ if (!(props & fst::kTopSorted)) {
+ if (fst::TopSort(&den_lat) == false)
+ KALDI_ERR << "Cycles detected in lattice.";
+ }
+ // get the lattice length and times of states
+ mmi->state_times.clear();
+ vector<int32> &state_times = mmi->state_times;
+ int32 max_time = kaldi::LatticeStateTimes(den_lat, &state_times);
+ // check for temporal length of denominator lattices
+ if (max_time != mat_nrow) {
+ KALDI_WARN << "Denominator lattice has wrong length "
+ << max_time << " vs. " << mat_nrow;
+ mmi->num_other_error++;
+ return 0;
+ }
+
+ return 1;
+ }
+
+ NervMatrix * calc_diff_mmi(KaldiMMI * mmi, NervMatrix * mat, const char * key)
+ {
+ std::string utt(key);
+ assert(sizeof(BaseFloat) == sizeof(float));
+
+ kaldi::Matrix<BaseFloat> nnet_out_h, nnet_diff_h;
+ nnet_out_h.Resize(mat->nrow, mat->ncol, kUndefined);
+
+ size_t stride = mat->stride;
+ for (int i = 0; i < mat->nrow; i++)
+ {
+ const BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
+ BaseFloat *row = nnet_out_h.RowData(i);
+ memmove(row, nerv_row, sizeof(BaseFloat) * mat->ncol);
+ }
+
+ mmi->num_frames = nnet_out_h.NumRows();
+
+ PdfPriorOptions &prior_opts = *(mmi->prior_opts);
+ if (prior_opts.class_frame_counts != "") {
+ CuMatrix<BaseFloat> nnet_out;
+ nnet_out.Resize(mat->nrow, mat->ncol, kUndefined);
+ nnet_out.CopyFromMat(nnet_out_h);
+ mmi->log_prior->SubtractOnLogpost(&nnet_out);
+ nnet_out.CopyToMat(&nnet_out_h);
+ nnet_out.Resize(0,0);
+ }
+
+ // 4) rescore the latice
+ LatticeAcousticRescore(nnet_out_h, *(mmi->trans_model), mmi->state_times, &(mmi->den_lat));
+ if (mmi->acoustic_scale != 1.0 || mmi->lm_scale != 1.0)
+ fst::ScaleLattice(fst::LatticeScale(mmi->lm_scale, mmi->acoustic_scale), &(mmi->den_lat));
+
+ kaldi::Posterior post;
+ mmi->lat_like = kaldi::LatticeForwardBackward(mmi->den_lat, &post, &(mmi->lat_ac_like));
+
+ nnet_diff_h.Resize(mat->nrow, mat->ncol, kSetZero);
+ for (int32 t = 0; t < post.size(); t++) {
+ for (int32 arc = 0; arc < post[t].size(); arc++) {
+ int32 pdf = mmi->trans_model->TransitionIdToPdf(post[t][arc].first);
+ nnet_diff_h(t, pdf) += post[t][arc].second;
+ }
+ }
+
+ double path_ac_like = 0.0;
+ for(int32 t=0; t<mmi->num_frames; t++) {
+ int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]);
+ path_ac_like += nnet_out_h(t,pdf);
+ }
+ path_ac_like *= mmi->acoustic_scale;
+ mmi->mmi_obj = path_ac_like - mmi->lat_like;
+
+ mmi->post_on_ali = 0.0;
+ for(int32 t=0; t<mmi->num_frames; t++) {
+ int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]);
+ double posterior = nnet_diff_h(t, pdf);
+ mmi->post_on_ali += posterior;
+ }
+
+ KALDI_VLOG(1) << "Lattice #" << mmi->num_done + 1 << " processed"
+ << " (" << utt << "): found " << mmi->den_lat.NumStates()
+ << " states and " << fst::NumArcs(mmi->den_lat) << " arcs.";
+
+ KALDI_VLOG(1) << "Utterance " << utt << ": Average MMI obj. value = "
+ << (mmi->mmi_obj/mmi->num_frames) << " over " << mmi->num_frames
+ << " frames,"
+ << " (Avg. den-posterior on ali " << mmi->post_on_ali/mmi->num_frames << ")";
+
+ // 7a) Search for the frames with num/den mismatch
+ int32 frm_drop = 0;
+ std::vector<int32> frm_drop_vec;
+ for(int32 t=0; t<mmi->num_frames; t++) {
+ int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]);
+ double posterior = nnet_diff_h(t, pdf);
+ if(posterior < 1e-20) {
+ frm_drop++;
+ frm_drop_vec.push_back(t);
+ }
+ }
+
+ // 8) subtract the pdf-Viterbi-path
+ for(int32 t=0; t<nnet_diff_h.NumRows(); t++) {
+ int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]);
+ nnet_diff_h(t, pdf) -= 1.0;
+ }
+
+ // 9) Drop mismatched frames from the training by zeroing the derivative
+ if(mmi->drop_frames) {
+ for(int32 i=0; i<frm_drop_vec.size(); i++) {
+ nnet_diff_h.Row(frm_drop_vec[i]).Set(0.0);
+ }
+ mmi->num_frm_drop += frm_drop;
+ }
+
+ // Report the frame dropping
+ if (frm_drop > 0) {
+ std::stringstream ss;
+ ss << (mmi->drop_frames?"Dropped":"[dropping disabled] Would drop")
+ << " frames in " << utt << " " << frm_drop << "/" << mmi->num_frames << ",";
+ //get frame intervals from vec frm_drop_vec
+ ss << " intervals :";
+ //search for streaks of consecutive numbers:
+ int32 beg_streak=frm_drop_vec[0];
+ int32 len_streak=0;
+ int32 i;
+ for(i=0; i<frm_drop_vec.size(); i++,len_streak++) {
+ if(beg_streak + len_streak != frm_drop_vec[i]) {
+ ss << " " << beg_streak << ".." << frm_drop_vec[i-1] << "frm";
+ beg_streak = frm_drop_vec[i];
+ len_streak = 0;
+ }
+ }
+ ss << " " << beg_streak << ".." << frm_drop_vec[i-1] << "frm";
+ //print
+ KALDI_WARN << ss.str();
+ }
+
+ assert(mat->nrow == nnet_diff_h.NumRows() && mat->ncol == nnet_diff_h.NumCols());
+ stride = mat->stride;
+ for (int i = 0; i < mat->nrow; i++)
+ {
+ const BaseFloat *row = nnet_diff_h.RowData(i);
+ BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
+ memmove(nerv_row, row, sizeof(BaseFloat) * mat->ncol);
+ }
+ nnet_diff_h.Resize(0,0);
+
+ // increase time counter
+ mmi->total_mmi_obj += mmi->mmi_obj;
+ mmi->total_post_on_ali += mmi->post_on_ali;
+ mmi->total_frames += mmi->num_frames;
+ mmi->num_done++;
+
+ if (mmi->num_done % 100 == 0) {
+ mmi->time_now = mmi->time->Elapsed();
+ KALDI_VLOG(1) << "After " << mmi->num_done << " utterances: time elapsed = "
+ << mmi->time_now/60 << " min; processed " << mmi->total_frames/mmi->time_now
+ << " frames per second.";
+#if HAVE_CUDA==1
+ // check the GPU is not overheated
+ CuDevice::Instantiate().CheckGpuHealth();
+#endif
+ }
+ return mat;
+ }
+
+ double get_num_frames_mmi(const KaldiMMI *mmi)
+ {
+ return (double)mmi->num_frames;
+ }
+
+}
diff --git a/kaldi_seq/src/kaldi_mmi.h b/kaldi_seq/src/kaldi_mmi.h
new file mode 100644
index 0000000..ce6787c
--- /dev/null
+++ b/kaldi_seq/src/kaldi_mmi.h
@@ -0,0 +1,20 @@
+#ifndef NERV_kaldi_KALDI_MMI
+#define NERV_kaldi_KALDI_MMI
+#include "nerv/matrix/matrix.h"
+#include "nerv/common.h"
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ typedef struct KaldiMMI KaldiMMI;
+
+ KaldiMMI * new_KaldiMMI(const char*, const char*, const char*, const char*);
+ void destroy_KaldiMMI(KaldiMMI *);
+ int check_mmi(KaldiMMI *, const Matrix*, const char *);
+ Matrix * calc_diff_mmi(KaldiMMI *, Matrix *, const char *);
+ double get_num_frames_mmi(const KaldiMMI *);
+
+#ifdef __cplusplus
+}
+#endif
+#endif