summaryrefslogblamecommitdiff
path: root/kaldi_seq/src/kaldi_mpe.cpp
blob: 5d4149c5109a0c0e811b3344b55d75c03685e65c (plain) (tree)





































































                                                                                             

                                    
 
                                                                                                              





































































































































                                                                                                                     
                                                          



                                                      












































                                                                                    
                                                     



















































                                                                                           
                                                     
 
                                      















                                                                                        
                                                                                    

                                                         
                                                                                  
                                            





















                                                                                                      











                                                                                  
                                                                                









































                                                                                            
#include <string>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "tree/context-dep.h"
#include "hmm/transition-model.h"
#include "fstext/fstext-lib.h"
#include "decoder/faster-decoder.h"
#include "decoder/decodable-matrix.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"

#include "nnet/nnet-trnopts.h"
#include "nnet/nnet-component.h"
#include "nnet/nnet-activation.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-pdf-prior.h"
#include "nnet/nnet-utils.h"
#include "base/timer.h"
#include "cudamatrix/cu-device.h"

typedef kaldi::BaseFloat BaseFloat;
typedef struct Matrix NervMatrix;

namespace kaldi {
    namespace nnet1 {

        void LatticeAcousticRescore(const Matrix<BaseFloat> &log_like,
                const TransitionModel &trans_model,
                const std::vector<int32> &state_times,
                Lattice *lat) {
            kaldi::uint64 props = lat->Properties(fst::kFstProperties, false);
            if (!(props & fst::kTopSorted))
                KALDI_ERR << "Input lattice must be topologically sorted.";

            KALDI_ASSERT(!state_times.empty());
            std::vector<std::vector<int32> > time_to_state(log_like.NumRows());
            for (size_t i = 0; i < state_times.size(); i++) {
                KALDI_ASSERT(state_times[i] >= 0);
                if (state_times[i] < log_like.NumRows())  // end state may be past this..
                    time_to_state[state_times[i]].push_back(i);
                else
                    KALDI_ASSERT(state_times[i] == log_like.NumRows()
                            && "There appears to be lattice/feature mismatch.");
            }

            for (int32 t = 0; t < log_like.NumRows(); t++) {
                for (size_t i = 0; i < time_to_state[t].size(); i++) {
                    int32 state = time_to_state[t][i];
                    for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
                            aiter.Next()) {
                        LatticeArc arc = aiter.Value();
                        int32 trans_id = arc.ilabel;
                        if (trans_id != 0) {  // Non-epsilon input label on arc
                            int32 pdf_id = trans_model.TransitionIdToPdf(trans_id);
                            arc.weight.SetValue2(-log_like(t, pdf_id) + arc.weight.Value2());
                            aiter.SetValue(arc);
                        }
                    }
                }
            }
        }

    }  // namespace nnet1
}  // namespace kaldi


extern "C" {
#include "kaldi_mpe.h"
#include "string.h"
#include "assert.h"
#include "nerv/lib/common.h"
#include "nerv/lib/matrix/mmatrix.h"

    extern NervMatrix *nerv_matrix_host_float_create(long nrow, long ncol, MContext *context, Status *status);
    extern void nerv_matrix_host_float_copy_fromd(NervMatrix *mat, const NervMatrix *cumat, int, int, int, Status *);
    using namespace kaldi;
    using namespace kaldi::nnet1;
    typedef kaldi::int32 int32;

    struct KaldiMPE {
        TransitionModel *trans_model;
        RandomAccessLatticeReader *den_lat_reader;
        RandomAccessInt32VectorReader *ref_ali_reader;

        Lattice den_lat;
        vector<int32> state_times;

        PdfPriorOptions *prior_opts;
        PdfPrior *log_prior;

        std::vector<int32> silence_phones;
        std::vector<int32> ref_ali;

        Timer *time;
        double time_now;

        int32 num_done, num_no_ref_ali, num_no_den_lat, num_other_error;

        kaldi::int64 total_frames;
        int32 num_frames;
        double total_frame_acc, utt_frame_acc;

        bool binary;
        bool one_silence_class;
        BaseFloat acoustic_scale, lm_scale, old_acoustic_scale;
        kaldi::int32 max_frames;
        bool do_smbr;
        std::string use_gpu;
    };

    KaldiMPE * new_KaldiMPE(const char* arg, const char* mdl, const char* lat, const char* ali)
    {
        KaldiMPE * mpe = new KaldiMPE;

        const char *usage =
            "Perform iteration of Neural Network MPE/sMBR training by stochastic "
            "gradient descent.\n"
            "The network weights are updated on each utterance.\n"
            "Usage:  nnet-train-mpe-sequential [options] <model-in> <transition-model-in> "
            "<feature-rspecifier> <den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n"
            "e.g.: \n"
            " nnet-train-mpe-sequential nnet.init trans.mdl scp:train.scp scp:denlats.scp ark:train.ali "
            "nnet.iter1\n";

        ParseOptions po(usage);

        NnetTrainOptions trn_opts; trn_opts.learn_rate=0.00001;
        trn_opts.Register(&po);

        mpe->binary = true;
        po.Register("binary", &(mpe->binary), "Write output in binary mode");

        std::string feature_transform;
        po.Register("feature-transform", &feature_transform,
                "Feature transform in Nnet format");
        std::string silence_phones_str;
        po.Register("silence-phones", &silence_phones_str, "Colon-separated list "
                "of integer id's of silence phones, e.g. 46:47");

        mpe->prior_opts = new PdfPriorOptions;
        PdfPriorOptions &prior_opts = *(mpe->prior_opts);
        prior_opts.Register(&po);

        mpe->one_silence_class = false;
        mpe->acoustic_scale = 1.0,
            mpe->lm_scale = 1.0,
            mpe->old_acoustic_scale = 0.0;
        po.Register("acoustic-scale", &(mpe->acoustic_scale),
                "Scaling factor for acoustic likelihoods");
        po.Register("lm-scale", &(mpe->lm_scale),
                "Scaling factor for \"graph costs\" (including LM costs)");
        po.Register("old-acoustic-scale", &(mpe->old_acoustic_scale),
                "Add in the scores in the input lattices with this scale, rather "
                "than discarding them.");
        po.Register("one-silence-class", &(mpe->one_silence_class), "If true, newer "
                "behavior which will tend to reduce insertions.");
        mpe->max_frames = 6000; // Allow segments maximum of one minute by default
        po.Register("max-frames",&(mpe->max_frames), "Maximum number of frames a segment can have to be processed");
        mpe->do_smbr = false;
        po.Register("do-smbr", &(mpe->do_smbr), "Use state-level accuracies instead of "
                "phone accuracies.");

        mpe->use_gpu=std::string("yes");
        po.Regi