summaryrefslogtreecommitdiff
path: root/kaldi_io/tools
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/tools')
-rw-r--r--kaldi_io/tools/kaldi_to_nerv.cpp109
1 files changed, 109 insertions, 0 deletions
diff --git a/kaldi_io/tools/kaldi_to_nerv.cpp b/kaldi_io/tools/kaldi_to_nerv.cpp
new file mode 100644
index 0000000..1edb0f2
--- /dev/null
+++ b/kaldi_io/tools/kaldi_to_nerv.cpp
@@ -0,0 +1,109 @@
+#include <cstdio>
+#include <fstream>
+#include <string>
+#include <cstring>
+#include <cassert>
+
+char token[1024];
+char output[1024];
+double mat[4096][4096];
+int main(int argc, char **argv) {
+ std::ofstream fout;
+ fout.open(argv[1]);
+ int cnt = 0;
+ bool shift;
+ while (scanf("%s", token) != EOF)
+ {
+ int nrow, ncol;
+ int i, j;
+ if (strcmp(token, "<AffineTransform>") == 0)
+ {
+ double lrate, blrate, mnorm;
+ scanf("%d %d", &ncol, &nrow);
+ scanf("%s %lf %s %lf %s %lf",
+ token, &lrate, token, &blrate, token, &mnorm);
+ scanf("%s", token);
+ assert(*token == '[');
+ printf("%d %d\n", nrow, ncol);
+ for (j = 0; j < ncol; j++)
+ for (i = 0; i < nrow; i++)
+ scanf("%lf", mat[i] + j);
+ long base = fout.tellp();
+ sprintf(output, "%16d", 0);
+ fout << output;
+ sprintf(output, "{type=\"nerv.LinearTransParam\",id=\"affine%d_ltp\"}\n",
+ cnt);
+ fout << output;
+ sprintf(output, "%d %d\n", nrow, ncol);
+ fout << output;
+ for (i = 0; i < nrow; i++)
+ {
+ for (j = 0; j < ncol; j++)
+ fout << mat[i][j] << " ";
+ fout << std::endl;
+ }
+ long length = fout.tellp() - base;
+ fout.seekp(base);
+ sprintf(output, "[%13lu]\n", length);
+ fout << output;
+ fout.seekp(0, std::ios_base::end);
+ scanf("%s", token);
+ assert(*token == ']');
+ if (scanf("%s", token) == 1 && *token == '[')
+ {
+ base = fout.tellp();
+ for (j = 0; j < ncol; j++)
+ scanf("%lf", mat[0] + j);
+ sprintf(output, "%16d", 0);
+ fout << output;
+ sprintf(output, "{type=\"nerv.BiasParam\",id=\"affine%d_bp\"}\n",
+ cnt);
+ fout << output;
+ sprintf(output, "1 %d\n", ncol);
+ fout << output;
+ for (j = 0; j < ncol; j++)
+ fout << mat[0][j] << " ";
+ fout << std::endl;
+ length = fout.tellp() - base;
+ fout.seekp(base);
+ sprintf(output, "[%13lu]\n", length);
+ fout << output;
+ fout.seekp(0, std::ios_base::end);
+ cnt++;
+ }
+ }
+ else if ((shift = (strcmp(token, "<AddShift>") == 0)) ||
+ strcmp(token, "<Rescale>") == 0)
+ {
+ double lrate, blrate, mnorm;
+ scanf("%d %d", &ncol, &ncol);
+ scanf("%s %lf",
+ token, &lrate);
+ scanf("%s", token);
+ assert(*token == '[');
+ printf("%d\n", ncol);
+ for (j = 0; j < ncol; j++)
+ scanf("%lf", mat[0] + j);
+ long base = fout.tellp();
+ sprintf(output, "%16d", 0);
+ fout << output;
+ sprintf(output, "{type=\"nerv.BiasParam\",id=\"%s%d\"}\n",
+ shift ? "bias" : "window",
+ cnt);
+ fout << output;
+ sprintf(output, "%d %d\n", 1, ncol);
+ fout << output;
+ for (j = 0; j < ncol; j++)
+ fout << mat[0][j] << " ";
+ fout << std::endl;
+ long length = fout.tellp() - base;
+ fout.seekp(base);
+ sprintf(output, "[%13lu]\n", length);
+ fout << output;
+ fout.seekp(0, std::ios_base::end);
+ scanf("%s", token);
+ assert(*token == ']');
+ }
+ }
+ return 0;
+}