diff options
author | Yimmon Zhuang <[email protected]> | 2015-10-14 15:37:20 +0800 |
---|---|---|
committer | Yimmon Zhuang <[email protected]> | 2015-10-14 15:37:20 +0800 |
commit | b33b3a6732c6b6a66bd5c44c615be56d66f4ed67 (patch) | |
tree | 47501412a3324e4c13b1238eeb913aae02b2024a /kaldi_decode/src/nnet-forward.cc | |
parent | e39fb231f64ddc8b79a6eb5434f529aadb3165fe (diff) |
support kaldi decoder
Diffstat (limited to 'kaldi_decode/src/nnet-forward.cc')
-rw-r--r-- | kaldi_decode/src/nnet-forward.cc | 215 |
1 files changed, 215 insertions, 0 deletions
diff --git a/kaldi_decode/src/nnet-forward.cc b/kaldi_decode/src/nnet-forward.cc new file mode 100644 index 0000000..007f623 --- /dev/null +++ b/kaldi_decode/src/nnet-forward.cc @@ -0,0 +1,215 @@ +// nnetbin/nnet-forward.cc + +// Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +extern "C"{ +#include "lua.h" +#include "lauxlib.h" +#include "lualib.h" +#include "nerv/matrix/matrix.h" +#include "nerv/common.h" +#include "nerv/luaT/luaT.h" +} + +#include <limits> + +#include "nnet/nnet-nnet.h" +#include "nnet/nnet-loss.h" +#include "nnet/nnet-pdf-prior.h" +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "base/timer.h" + +typedef kaldi::BaseFloat BaseFloat; +typedef struct Matrix NervMatrix; + + +int main(int argc, char *argv[]) { + using namespace kaldi; + using namespace kaldi::nnet1; + try { + const char *usage = + "Perform forward pass through Neural Network.\n" + "\n" + "Usage: nnet-forward [options] <nerv-config> <feature-rspecifier> <feature-wspecifier> [nerv4decode.lua]\n" + "e.g.: \n" + " nnet-forward config.lua ark:features.ark ark:mlpoutput.ark\n"; + + ParseOptions po(usage); + + PdfPriorOptions prior_opts; + prior_opts.Register(&po); + + bool apply_log = false; + po.Register("apply-log", &apply_log, "Transform MLP output to logscale"); + + std::string use_gpu="no"; + po.Register("use-gpu", &use_gpu, "yes|no|optional, only has effect if compiled with CUDA"); + + using namespace kaldi; + using namespace kaldi::nnet1; + typedef kaldi::int32 int32; + + int32 time_shift = 0; + po.Register("time-shift", &time_shift, "LSTM : repeat last input frame N-times, discrad N initial output frames."); + + po.Read(argc, argv); + + if (po.NumArgs() < 3) { + po.PrintUsage(); + exit(1); + } + + std::string config = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + feature_wspecifier = po.GetArg(3), + nerv4decode = "src/nerv4decode.lua"; + if(po.NumArgs() >= 4) + nerv4decode = po.GetArg(4); + + //Select the GPU +#if HAVE_CUDA==1 + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + // we will subtract log-priors later, + PdfPrior pdf_prior(prior_opts); + + kaldi::int64 tot_t = 0; + + BaseFloatMatrixWriter feature_writer(feature_wspecifier); + + CuMatrix<BaseFloat> nnet_out; + kaldi::Matrix<BaseFloat> nnet_out_host; + + lua_State *L = lua_open(); + luaL_openlibs(L); + if(luaL_loadfile(L, nerv4decode.c_str())) + KALDI_ERR << "luaL_loadfile() " << nerv4decode << " failed " << lua_tostring(L, -1); + + if(lua_pcall(L, 0, 0, 0)) + KALDI_ERR << "lua_pall failed " << lua_tostring(L, -1); + + lua_settop(L, 0); + lua_getglobal(L, "init"); + lua_pushstring(L, config.c_str()); + lua_pushstring(L, feature_rspecifier.c_str()); + if(lua_pcall(L, 2, 0, 0)) + KALDI_ERR << "lua_pcall failed " << lua_tostring(L, -1); + + Timer time; + double time_now = 0; + int32 num_done = 0; + // iterate over all feature files + for(;;){ + lua_settop(L, 0); + lua_getglobal(L, "feed"); + if(lua_pcall(L, 0, 2, 0)) + KALDI_ERR << "lua_pcall failed " << lua_tostring(L, -1); + + std::string utt = std::string(lua_tostring(L, -2)); + if(utt == "") + break; + NervMatrix *mat = *(NervMatrix **)lua_touserdata(L, -1); + + nnet_out_host.Resize(mat->nrow, mat->ncol, kUndefined); + + size_t stride = mat->stride; + for(int i = 0; i < mat->nrow; i++){ + const BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride); + BaseFloat *row = nnet_out_host.RowData(i); + memmove(row, nerv_row, sizeof(BaseFloat) * mat->ncol); + } + + KALDI_VLOG(2) << "Processing utterance " << num_done+1 + << ", " << utt + << ", " << nnet_out_host.NumRows() << "frm"; + + nnet_out.Resize(nnet_out_host.NumRows(), nnet_out_host.NumCols(), kUndefined); + nnet_out.CopyFromMat(nnet_out_host); + + if (!KALDI_ISFINITE(nnet_out.Sum())) { // check there's no nan/inf, + KALDI_ERR << "NaN or inf found in nn-output for " << utt; + } + + // convert posteriors to log-posteriors, + if (apply_log) { + if (!(nnet_out.Min() >= 0.0 && nnet_out.Max() <= 1.0)) { + KALDI_WARN << utt << " " + << "Applying 'log' to data which don't seem to be probabilities " + << "(is there a softmax somwhere?)"; + } + nnet_out.Add(1e-20); // avoid log(0), + nnet_out.ApplyLog(); + } + + // subtract log-priors from log-posteriors or pre-softmax, + if (prior_opts.class_frame_counts != "") { + if (nnet_out.Min() >= 0.0 && nnet_out.Max() <= 1.0) { + KALDI_WARN << utt << " " + << "Subtracting log-prior on 'probability-like' data in range [0..1] " + << "(Did you forget --no-softmax=true or --apply-log=true ?)"; + } + pdf_prior.SubtractOnLogpost(&nnet_out); + } + + // download from GPU, + nnet_out_host.Resize(nnet_out.NumRows(), nnet_out.NumCols()); + nnet_out.CopyToMat(&nnet_out_host); + + // time-shift, remove N first frames of LSTM output, + if (time_shift > 0) { + kaldi::Matrix<BaseFloat> tmp(nnet_out_host); + nnet_out_host = tmp.RowRange(time_shift, tmp.NumRows() - time_shift); + } + + // write, + if (!KALDI_ISFINITE(nnet_out_host.Sum())) { // check there's no nan/inf, + KALDI_ERR << "NaN or inf found in final output nn-output for " << utt; + } + feature_writer.Write(utt, nnet_out_host); + + // progress log + if (num_done % 100 == 0) { + time_now = time.Elapsed(); + KALDI_VLOG(1) << "After " << num_done << " utterances: time elapsed = " + << time_now/60 << " min; processed " << tot_t/time_now + << " frames per second."; + } + num_done++; + tot_t += nnet_out_host.NumRows(); + } + + // final message + KALDI_LOG << "Done " << num_done << " files" + << " in " << time.Elapsed()/60 << "min," + << " (fps " << tot_t/time.Elapsed() << ")"; + +#if HAVE_CUDA==1 + if (kaldi::g_kaldi_verbose_level >= 1) { + CuDevice::Instantiate().PrintProfile(); + } +#endif + lua_close(L); + if (num_done == 0) return -1; + return 0; + } catch(const std::exception &e) { + KALDI_ERR << e.what(); + return -1; + } +} |