summaryrefslogblamecommitdiff
path: root/kaldi_decode/src/nnet-forward.cc
blob: 87817053bebcb2690df302cfe96c7a68e3f86955 (plain) (tree)






















                                                                               


                                   





















                                                            
                                                                                                                           






























                                                                                                                           


                                                  

















                                                                 

                                                                                               















































































































                                                                                              
// nnetbin/nnet-forward.cc

// Copyright 2011-2013  Brno University of Technology (Author: Karel Vesely)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

extern "C"{
#include "lua.h"
#include "lauxlib.h"
#include "lualib.h"
#include "nerv/lib/matrix/matrix.h"
#include "nerv/lib/common.h"
#include "nerv/lib/luaT/luaT.h"
}

#include <limits>

#include "nnet/nnet-nnet.h"
#include "nnet/nnet-loss.h"
#include "nnet/nnet-pdf-prior.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "base/timer.h"

typedef kaldi::BaseFloat BaseFloat;
typedef struct Matrix NervMatrix;


int main(int argc, char *argv[]) {
    using namespace kaldi;
    using namespace kaldi::nnet1;
    try {
        const char *usage =
            "Perform forward pass through Neural Network.\n"
            "\n"
            "Usage:  nnet-forward [options] <nerv-config> <feature-rspecifier> <feature-wspecifier> [asr_propagator.lua]\n"
            "e.g.: \n"
            " nnet-forward config.lua ark:features.ark ark:mlpoutput.ark\n";

        ParseOptions po(usage);

        PdfPriorOptions prior_opts;
        prior_opts.Register(&po);

        bool apply_log = false;
        po.Register("apply-log", &apply_log, "Transform MLP output to logscale");

        std::string use_gpu="no";
        po.Register("use-gpu", &use_gpu, "yes|no|optional, only has effect if compiled with CUDA");

        using namespace kaldi;
        using namespace kaldi::nnet1;
        typedef kaldi::int32 int32;

        int32 time_shift = 0;
        po.Register("time-shift", &time_shift, "LSTM : repeat last input frame N-times, discrad N initial output frames.");

        po.Read(argc, argv);

        if (po.NumArgs() < 3) {
            po.PrintUsage();
            exit(1);
        }

        std::string config = po.GetArg(1),
            feature_rspecifier = po.GetArg(2),
            feature_wspecifier = po.GetArg(3),
            propagator = "src/asr_propagator.lua";
            if(po.NumArgs() >= 4)
                propagator = po.GetArg(4);

        //Select the GPU
#if HAVE_CUDA==1
        CuDevice::Instantiate().SelectGpuId(use_gpu);
#endif

        // we will subtract log-priors later,
        PdfPrior pdf_prior(prior_opts);

        kaldi::int64 tot_t = 0;

        BaseFloatMatrixWriter feature_writer(feature_wspecifier);

        CuMatrix<BaseFloat> nnet_out;
        kaldi::Matrix<BaseFloat> nnet_out_host;

        lua_State *L = lua_open();
        luaL_openlibs(L);
        if(luaL_loadfile(L, propagator.c_str()))
            KALDI_ERR << "luaL_loadfile() " << propagator << " failed " << lua_tostring(L, -1);

        if(lua_pcall(L, 0, 0, 0))
            KALDI_ERR << "lua_pall failed " << lua_tostring(L, -1);

        lua_settop(L, 0);
        lua_getglobal(L, "init");
        lua_pushstring(L, config.c_str());
        lua_pushstring(L, feature_rspecifier.c_str());
        if(lua_pcall(L, 2, 0, 0))
            KALDI_ERR << "lua_pcall failed " << lua_tostring(L, -1);

        Timer time;
        double time_now = 0;
        int32 num_done = 0;
        // iterate over all feature files
        for(;;){
            lua_settop(L, 0);
            lua_getglobal(L, "feed");
            if(lua_pcall(L, 0, 2, 0))
                KALDI_ERR << "lua_pcall failed " << lua_tostring(L, -1);

            std::string utt = std::string(lua_tostring(L, -2));
            if(utt == "")
                break;
            NervMatrix *mat = *(NervMatrix **)lua_touserdata(L, -1);

            nnet_out_host.Resize(mat->nrow, mat->ncol, kUndefined);

            size_t stride = mat->stride;
            for(int i = 0; i < mat->nrow; i++){
                const BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
                BaseFloat *row = nnet_out_host.RowData(i);
                memmove(row, nerv_row, sizeof(BaseFloat) * mat->ncol);
            }

            KALDI_VLOG(2) << "Processing utterance " << num_done+1
                << ", " << utt
                << ", " << nnet_out_host.NumRows() << "frm";

            nnet_out.Resize(nnet_out_host.NumRows(), nnet_out_host.NumCols(), kUndefined);
            nnet_out.CopyFromMat(nnet_out_host);

            if (!KALDI_ISFINITE(nnet_out.Sum())) { // check there's no nan/inf,
                KALDI_ERR << "NaN or inf found in nn-output for " << utt;
            }

            // convert posteriors to log-posteriors,
            if (apply_log) {
                if (!(nnet_out.Min() >= 0.0 && nnet_out.Max() <= 1.0)) {
                    KALDI_WARN << utt << " "
                        << "Applying 'log' to data which don't seem to be probabilities "
                        << "(is there a softmax somwhere?)";
                }
                nnet_out.Add(1e-20); // avoid log(0),
                nnet_out.ApplyLog();
            }

            // subtract log-priors from log-posteriors or pre-softmax,
            if (prior_opts.class_frame_counts != "") {
                if (nnet_out.Min() >= 0.0 && nnet_out.Max() <= 1.0) {
                    KALDI_WARN << utt << " "
                        << "Subtracting log-prior on 'probability-like' data in range [0..1] "
                        << "(Did you forget --no-softmax=true or --apply-log=true ?)";
                }
                pdf_prior.SubtractOnLogpost(&nnet_out);
            }

            // download from GPU,
            nnet_out_host.Resize(nnet_out.NumRows(), nnet_out.NumCols());
            nnet_out.CopyToMat(&nnet_out_host);

            // time-shift, remove N first frames of LSTM output,
            if (time_shift > 0) {
                kaldi::Matrix<BaseFloat> tmp(nnet_out_host);
                nnet_out_host = tmp.RowRange(time_shift, tmp.NumRows() - time_shift);
            }

            // write,
            if (!KALDI_ISFINITE(nnet_out_host.Sum())) { // check there's no nan/inf,
                KALDI_ERR << "NaN or inf found in final output nn-output for " << utt;
            }
            feature_writer.Write(utt, nnet_out_host);

            // progress log
            if (num_done % 100 == 0) {
                time_now = time.Elapsed();
                KALDI_VLOG(1) << "After " << num_done << " utterances: time elapsed = "
                    << time_now/60 << " min; processed " << tot_t/time_now
                    << " frames per second.";
            }
            num_done++;
            tot_t += nnet_out_host.NumRows();
        }

        // final message
        KALDI_LOG << "Done " << num_done << " files"
            << " in " << time.Elapsed()/60 << "min,"
            << " (fps " << tot_t/time.Elapsed() << ")";

#if HAVE_CUDA==1
        if (kaldi::g_kaldi_verbose_level >= 1) {
            CuDevice::Instantiate().PrintProfile();
        }
#endif
        lua_close(L);
        if (num_done == 0) return -1;
        return 0;
    } catch(const std::exception &e) {
        KALDI_ERR << e.what();
        return -1;
    }
}