summaryrefslogtreecommitdiff
path: root/kaldi_seq/tools/net_kaldi2nerv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_seq/tools/net_kaldi2nerv.cpp')
-rw-r--r--kaldi_seq/tools/net_kaldi2nerv.cpp85
1 files changed, 85 insertions, 0 deletions
diff --git a/kaldi_seq/tools/net_kaldi2nerv.cpp b/kaldi_seq/tools/net_kaldi2nerv.cpp
new file mode 100644
index 0000000..bbac3db
--- /dev/null
+++ b/kaldi_seq/tools/net_kaldi2nerv.cpp
@@ -0,0 +1,85 @@
+#include <iostream>
+#include <cstdio>
+#include <cstring>
+#include <cstdlib>
+#include <cassert>
+using namespace std;
+
+const char fmt[] = "[%013d]\n";
+
+int main(int argc, char *argv[])
+{
+ if(argc < 3){
+ printf("Usage: %s kaldi_nnet nerv_output\n", argv[0]);
+ exit(0);
+ }
+
+ FILE *fin = fopen(argv[1], "r");
+ FILE *fout = fopen(argv[2], "w");
+
+ if(!fin || !fout){
+ printf("fopen error\n");
+ exit(1);
+ }
+
+ char buf[1024], tag[64];
+ int a, b;
+ char ***arr;
+ long start, size;
+ int affine_ltp = 0, affine_bp = 0;
+
+ while(fgets(buf, 1024, fin)){
+ if(sscanf(buf, "%s%d%d", tag, &b, &a) == 3 && strcmp(tag, "<AffineTransform>") == 0){
+ fgets(buf, 1024, fin);
+ arr = new char**[a];
+ for(int i = 0; i < a; i++)
+ arr[i] = new char*[b];
+ for(int j = 0; j < b; j++)
+ for(int i = 0; i < a; i++){
+ arr[i][j] = new char[16];
+ fscanf(fin, "%s", arr[i][j]);
+ }
+
+ start = ftell(fout);
+ fprintf(fout, fmt, 0);
+ fprintf(fout, "{type=\"nerv.LinearTransParam\",id=\"affine%d_ltp\"}\n", affine_ltp++);
+ fprintf(fout, "%d %d\n", a, b);
+ for(int i = 0; i < a; i++){
+ for(int j = 0; j < b; j++){
+ fprintf(fout, "%s ", arr[i][j]);
+ delete [] arr[i][j];
+ }
+ fprintf(fout, "\n");
+ delete [] arr[i];
+ }
+ delete [] arr;
+
+ size = ftell(fout) - start;
+ fseek(fout, start, SEEK_SET);
+ fprintf(fout, fmt, (int)size);
+ fseek(fout, 0, SEEK_END);
+
+ fgets(buf, 1024, fin);
+ fscanf(fin, "%*s");
+
+ start = ftell(fout);
+ fprintf(fout, fmt, 0);
+ fprintf(fout, "{type=\"nerv.BiasParam\",id=\"affine%d_bp\"}\n", affine_bp++);
+ fprintf(fout, "%d %d\n", 1, b);
+ for(int i = 0; i < b; i++){
+ fscanf(fin, "%s", buf);
+ fprintf(fout, "%s ", buf);
+ }
+ fputs("\n", fout);
+ size = ftell(fout) - start;
+ fseek(fout, start, SEEK_SET);
+ fprintf(fout, fmt, (int)size);
+ fseek(fout, 0, SEEK_END);
+ }
+ }
+
+ fclose(fin);
+ fclose(fout);
+
+ return 0;
+}