#include "Labels.h"
#include "Timer.h"
#include "Error.h"
#include <cstdio>
#include <sstream>


namespace TNet {


    ////////////////////////////////////////////////////////////////////////
    // Class LabelRepository::
    void
        LabelRepository::
        Init(const char* pLabelMlfFile, const char* pOutputLabelMapFile, const char* pLabelDir, const char* pLabelExt) {
            InitMap(pLabelMlfFile, pOutputLabelMapFile, pLabelDir, pLabelExt);
        }
    void
        LabelRepository::
        InitExt(const char* pLabelMlfFile, const char* fmt, const char *arg, const char* pLabelDir, const char* pLabelExt) {
            if (strcmp(fmt, "map") == 0)
            {
                InitMap(pLabelMlfFile, arg, pLabelDir, pLabelExt);
                mlf_fmt = MAP;
            }
            else if (strcmp(fmt, "raw") == 0)
            {
                InitRaw(pLabelMlfFile, arg, pLabelDir, pLabelExt);
                mlf_fmt = RAW;
            }
        }

    void
        LabelRepository::
        InitMap(const char* pLabelMlfFile, const char* pOutputLabelMapFile, const char* pLabelDir, const char* pLabelExt)
        {
            assert(NULL != pLabelMlfFile);
            assert(NULL != pOutputLabelMapFile);

            // initialize the label streams
            delete mpLabelStream; //if NULL, does nothing
            delete _mpLabelStream;
            _mpLabelStream = new std::ifstream(pLabelMlfFile);
            mpLabelStream  = new IMlfStream(*_mpLabelStream);

            // Label stream is initialized, just test it
            if(!mpLabelStream->good()) 
                Error(std::string("Cannot open Label MLF file: ")+pLabelMlfFile);

            // Index the labels (good for randomized file lists)
            Timer tim; tim.Start();
            mpLabelStream->Index();
            tim.End(); mIndexTime += tim.Val(); 

            // Read the state-label to state-id map
            ReadOutputLabelMap(pOutputLabelMapFile);

            // Store the label dir/ext
            mpLabelDir = pLabelDir;
            mpLabelExt = pLabelExt;
        }

    void
        LabelRepository::
        InitRaw(const char* pLabelMlfFile, const char *arg, const char* pLabelDir, const char* pLabelExt)
        {
            assert(NULL != pLabelMlfFile);
            std::istringstream iss(arg); 
            size_t dim;
            iss >> dim;
            if (iss.fail() || dim <= 0)
                PError("[lab] malformed dimension specification");
            raw_dim = dim;
            // initialize the label streams
            delete mpLabelStream; //if NULL, does nothing
            delete _mpLabelStream;
            _mpLabelStream = new std::ifstream(pLabelMlfFile);
            mpLabelStream  = new IMlfStream(*_mpLabelStream);

            // Label stream is initialized, just test it
            if(!mpLabelStream->good()) 
                Error(std::string("Cannot open Label MLF file: ")+pLabelMlfFile);

            // Index the labels (good for randomized file lists)
            Timer tim; tim.Start();
            mpLabelStream->Index();
            tim.End(); mIndexTime += tim.Val(); 

            // Read the state-label to state-id map
            //ReadOutputLabelMap(pOutputLabelMapFile);

            // Store the label dir/ext
            mpLabelDir = pLabelDir;
            mpLabelExt = pLabelExt;
        }


    void 
        LabelRepository::
        GenDesiredMatrix(BfMatrix& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical, bool has_vad)
        {
            //timer
            Timer tim; tim.Start();

            //Get the MLF stream reference...
            IMlfStream& mLabelStream = *mpLabelStream;
            //Build the file name of the label
            MakeHtkFileName(mpLabelFile, pFeatureLogical, mpLabelDir, mpLabelExt);

            //Find block in MLF file
            mLabelStream.Open(mpLabelFile);
            if(!mLabelStream.good()) {
                Error(std::string("Cannot open label MLF record: ") + mpLabelFile);
            }


            //resize the matrix
            if(nFrames < 1) {
                KALDI_ERR << "Number of frames:" << nFrames << " is lower than 1!!!\n"
                    << pFeatureLogical;
            }
            int label_map_size = mLabelMap.size();
            rDesired.Init(nFrames, label_map_size + (has_vad ? 2 : 0), true); //true: Zero()

            //aux variables
            std::string line, state;
            unsigned long long beg, end;
            size_t state_index;
            size_t trunc_frames = 0;
            TagToIdMap::iterator it;
            int vad_state;

            //parse the label file
            while(!mLabelStream.eof()) {
                std::getline(mLabelStream, line);
                if(line == "") continue; //skip newlines/comments from MLF
                if(line[0] == '#') continue;

                std::istringstream& iss = mGenDesiredMatrixStream;
                iss.clear();
                iss.str(line);

                //parse the line
                //begin
                iss >> std::ws >> beg;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 1 (begin)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }
                //end
                iss >> std::ws >> end;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 2 (end)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }
                //state tag
                iss >> std::ws >> state;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 3 (state_tag)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }

                if (has_vad) /* an additional column for vad */
                {
                    iss >> std::ws >> vad_state;
                    if(iss.fail()) { 
                        KALDI_ERR << "Cannot parse column 4 (vad_state)\n"
                            << "line: " << line << "\n"
                            << "file: " << mpLabelFile << "\n";
                    }
                }

                //fprintf(stderr, "Parsed: %lld %lld %s\n", beg, end, state.c_str());

                //divide beg/end by sourceRate and round up to get interval of frames
                beg = (beg+sourceRate/2)/sourceRate;
                end = (end+sourceRate/2)/sourceRate; 
                //beg = (int)round(beg / (double)sourceRate);
                //end = (int)round(end / (double)sourceRate); 

                //find the state id
                it = mLabelMap.find(state);
                if(mLabelMap.end() == it) {
                    Error(std::string("Unknown state tag: '") + state + "' file:'" + mpLabelFile);
                }
                state_index = it->second;

                // Fill the desired matrix
                for(unsigned long long frame=beg; frame<end; frame++) { 
                    //don't write after matrix... (possible longer transcript than feature file)
                    if(frame >= (int)rDesired.Rows()) { trunc_frames++; continue; }

                    //check the next frame is empty:
                    if(0.0 != rDesired[frame].Sum()) {
                        //ERROR!!!
                        //find out what was previously filled!!!
                        BaseFloat max = rDesired[frame].Max();
                        int idx = -1;
                        for(int i=0; i<(int)rDesired[frame].Dim(); i++) { 
                            if(rDesired[frame][i] == max) idx = i; 
                        }
                        for(it=mLabelMap.begin(); it!=mLabelMap.end(); ++it) {
                            if((int)it->second == idx) break;
                        }
                        std::string state_prev = "error";
                        if(it != mLabelMap.end()) {
                            state_prev = it->first;
                        }
                        //print the error message
                        std::ostringstream os; 
                        os << "Frame already assigned to other state, "
                            << " file: " << mpLabelFile 
                            << " frame: " << frame
                            << " nframes: " << nFrames 
                            << " sum: " << rDesired[frame].Sum()  
                            << " previously assigned to: " << state_prev << "(" << idx << ")" 
                            << " now should be assigned to: " << state << "(" << state_index << ")"
                            << "\n";
                        Error(os.str());
                    }

                    //fill the row
                    rDesired[(size_t)frame][state_index] = 1.0f;
                    if (has_vad)
                        rDesired[(size_t)frame][label_map_size + !!vad_state] = 1.0f;
                }
            }

            mLabelStream.Close();

            //check the desired matrix (rows sum up to 1.0)
            for(size_t i=0; i<rDesired.Rows(); ++i) {
                float desired_row_sum = rDesired[i].Sum();
                if(!desired_row_sum == 1.0) {
                    std::ostringstream os;
                    os << "Desired vector sum isn't 1.0, "
                        << " file: " << mpLabelFile 
                        << " row: " << i 
                        << " nframes: " << nFrames 
                        << " content: " << rDesired[i] 
                        << " sum: " << desired_row_sum  << "\n";
                    Error(os.str());
                }
            }

            //warning when truncating many frames
            if(trunc_frames > 10) {
                std::ostringstream os;
                os << "Truncated frames: " << trunc_frames 
                    << " Check sourcerate in features and validity of labels\n";
                Warning(os.str());
            }

            //timer
            tim.End(); mGenDesiredMatrixTime += tim.Val();
        }



    void
        LabelRepository::
        ReadOutputLabelMap(const char* file)
        {
            assert(mLabelMap.size() == 0);
            int i = 0;
            std::string state_tag;
            std::ifstream in(file);
            if(!in.good())
                Error(std::string("Cannot open OutputLabelMapFile: ")+file);

            in >> std::ws;
            while(!in.eof()) {
                in >> state_tag;
                in >> std::ws;
                assert(mLabelMap.find(state_tag) == mLabelMap.end());
                mLabelMap[state_tag] = i++;
            }

            in.close();
            assert(mLabelMap.size() > 0);
        }


    void 
        LabelRepository::
        GenDesiredMatrixExt(std::vector<BfMatrix>& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical) {
            switch (mlf_fmt)
            {
                case MAP: GenDesiredMatrixExtMap(rDesired, nFrames, sourceRate, pFeatureLogical);
                            break;
                case RAW: GenDesiredMatrixExtRaw(rDesired, nFrames, sourceRate, pFeatureLogical);
                          break;
                default: assert(0);
            }
        }

    void 
        LabelRepository::
        GenDesiredMatrixExtMap(std::vector<BfMatrix>& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical)
        {
            //timer
            Timer tim; tim.Start();

            //Get the MLF stream reference...
            IMlfStream& mLabelStream = *mpLabelStream;
            //Build the file name of the label
            MakeHtkFileName(mpLabelFile, pFeatureLogical, mpLabelDir, mpLabelExt);

            //Find block in MLF file
            mLabelStream.Open(mpLabelFile);
            if(!mLabelStream.good()) {
                Error(std::string("Cannot open label MLF record: ") + mpLabelFile);
            }


            //resize the matrix
            if(nFrames < 1) {
                KALDI_ERR << "Number of frames:" << nFrames << " is lower than 1!!!\n"
                    << pFeatureLogical;
            }

            size_t prev = rDesired.size();
            rDesired.resize(prev + 1, BfMatrix()); /* state + vad */
            int label_map_size = mLabelMap.size();
            rDesired[prev].Init(nFrames, 1, true); //true: Zero()

            //aux variables
            std::string line, state;
            unsigned long long beg, end;
            size_t state_index;
            size_t trunc_frames = 0;
            TagToIdMap::iterator it;

            //parse the label file
            while(!mLabelStream.eof()) {
                std::getline(mLabelStream, line);
                if(line == "") continue; //skip newlines/comments from MLF
                if(line[0] == '#') continue;

                std::istringstream& iss = mGenDesiredMatrixStream;
                iss.clear();
                iss.str(line);

                //parse the line
                //begin
                iss >> std::ws >> beg;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 1 (begin)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }
                //end
                iss >> std::ws >> end;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 2 (end)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }
                //state tag
                iss >> std::ws >> state;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 3 (state_tag)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }


                //fprintf(stderr, "Parsed: %lld %lld %s\n", beg, end, state.c_str());

                //divide beg/end by sourceRate and round up to get interval of frames
                beg = (beg+sourceRate/2)/sourceRate;
                if (end == (unsigned long long)-1)
                    end = rDesired[prev].Rows();
                else
                    end = (end+sourceRate/2)/sourceRate; 
                //beg = (int)round(beg / (double)sourceRate);
                //end = (int)round(end / (double)sourceRate); 

                //find the state id
                it = mLabelMap.find(state);
                if(mLabelMap.end() == it) {
                    Error(std::string("Unknown state tag: '") + state + "' file:'" + mpLabelFile);
                }
                state_index = it->second;

                // Fill the desired matrix
                for(unsigned long long frame=beg; frame<end; frame++) { 
                    //don't write after matrix... (possible longer transcript than feature file)
                    if(frame >= (int)rDesired[prev].Rows()) { trunc_frames++; continue; }

                    //check the next frame is empty:
                    if(0.0 != rDesired[prev][frame][0]) {
                        //ERROR!!!
                        //find out what was previously filled!!!
                        /*
                        BaseFloat max = rDesired[prev][frame].Max();
                        int idx = -1;
                        for(int i=0; i<(int)rDesired[prev][frame].Dim(); i++) { 
                            if(rDesired[prev][frame][i] == max) idx = i; 
                        }
                        */
                        BaseFloat max = rDesired[prev][frame][0];
                        int idx = round(max);
                        for(it=mLabelMap.begin(); it!=mLabelMap.end(); ++it) {
                            if((int)it->second == idx) break;
                        }
                        std::string state_prev = "error";
                        if(it != mLabelMap.end()) {
                            state_prev = it->first;
                        }
                        //print the error message
                        std::ostringstream os; 
                        os << "Frame already assigned to other state, "
                            << " file: " << mpLabelFile 
                            << " frame: " << frame
                            << " nframes: " << nFrames 
                            << " sum: " << max
                            << " previously assigned to: " << state_prev << "(" << idx << ")" 
                            << " now should be assigned to: " << state << "(" << state_index << ")"
                            << "\n";
                        Error(os.str());
                    }

                    //fill the row
                    //rDesired[prev][(size_t)frame][state_index] = 1.0f;
                    rDesired[prev][(size_t)frame][0] = state_index;
                }
            }

            mLabelStream.Close();
            /*
            //check the desired matrix (rows sum up to 1.0)
            for(size_t i=0; i<rDesired[prev].Rows(); ++i) {
                float desired_row_sum = rDesired[prev][i].Sum();
                if(!desired_row_sum == 1.0) {
                    std::ostringstream os;
                    os << "Desired vector sum isn't 1.0, "
                        << " file: " << mpLabelFile 
                        << " row: " << i 
                        << " nframes: " << nFrames 
                        << " content: " << rDesired[prev][i] 
                        << " sum: " << desired_row_sum  << "\n";
                    Error(os.str());
                }
            }
            */

            //warning when truncating many frames
            if(trunc_frames > 10) {
                std::ostringstream os;
                os << "Truncated frames: " << trunc_frames 
                    << " Check sourcerate in features and validity of labels\n";
                Warning(os.str());
            }

            //timer
            tim.End(); mGenDesiredMatrixTime += tim.Val();
        }

    void 
        LabelRepository::
        GenDesiredMatrixExtRaw(std::vector<BfMatrix>& rDesired, size_t nFrames, size_t sourceRate, const char* pFeatureLogical)
        {
            //timer
            Timer tim; tim.Start();

            //Get the MLF stream reference...
            IMlfStream& mLabelStream = *mpLabelStream;
            //Build the file name of the label
            MakeHtkFileName(mpLabelFile, pFeatureLogical, mpLabelDir, mpLabelExt);

            //Find block in MLF file
            mLabelStream.Open(mpLabelFile);
            if(!mLabelStream.good()) {
                Error(std::string("Cannot open label MLF record: ") + mpLabelFile);
            }


            //resize the matrix
            if(nFrames < 1) {
                KALDI_ERR << "Number of frames:" << nFrames << " is lower than 1!!!\n"
                    << pFeatureLogical;
            }

            size_t prev = rDesired.size();
            rDesired.resize(prev + 1, BfMatrix()); /* state + vad */
            rDesired[prev].Init(nFrames, raw_dim, true); //true: Zero()

            //aux variables
            std::string line, state;
            unsigned long long beg, end;
            size_t trunc_frames = 0;
            Vector<BaseFloat> raw;
            raw.Init(raw_dim);

            //parse the label file
            while(!mLabelStream.eof()) {
                std::getline(mLabelStream, line);
                if(line == "") continue; //skip newlines/comments from MLF
                if(line[0] == '#') continue;

                std::istringstream& iss = mGenDesiredMatrixStream;
                iss.clear();
                iss.str(line);

                //parse the line
                //begin
                iss >> std::ws >> beg;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 1 (begin)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }
                //end
                iss >> std::ws >> end;
                if(iss.fail()) { 
                    KALDI_ERR << "Cannot parse column 2 (end)\n"
                        << "line: " << line << "\n"
                        << "file: " << mpLabelFile << "\n";
                }

                for (size_t i = 0; i < raw_dim; i++)
                {
                    if (iss.eof())
                        PError("[label] insufficient columns for the label: %s", mpLabelFile);
                    iss >> raw[i];
                    if (iss.fail())
                        PError("[label] cannot parse raw value for the label: %s", mpLabelFile);
                }
                /*
                   for (size_t i = 0; i < raw_dim; i++)
                   fprintf(stderr, "%.3f", raw[i]);
                   fprintf(stderr, "\n");
                   */
                //divide beg/end by sourceRate and round up to get interval of frames
                beg = (beg+sourceRate/2)/sourceRate;
                if (end == (unsigned long long)-1)
                    end = rDesired[prev].Rows();
                else
                    end = (end+sourceRate/2)/sourceRate; 
                //printf("end:%lld\n", end);
                //beg = (int)round(beg / (double)sourceRate);
                //end = (int)round(end / (double)sourceRate); 

                // Fill the desired matrix
                for(unsigned long long frame=beg; frame<end; frame++) { 
                    //don't write after matrix... (possible longer transcript than feature file)
                    if(frame >= (int)rDesired[prev].Rows()) { trunc_frames++; continue; }

                    //check the next frame is empty:
                    if(0.0 != rDesired[prev][frame][0]) {
                        //ERROR!!!
                        //find out what was previously filled!!!
                        /*
                           BaseFloat max = rDesired[prev][frame].Max();
                           int idx = -1;
                           for(int i=0; i<(int)rDesired[prev][frame].Dim(); i++) { 
                           if(rDesired[prev][frame][i] == max) idx = i; 
                           }
                           */
                        BaseFloat max = rDesired[prev][frame][0];
                        //print the error message
                        std::ostringstream os; 
                        os << "Frame already assigned to other state, "
                            << " file: " << mpLabelFile 
                            << " frame: " << frame
                            << " nframes: " << nFrames 
                            << " sum: " << max
                            << "\n";
                        Error(os.str());
                    }

                    //fill the row
                    //rDesired[prev][(size_t)frame][state_index] = 1.0f;
                    rDesired[prev][(size_t)frame].Copy(raw);
                }
            }

            mLabelStream.Close();
            /*
            //check the desired matrix (rows sum up to 1.0)
            for(size_t i=0; i<rDesired[prev].Rows(); ++i) {
            float desired_row_sum = rDesired[prev][i].Sum();
            if(!desired_row_sum == 1.0) {
            std::ostringstream os;
            os << "Desired vector sum isn't 1.0, "
            << " file: " << mpLabelFile 
            << " row: " << i 
            << " nframes: " << nFrames 
            << " content: " << rDesired[prev][i] 
            << " sum: " << desired_row_sum  << "\n";
            Error(os.str());
            }
            }
            */

            //warning when truncating many frames
            if(trunc_frames > 10) {
                std::ostringstream os;
                os << "Truncated frames: " << trunc_frames 
                    << " Check sourcerate in features and validity of labels\n";
                Warning(os.str());
            }

            //timer
            tim.End(); mGenDesiredMatrixTime += tim.Val();
        }

}//namespace