summaryrefslogtreecommitdiff
path: root/kaldi_io/tools
diff options
context:
space:
mode:
Diffstat (limited to 'kaldi_io/tools')
-rw-r--r--kaldi_io/tools/kaldi_to_nerv.cpp109
-rw-r--r--kaldi_io/tools/nerv_to_kaldi.lua66
2 files changed, 175 insertions, 0 deletions
diff --git a/kaldi_io/tools/kaldi_to_nerv.cpp b/kaldi_io/tools/kaldi_to_nerv.cpp
new file mode 100644
index 0000000..1edb0f2
--- /dev/null
+++ b/kaldi_io/tools/kaldi_to_nerv.cpp
@@ -0,0 +1,109 @@
+#include <cstdio>
+#include <fstream>
+#include <string>
+#include <cstring>
+#include <cassert>
+
+char token[1024];
+char output[1024];
+double mat[4096][4096];
+int main(int argc, char **argv) {
+ std::ofstream fout;
+ fout.open(argv[1]);
+ int cnt = 0;
+ bool shift;
+ while (scanf("%s", token) != EOF)
+ {
+ int nrow, ncol;
+ int i, j;
+ if (strcmp(token, "<AffineTransform>") == 0)
+ {
+ double lrate, blrate, mnorm;
+ scanf("%d %d", &ncol, &nrow);
+ scanf("%s %lf %s %lf %s %lf",
+ token, &lrate, token, &blrate, token, &mnorm);
+ scanf("%s", token);
+ assert(*token == '[');
+ printf("%d %d\n", nrow, ncol);
+ for (j = 0; j < ncol; j++)
+ for (i = 0; i < nrow; i++)
+ scanf("%lf", mat[i] + j);
+ long base = fout.tellp();
+ sprintf(output, "%16d", 0);
+ fout << output;
+ sprintf(output, "{type=\"nerv.LinearTransParam\",id=\"affine%d_ltp\"}\n",
+ cnt);
+ fout << output;
+ sprintf(output, "%d %d\n", nrow, ncol);
+ fout << output;
+ for (i = 0; i < nrow; i++)
+ {
+ for (j = 0; j < ncol; j++)
+ fout << mat[i][j] << " ";
+ fout << std::endl;
+ }
+ long length = fout.tellp() - base;
+ fout.seekp(base);
+ sprintf(output, "[%13lu]\n", length);
+ fout << output;
+ fout.seekp(0, std::ios_base::end);
+ scanf("%s", token);
+ assert(*token == ']');
+ if (scanf("%s", token) == 1 && *token == '[')
+ {
+ base = fout.tellp();
+ for (j = 0; j < ncol; j++)
+ scanf("%lf", mat[0] + j);
+ sprintf(output, "%16d", 0);
+ fout << output;
+ sprintf(output, "{type=\"nerv.BiasParam\",id=\"affine%d_bp\"}\n",
+ cnt);
+ fout << output;
+ sprintf(output, "1 %d\n", ncol);
+ fout << output;
+ for (j = 0; j < ncol; j++)
+ fout << mat[0][j] << " ";
+ fout << std::endl;
+ length = fout.tellp() - base;
+ fout.seekp(base);
+ sprintf(output, "[%13lu]\n", length);
+ fout << output;
+ fout.seekp(0, std::ios_base::end);
+ cnt++;
+ }
+ }
+ else if ((shift = (strcmp(token, "<AddShift>") == 0)) ||
+ strcmp(token, "<Rescale>") == 0)
+ {
+ double lrate, blrate, mnorm;
+ scanf("%d %d", &ncol, &ncol);
+ scanf("%s %lf",
+ token, &lrate);
+ scanf("%s", token);
+ assert(*token == '[');
+ printf("%d\n", ncol);
+ for (j = 0; j < ncol; j++)
+ scanf("%lf", mat[0] + j);
+ long base = fout.tellp();
+ sprintf(output, "%16d", 0);
+ fout << output;
+ sprintf(output, "{type=\"nerv.BiasParam\",id=\"%s%d\"}\n",
+ shift ? "bias" : "window",
+ cnt);
+ fout << output;
+ sprintf(output, "%d %d\n", 1, ncol);
+ fout << output;
+ for (j = 0; j < ncol; j++)
+ fout << mat[0][j] << " ";
+ fout << std::endl;
+ long length = fout.tellp() - base;
+ fout.seekp(base);
+ sprintf(output, "[%13lu]\n", length);
+ fout << output;
+ fout.seekp(0, std::ios_base::end);
+ scanf("%s", token);
+ assert(*token == ']');
+ }
+ }
+ return 0;
+}
diff --git a/kaldi_io/tools/nerv_to_kaldi.lua b/kaldi_io/tools/nerv_to_kaldi.lua
new file mode 100644
index 0000000..804f09b
--- /dev/null
+++ b/kaldi_io/tools/nerv_to_kaldi.lua
@@ -0,0 +1,66 @@
+-- usage: nerv config_file nerv_param_input tnet_output
+
+dofile(arg[1])
+param_repo = nerv.ParamRepo()
+param_repo:import({arg[2], gconf.initialized_param[2]}, nil, gconf)
+layer_repo = make_layer_repo(param_repo)
+f = assert(io.open(arg[3], "w"))
+
+function print_tnet_matrix(cumat)
+ local strs = {}
+ collectgarbage()
+ if cumat:nrow() == 1 then
+ local mat = nerv.MMatrixFloat(1, cumat:ncol())
+ cumat:copy_toh(mat)
+ table.insert(strs, "[ ")
+ for j = 0, mat:ncol() - 1 do
+ table.insert(strs, string.format("%.8f ", mat[0][j]))
+ end
+ table.insert(strs, " ]\n")
+ f:write(table.concat(strs))
+ else
+ cumat = cumat:trans()
+ local mat = nerv.MMatrixFloat(cumat:nrow(), cumat:ncol())
+ cumat:copy_toh(mat)
+ table.insert(strs, string.format(" [\n", mat:nrow(), mat:ncol()))
+ for i = 0, mat:nrow() - 1 do
+ local row = mat[i]
+ for j = 0, mat:ncol() - 1 do
+ table.insert(strs, string.format("%.8f ", row[j]))
+ end
+ if i == mat:nrow() - 1 then
+ table.insert(strs, " ]\n")
+ else
+ table.insert(strs, "\n")
+ end
+ f:write(table.concat(strs))
+ strs = {}
+ end
+ end
+end
+local lnames = {"affine0", "sigmoid0",
+ "affine1", "sigmoid1",
+ "affine2", "sigmoid2",
+ "affine3", "sigmoid3",
+ "affine4", "sigmoid4",
+ "affine5", "sigmoid5",
+ "affine6", "ce_crit"}
+f:write("<Nnet>\n")
+for i, name in ipairs(lnames) do
+ local layer = layer_repo:get_layer(name)
+ local layer_type = layer.__typename
+ if layer_type == "nerv.AffineLayer" then
+ f:write(string.format("<AffineTransform> %d %d\n<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0",
+ layer.dim_out[1], layer.dim_in[1]))
+ print_tnet_matrix(layer.ltp.trans)
+ print_tnet_matrix(layer.bp.trans)
+ elseif layer_type == "nerv.SigmoidLayer" then
+ f:write(string.format("<Sigmoid> %d %d\n", layer.dim_out[1], layer.dim_in[1]))
+ elseif layer_type == "nerv.SoftmaxCELayer" then
+ f:write(string.format("<Softmax> %d %d\n", layer.dim_in[1], layer.dim_in[1]))
+ else
+ nerv.error("unknown layer type %s", layer_type)
+ end
+end
+f:write("</Nnet>\n")
+f:close()