summaryrefslogblamecommitdiff
path: root/kaldi_io/src/cwrapper_kaldi.cpp
blob: 788128bd4301cbf2ff531632fcbd1703c948cc74 (plain) (tree)
1
2
3
4
5
6
7
8
                 
              



                                   

                                                                          



                           

                                    
 

                                                                                                           











                                                                                                      

                                                                                   









                                                                      
                                                                        
                                                    
                                                                         






















                                                                                    
                                                                
                                 

     





                                                             






                                                                                                                       


                                                                                       
         

                                                                           



                                                                                                          



                                                                               
         
                                                                        
                                       
                                                           
         
                          


                    



                                                                                   














                                                                                         
                                                                        
                                                    
                                                                         


















                                                                                            









                                                                                                  
                                                                                                       
                                                         
                                                                           




                                                                                   

                       






                                                                    
                                                                        
                                                    
                                                                         

                                         













                                                                                                          
#include <string>
#include <map>
#include "base/kaldi-common.h"
#include "hmm/posterior.h"
#include "util/table-types.h"
typedef kaldi::BaseFloat BaseFloat;
typedef std::map<std::string, kaldi::Matrix<BaseFloat> > StringToMatrix_t;
typedef std::map<std::string, std::string > StringToString_t;
extern "C" {
#include "cwrapper_kaldi.h"
#include "string.h"
#include "assert.h"
#include "nerv/lib/common.h"
#include "nerv/lib/matrix/mmatrix.h"

    extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, MContext *context, Status *status);
    extern Matrix *nerv_matrix_host_double_create(long nrow, long ncol, MContext *context, Status *status);

    struct KaldiFeatureRepo {
        kaldi::SequentialBaseFloatMatrixReader* feature_reader;
        string utt;
    };

    KaldiFeatureRepo *kaldi_feature_repo_new(const char *feature_rspecifier) {
        KaldiFeatureRepo *repo = new KaldiFeatureRepo();
        repo->feature_reader = new kaldi::SequentialBaseFloatMatrixReader(string(feature_rspecifier));
        return repo;
    }

    Matrix *kaldi_feature_repo_read_utterance(KaldiFeatureRepo *repo, lua_State *L,
                                            int debug, MContext *context) {
        Matrix *mat;                    /* nerv implementation */

        repo->utt = repo->feature_reader->Key();
        kaldi::Matrix<BaseFloat> kmat = repo->feature_reader->Value();

        int n = kmat.NumRows();
        int m = kmat.NumCols();
        Status status;
        assert(sizeof(BaseFloat) == sizeof(float));
        if(sizeof(BaseFloat) == sizeof(float))
            mat = nerv_matrix_host_float_create(n, m, context, &status);
        else if(sizeof(BaseFloat) == sizeof(double))
            mat = nerv_matrix_host_double_create(n, m, context, &status);
        NERV_LUA_CHECK_STATUS(L, status);
        size_t stride = mat->stride;
        if (debug)
            fprintf(stderr, "[kaldi] feature: %s %d %d\n", repo->utt.c_str(), n, m);

        for (int i = 0; i < n; i++)
        {
            const BaseFloat *row = kmat.RowData(i);
            BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
            /* use memmove to copy the row, since KaldiLib uses compact storage */
            memmove(nerv_row, row, sizeof(BaseFloat) * m);
        }
        return mat;
    }

    void kaldi_feature_repo_next(KaldiFeatureRepo *repo) {
        repo->feature_reader->Next();
    }

    int kaldi_feature_repo_is_end(KaldiFeatureRepo *repo) {
        return repo->feature_reader->Done();
    }

    const char *kaldi_feature_repo_key(KaldiFeatureRepo *repo) {
        return repo->utt.c_str();
    }

    void kaldi_feature_repo_destroy(KaldiFeatureRepo *repo) {
        if (repo->feature_reader)
            delete repo->feature_reader;
        delete repo;
    }

    struct KaldiLookupFeatureRepo {
        StringToMatrix_t key2mat;
        StringToString_t map;
    };

    KaldiLookupFeatureRepo *kaldi_lookup_feature_repo_new(const char *feature_rspecifier, const char *map_rspecifier) {
        KaldiLookupFeatureRepo *repo = new KaldiLookupFeatureRepo();
        kaldi::SequentialBaseFloatMatrixReader *feature_reader = \
                new kaldi::SequentialBaseFloatMatrixReader(string(feature_rspecifier));
        for (; !feature_reader->Done(); feature_reader->Next())
        {
            const std::string &key = feature_reader->Key();
            const kaldi::Matrix<BaseFloat> &feat = feature_reader->Value();
            if (repo->key2mat.find(key) != repo->key2mat.end())
                fprintf(stderr, "[kaldi] warning: lookup feature for key %s already exists", key.c_str());
            repo->key2mat[key] = feat;
        }
        delete feature_reader;
        kaldi::SequentialTokenVectorReader *map_reader = \
                new kaldi::SequentialTokenVectorReader(string(map_rspecifier));
        for (; !map_reader->Done(); map_reader->Next())
        {
            const std::vector<std::string> target = map_reader->Value();
            assert(target.size() >= 1);
            repo->map[map_reader->Key()] = *target.begin();
        }
        delete map_reader;
        return repo;
    }

    Matrix *kaldi_lookup_feature_repo_read_utterance(KaldiLookupFeatureRepo *repo,
                                                    KaldiFeatureRepo *frepo,
                                                    int nframes, lua_State *L,
                                                    int debug, MContext *context) {
        Matrix *mat;                    /* nerv implementation */
        StringToString_t::iterator mit = repo->map.find(frepo->utt);
        if