summaryrefslogtreecommitdiff
path: root/kaldi_seq/tools/net_kaldi2nerv.cpp
blob: bbac3dbf33a56f5406d16abd0996001f71c58d3e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cassert>
using namespace std;

const char fmt[] = "[%013d]\n";

int main(int argc, char *argv[])
{
    if(argc < 3){
        printf("Usage: %s kaldi_nnet nerv_output\n", argv[0]);
        exit(0);
    }

    FILE *fin = fopen(argv[1], "r");
    FILE *fout = fopen(argv[2], "w");

    if(!fin || !fout){
        printf("fopen error\n");
        exit(1);
    }

    char buf[1024], tag[64];
    int a, b;
    char ***arr;
    long start, size;
    int affine_ltp = 0, affine_bp = 0;

    while(fgets(buf, 1024, fin)){
        if(sscanf(buf, "%s%d%d", tag, &b, &a) == 3 && strcmp(tag, "<AffineTransform>") == 0){
            fgets(buf, 1024, fin);
            arr = new char**[a];
            for(int i = 0; i < a; i++)
                arr[i] = new char*[b];
            for(int j = 0; j < b; j++)
                for(int i = 0; i < a; i++){
                    arr[i][j] = new char[16];
                    fscanf(fin, "%s", arr[i][j]);
                }

            start = ftell(fout);
            fprintf(fout, fmt, 0);
            fprintf(fout, "{type=\"nerv.LinearTransParam\",id=\"affine%d_ltp\"}\n", affine_ltp++);
            fprintf(fout, "%d %d\n", a, b);
            for(int i = 0; i < a; i++){
                for(int j = 0; j < b; j++){
                    fprintf(fout, "%s ", arr[i][j]);
                    delete [] arr[i][j];
                }
                fprintf(fout, "\n");
                delete [] arr[i];
            }
            delete [] arr;

            size = ftell(fout) - start;
            fseek(fout, start, SEEK_SET);
            fprintf(fout, fmt, (int)size);
            fseek(fout, 0, SEEK_END);

            fgets(buf, 1024, fin);
            fscanf(fin, "%*s");

            start = ftell(fout);
            fprintf(fout, fmt, 0);
            fprintf(fout, "{type=\"nerv.BiasParam\",id=\"affine%d_bp\"}\n", affine_bp++);
            fprintf(fout, "%d %d\n", 1, b);
            for(int i = 0; i < b; i++){
                fscanf(fin, "%s", buf);
                fprintf(fout, "%s ", buf);
            }
            fputs("\n", fout);
            size = ftell(fout) - start;
            fseek(fout, start, SEEK_SET);
            fprintf(fout, fmt, (int)size);
            fseek(fout, 0, SEEK_END);
        }
    }

    fclose(fin);
    fclose(fout);

    return 0;
}