#include <string>
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "tree/context-dep.h"
#include "hmm/transition-model.h"
#include "fstext/fstext-lib.h"
#include "decoder/faster-decoder.h"
#include "decoder/decodable-matrix.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
#include "nnet/nnet-trnopts.h"
#include "nnet/nnet-component.h"
#include "nnet/nnet-activation.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-pdf-prior.h"
#include "nnet/nnet-utils.h"
#include "base/timer.h"
#include "cudamatrix/cu-device.h"
#include <iomanip>
typedef kaldi::BaseFloat BaseFloat;
typedef struct Matrix NervMatrix;
namespace kaldi{
namespace nnet1{
void LatticeAcousticRescore(const kaldi::Matrix<BaseFloat> &log_like,
const TransitionModel &trans_model,
const std::vector<int32> &state_times,
Lattice *lat);
}
}
extern "C" {
#include "kaldi_mmi.h"
#include "string.h"
#include "assert.h"
#include "nerv/common.h"
extern NervMatrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status);
extern void nerv_matrix_host_float_copy_fromd(NervMatrix *mat, const NervMatrix *cumat, int, int, int, Status *);
using namespace kaldi;
using namespace kaldi::nnet1;
typedef kaldi::int32 int32;
struct KaldiMMI {
TransitionModel *trans_model;
RandomAccessLatticeReader *den_lat_reader;
RandomAccessInt32VectorReader *ref_ali_reader;
Lattice den_lat;
vector<int32> state_times;
PdfPriorOptions *prior_opts;
PdfPrior *log_prior;
std::vector<int32> ref_ali;
Timer *time;
double time_now;
int32 num_done, num_no_ref_ali, num_no_den_lat, num_other_error;
int32 num_frm_drop;
kaldi::int64 total_frames;
double lat_like; // total likelihood of the lattice
double lat_ac_like; // acoustic likelihood weighted by posterior.
double total_mmi_obj, mmi_obj;
double total_post_on_ali, post_on_ali;
int32 num_frames;
bool binary;
BaseFloat acoustic_scale, lm_scale, old_acoustic_scale;
kaldi::int32 max_frames;
bool drop_frames;
std::string use_gpu;
};
KaldiMMI * new_KaldiMMI(const char* arg, const char* mdl, const char* lat, const char* ali)
{
KaldiMMI * mmi = new KaldiMMI;
const char *usage =
"Perform one iteration of DNN-MMI training by stochastic "
"gradient descent.\n"
"The network weights are updated on each utterance.\n"
"Usage: nnet-train-mmi-sequential [options] <model-in> <transition-model-in> "
"<feature-rspecifier> <den-lat-rspecifier> <ali-rspecifier> [<model-out>]\n"
"e.g.: \n"
" nnet-train-mmi-sequential nnet.init trans.mdl scp:train.scp scp:denlats.scp ark:train.ali "
"nnet.iter1\n";
ParseOptions po(usage);
NnetTrainOptions trn_opts; trn_opts.learn_rate=0.00001;
trn_opts.Register(&po);
mmi->binary = true;
po.Register("binary", &(mmi->binary), "Write output in binary mode");
std::string feature_transform;
po.Register("feature-transform", &feature_transform,
"Feature transform in Nnet format");
mmi->prior_opts = new PdfPriorOptions;
PdfPriorOptions &prior_opts = *(mmi->prior_opts);
prior_opts.Register(&po);
mmi->acoustic_scale = 1.0,
mmi->lm_scale = 1.0,
mmi->old_acoustic_scale = 0.0;
po.Register("acoustic-scale", &(mmi->acoustic_scale),
"Scaling factor for acoustic likelihoods");
po.Register("lm-scale", &(mmi->lm_scale),
"Scaling factor for \"graph costs\" (including LM costs)");
po.Register("old-acoustic-scale", &(mmi->old_acoustic_scale),
"Add in the scores in the input lattices with this scale, rather "
"than discarding them.");
mmi->max_frames = 6000; // Allow segments maximum of one minute by default
po.Register("max-frames",&(mmi->max_frames), "Maximum number of frames a segment can have to be processed");
mmi->drop_frames = true;
po.Register("drop-frames", &(mmi->drop_frames),
"Drop frames, where is zero den-posterior under numerator path "
"(ie. path not in lattice)");
mmi->use_gpu=std::string("yes");
po.Register("use-gpu", &(mmi->use_gpu), "yes|no|optional, only has effect if compiled with CUDA");
int narg = 0;
char args[64][1024];
char *token;
char *saveptr = NULL;
char tmpstr[1024];
strcpy(tmpstr, arg);
strcpy(args[0], "nnet-train-mmi-sequential");
for(narg = 1, token = strtok_r(tmpstr, " ", &saveptr); token; token = strtok_r(NULL, " ", &saveptr))
strcpy(args[narg++], token);
strcpy(args[narg++], "0.nnet");
strcpy(args[narg++], mdl);
strcpy(args[narg++], "feat");
strcpy(args[narg++], lat);
strcpy(args[narg++], ali);
strcpy(args[narg++], "1.nnet");
char **argsv = new char*[narg];
for(int _i = 0; _i < narg; _i++)
argsv[_i] = args[_i];
po.Read(narg, argsv);
delete [] argsv;
if (po.NumArgs() != 6) {
po.PrintUsage();
exit(1);
}
std::string transition_model_filename = po.GetArg(2),
den_lat_rspecifier = po.GetArg(4),
ref_ali_rspecifier = po.GetArg(5);
// Select the GPU
#if HAVE_CUDA == 1
CuDevice::Instantiate().SelectGpuId(use_gpu);
#endif
// Read the class-frame-counts, compute priors
mmi->log_prior = new PdfPrior(prior_opts);
PdfPrior &log_prior = *(mmi->log_prior);
// Read transition model
mmi->trans_model = new TransitionModel;
ReadKaldiObject(transition_model_filename, mmi->trans_model);
mmi->den_lat_reader = new RandomAccessLatticeReader(den_lat_rspecifier);
mmi->ref_ali_reader = new RandomAccessInt32VectorReader(ref_ali_rspecifier);
if (mmi->drop_frames) {
KALDI_LOG << "--drop-frames=true :"
" we will zero gradient for frames with total den/num mismatch."
" The mismatch is likely to be caused by missing correct path "
" from den-lattice due wrong annotation or search error."
" Leaving such frames out stabilizes the training.";
}
mmi->time = new Timer;
mmi->time_now = 0;
mmi->num_done =0;
mmi->num_no_ref_ali = 0;
mmi->num_no_den_lat = 0;
mmi->num_other_error = 0;
mmi->total_frames = 0;
mmi->num_frm_drop = 0;
mmi->total_mmi_obj = 0.0, mmi->mmi_obj = 0.0;
mmi->total_post_on_ali = 0.0, mmi->post_on_ali = 0.0;
return mmi;
}
void destroy_KaldiMMI(KaldiMMI *mmi)
{
delete mmi->trans_model;
delete mmi->den_lat_reader;
delete mmi->ref_ali_reader;
delete mmi->time;
delete mmi->prior_opts;
delete mmi->log_prior;
}
int check_mmi(KaldiMMI *mmi, const NervMatrix* mat, const char *key)
{
std::string utt(key);
if (!mmi->den_lat_reader->HasKey(utt)) {
KALDI_WARN << "Utterance " << utt << ": found no lattice.";
mmi->num_no_den_lat++;
return 0;
}
if (!mmi->ref_ali_reader->HasKey(utt)) {
KALDI_WARN << "Utterance " << utt << ": found no reference alignment.";
mmi->num_no_ref_ali++;
return 0;
}
assert(sizeof(BaseFloat) == sizeof(float));
// 1) get the features, numerator alignment
mmi->ref_ali = mmi->ref_ali_reader->Value(utt);
long