summaryrefslogblamecommitdiff
path: root/htk_io/src/KaldiLib/Labels.cc
blob: 8b04cdec81af413fa03939d48c33474fd5ab3208 (plain) (tree)





















































































































































































































































































































































































                                                                                                                               



                                                         









































































































































































































































                                                                                                                               
#include "Labels.h"
#include "Timer.h"
#include "Error.h"
#include <cstdio>
#include <sstream>


namespace TNet {


    ////////////////////////////////////////////////////////////////////////
    // Class LabelRepository::
    void
        LabelRepository::
        Init(const char* pLabelMlfFile, const char* pOutputLabelMapFile, const char* pLabelDir, const char* pLabelExt) {
            InitMap(pLabelMlfFile, pOutputLabelMapFile, pLabelDir, pLabelExt);
        }
    void
        LabelRepository::
        InitExt(const char* pLabelMlfFile, const char* fmt, const char *arg, const char* pLabelDir, const char* pLabelExt) {
            if (strcmp(fmt, "map") == 0)
            {
                InitMap(pLabelMlfFile, arg, pLabelDir, pLabelExt);
                mlf_fmt = MAP;
            }
            else if (strcmp(fmt, "raw") == 0)
            {
                InitRaw(pLabelMlfFile, arg, pLabelDir, pLabelExt);
                mlf_fmt = RAW;
            }
        }

    void
        LabelRepository::
        InitMap(const char* pLabelMlfFile, const char* pOutputLabelMapFile, const char* pLabelDir, const char* pLabelExt)
        {
            assert(NULL != pLabelMlfFile);
            assert(NULL != pOutputLabelMapFile);

            // initialize the label streams
            delete mpLabelStream; //if NULL, does nothing
            delete _mpLabelStream;
            _mpLabelStream = new std::ifstream(pLabelMlfFile);
            mpLabelStream  = new IMlfStream(*_mpLabelStream);

            // Label stream is initialized, just test it
            if(!mpLabelStream->good()) 
                Error(std::string("Cannot open Label MLF file: ")+pLabelMlfFile);

            // Index the labels (good for randomized file lists)
            Timer tim; tim.Start();
            mpLabelStream->Index();
            tim.End(); mIndexTime += tim.Val(); 

            // Read the state-label to state-id map
            ReadOutputLabelMap(pOutputLabelMapFile);

            // Store the label dir/ext
            mpLabelDir = pLabelDir;
            mpLabelExt = pLabelExt;
        }

    void
        LabelRepository::
        InitRaw(const char* pLabelMlfFile, const char *arg, const char* pLabelDir, const char* pLabelExt)
        {
            assert(NULL != pLabelMlfFile);
            std::istringstream iss(arg); 
            size_t dim;
            iss >> dim;
            if (iss.fail() || dim <= 0)
                PError("[lab] malformed dimension specification");
            raw_dim = dim;
            // initialize the label streams
            delete mpLabelStream; //if NULL, does nothing
            delete _mpLabelStream;
            _mpLabelStream = new std::ifstream(pLabelMlfFile);
            mpLabelStream  = new IMlfStream(*_mpLabelStream);

            // Label stream is initialized, just test it
            if(!mpLabelStream->good()) 
                Error(std::string("Cannot open Label MLF file: ")+pLabelMlfFile);

            // Index the labels (good for randomized file lists)
            Timer tim; tim.Start();
            mpLabelStream->Index();
            tim.End(); mIndexTime += tim.Val(); 

            // Read the state-label to state-id map
            //ReadOutputLabelMap(pOutputLabelMapFile);

            // Store the label dir/ext
            mpLabelDir = pLabelDir;
            mpLabelExt = pLabelExt;
        }


    void 
        LabelRepository::
        GenDesiredMatrix(BfMatrix& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical, bool has_vad)
        {
            //timer
            Timer tim; tim.Start();

            //Get the MLF stream reference...
            IMlfStream& mLabelStream = *mpLabelStream;
            //Build the file name of the label
            MakeHtkFileName(mpLabelFile, pFeatureLogical, mpLabelDir, mpLabelExt);

            //Find block in MLF file
            mLabelStream.Open(mpLabelFile);
            if(!mLabelStream.good()) {
                Error(std::string("Cannot open label MLF record: ") + mpLabelFile);
            }


            //resize the matrix
            if(nFrames < 1) {
                KALDI_ERR << "Number of frames:" << nFrames << " is lower than 1!!!\n"
                    << pFeatureLogical;
            }
            int label_map_size = mLabelMap.size();
            rDesired.Init(nFrames, label_map_size + (has_vad ? 2 : 0), true); //true: Zero()

            //aux variables
            std::string line, state;
            unsigned long long beg, end;
            size_t state_index;
            size_t trunc_frames = 0;
            TagToIdMap::iterator it;
            int vad_state;

            //parse the label file
            while(!mLabelStream.eof()) {
                std::getline(mLabelStream, line);
                if(line == "") continue; //skip newlines/comments from MLF
                if(line[0] == '#') continue;

                std::istringstream& iss = mGenDesiredMatrixStream;
                iss.clear();
                iss.str(line);

                //parse the line
                //begin
                iss >> std::ws >> beg;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 1 (begin)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }
                //end
                iss >> std::ws >> end;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 2 (end)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }
                //state tag
                iss >> std::ws >> state;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 3 (state_tag)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }

                if (has_vad) /* an additional column for vad */
                {
                    iss >> std::ws >> vad_state;
                    if(iss.fail()) { 
                        KALDI_ERR << "Cannot parse column 4 (vad_state)\n"
                            << "line: " << line << "\n"
                            << "file: " << mpLabelFile << "\n";
                    }
                }

                //fprintf(stderr, "Parsed: %lld %lld %s\n", beg, end, state.c_str());

                //divide beg/end by sourceRate and round up to get interval of frames
                beg = (beg+sourceRate/2)/sourceRate;
                end = (end+sourceRate/2)/sourceRate; 
                //beg = (int)round(beg / (double)sourceRate);
                //end = (int)round(end / (double)sourceRate); 

                //find the state id
                it =