From 74809198a31cb7d902de23c217ca7492b5f8a29b Mon Sep 17 00:00:00 2001 From: Yimmon Zhuang Date: Fri, 18 Sep 2015 22:18:47 +0800 Subject: mpe implement --- kaldi_seq/src/kaldi_mpe.cpp | 409 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 409 insertions(+) create mode 100644 kaldi_seq/src/kaldi_mpe.cpp (limited to 'kaldi_seq/src/kaldi_mpe.cpp') diff --git a/kaldi_seq/src/kaldi_mpe.cpp b/kaldi_seq/src/kaldi_mpe.cpp new file mode 100644 index 0000000..5c4f7fc --- /dev/null +++ b/kaldi_seq/src/kaldi_mpe.cpp @@ -0,0 +1,409 @@ +#include +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/faster-decoder.h" +#include "decoder/decodable-matrix.h" +#include "lat/kaldi-lattice.h" +#include "lat/lattice-functions.h" + +#include "nnet/nnet-trnopts.h" +#include "nnet/nnet-component.h" +#include "nnet/nnet-activation.h" +#include "nnet/nnet-nnet.h" +#include "nnet/nnet-pdf-prior.h" +#include "nnet/nnet-utils.h" +#include "base/timer.h" +#include "cudamatrix/cu-device.h" + +typedef kaldi::BaseFloat BaseFloat; +typedef struct Matrix NervMatrix; + +namespace kaldi { + namespace nnet1 { + + void LatticeAcousticRescore(const Matrix &log_like, + const TransitionModel &trans_model, + const std::vector &state_times, + Lattice *lat) { + kaldi::uint64 props = lat->Properties(fst::kFstProperties, false); + if (!(props & fst::kTopSorted)) + KALDI_ERR << "Input lattice must be topologically sorted."; + + KALDI_ASSERT(!state_times.empty()); + std::vector > time_to_state(log_like.NumRows()); + for (size_t i = 0; i < state_times.size(); i++) { + KALDI_ASSERT(state_times[i] >= 0); + if (state_times[i] < log_like.NumRows()) // end state may be past this.. + time_to_state[state_times[i]].push_back(i); + else + KALDI_ASSERT(state_times[i] == log_like.NumRows() + && "There appears to be lattice/feature mismatch."); + } + + for (int32 t = 0; t < log_like.NumRows(); t++) { + for (size_t i = 0; i < time_to_state[t].size(); i++) { + int32 state = time_to_state[t][i]; + for (fst::MutableArcIterator aiter(lat, state); !aiter.Done(); + aiter.Next()) { + LatticeArc arc = aiter.Value(); + int32 trans_id = arc.ilabel; + if (trans_id != 0) { // Non-epsilon input label on arc + int32 pdf_id = trans_model.TransitionIdToPdf(trans_id); + arc.weight.SetValue2(-log_like(t, pdf_id) + arc.weight.Value2()); + aiter.SetValue(arc); + } + } + } + } + } + + } // namespace nnet1 +} // namespace kaldi + + +extern "C" { +#include "kaldi_mpe.h" +#include "string.h" +#include "assert.h" +#include "nerv/common.h" + + extern NervMatrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status); + extern void nerv_matrix_host_float_copy_fromd(NervMatrix *mat, const NervMatrix *cumat, int, int, int, Status *); + using namespace kaldi; + using namespace kaldi::nnet1; + typedef kaldi::int32 int32; + + struct KaldiMPE { + TransitionModel *trans_model; + RandomAccessLatticeReader *den_lat_reader; + RandomAccessInt32VectorReader *ref_ali_reader; + + Lattice den_lat; + vector state_times; + + PdfPriorOptions *prior_opts; + PdfPrior *log_prior; + + std::vector silence_phones; + std::vector ref_ali; + + Timer *time; + double time_now; + + int32 num_done, num_no_ref_ali, num_no_den_lat, num_other_error; + + kaldi::int64 total_frames; + int32 num_frames; + double total_frame_acc, utt_frame_acc; + + bool binary; + bool one_silence_class; + BaseFloat acoustic_scale, lm_scale, old_acoustic_scale; + kaldi::int32 max_frames; + bool do_smbr; + std::string use_gpu; + }; + + KaldiMPE * new_KaldiMPE(const char* arg, const char* mdl, const char* lat, const char* ali) + { + KaldiMPE * mpe = new KaldiMPE; + + const char *usage = + "Perform iteration of Neural Network MPE/sMBR training by stochastic " + "gradient descent.\n" + "The network weights are updated on each utterance.\n" + "Usage: nnet-train-mpe-sequential [options] " + " []\n" + "e.g.: \n" + " nnet-train-mpe-sequential nnet.init trans.mdl scp:train.scp scp:denlats.scp ark:train.ali " + "nnet.iter1\n"; + + ParseOptions po(usage); + + NnetTrainOptions trn_opts; trn_opts.learn_rate=0.00001; + trn_opts.Register(&po); + + mpe->binary = true; + po.Register("binary", &(mpe->binary), "Write output in binary mode"); + + std::string feature_transform; + po.Register("feature-transform", &feature_transform, + "Feature transform in Nnet format"); + std::string silence_phones_str; + po.Register("silence-phones", &silence_phones_str, "Colon-separated list " + "of integer id's of silence phones, e.g. 46:47"); + + mpe->prior_opts = new PdfPriorOptions; + PdfPriorOptions &prior_opts = *(mpe->prior_opts); + prior_opts.Register(&po); + + mpe->one_silence_class = false; + mpe->acoustic_scale = 1.0, + mpe->lm_scale = 1.0, + mpe->old_acoustic_scale = 0.0; + po.Register("acoustic-scale", &(mpe->acoustic_scale), + "Scaling factor for acoustic likelihoods"); + po.Register("lm-scale", &(mpe->lm_scale), + "Scaling factor for \"graph costs\" (including LM costs)"); + po.Register("old-acoustic-scale", &(mpe->old_acoustic_scale), + "Add in the scores in the input lattices with this scale, rather " + "than discarding them."); + po.Register("one-silence-class", &(mpe->one_silence_class), "If true, newer " + "behavior which will tend to reduce insertions."); + mpe->max_frames = 6000; // Allow segments maximum of one minute by default + po.Register("max-frames",&(mpe->max_frames), "Maximum number of frames a segment can have to be processed"); + mpe->do_smbr = false; + po.Register("do-smbr", &(mpe->do_smbr), "Use state-level accuracies instead of " + "phone accuracies."); + + mpe->use_gpu=std::string("yes"); + po.Register("use-gpu", &(mpe->use_gpu), "yes|no|optional, only has effect if compiled with CUDA"); + + int narg = 0; + char args[64][1024]; + char *token; + char *saveptr = NULL; + char tmpstr[1024]; + + strcpy(tmpstr, arg); + strcpy(args[0], "nnet-train-mpe-sequential"); + for(narg = 1, token = strtok_r(tmpstr, " ", &saveptr); token; token = strtok_r(NULL, " ", &saveptr)) + strcpy(args[narg++], token); + strcpy(args[narg++], "0.nnet"); + strcpy(args[narg++], mdl); + strcpy(args[narg++], "feat"); + strcpy(args[narg++], lat); + strcpy(args[narg++], ali); + strcpy(args[narg++], "1.nnet"); + + char **argsv = new char*[narg]; + for(int _i = 0; _i < narg; _i++) + argsv[_i] = args[_i]; + + po.Read(narg, argsv); + delete [] argsv; + + if (po.NumArgs() != 6) { + po.PrintUsage(); + exit(1); + } + + std::string transition_model_filename = po.GetArg(2), + den_lat_rspecifier = po.GetArg(4), + ref_ali_rspecifier = po.GetArg(5); + + std::vector &silence_phones = mpe->silence_phones; + if (!kaldi::SplitStringToIntegers(silence_phones_str, ":", false, + &silence_phones)) + KALDI_ERR << "Invalid silence-phones string " << silence_phones_str; + kaldi::SortAndUniq(&silence_phones); + if (silence_phones.empty()) + KALDI_LOG << "No silence phones specified."; + + // Select the GPU +#if HAVE_CUDA == 1 + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + // Read the class-frame-counts, compute priors + mpe->log_prior = new PdfPrior(prior_opts); + PdfPrior &log_prior = *(mpe->log_prior); + + // Read transition model + mpe->trans_model = new TransitionModel; + ReadKaldiObject(transition_model_filename, mpe->trans_model); + + mpe->den_lat_reader = new RandomAccessLatticeReader(den_lat_rspecifier); + mpe->ref_ali_reader = new RandomAccessInt32VectorReader(ref_ali_rspecifier); + + mpe->time = new Timer; + mpe->time_now = 0; + mpe->num_done =0; + mpe->num_no_ref_ali = 0; + mpe->num_no_den_lat = 0; + mpe->num_other_error = 0; + mpe->total_frames = 0; + mpe->total_frame_acc = 0.0; + mpe->utt_frame_acc = 0.0; + + return mpe; + } + + void destroy_KaldiMPE(KaldiMPE *mpe) + { + delete mpe->trans_model; + delete mpe->den_lat_reader; + delete mpe->ref_ali_reader; + delete mpe->time; + delete mpe->prior_opts; + delete mpe->log_prior; + } + + int check_mpe(KaldiMPE *mpe, const NervMatrix* mat, const char *key) + { + std::string utt(key); + if (!mpe->den_lat_reader->HasKey(utt)) { + KALDI_WARN << "Utterance " << utt << ": found no lattice."; + mpe->num_no_den_lat++; + return 0; + } + if (!mpe->ref_ali_reader->HasKey(utt)) { + KALDI_WARN << "Utterance " << utt << ": found no reference alignment."; + mpe->num_no_ref_ali++; + return 0; + } + + assert(sizeof(BaseFloat) == sizeof(float)); + // 1) get the features, numerator alignment + mpe->ref_ali = mpe->ref_ali_reader->Value(utt); + long mat_nrow = mat->nrow, mat_ncol = mat->ncol; + // check for temporal length of numerator alignments + if (static_cast(mpe->ref_ali.size()) != mat_nrow) { + KALDI_WARN << "Numerator alignment has wrong length " + << mpe->ref_ali.size() << " vs. "<< mat_nrow; + mpe->num_other_error++; + return 0; + } + if (mat_nrow > mpe->max_frames) { + KALDI_WARN << "Utterance " << utt << ": Skipped because it has " << mat_nrow << + " frames, which is more than " << mpe->max_frames << "."; + mpe->num_other_error++; + return 0; + } + // 2) get the denominator lattice, preprocess + mpe->den_lat = mpe->den_lat_reader->Value(utt); + Lattice &den_lat = mpe->den_lat; + if (den_lat.Start() == -1) { + KALDI_WARN << "Empty lattice for utt " << utt; + mpe->num_other_error++; + return 0; + } + if (mpe->old_acoustic_scale != 1.0) { + fst::ScaleLattice(fst::AcousticLatticeScale(mpe->old_acoustic_scale), + &den_lat); + } + // optional sort it topologically + kaldi::uint64 props = den_lat.Properties(fst::kFstProperties, false); + if (!(props & fst::kTopSorted)) { + if (fst::TopSort(&den_lat) == false) + KALDI_ERR << "Cycles detected in lattice."; + } + // get the lattice length and times of states + mpe->state_times.clear(); + vector &state_times = mpe->state_times; + int32 max_time = kaldi::LatticeStateTimes(den_lat, &state_times); + // check for temporal length of denominator lattices + if (max_time != mat_nrow) { + KALDI_WARN << "Denominator lattice has wrong length " + << max_time << " vs. " << mat_nrow; + mpe->num_other_error++; + return 0; + } + + return 1; + } + + NervMatrix * calc_diff_mpe(KaldiMPE * mpe, NervMatrix * mat, const char * key) + { + std::string utt(key); + assert(sizeof(BaseFloat) == sizeof(float)); + + kaldi::Matrix nnet_out_h; + nnet_out_h.Resize(mat->nrow, mat->ncol, kUndefined); + + size_t stride = mat->stride; + for (int i = 0; i < mat->nrow; i++) + { + const BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride); + BaseFloat *row = nnet_out_h.RowData(i); + memmove(row, nerv_row, sizeof(BaseFloat) * mat->ncol); + } + + mpe->num_frames = nnet_out_h.NumRows(); + + PdfPriorOptions &prior_opts = *(mpe->prior_opts); + if (prior_opts.class_frame_counts != "") { + CuMatrix nnet_out; + nnet_out.Resize(mat->nrow, mat->ncol, kUndefined); + nnet_out.CopyFromMat(nnet_out_h); + mpe->log_prior->SubtractOnLogpost(&nnet_out); + nnet_out.Resize(0,0); + } + + // 4) rescore the latice + LatticeAcousticRescore(nnet_out_h, *(mpe->trans_model), mpe->state_times, &(mpe->den_lat)); + if (mpe->acoustic_scale != 1.0 || mpe->lm_scale != 1.0) + fst::ScaleLattice(fst::LatticeScale(mpe->lm_scale, mpe->acoustic_scale), &(mpe->den_lat)); + + kaldi::Posterior post; + std::vector &silence_phones = mpe->silence_phones; + + if (mpe->do_smbr) { // use state-level accuracies, i.e. sMBR estimation + mpe->utt_frame_acc = LatticeForwardBackwardMpeVariants( + *(mpe->trans_model), silence_phones, mpe->den_lat, mpe->ref_ali, "smbr", + mpe->one_silence_class, &post); + } else { // use phone-level accuracies, i.e. MPFE (minimum phone frame error) + mpe->utt_frame_acc = LatticeForwardBackwardMpeVariants( + *(mpe->trans_model), silence_phones, mpe->den_lat, mpe->ref_ali, "mpfe", + mpe->one_silence_class, &post); + } + + // 6) convert the Posterior to a matrix, + CuMatrix nnet_diff; + PosteriorToMatrixMapped(post, *(mpe->trans_model), &nnet_diff); + nnet_diff.Scale(-1.0); // need to flip the sign of derivative, + + KALDI_VLOG(1) << "Lattice #" << mpe->num_done + 1 << " processed" + << " (" << utt << "): found " << mpe->den_lat.NumStates() + << " states and " << fst::NumArcs(mpe->den_lat) << " arcs."; + + KALDI_VLOG(1) << "Utterance " << utt << ": Average frame accuracy = " + << (mpe->utt_frame_acc/mpe->num_frames) << " over " << mpe->num_frames + << " frames," + << " diff-range(" << nnet_diff.Min() << "," << nnet_diff.Max() << ")"; + + nnet_diff.CopyToMat(&nnet_out_h); + nnet_diff.Resize(0,0); // release GPU memory, + + assert(mat->nrow == nnet_out_h.NumRows() && mat->ncol == nnet_out_h.NumCols()); + stride = mat->stride; + for (int i = 0; i < mat->nrow; i++) + { + const BaseFloat *row = nnet_out_h.RowData(i); + BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride); + memmove(nerv_row, row, sizeof(BaseFloat) * mat->ncol); + } + nnet_out_h.Resize(0,0); + + // increase time counter + mpe->total_frame_acc += mpe->utt_frame_acc; + mpe->total_frames += mpe->num_frames; + mpe->num_done++; + + if (mpe->num_done % 100 == 0) { + mpe->time_now = mpe->time->Elapsed(); + KALDI_VLOG(1) << "After " << mpe->num_done << " utterances: time elapsed = " + << mpe->time_now/60 << " min; processed " << mpe->total_frames/mpe->time_now + << " frames per second."; +#if HAVE_CUDA==1 + // check the GPU is not overheated + CuDevice::Instantiate().CheckGpuHealth(); +#endif + } + return mat; + } + + double get_num_frames_mpe(const KaldiMPE *mpe) + { + return (double)mpe->num_frames; + } + + double get_utt_frame_acc_mpe(const KaldiMPE *mpe) + { + return (double)mpe->utt_frame_acc; + } + +} -- cgit v1.2.3-70-g09d2 From 3d85aeadc910c6f7fe067061a960e30aed5f7135 Mon Sep 17 00:00:00 2001 From: Yimmon Zhuang Date: Thu, 8 Oct 2015 15:01:40 +0800 Subject: mpe bugfix --- kaldi_seq/src/kaldi_mpe.cpp | 1 + 1 file changed, 1 insertion(+) (limited to 'kaldi_seq/src/kaldi_mpe.cpp') diff --git a/kaldi_seq/src/kaldi_mpe.cpp b/kaldi_seq/src/kaldi_mpe.cpp index 5c4f7fc..8cdf010 100644 --- a/kaldi_seq/src/kaldi_mpe.cpp +++ b/kaldi_seq/src/kaldi_mpe.cpp @@ -330,6 +330,7 @@ extern "C" { nnet_out.Resize(mat->nrow, mat->ncol, kUndefined); nnet_out.CopyFromMat(nnet_out_h); mpe->log_prior->SubtractOnLogpost(&nnet_out); + nnet_out.CopyToMat(&nnet_out_h); nnet_out.Resize(0,0); } -- cgit v1.2.3-70-g09d2 From 85a76e8a91a3114e2061b8ffc1cad979c37a5873 Mon Sep 17 00:00:00 2001 From: Yimmon Zhuang Date: Sat, 10 Oct 2015 19:25:17 +0800 Subject: disable batch when sequence training --- kaldi_seq/Makefile | 11 +++++++---- kaldi_seq/src/kaldi_mmi.cpp | 3 +-- kaldi_seq/src/kaldi_mpe.cpp | 13 +++++++------ 3 files changed, 15 insertions(+), 12 deletions(-) (limited to 'kaldi_seq/src/kaldi_mpe.cpp') diff --git a/kaldi_seq/Makefile b/kaldi_seq/Makefile index 7e2dd2e..1232c5a 100644 --- a/kaldi_seq/Makefile +++ b/kaldi_seq/Makefile @@ -23,18 +23,21 @@ build: $(OBJ_DIR) $(OBJ_SUBDIR) $(OBJS) install: $(LUA_DIR) $(LUA_SUBDIR) $(LIBS) include $(KDIR)/src/kaldi.mk -KL := -L/home/slhome/ymz09/mylibs/ -L$(KDIR)/tools/openfst-1.3.4/lib/ $(KDIR)/src/feat/kaldi-feat.a $(KDIR)/src/matrix/kaldi-matrix.a $(KDIR)/src/base/kaldi-base.a $(KDIR)/src/util/kaldi-util.a $(KDIR)/src/hmm/kaldi-hmm.a $(KDIR)/src/tree/kaldi-tree.a $(KDIR)/src/nnet/kaldi-nnet.a $(KDIR)/src/cudamatrix/kaldi-cudamatrix.a $(KDIR)/src/lat/kaldi-lat.a $(KDIR)/src/hmm/kaldi-hmm.a $(KDIR)/src/tree/kaldi-tree.a $(KDIR)/src/matrix/kaldi-matrix.a $(KDIR)/src/util/kaldi-util.a $(KDIR)/src/base/kaldi-base.a -lcblas -llapack_atlas -lfst -lcudart -lcublas -DHAVE_CUDA + +KL1 := -rdynamic -Wl,-rpath=$(KDIR)/tools/openfst/lib -L/usr/local/cuda/lib64 -Wl,-rpath,/usr/local/cuda/lib64 -Wl,-rpath=$(KDIR)/src/lib -L. -L$(KDIR)/src/nnet/ -L$(KDIR)/src/cudamatrix/ -L$(KDIR)/src/lat/ -L$(KDIR)/src/hmm/ -L$(KDIR)/src/tree/ -L$(KDIR)/src/matrix/ -L$(KDIR)/src/util/ -L$(KDIR)/src/base/ $(KDIR)/src/nnet//libkaldi-nnet.so $(KDIR)/src/cudamatrix//libkaldi-cudamatrix.so $(KDIR)/src/lat//libkaldi-lat.so $(KDIR)/src/hmm//libkaldi-hmm.so $(KDIR)/src/tree//libkaldi-tree.so $(KDIR)/src/matrix//libkaldi-matrix.so $(KDIR)/src/util//libkaldi-util.so $(KDIR)/src/base//libkaldi-base.so -L$(KDIR)/tools/openfst/lib -lfst /usr/lib/liblapack.so /usr/lib/libcblas.so /usr/lib/libatlas.so /usr/lib/libf77blas.so -lm -lpthread -ldl -lcublas -lcudart -lkaldi-nnet -lkaldi-cudamatrix -lkaldi-lat -lkaldi-hmm -lkaldi-tree -lkaldi-matrix -lkaldi-util -lkaldi-base + +KL2 := -msse -msse2 -Wall -pthread -DKALDI_DOUBLEPRECISION=0 -DHAVE_POSIX_MEMALIGN -Wno-sign-compare -Wno-unused-local-typedefs -Winit-self -DHAVE_EXECINFO_H=1 -rdynamic -DHAVE_CXXABI_H -DHAVE_ATLAS -I$(KDIR)/tools/ATLAS/include -I$(KDIR)/tools/openfst/include -Wno-sign-compare -g -fPIC -I/usr/local/cuda/include -L/usr/local/cuda/lib64 -DKALDI_NO_EXPF $(OBJ_DIR) $(LUA_DIR) $(OBJ_SUBDIR) $(LUA_SUBDIR): -mkdir -p $@ $(LUA_DIR)/%.lua: %.lua cp $< $@ $(LIBS): $(OBJ_DIR)/src/kaldi_mpe.o $(OBJ_DIR)/src/kaldi_mmi.o $(OBJ_DIR)/init.o $(OBJ_DIR)/src/init.o - gcc -shared -fPIC -o $@ $(OBJ_DIR)/src/kaldi_mpe.o $(OBJ_DIR)/src/kaldi_mmi.o $(OBJ_DIR)/init.o $(OBJ_DIR)/src/init.o -lstdc++ -Wl,-rpath=$(LIB_PATH) -L$(LIB_PATH) -lnervcore -lluaT $(KL) + gcc -shared -fPIC -o $@ $(OBJ_DIR)/src/kaldi_mpe.o $(OBJ_DIR)/src/kaldi_mmi.o $(OBJ_DIR)/init.o $(OBJ_DIR)/src/init.o -lstdc++ -Wl,-rpath=$(LIB_PATH) -L$(LIB_PATH) -lnervcore -lluaT $(KL1) $(OBJ_DIR)/src/kaldi_mpe.o: src/kaldi_mpe.cpp - g++ -o $@ -c $< -DHAVE_ATLAS $(KALDIINCLUDE) -g -fPIC $(INCLUDE) -DKALDI_DOUBLEPRECISION=0 -msse2 -DHAVE_POSIX_MEMALIGN + g++ -o $@ -c $< $(KALDIINCLUDE) -g -fPIC $(INCLUDE) $(KL2) $(OBJ_DIR)/src/kaldi_mmi.o: src/kaldi_mmi.cpp - g++ -o $@ -c $< -DHAVE_ATLAS $(KALDIINCLUDE) -g -fPIC $(INCLUDE) -DKALDI_DOUBLEPRECISION=0 -msse2 -DHAVE_POSIX_MEMALIGN + g++ -o $@ -c $< $(KALDIINCLUDE) -g -fPIC $(INCLUDE) $(KL2) $(OBJ_DIR)/%.o: %.c gcc -o $@ -c $< -g $(INCLUDE) -fPIC clean: diff --git a/kaldi_seq/src/kaldi_mmi.cpp b/kaldi_seq/src/kaldi_mmi.cpp index a64abd0..ea9b4f1 100644 --- a/kaldi_seq/src/kaldi_mmi.cpp +++ b/kaldi_seq/src/kaldi_mmi.cpp @@ -164,12 +164,11 @@ extern "C" { // Select the GPU #if HAVE_CUDA == 1 - CuDevice::Instantiate().SelectGpuId(use_gpu); + CuDevice::Instantiate().SelectGpuId(mmi->use_gpu); #endif // Read the class-frame-counts, compute priors mmi->log_prior = new PdfPrior(prior_opts); - PdfPrior &log_prior = *(mmi->log_prior); // Read transition model mmi->trans_model = new TransitionModel; diff --git a/kaldi_seq/src/kaldi_mpe.cpp b/kaldi_seq/src/kaldi_mpe.cpp index 8cdf010..60384e2 100644 --- a/kaldi_seq/src/kaldi_mpe.cpp +++ b/kaldi_seq/src/kaldi_mpe.cpp @@ -205,12 +205,11 @@ extern "C" { // Select the GPU #if HAVE_CUDA == 1 - CuDevice::Instantiate().SelectGpuId(use_gpu); + CuDevice::Instantiate().SelectGpuId(mpe->use_gpu); #endif // Read the class-frame-counts, compute priors mpe->log_prior = new PdfPrior(prior_opts); - PdfPrior &log_prior = *(mpe->log_prior); // Read transition model mpe->trans_model = new TransitionModel; @@ -256,7 +255,7 @@ extern "C" { return 0; } - assert(sizeof(BaseFloat) == sizeof(float)); + //assert(sizeof(BaseFloat) == sizeof(float)); // 1) get the features, numerator alignment mpe->ref_ali = mpe->ref_ali_reader->Value(utt); long mat_nrow = mat->nrow, mat_ncol = mat->ncol; @@ -309,8 +308,9 @@ extern "C" { NervMatrix * calc_diff_mpe(KaldiMPE * mpe, NervMatrix * mat, const char * key) { std::string utt(key); - assert(sizeof(BaseFloat) == sizeof(float)); + //assert(sizeof(BaseFloat) == sizeof(float)); + CuMatrix nnet_diff; kaldi::Matrix nnet_out_h; nnet_out_h.Resize(mat->nrow, mat->ncol, kUndefined); @@ -327,9 +327,10 @@ extern "C" { PdfPriorOptions &prior_opts = *(mpe->prior_opts); if (prior_opts.class_frame_counts != "") { CuMatrix nnet_out; - nnet_out.Resize(mat->nrow, mat->ncol, kUndefined); + nnet_out.Resize(nnet_out_h.NumRows(), nnet_out_h.NumCols(), kUndefined); nnet_out.CopyFromMat(nnet_out_h); mpe->log_prior->SubtractOnLogpost(&nnet_out); + nnet_out_h.Resize(nnet_out.NumRows(), nnet_out.NumCols(), kUndefined); nnet_out.CopyToMat(&nnet_out_h); nnet_out.Resize(0,0); } @@ -353,7 +354,6 @@ extern "C" { } // 6) convert the Posterior to a matrix, - CuMatrix nnet_diff; PosteriorToMatrixMapped(post, *(mpe->trans_model), &nnet_diff); nnet_diff.Scale(-1.0); // need to flip the sign of derivative, @@ -366,6 +366,7 @@ extern "C" { << " frames," << " diff-range(" << nnet_diff.Min() << "," << nnet_diff.Max() << ")"; + nnet_out_h.Resize(nnet_diff.NumRows(), nnet_diff.NumCols(), kUndefined); nnet_diff.CopyToMat(&nnet_out_h); nnet_diff.Resize(0,0); // release GPU memory, -- cgit v1.2.3-70-g09d2