summaryrefslogtreecommitdiff
path: root/kaldi_io/src/cwrapper_kaldi.cpp
blob: 83331ceccf8285f2e58daedb8afac4da4702cfab (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#include <string>
#include <map>
#include "base/kaldi-common.h"
#include "hmm/posterior.h"
#include "util/table-types.h"
typedef kaldi::BaseFloat BaseFloat;
typedef std::map<std::string, kaldi::Matrix<BaseFloat> > StringToMatrix_t;
typedef std::map<std::string, std::string > StringToString_t;
extern "C" {
#include "cwrapper_kaldi.h"
#include "string.h"
#include "assert.h"
#include "nerv/lib/common.h"
#include "nerv/lib/matrix/mmatrix.h"

    extern Matrix *nerv_matrix_host_float_create(long nrow, long ncol, MContext *context, Status *status);
    extern Matrix *nerv_matrix_host_double_create(long nrow, long ncol, MContext *context, Status *status);

    struct KaldiFeatureRepo {
        kaldi::SequentialBaseFloatMatrixReader* feature_reader;
        string utt;
    };

    KaldiFeatureRepo *kaldi_feature_repo_new(const char *feature_rspecifier) {
        KaldiFeatureRepo *repo = new KaldiFeatureRepo();
        repo->feature_reader = new kaldi::SequentialBaseFloatMatrixReader(string(feature_rspecifier));
        return repo;
    }

    Matrix *kaldi_feature_repo_read_utterance(KaldiFeatureRepo *repo, lua_State *L,
                                            int debug, MContext *context) {
        Matrix *mat;                    /* nerv implementation */

        repo->utt = repo->feature_reader->Key();
        kaldi::Matrix<BaseFloat> kmat = repo->feature_reader->Value();

        int n = kmat.NumRows();
        int m = kmat.NumCols();
        Status status;
        assert(sizeof(BaseFloat) == sizeof(float));
        if(sizeof(BaseFloat) == sizeof(float))
            mat = nerv_matrix_host_float_create(n, m, context, &status);
        else if(sizeof(BaseFloat) == sizeof(double))
            mat = nerv_matrix_host_double_create(n, m, context, &status);
        NERV_LUA_CHECK_STATUS(L, status);
        size_t stride = mat->stride;
        if (debug)
            fprintf(stderr, "[kaldi] feature: %s %d %d\n", repo->utt.c_str(), n, m);

        for (int i = 0; i < n; i++)
        {
            const BaseFloat *row = kmat.RowData(i);
            BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
            /* use memmove to copy the row, since KaldiLib uses compact storage */
            memmove(nerv_row, row, sizeof(BaseFloat) * m);
        }
        return mat;
    }

    void kaldi_feature_repo_next(KaldiFeatureRepo *repo) {
        repo->feature_reader->Next();
    }

    int kaldi_feature_repo_is_end(KaldiFeatureRepo *repo) {
        return repo->feature_reader->Done();
    }

    const char *kaldi_feature_repo_key(KaldiFeatureRepo *repo) {
        return repo->utt.c_str();
    }

    void kaldi_feature_repo_destroy(KaldiFeatureRepo *repo) {
        if (repo->feature_reader)
            delete repo->feature_reader;
        delete repo;
    }

    struct KaldiLookupFeatureRepo {
        StringToMatrix_t key2mat;
        StringToString_t map;
    };

    KaldiLookupFeatureRepo *kaldi_lookup_feature_repo_new(const char *feature_rspecifier, const char *map_rspecifier) {
        KaldiLookupFeatureRepo *repo = new KaldiLookupFeatureRepo();
        kaldi::SequentialBaseFloatMatrixReader feature_reader = kaldi::SequentialBaseFloatMatrixReader(string(feature_rspecifier));
        for (; !feature_reader.Done(); feature_reader.Next())
        {
            const std::string &key = feature_reader.Key();
            const kaldi::Matrix<BaseFloat> &feat = feature_reader.Value();
            if (repo->key2mat.find(key) != repo->key2mat.end())
                fprintf(stderr, "[kaldi] warning: lookup feature for key %s already exists", key.c_str());
            repo->key2mat[key] = feat;
        }
        kaldi::SequentialTokenVectorReader map_reader = kaldi::SequentialTokenVectorReader(string(map_rspecifier));
        for (; !map_reader.Done(); map_reader.Next())
        {
            const std::vector<std::string> target = map_reader.Value();
            assert(target.size() >= 1);
            repo->map[map_reader.Key()] = *target.begin();
        }
        return repo;
    }

    Matrix *kaldi_lookup_feature_repo_read_utterance(KaldiLookupFeatureRepo *repo,
                                                    KaldiFeatureRepo *frepo,
                                                    int nframes, lua_State *L,
                                                    int debug, MContext *context) {
        Matrix *mat;                    /* nerv implementation */
        StringToString_t::iterator mit = repo->map.find(frepo->utt);
        if (mit == repo->map.end())
            nerv_error(L, "[kaldi] mapped key for key %s not found", frepo->utt.c_str());
        const std::string &key = mit->second;
        StringToMatrix_t::iterator it = repo->key2mat.find(key);
        if (it == repo->key2mat.end())
            nerv_error(L, "[kaldi] lookup feature for key %s not found", key.c_str());
        const kaldi::Matrix<BaseFloat> &kmat = it->second;

        int n = kmat.NumRows() < nframes ? kmat.NumRows() : nframes;
        int m = kmat.NumCols();
        Status status;
        assert(sizeof(BaseFloat) == sizeof(float));
        if(sizeof(BaseFloat) == sizeof(float))
            mat = nerv_matrix_host_float_create(n, m, context, &status);
        else if(sizeof(BaseFloat) == sizeof(double))
            mat = nerv_matrix_host_double_create(n, m, context, &status);
        NERV_LUA_CHECK_STATUS(L, status);
        size_t stride = mat->stride;
        if (debug)
            fprintf(stderr, "[kaldi] lookup feature: %s %d %d\n", frepo->utt.c_str(), n, m);

        for (int i = 0; i < n; i++)
        {
            const BaseFloat *row = kmat.RowData(i);
            BaseFloat *nerv_row = (BaseFloat *)((char *)mat->data.f + i * stride);
            /* use memmove to copy the row, since KaldiLib uses compact storage */
            memmove(nerv_row, row, sizeof(BaseFloat) * m);
        }
        return mat;
    }

    void kaldi_lookup_feature_repo_destroy(KaldiLookupFeatureRepo *repo) {
        delete repo;
    }

    struct KaldiLabelRepo {
        kaldi::RandomAccessPosteriorReader *targets_reader;
    };

    KaldiLabelRepo *kaldi_label_repo_new(const char *targets_rspecifier, const char *fmt) {
        KaldiLabelRepo *repo = new KaldiLabelRepo();
        repo->targets_reader = new kaldi::RandomAccessPosteriorReader(string(targets_rspecifier));
        return repo;
    }

    Matrix *kaldi_label_repo_read_utterance(KaldiLabelRepo *repo, KaldiFeatureRepo *frepo, int nframes,
                                            lua_State *L,
                                            int debug, MContext *context) {
        Matrix *mat;
        kaldi::Posterior targets = repo->targets_reader->Value(frepo->utt);

        int n = targets.size() < nframes ? targets.size() : nframes;
        int m = (int)targets[0].size();

        Status status;
        assert(sizeof(BaseFloat) == sizeof(float));
        if(sizeof(BaseFloat) == sizeof(float))
            mat = nerv_matrix_host_float_create(n, m, context, &status);
        else if(sizeof(BaseFloat) == sizeof(double))
            mat = nerv_matrix_host_double_create(n, m, context, &status);
        NERV_LUA_CHECK_STATUS(L, status);
        size_t stride = mat->stride;
        if (debug)
            fprintf(stderr, "[kaldi] label: %s %d %d\n", frepo->utt.c_str(), n, m);
        for (int i = 0; i < n; i++)
            for(int j = 0; j < m; j++)
                *((BaseFloat *)((char *)mat->data.f + (i * stride + j))) = (BaseFloat)targets[i][j].first;
        return mat;
    }

    void kaldi_label_repo_destroy(KaldiLabelRepo *repo) {
        if(repo->targets_reader)
            delete repo->targets_reader;
        delete repo;
    }
}