diff options
Diffstat (limited to 'kaldi_seq/tools/net_kaldi2nerv.cpp')
-rw-r--r-- | kaldi_seq/tools/net_kaldi2nerv.cpp | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/kaldi_seq/tools/net_kaldi2nerv.cpp b/kaldi_seq/tools/net_kaldi2nerv.cpp new file mode 100644 index 0000000..bbac3db --- /dev/null +++ b/kaldi_seq/tools/net_kaldi2nerv.cpp @@ -0,0 +1,85 @@ +#include <iostream> +#include <cstdio> +#include <cstring> +#include <cstdlib> +#include <cassert> +using namespace std; + +const char fmt[] = "[%013d]\n"; + +int main(int argc, char *argv[]) +{ + if(argc < 3){ + printf("Usage: %s kaldi_nnet nerv_output\n", argv[0]); + exit(0); + } + + FILE *fin = fopen(argv[1], "r"); + FILE *fout = fopen(argv[2], "w"); + + if(!fin || !fout){ + printf("fopen error\n"); + exit(1); + } + + char buf[1024], tag[64]; + int a, b; + char ***arr; + long start, size; + int affine_ltp = 0, affine_bp = 0; + + while(fgets(buf, 1024, fin)){ + if(sscanf(buf, "%s%d%d", tag, &b, &a) == 3 && strcmp(tag, "<AffineTransform>") == 0){ + fgets(buf, 1024, fin); + arr = new char**[a]; + for(int i = 0; i < a; i++) + arr[i] = new char*[b]; + for(int j = 0; j < b; j++) + for(int i = 0; i < a; i++){ + arr[i][j] = new char[16]; + fscanf(fin, "%s", arr[i][j]); + } + + start = ftell(fout); + fprintf(fout, fmt, 0); + fprintf(fout, "{type=\"nerv.LinearTransParam\",id=\"affine%d_ltp\"}\n", affine_ltp++); + fprintf(fout, "%d %d\n", a, b); + for(int i = 0; i < a; i++){ + for(int j = 0; j < b; j++){ + fprintf(fout, "%s ", arr[i][j]); + delete [] arr[i][j]; + } + fprintf(fout, "\n"); + delete [] arr[i]; + } + delete [] arr; + + size = ftell(fout) - start; + fseek(fout, start, SEEK_SET); + fprintf(fout, fmt, (int)size); + fseek(fout, 0, SEEK_END); + + fgets(buf, 1024, fin); + fscanf(fin, "%*s"); + + start = ftell(fout); + fprintf(fout, fmt, 0); + fprintf(fout, "{type=\"nerv.BiasParam\",id=\"affine%d_bp\"}\n", affine_bp++); + fprintf(fout, "%d %d\n", 1, b); + for(int i = 0; i < b; i++){ + fscanf(fin, "%s", buf); + fprintf(fout, "%s ", buf); + } + fputs("\n", fout); + size = ftell(fout) - start; + fseek(fout, start, SEEK_SET); + fprintf(fout, fmt, (int)size); + fseek(fout, 0, SEEK_END); + } + } + + fclose(fin); + fclose(fout); + + return 0; +} |