summaryrefslogtreecommitdiff
path: root/kaldi_io/src/cwrapper_kaldi.cpp
blob: 3dd055f8f304c4884205875c88cf7db6637982ce (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include <string>
#include "kaldi/base/kaldi-common.h"
#include "kaldi/hmm/posterior.h"
#include "kaldi/util/table-types.h"
typedef kaldi::BaseFloat BaseFloat;

extern "C" {
#include "cwrapper_kaldi.h"
#include "string.h"
#include "assert.h"
#include "nerv/common.h"

    extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, Status *status);
    extern Matrix *nerv_matrix_host_double_create(long nrow, long ncol, Status *status);

    struct KaldiFeatureRepo {
        kaldi::SequentialBaseFloatMatrixReader* feature_reader;
        string utt;
    };

    KaldiFeatureRepo *kaldi_feature_repo_new(const char *feature_rspecifier) {
        KaldiFeatureRepo *repo = new KaldiFeatureRepo();
        repo->feature_reader = new kaldi::SequentialBaseFloatMatrixReader(string(feature_rspecifier));
        return repo;
    }

    Matrix *kaldi_feature_repo_read_utterance(KaldiFeatureRepo *repo, lua_State *L, int debug) {
        Matrix *mat;                    /* nerv implementation */

        repo->utt = repo->feature_reader->Key();
        kaldi::Matrix<BaseFloat> kmat = repo->feature_reader->Value();

        int n = kmat.NumRows();
        int m = kmat.NumCols();
        Status status;
        assert(sizeof(BaseFloat) == sizeof(float));
        if(sizeof(BaseFloat) == sizeof(float))
            mat = nerv_matrix_host_float_create(n, m, &status);
        else if(sizeof(BaseFloat) == sizeof(double))
            mat = nerv_matrix_host_double_create(n, m, &status);
        NERV_LUA_CHECK_STATUS(L, status);
        size_t stride = mat->stride;
        if (debug)
            fprintf(stderr, "[kaldi] feature: %s %d %d\n", repo->utt.c_str(), n, m);

        for (int i = 0; i < n; i++)
        {
            const BaseFloat *row = kmat.RowData(i);
            BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
            /* use memmove to copy the row, since KaldiLib uses compact storage */
            memmove(nerv_row, row, sizeof(BaseFloat) * m);
        }
        return mat;
    }

    void kaldi_feature_repo_next(KaldiFeatureRepo *repo) {
        repo->feature_reader->Next();
    }

    int kaldi_feature_repo_is_end(KaldiFeatureRepo *repo) {
        return repo->feature_reader->Done();
    }

    void kaldi_feature_repo_destroy(KaldiFeatureRepo *repo) {
        if (repo->feature_reader)
            delete repo->feature_reader;
        delete repo;
    }

    struct KaldiLabelRepo {
        kaldi::RandomAccessPosteriorReader *targets_reader;
    };

    KaldiLabelRepo *kaldi_label_repo_new(const char *targets_rspecifier, const char *fmt) {
        KaldiLabelRepo *repo = new KaldiLabelRepo();
        repo->targets_reader = new kaldi::RandomAccessPosteriorReader(string(targets_rspecifier));
        return repo;
    }

    Matrix *kaldi_label_repo_read_utterance(KaldiLabelRepo *repo, KaldiFeatureRepo *frepo, int frm_ext, int nframes,
                                            lua_State *L,
                                            int debug) {
        Matrix *mat;
        kaldi::Posterior targets = repo->targets_reader->Value(frepo->utt);

        int n = targets.size() < nframes ? targets.size() : nframes;
        int m = (int)targets[0].size();

        Status status;
        assert(sizeof(BaseFloat) == sizeof(float));
        if(sizeof(BaseFloat) == sizeof(float))
            mat = nerv_matrix_host_float_create(n, m, &status);
        else if(sizeof(BaseFloat) == sizeof(double))
            mat = nerv_matrix_host_double_create(n, m, &status);
        NERV_LUA_CHECK_STATUS(L, status);
        size_t stride = mat->stride;

        if (debug)
            fprintf(stderr, "[kaldi] label: %s %d %d\n", frepo->utt.c_str(), n, m);
        for (int i = 0; i < n; i++)
            for(int j = 0; j < m; j++)
                *((BaseFloat *)((char *)mat->data.f + (i * stride + j))) = (BaseFloat)targets[i][j].first;
        return mat;
    }

    void kaldi_label_repo_destroy(KaldiLabelRepo *repo) {
        if(repo->targets_reader)
            delete repo->targets_reader;
        delete repo;
    }
}