summaryrefslogtreecommitdiff
path: root/kaldi_seq/tools/transf_kaldi2nerv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_seq/tools/transf_kaldi2nerv.cpp')
-rw-r--r--kaldi_seq/tools/transf_kaldi2nerv.cpp106
1 files changed, 106 insertions, 0 deletions
diff --git a/kaldi_seq/tools/transf_kaldi2nerv.cpp b/kaldi_seq/tools/transf_kaldi2nerv.cpp
new file mode 100644
index 0000000..525bcda
--- /dev/null
+++ b/kaldi_seq/tools/transf_kaldi2nerv.cpp
@@ -0,0 +1,106 @@
+#include <iostream>
+#include <cstdio>
+#include <cstring>
+#include <cstdlib>
+#include <cassert>
+using namespace std;
+
+const char fmt[] = "[%013d]\n";
+
+int main(int argc, char *argv[])
+{
+ if(argc < 3){
+ printf("Usage: %s kaldi_transf nerv_output\n", argv[0]);
+ exit(1);
+ }
+
+ FILE *fin = fopen(argv[1], "r");
+ FILE *fout = fopen(argv[2], "w");
+ if(!fin || !fout){
+ puts("fopen error");
+ exit(1);
+ }
+
+ char buf[1024], tag[64];
+ int a, b;
+ int size_window, size_bias;
+ char **window, **bias;
+
+ while(fgets(buf, sizeof(buf), fin))
+ {
+ if(sscanf(buf, "%s%d%d", tag, &a, &b) == 3){
+ if(strcmp(tag, "<AddShift>") == 0){
+ assert(a == b);
+ size_bias = a;
+ fscanf(fin, "%*s%*s%*s");
+ bias = new char*[size_bias];
+ for(int i = 0; i < size_bias; i++){
+ bias[i] = new char[16];
+ fscanf(fin, "%s", bias[i]);
+ }
+ } else if(strcmp(tag, "<Rescale>") == 0){
+ assert(a == b);
+ size_window = a;
+ fscanf(fin, "%*s%*s%*s");
+ window = new char*[size_window];
+ for(int i = 0; i < size_window; i++){
+ window[i] = new char[16];
+ fscanf(fin, "%s", window[i]);
+ }
+ }
+ }
+ }
+
+ long start = ftell(fout), size;
+ fprintf(fout, fmt, 0);
+ fprintf(fout, "{id = \"bias1\", type = \"nerv.MatrixParam\"}\n");
+ fprintf(fout, "1 %d\n", size_bias);
+ for(int i = 0; i<size_bias; i++)
+ fprintf(fout, "0 ");
+ fputs("\n", fout);
+ size = ftell(fout) - start;
+ fseek(fout, start, SEEK_SET);
+ fprintf(fout, fmt, (int)size);
+ fseek(fout, 0, SEEK_END);
+
+ start = ftell(fout);
+ fprintf(fout, fmt, 0);
+ fprintf(fout, "{id = \"window1\", type = \"nerv.MatrixParam\"}\n");
+ fprintf(fout, "1 %d\n", size_window);
+ for(int i = 0; i<size_window; i++)
+ fprintf(fout, "1 ");
+ fputs("\n", fout);
+ size = ftell(fout) - start;
+ fseek(fout, start, SEEK_SET);
+ fprintf(fout, fmt, (int)size);
+ fseek(fout, 0, SEEK_END);
+
+ start = ftell(fout);
+ fprintf(fout, fmt, 0);
+ fprintf(fout, "{id = \"bias2\", type = \"nerv.MatrixParam\"}\n");
+ fprintf(fout, "1 %d\n", size_bias);
+ for(int i = 0; i<size_bias; i++)
+ fprintf(fout, "%s ", bias[i]);
+ fputs("\n", fout);
+ size = ftell(fout) - start;
+ fseek(fout, start, SEEK_SET);
+ fprintf(fout, fmt, (int)size);
+ fseek(fout, 0, SEEK_END);
+
+ start = ftell(fout);
+ fprintf(fout, fmt, 0);
+ fprintf(fout, "{id = \"window2\", type = \"nerv.MatrixParam\"}\n");
+ fprintf(fout, "1 %d\n", size_window);
+ for(int i = 0; i<size_window; i++)
+ fprintf(fout, "%s ", window[i]);
+ fputs("\n", fout);
+ size = ftell(fout) - start;
+ fseek(fout, start, SEEK_SET);
+ fprintf(fout, fmt, (int)size);
+ fseek(fout, 0, SEEK_END);
+
+ fclose(fin);
+ fclose(fout);
+
+ return 0;
+}