summaryrefslogblamecommitdiff
path: root/kaldi_seq/src/kaldi_mmi.cpp
blob: 6f8dad91a8e40a4b1289d518ca1f959c0151dbdd (plain) (tree)





































                                                                             

                                    
 
                                                                                                              




























































































































                                                                                                                     
                                                          



                                                      































































































































































































































































                                                                                                      
#include <string>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "tree/context-dep.h"
#include "hmm/transition-model.h"
#include "fstext/fstext-lib.h"
#include "decoder/faster-decoder.h"
#include "decoder/decodable-matrix.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"

#include "nnet/nnet-trnopts.h"
#include "nnet/nnet-component.h"
#include "nnet/nnet-activation.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-pdf-prior.h"
#include "nnet/nnet-utils.h"
#include "base/timer.h"
#include "cudamatrix/cu-device.h"

#include <iomanip>

typedef kaldi::BaseFloat BaseFloat;
typedef struct Matrix NervMatrix;

namespace kaldi{
    namespace nnet1{
        void LatticeAcousticRescore(const kaldi::Matrix<BaseFloat> &log_like,
                const TransitionModel &trans_model,
                const std::vector<int32> &state_times,
                Lattice *lat);
    }
}

extern "C" {
#include "kaldi_mmi.h"
#include "string.h"
#include "assert.h"
#include "nerv/lib/common.h"
#include "nerv/lib/matrix/mmatrix.h"

    extern NervMatrix *nerv_matrix_host_float_create(long nrow, long ncol, MContext *context, Status *status);
    extern void nerv_matrix_host_float_copy_fromd(NervMatrix *mat, const NervMatrix *cumat, int, int, int, Status *);
    using namespace kaldi;
    using namespace kaldi::nnet1;
    typedef kaldi::int32 int32;

    struct KaldiMMI {
        TransitionModel *trans_model;
        RandomAccessLatticeReader *den_lat_reader;
        RandomAccessInt32VectorReader *ref_ali_reader;

        Lattice den_lat;
        vector<int32> state_times;

        PdfPriorOptions *prior_opts;
        PdfPrior *log_prior;

        std::vector<int32> ref_ali;

        Timer *time;
        double time_now;

        int32 num_done, num_no_ref_ali, num_no_den_lat, num_other_error;
        int32 num_frm_drop;

        kaldi::int64 total_frames;
        double lat_like; // total likelihood of the lattice
        double lat_ac_like; // acoustic likelihood weighted by posterior.
        double total_mmi_obj, mmi_obj;
        double total_post_on_ali, post_on_ali;

        int32 num_frames;

        bool binary;
        BaseFloat acoustic_scale, lm_scale, old_acoustic_scale;
        kaldi::int32 max_frames;
        bool drop_frames;
        std::string use_gpu;
    };

    KaldiMMI * new_KaldiMMI(const char* arg, const char* mdl, const char* lat, const char* ali)
    {
        KaldiMMI * mmi = new KaldiMMI;

        const char *usage =
            "Perform one iteration of DNN-MMI training by stochastic "
            "gradient descent.\n"
            "The network weights are updated on each utterance.\n"
            "Usage:  nnet-train-mmi-sequential [options] <model-in> <transition-model-in> "
            "<feature-rspecifier> <den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n"
            "e.g.: \n"
            " nnet-train-mmi-sequential nnet.init trans.mdl scp:train.scp scp:denlats.scp ark:train.ali "
            "nnet.iter1\n";

        ParseOptions po(usage);

        NnetTrainOptions trn_opts; trn_opts.learn_rate=0.00001;
        trn_opts.Register(&po);

        mmi->binary = true;
        po.Register("binary", &(mmi->binary), "Write output in binary mode");

        std::string feature_transform;
        po.Register("feature-transform", &feature_transform,
                "Feature transform in Nnet format");

        mmi->prior_opts = new PdfPriorOptions;
        PdfPriorOptions &prior_opts = *(mmi->prior_opts);
        prior_opts.Register(&po);

        mmi->acoustic_scale = 1.0,
            mmi->lm_scale = 1.0,
            mmi->old_acoustic_scale = 0.0;
        po.Register("acoustic-scale", &(mmi->acoustic_scale),
                "Scaling factor for acoustic likelihoods");
        po.Register("lm-scale", &(mmi->lm_scale),
                "Scaling factor for \"graph costs\" (including LM costs)");
        po.Register("old-acoustic-scale", &(mmi->old_acoustic_scale),
                "Add in the scores in the input lattices with this scale, rather "
                "than discarding them.");
        mmi->max_frames = 6000; // Allow segments maximum of one minute by default
        po.Register("max-frames",&(mmi->max_frames), "Maximum number of frames a segment can have to be processed");

        mmi->drop_frames = true;
        po.Register("drop-frames", &(mmi->drop_frames),
                "Drop frames, where is zero den-posterior under numerator path "
                "(ie. path not in lattice)");

        mmi->use_gpu=std::string("yes");
        po.Register("use-gpu", &(mmi->use_gpu), "yes|no|optional, only has effect if compiled with CUDA");

        int narg = 0;
        char args[64][1024];
        char *token;
        char *saveptr = NULL;
        char tmpstr[1024];

        strcpy(tmpstr, arg);
        strcpy(args[0], "nnet-train-mmi-sequential");
        for(narg = 1, token = strtok_r(tmpstr, " ", &saveptr); token; token = strtok_r(NULL, " ", &saveptr))
            strcpy(args[narg++], token);
        strcpy(args[narg++], "0.nnet");
        strcpy(args[narg++], mdl);
        strcpy(args[narg++], "feat");
        strcpy(args[narg++], lat);
        strcpy(args[narg++], ali);
        strcpy(args[narg++], "1.nnet");

        char **argsv = new char*[narg];
        for(int _i = 0; _i < narg; _i++)
            argsv[_i] = args[_i];

        po.Read(narg, argsv);
        delete [] argsv;

        if (po.NumArgs() != 6) {
            po.PrintUsage();
            exit(1);
        }

        std::string transition_model_filename = po.GetArg(2),
            den_lat_rspecifier = po.GetArg(4),
            ref_ali_rspecifier = po.GetArg(5);

        // Select the GPU
#if HAVE_CUDA == 1
        CuDevice::Instantiate().SelectGpuId(mmi->use_gpu);
#endif

        // Read the class-frame-counts, compute priors
        mmi->log_prior = new PdfPrior(prior_opts);

        // Read transition model
        mmi->trans_model = new TransitionModel;
        ReadKaldiObject(transition_model_filename, mmi->trans_model);

        mmi->den_lat_reader = new RandomAccessLatticeReader(den_lat_rspecifier);
        mmi->ref_ali_reader = new RandomAccessInt32VectorReader(ref_ali_rspecifier);

        if (mmi->drop_frames) {
            KALDI_LOG << "--drop-frames=true :"
                " we will zero gradient for frames with total den/num mismatch."
                " The mismatch is likely to be caused by missing correct path "
                " from den-lattice due wrong annotation or search error."
                " Leaving such frames out stabilizes the training.";
        }

        mmi->time = new Timer;
        mmi->time_now = 0;
        mmi->num_done =0;
        mmi->num_no_ref_ali = 0;
        mmi->num_no_den_lat = 0;
        mmi->num_other_error = 0;
        mmi->total_frames = 0;
        mmi->num_frm_drop = 0;

        mmi->total_mmi_obj = 0.0, mmi->mmi_obj = 0.0;
        mmi->total_post_on_ali = 0.0, mmi->post_on_ali = 0.0;
        return mmi;
    }

    void destroy_KaldiMMI(KaldiMMI *mmi)
    {
        delete mmi->trans_model;
        delete mmi->den_lat_reader;
        delete mmi->ref_ali_reader;
        delete mmi->time;
        delete mmi->prior_opts;
        delete mmi->log_prior;
    }

    int check_mmi(KaldiMMI *mmi, const NervMatrix* mat, const char *key)
    {
        std::string utt(key);
        if (!mmi->den_lat_reader->HasKey(utt)) {
            KALDI_WARN << "Utterance " << utt << ": found no lattice.";
            mmi->num_no_den_lat++;
            return 0;
        }
        if (!mmi->ref_ali_reader->HasKey(utt)) {
            KALDI_WARN << "Utterance " << utt << ": found no reference alignment.";
            mmi->num_no_ref_ali++;
            return 0;
        }

        assert(sizeof(BaseFloat) == sizeof(float));
        // 1) get the features, numerator alignment
        mmi->ref_ali = mmi->ref_ali_reader->Value(utt);
        long mat_nrow = mat->nrow, mat_ncol = mat->ncol;
        // check for temporal length of numerator alignments
        if (static_cast<MatrixIndexT>(mmi->ref_ali.size()) != mat_nrow) {
            KALDI_WARN << "Numerator alignment has wrong length "
                << mmi->ref_ali.size() << " vs. "<< mat_nrow;
            mmi->num_other_error++;
            return 0;
        }
        if (mat_nrow > mmi->max_frames) {
            KALDI_WARN << "Utterance " << utt << ": Skipped because it has " << mat_nrow <<
                " frames, which is more than " << mmi->max_frames << ".";
            mmi->num_other_error++;
            return 0;
        }
        // 2) get the denominator lattice, preprocess
        mmi->den_lat = mmi->den_lat_reader->Value(utt);
        Lattice &den_lat = mmi->den_lat;
        if (den_lat.Start() == -1) {
            KALDI_WARN << "Empty lattice for utt " << utt;
            mmi->num_other_error++;
            return 0;
        }
        if (mmi->old_acoustic_scale != 1.0) {
            fst::ScaleLattice(fst::AcousticLatticeScale(mmi->old_acoustic_scale),
                    &den_lat);
        }
        // optional sort it topologically
        kaldi::uint64 props = den_lat.Properties(fst::kFstProperties, false);
        if (!(props & fst::kTopSorted)) {
            if (fst::TopSort(&den_lat) == false)
                KALDI_ERR << "Cycles detected in lattice.";
        }
        // get the lattice length and times of states
        mmi->state_times.clear();
        vector<int32> &state_times = mmi->state_times;
        int32 max_time = kaldi::LatticeStateTimes(den_lat, &state_times);
        // check for temporal length of denominator lattices
        if (max_time != mat_nrow) {
            KALDI_WARN << "Denominator lattice has wrong length "
                << max_time << " vs. " << mat_nrow;
            mmi->num_other_error++;
            return 0;
        }

        return 1;
    }

    NervMatrix * calc_diff_mmi(KaldiMMI * mmi, NervMatrix * mat, const char * key)
    {
        std::string utt(key);
        assert(sizeof(BaseFloat) == sizeof(float));

        kaldi::Matrix<BaseFloat> nnet_out_h, nnet_diff_h;
        nnet_out_h.Resize(mat->nrow, mat->ncol, kUndefined);

        size_t stride = mat->stride;
        for (int i = 0; i < mat->nrow; i++)
        {
            const BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
            BaseFloat *row = nnet_out_h.RowData(i);
            memmove(row, nerv_row, sizeof(BaseFloat) * mat->ncol);
        }

        mmi->num_frames = nnet_out_h.NumRows();

        PdfPriorOptions &prior_opts = *(mmi->prior_opts);
        if (prior_opts.class_frame_counts != "") {
            CuMatrix<BaseFloat> nnet_out;
            nnet_out.Resize(mat->nrow, mat->ncol, kUndefined);
            nnet_out.CopyFromMat(nnet_out_h);
            mmi->log_prior->SubtractOnLogpost(&nnet_out);
            nnet_out.CopyToMat(&nnet_out_h);
            nnet_out.Resize(0,0);
        }

        // 4) rescore the latice
        LatticeAcousticRescore(nnet_out_h, *(mmi->trans_model), mmi->state_times, &(mmi->den_lat));
        if (mmi->acoustic_scale != 1.0 || mmi->lm_scale != 1.0)
            fst::ScaleLattice(fst::LatticeScale(mmi->lm_scale, mmi->acoustic_scale), &(mmi->den_lat));

        kaldi::Posterior post;
        mmi->lat_like = kaldi::LatticeForwardBackward(mmi->den_lat, &post, &(mmi->lat_ac_like));

        nnet_diff_h.Resize(mat->nrow, mat->ncol, kSetZero);
        for (int32 t = 0; t < post.size(); t++) {
            for (int32 arc = 0; arc < post[t].size(); arc++) {
                int32 pdf = mmi->trans_model->TransitionIdToPdf(post[t][arc].first);
                nnet_diff_h(t, pdf) += post[t][arc].second;
            }
        }

        double path_ac_like = 0.0;
        for(int32 t=0; t<mmi->num_frames; t++) {
            int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]);
            path_ac_like += nnet_out_h(t,pdf);
        }
        path_ac_like *= mmi->acoustic_scale;
        mmi->mmi_obj = path_ac_like - mmi->lat_like;

        mmi->post_on_ali = 0.0;
        for(int32 t=0; t<mmi->num_frames; t++) {
            int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]);
            double posterior = nnet_diff_h(t, pdf);
            mmi->post_on_ali += posterior;
        }

        KALDI_VLOG(1) << "Lattice #" << mmi->num_done + 1 << " processed"
            << " (" << utt << "): found " << mmi->den_lat.NumStates()
            << " states and " << fst::NumArcs(mmi->den_lat) << " arcs.";

        KALDI_VLOG(1) << "Utterance " << utt << ": Average MMI obj. value = "
            << (mmi->mmi_obj/mmi->num_frames) << " over " << mmi->num_frames
            << " frames,"
            << " (Avg. den-posterior on ali " << mmi->post_on_ali/mmi->num_frames << ")";

        // 7a) Search for the frames with num/den mismatch
        int32 frm_drop = 0;
        std::vector<int32> frm_drop_vec;
        for(int32 t=0; t<mmi->num_frames; t++) {
            int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]);
            double posterior = nnet_diff_h(t, pdf);
            if(posterior < 1e-20) {
                frm_drop++;
                frm_drop_vec.push_back(t);
            }
        }

        // 8) subtract the pdf-Viterbi-path
        for(int32 t=0; t<nnet_diff_h.NumRows(); t++) {
            int32 pdf = mmi->trans_model->TransitionIdToPdf(mmi->ref_ali[t]);
            nnet_diff_h(t, pdf) -= 1.0;
        }

        // 9) Drop mismatched frames from the training by zeroing the derivative
        if(mmi->drop_frames) {
            for(int32 i=0; i<frm_drop_vec.size(); i++) {
                nnet_diff_h.Row(frm_drop_vec[i]).Set(0.0);
            }
            mmi->num_frm_drop += frm_drop;
        }

        // Report the frame dropping
        if (frm_drop > 0) {
            std::stringstream ss;
            ss << (mmi->drop_frames?"Dropped":"[dropping disabled] Would drop")
                << " frames in " << utt << " " << frm_drop << "/" << mmi->num_frames << ",";
            //get frame intervals from vec frm_drop_vec
            ss << " intervals :";
            //search for streaks of consecutive numbers:
            int32 beg_streak=frm_drop_vec[0];
            int32 len_streak=0;
            int32 i;
            for(i=0; i<frm_drop_vec.size(); i++,len_streak++) {
                if(beg_streak + len_streak != frm_drop_vec[i]) {
                    ss << " " << beg_streak << ".." << frm_drop_vec[i-1] << "frm";
                    beg_streak = frm_drop_vec[i];
                    len_streak = 0;
                }
            }
            ss << " " << beg_streak << ".." << frm_drop_vec[i-1] << "frm";
            //print
            KALDI_WARN << ss.str();
        }

        assert(mat->nrow == nnet_diff_h.NumRows() && mat->ncol == nnet_diff_h.NumCols());
        stride = mat->stride;
        for (int i = 0; i < mat->nrow; i++)
        {
            const BaseFloat *row = nnet_diff_h.RowData(i);
            BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
            memmove(nerv_row, row, sizeof(BaseFloat) * mat->ncol);
        }
        nnet_diff_h.Resize(0,0);

        // increase time counter
        mmi->total_mmi_obj += mmi->mmi_obj;
        mmi->total_post_on_ali += mmi->post_on_ali;
        mmi->total_frames += mmi->num_frames;
        mmi->num_done++;

        if (mmi->num_done % 100 == 0) {
            mmi->time_now = mmi->time->Elapsed();
            KALDI_VLOG(1) << "After " << mmi->num_done << " utterances: time elapsed = "
                << mmi->time_now/60 << " min; processed " << mmi->total_frames/mmi->time_now
                << " frames per second.";
#if HAVE_CUDA==1
            // check the GPU is not overheated
            CuDevice::Instantiate().CheckGpuHealth();
#endif
        }
        return mat;
    }

    double get_num_frames_mmi(const KaldiMMI *mmi)
    {
        return (double)mmi->num_frames;
    }

}