diff options
Diffstat (limited to 'kaldi_seq/src/kaldi_mmi.cpp')
-rw-r--r-- | kaldi_seq/src/kaldi_mmi.cpp | 428 |
1 files changed, 428 insertions, 0 deletions
diff --git a/kaldi_seq/src/kaldi_mmi.cpp b/kaldi_seq/src/kaldi_mmi.cpp new file mode 100644 index 0000000..a64abd0 --- /dev/null +++ b/kaldi_seq/src/kaldi_mmi.cpp @@ -0,0 +1,428 @@ +#include <string> +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/faster-decoder.h" +#include "decoder/decodable-matrix.h" +#include "lat/kaldi-lattice.h" +#include "lat/lattice-functions.h" + +#include "nnet/nnet-trnopts.h" +#include "nnet/nnet-component.h" +#include "nnet/nnet-activation.h" +#include "nnet/nnet-nnet.h" +#include "nnet/nnet-pdf-prior.h" +#include "nnet/nnet-utils.h" +#include "base/timer.h" +#include "cudamatrix/cu-device.h" + +#include <iomanip> + +typedef kaldi::BaseFloat BaseFloat; +typedef struct Matrix NervMatrix; + +namespace kaldi{ + namespace nnet1{ + void LatticeAcousticRescore(const kaldi::Matrix<BaseFloat> &log_like, + const TransitionModel &trans_model, + const std::vector<int32> &state_times, + Lattice *lat); + } +} + +extern "C" { +#include "kaldi_mmi.h" +#include "string.h" +#include "assert.h" +#include "nerv/common.h" + + extern NervMatrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status); + extern void nerv_matrix_host_float_copy_fromd(NervMatrix *mat, const NervMatrix *cumat, int, int, int, Status *); + using namespace kaldi; + using namespace kaldi::nnet1; + typedef kaldi::int32 int32; + + struct KaldiMMI { + TransitionModel *trans_model; + RandomAccessLatticeReader *den_lat_reader; + RandomAccessInt32VectorReader *ref_ali_reader; + + Lattice den_lat; + vector<int32> state_times; + + PdfPriorOptions *prior_opts; + PdfPrior *log_prior; + + std::vector<int32> ref_ali; + + Timer *time; + double time_now; + + int32 num_done, num_no_ref_ali, num_no_den_lat, num_other_error; + int32 num_frm_drop; + + kaldi::int64 total_frames; + double lat_like; // total likelihood of the lattice + double lat_ac_like; // acoustic likelihood weighted by posterior. + double total_mmi_obj, mmi_obj; + double total_post_on_ali, post_on_ali; + + int32 num_frames; + + bool binary; + BaseFloat acoustic_scale, lm_scale, old_acoustic_scale; + kaldi::int32 max_frames; + bool drop_frames; + std::string use_gpu; + }; + + KaldiMMI * new_KaldiMMI(const char* arg, const char* mdl, const char* lat, const char* ali) + { + KaldiMMI * mmi = new KaldiMMI; + + const char *usage = + "Perform one iteration of DNN-MMI training by stochastic " + "gradient descent.\n" + "The network weights are updated on each utterance.\n" + "Usage: nnet-train-mmi-sequential [options] <model-in> <transition-model-in> " + "<feature-rspecifier> <den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n" + "e.g.: \n" + " nnet-train-mmi-sequential nnet.init trans.mdl scp:train.scp scp:denlats.scp ark:train.ali " + "nnet.iter1\n"; + + ParseOptions po(usage); + + NnetTrainOptions trn_opts; trn_opts.learn_rate=0.00001; + trn_opts.Register(&po); + + mmi->binary = true; + po.Register("binary", &(mmi->binary), "Write output in binary mode"); + + std::string feature_transform; + po.Register("feature-transform", &feature_transform, + "Feature transform in Nnet format"); + + mmi->prior_opts = new PdfPriorOptions; + PdfPriorOptions &prior_opts = *(mmi->prior_opts); + prior_opts.Register(&po); + + mmi->acoustic_scale = 1.0, + mmi->lm_scale = 1.0, + mmi->old_acoustic_scale = 0.0; + po.Register("acoustic-scale", &(mmi->acoustic_scale), + "Scaling factor for acoustic likelihoods"); + po.Register("lm-scale", &(mmi->lm_scale), + "Scaling factor for \"graph costs\" (including LM costs)"); + po.Register("old-acoustic-scale", &(mmi->old_acoustic_scale), + "Add in the scores in the input lattices with this scale, rather " + "than discarding them."); + mmi->max_frames = 6000; // Allow segments maximum of one minute by default + po.Register("max-frames",&(mmi->max_frames), "Maximum number of frames a segment can have to be processed"); + + mmi->drop_frames = true; + po.Register("drop-frames", &(mmi->drop_frames), + "Drop frames, where is zero den-posterior under numerator path " + "(ie. path not in lattice)"); + + mmi->use_gpu=std::string("yes"); + po.Register("use-gpu", &(mmi->use_gpu), "yes|no|optional, only has effect if compiled with CUDA"); + + int narg = 0; + char args[64][1024]; + char *token; + char *saveptr = NULL; + char tmpstr[1024]; + + strcpy(tmpstr, arg); + strcpy(args[0], "nnet-train-mmi-sequential"); + for(narg = 1, token = strtok_r(tmpstr, " ", &saveptr); token; token = strtok_r(NULL, " ", &saveptr)) + strcpy(args[narg++], token); + strcpy(args[narg++], "0.nnet"); + strcpy(args[narg++], mdl); + strcpy(args[narg++], "feat"); + strcpy(args[narg++], lat); + strcpy(args[narg++], ali); + strcpy(args[narg++], "1.nnet"); + + char **argsv = new char*[narg]; + for(int _i = 0; _i < narg; _i++) + argsv[_i] = args[_i]; + + po.Read(narg, argsv); + delete [] argsv; + + if (po.NumArgs() != 6) { + po.PrintUsage(); + exit(1); + } + + std::string transition_model_filename = po.GetArg(2), + den_lat_rspecifier = po.GetArg(4), + ref_ali_rspecifier = po.GetArg(5); + + // Select the GPU +#if HAVE_CUDA == 1 + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + // Read the class-frame-counts, compute priors + mmi->log_prior = new PdfPrior(prior_opts); + PdfPrior &log_prior = *(mmi->log_prior); + + // Read transition model + mmi->trans_model = new TransitionModel; + ReadKaldiObject(transition_model_filename, mmi->trans_model); + + mmi->den_lat_reader = new RandomAccessLatticeReader(den_lat_rspecifier); + mmi->ref_ali_reader = new RandomAccessInt32VectorReader(ref_ali_rspecifier); + + if (mmi->drop_frames) { + KALDI_LOG << "--drop-frames=true :" + " we will zero gradient for frames with total den/num mismatch." + " The mismatch is likely to be caused by missing correct path " + " from den-lattice due wrong annotation or search error." + " Leaving such frames out stabilizes the training."; + } + + mmi->time = new Timer; + mmi->time_now = 0; + mmi->num_done =0; + mmi->num_no_ref_ali = 0; + mmi->num_no_den_lat = 0; + mmi->num_other_error = 0; + mmi->total_frames = 0; + mmi->num_frm_drop = 0; + + mmi->total_mmi_obj = 0.0, mmi->mmi_obj = 0.0; + mmi->total_post_on_ali = 0.0, mmi->post_on_ali = 0.0; + return mmi; + } + + void destroy_KaldiMMI(KaldiMMI *mmi) + { + delete mmi->trans_model; + delete mmi->den_lat_reader; + delete mmi->ref_ali_reader; + delete mmi->time; + delete mmi->prior_opts; + delete mmi->log_prior; + } + + int check_mmi(KaldiMMI *mmi, const NervMatrix* mat, const char *key) + { + std::string utt(key); + if (!mmi->den_lat_reader->HasKey(utt)) { + KALDI_WARN << "Utterance " << utt << ": found no lattice."; + mmi->num_no_den_lat++; + return 0; + } + if (!mmi->ref_ali_reader->HasKey(utt)) { + KALDI_WARN << "Utterance " << utt << ": found no reference alignment."; + mmi->num_no_ref_ali++; + return 0; + } + + assert(sizeof(BaseFloat) == sizeof(float)); + // 1) get the features, numerator alignment + mmi->ref_ali = mmi->ref_ali_reader->Value(utt); + long mat_nrow = mat->nrow, mat_ncol = mat->ncol; + // check for temporal length of numerator alignments + if (static_cast<MatrixIndexT>(mmi->ref_ali.size()) != mat_nrow) { + KALDI_WARN << "Numerator alignment has wrong length " + << mmi->ref_ali.size() << " vs. "<< mat_nrow; + mmi->num_other_error++; + return 0; + } + if (mat_nrow > mmi->max_frames) { + KALDI_WARN << "Utterance " << utt << ": Skipped because it has " << mat_nrow << + " frames, which is more than " << mmi->max_frames << "."; + mmi->num_other_error++; + return 0; + } + // 2) get the denominator lattice, preprocess + mmi->den_lat = mmi->den_lat_reader->Value(utt); + Lattice &den_lat = mmi->den_lat; + if (den_lat.Start() == -1) { + KALDI_WARN << "Empty lattice for utt " << utt; + mmi->num_other_error++; + return 0; + } + if (mmi->old_acoustic_scale != 1.0) { + fst::ScaleLattice(fst::AcousticLatticeScale(mmi->old_acoustic_scale), + &den_lat); + } + // optional sort it topologically + kaldi::uint64 props = den_lat.Properties(fst::kFstProperties, false); + if (!(props & fst::kTopSorted)) { + if (fst::TopSort(&den_lat) == false) + KALDI_ERR << "Cycles detected in lattice."; + } + // get the lattice length and times of states + mmi->state_times.clear(); + vector<int32> &state_times = mmi->state_times; + int32 max_time = kaldi::LatticeStateTimes(den_lat, &state_times); + // check for temporal length of denominator lattices + if (max_time != mat_nrow) { + KALDI_WARN << "Denominator lattice has wrong length " + << max_time << " vs. " << mat_nrow; + mmi->num_other_error++; + return 0; + } + + return 1; + } + + NervMatrix * calc_diff_mmi(KaldiMMI * mmi, NervMatrix * mat, const char * key) + { + std::string utt(key); + assert(sizeof(BaseFloat) == sizeof(float)); + + kaldi::Matrix<BaseFloat> nnet_out_h, nnet_diff_h; + nnet_out_h.Resize(mat->nrow, mat->ncol, kUndefined); + + size_t stride = mat->stride; + for (int i = 0; i < mat->nrow; i++) + { + const BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride); + BaseFloat *row = nnet_out_h.RowData(i); + memmove(row, nerv_row, sizeof(BaseFloat) * mat->ncol); + } + + mmi->num_frames = nnet_out_h.NumRows(); + + PdfPriorOptions &prior_opts = *(mmi->prior_opts); + if (prior_opts.class_frame_counts != "") { + CuMatrix<BaseFloat> nnet_out; + nnet_out.Resize(mat->nrow, mat->ncol, kUndefined); + nnet_out.CopyFromMat(nnet_out_h); + mmi->log_prior->SubtractOnLogpost(&nnet_out); + nnet_out.CopyToMat(&nnet_out_h); + nnet_out.Resize(0,0); + } + + // 4) rescore the latice + LatticeAcousticRescore(nnet_out_h, *(mmi->trans_model), mmi->state_times, &(mmi->den_lat)); + if (mmi->acoustic_scale != 1.0 || mmi->lm_scale != 1.0) + fst::ScaleLattice(fst::LatticeScale(mmi->lm_scale, mmi->acoustic_scale), &(mmi->den_lat)); + + kaldi::Posterior post; + mmi->lat_like = kaldi::LatticeForwardBackward(mmi->den_lat, &post, &(mmi->lat_ac_like)); + + nnet_diff_h.Resize(mat->nrow, mat->ncol, kSetZero); + for (int32 t = 0; t < post.size(); t++) { + for (int32 arc = 0; arc < post[t].size(); arc++) { + int32 pdf = mmi->trans_model->TransitionIdToPdf(post[t][arc].first); + nnet_diff_h(t, pdf) += post[t][arc].second; + } + } + + double path_ac_like = 0.0; + for(int32 t=0; t<mmi->num_frames; t++) { + int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]); + path_ac_like += nnet_out_h(t,pdf); + } + path_ac_like *= mmi->acoustic_scale; + mmi->mmi_obj = path_ac_like - mmi->lat_like; + + mmi->post_on_ali = 0.0; + for(int32 t=0; t<mmi->num_frames; t++) { + int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]); + double posterior = nnet_diff_h(t, pdf); + mmi->post_on_ali += posterior; + } + + KALDI_VLOG(1) << "Lattice #" << mmi->num_done + 1 << " processed" + << " (" << utt << "): found " << mmi->den_lat.NumStates() + << " states and " << fst::NumArcs(mmi->den_lat) << " arcs."; + + KALDI_VLOG(1) << "Utterance " << utt << ": Average MMI obj. value = " + << (mmi->mmi_obj/mmi->num_frames) << " over " << mmi->num_frames + << " frames," + << " (Avg. den-posterior on ali " << mmi->post_on_ali/mmi->num_frames << ")"; + + // 7a) Search for the frames with num/den mismatch + int32 frm_drop = 0; + std::vector<int32> frm_drop_vec; + for(int32 t=0; t<mmi->num_frames; t++) { + int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]); + double posterior = nnet_diff_h(t, pdf); + if(posterior < 1e-20) { + frm_drop++; + frm_drop_vec.push_back(t); + } + } + + // 8) subtract the pdf-Viterbi-path + for(int32 t=0; t<nnet_diff_h.NumRows(); t++) { + int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]); + nnet_diff_h(t, pdf) -= 1.0; + } + + // 9) Drop mismatched frames from the training by zeroing the derivative + if(mmi->drop_frames) { + for(int32 i=0; i<frm_drop_vec.size(); i++) { + nnet_diff_h.Row(frm_drop_vec[i]).Set(0.0); + } + mmi->num_frm_drop += frm_drop; + } + + // Report the frame dropping + if (frm_drop > 0) { + std::stringstream ss; + ss << (mmi->drop_frames?"Dropped":"[dropping disabled] Would drop") + << " frames in " << utt << " " << frm_drop << "/" << mmi->num_frames << ","; + //get frame intervals from vec frm_drop_vec + ss << " intervals :"; + //search for streaks of consecutive numbers: + int32 beg_streak=frm_drop_vec[0]; + int32 len_streak=0; + int32 i; + for(i=0; i<frm_drop_vec.size(); i++,len_streak++) { + if(beg_streak + len_streak != frm_drop_vec[i]) { + ss << " " << beg_streak << ".." << frm_drop_vec[i-1] << "frm"; + beg_streak = frm_drop_vec[i]; + len_streak = 0; + } + } + ss << " " << beg_streak << ".." << frm_drop_vec[i-1] << "frm"; + //print + KALDI_WARN << ss.str(); + } + + assert(mat->nrow == nnet_diff_h.NumRows() && mat->ncol == nnet_diff_h.NumCols()); + stride = mat->stride; + for (int i = 0; i < mat->nrow; i++) + { + const BaseFloat *row = nnet_diff_h.RowData(i); + BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride); + memmove(nerv_row, row, sizeof(BaseFloat) * mat->ncol); + } + nnet_diff_h.Resize(0,0); + + // increase time counter + mmi->total_mmi_obj += mmi->mmi_obj; + mmi->total_post_on_ali += mmi->post_on_ali; + mmi->total_frames += mmi->num_frames; + mmi->num_done++; + + if (mmi->num_done % 100 == 0) { + mmi->time_now = mmi->time->Elapsed(); + KALDI_VLOG(1) << "After " << mmi->num_done << " utterances: time elapsed = " + << mmi->time_now/60 << " min; processed " << mmi->total_frames/mmi->time_now + << " frames per second."; +#if HAVE_CUDA==1 + // check the GPU is not overheated + CuDevice::Instantiate().CheckGpuHealth(); +#endif + } + return mat; + } + + double get_num_frames_mmi(const KaldiMMI *mmi) + { + return (double)mmi->num_frames; + } + +} |