summaryrefslogblamecommitdiff
path: root/tnet_io/cwrapper.cpp
blob: 800df2e821d541e7c4d2a21c71419cb467e9b49b (plain) (tree)
1
2
3
4
5
6
7
8
9
10









                                   
                         
 
                                                                                   






















                                                                                                   


















                                                                                               
                                                                                              





                                                                       
                                                   
                                    

                                                                                   




                                                                                  
                                                      







                                                        



                                                         













                                                                        
                    

















                                                                             
                                                            

                                                         





                                                                                    
                                                   
                                    

                                                                   










                                                                                  
                    

     
#include "KaldiLib/Features.h"
#include "KaldiLib/Labels.h"
#include "KaldiLib/Common.h"
#include "KaldiLib/UserInterface.h"
#include <string>
#define SNAME "TNET"

extern "C" {
#include "cwrapper.h"
#include "string.h"
#include "../../common.h"

    extern Matrix *nerv_matrix_host_float_new_(lua_State *L, long nrow, long ncol);

    struct TNetFeatureRepo {
        TNet::FeatureRepository feature_repo;
        TNet::UserInterface ui;
        bool swap_features;
        int target_kind;
        int deriv_order;
        int* p_deriv_win_lenghts;
        int start_frm_ext;
        int end_frm_ext;
        char* cmn_path;
        char* cmn_file;
        const char* cmn_mask;
        char* cvn_path;
        char* cvn_file;
        const char* cvn_mask;
        const char* cvg_file;
        TNet::Matrix<float> feats_host; /* KaldiLib implementation */
    };

    TNetFeatureRepo *tnet_feature_repo_new(const char *p_script, const char *config, int context) {
        TNetFeatureRepo *repo = new TNetFeatureRepo();
        repo->ui.ReadConfig(config);
        repo->swap_features = !repo->ui.GetBool(SNAME":NATURALREADORDER", TNet::IsBigEndian());
        /* load defaults */
        repo->target_kind = repo->ui.GetFeatureParams(&repo->deriv_order,
                &repo->p_deriv_win_lenghts,
                &repo->start_frm_ext, &repo->end_frm_ext,
                &repo->cmn_path, &repo->cmn_file, &repo->cmn_mask,
                &repo->cvn_path, &repo->cvn_file, &repo->cvn_mask,
                &repo->cvg_file, SNAME":", 0);
        repo->start_frm_ext = repo->end_frm_ext = context;
        repo->feature_repo.Init(repo->swap_features,
                repo->start_frm_ext, repo->end_frm_ext, repo->target_kind,
                repo->deriv_order, repo->p_deriv_win_lenghts,
                repo->cmn_path, repo->cmn_mask,
                repo->cvn_path, repo->cvn_mask, repo->cvg_file);
        repo->feature_repo.AddFileList(p_script);
        repo->feature_repo.Rewind();
        return repo;
    }

    Matrix *tnet_feature_repo_read_utterance(TNetFeatureRepo *repo, lua_State *L, int debug) {
        Matrix *mat;                    /* nerv implementation */
        repo->feature_repo.ReadFullMatrix(repo->feats_host);
        std::string utter_str = repo->feature_repo.Current().Logical();
        repo->feats_host.CheckData(utter_str);
        int n = repo->feats_host.Rows();
        int m = repo->feats_host.Cols();
        mat = nerv_matrix_host_float_new_(L, n, m);
        size_t stride = mat->stride;
        if (debug)
            fprintf(stderr, "[tnet] feature: %s %d %d\n", utter_str.c_str(), n, m);
        for (int i = 0; i < n; i++)
        {
            float *row = repo->feats_host.pRowData(i);
            float *nerv_row = (float *)((char *)mat->data.f + i * stride);
            /* use memmove to copy the row, since KaldiLib uses compact storage */
            memmove(nerv_row, row, sizeof(float) * m);
        }
        return mat;
    }

    void tnet_feature_repo_next(TNetFeatureRepo *repo) {
        repo->feature_repo.MoveNext();
    }

    int tnet_feature_repo_is_end(TNetFeatureRepo *repo) {
        return repo->feature_repo.EndOfList();
    }

    size_t tnet_feature_repo_current_samplerate(TNetFeatureRepo *repo) {
        return repo->feature_repo.CurrentHeader().mSamplePeriod;
    }

    const char *tnet_feature_repo_current_tag(TNetFeatureRepo *repo) {
        return repo->feature_repo.Current().Logical().c_str();
    }

    void tnet_feature_repo_destroy(TNetFeatureRepo *repo) {
        if (repo->cmn_mask)
            free(repo->cmn_path);
        if (repo->cvn_mask)
            free(repo->cvn_path);
        free(repo->p_deriv_win_lenghts);
        delete repo;
    }

    struct TNetLabelRepo {
        TNet::LabelRepository label_repo;
    };

    TNetLabelRepo *tnet_label_repo_new(const char *mlf, const char *fmt,
                                        const char *fmt_arg, const char *dir,
                                        const char *ext) {
        TNetLabelRepo *repo = new TNetLabelRepo();
        repo->label_repo.InitExt(mlf, fmt, fmt_arg, dir, ext);
        /* repo->label_repo.Init(mlf, fmt_arg, dir, ext); */
        return repo;
    }

    Matrix *tnet_label_repo_read_utterance(TNetLabelRepo *repo,
                                            size_t frames,
                                            size_t sample_rate,
                                            const char *tag,
                                            lua_State *L,
                                            int debug) {
        std::vector<TNet::Matrix<float> >  labs_hosts; /* KaldiLib implementation */
        Matrix *mat;
        repo->label_repo.GenDesiredMatrixExt(labs_hosts, frames,
                                            sample_rate, tag);
        int n = labs_hosts[0].Rows();
        int m = labs_hosts[0].Cols();
        mat = nerv_matrix_host_float_new_(L, n, m);
        size_t stride = mat->stride;
        if (debug)
            fprintf(stderr, "[tnet] label: %s %d %d\n", tag, n, m);
        for (int i = 0; i < n; i++)
        {
            float *row = labs_hosts[0].pRowData(i);
            float *nerv_row = (float *)((char *)mat->data.f + i * stride);
            /* use memmove to copy the row, since KaldiLib uses compact storage */
            memmove(nerv_row, row, sizeof(float) * m);
        }
        return mat;
    }

    void tnet_label_repo_destroy(TNetLabelRepo *repo) {
        delete repo;
    }
}