summaryrefslogtreecommitdiff
path: root/kaldi_io/tools
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-08-14 13:54:51 +0800
committerDeterminant <[email protected]>2015-08-14 13:54:51 +0800
commit10cce5f6a5c9e2f8e00d5a2a4d87c9cb7c26bf4c (patch)
treee417bc520e78e749df39652aa61ae29a76957c76 /kaldi_io/tools
parent96a32415ab43377cf1575bd3f4f2980f58028209 (diff)
add example script for converting to kaldi nnet
Diffstat (limited to 'kaldi_io/tools')
-rw-r--r--kaldi_io/tools/nerv_to_kaldi.lua66
1 files changed, 66 insertions, 0 deletions
diff --git a/kaldi_io/tools/nerv_to_kaldi.lua b/kaldi_io/tools/nerv_to_kaldi.lua
new file mode 100644
index 0000000..804f09b
--- /dev/null
+++ b/kaldi_io/tools/nerv_to_kaldi.lua
@@ -0,0 +1,66 @@
+-- usage: nerv config_file nerv_param_input tnet_output
+
+dofile(arg[1])
+param_repo = nerv.ParamRepo()
+param_repo:import({arg[2], gconf.initialized_param[2]}, nil, gconf)
+layer_repo = make_layer_repo(param_repo)
+f = assert(io.open(arg[3], "w"))
+
+function print_tnet_matrix(cumat)
+ local strs = {}
+ collectgarbage()
+ if cumat:nrow() == 1 then
+ local mat = nerv.MMatrixFloat(1, cumat:ncol())
+ cumat:copy_toh(mat)
+ table.insert(strs, "[ ")
+ for j = 0, mat:ncol() - 1 do
+ table.insert(strs, string.format("%.8f ", mat[0][j]))
+ end
+ table.insert(strs, " ]\n")
+ f:write(table.concat(strs))
+ else
+ cumat = cumat:trans()
+ local mat = nerv.MMatrixFloat(cumat:nrow(), cumat:ncol())
+ cumat:copy_toh(mat)
+ table.insert(strs, string.format(" [\n", mat:nrow(), mat:ncol()))
+ for i = 0, mat:nrow() - 1 do
+ local row = mat[i]
+ for j = 0, mat:ncol() - 1 do
+ table.insert(strs, string.format("%.8f ", row[j]))
+ end
+ if i == mat:nrow() - 1 then
+ table.insert(strs, " ]\n")
+ else
+ table.insert(strs, "\n")
+ end
+ f:write(table.concat(strs))
+ strs = {}
+ end
+ end
+end
+local lnames = {"affine0", "sigmoid0",
+ "affine1", "sigmoid1",
+ "affine2", "sigmoid2",
+ "affine3", "sigmoid3",
+ "affine4", "sigmoid4",
+ "affine5", "sigmoid5",
+ "affine6", "ce_crit"}
+f:write("<Nnet>\n")
+for i, name in ipairs(lnames) do
+ local layer = layer_repo:get_layer(name)
+ local layer_type = layer.__typename
+ if layer_type == "nerv.AffineLayer" then
+ f:write(string.format("<AffineTransform> %d %d\n<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0",
+ layer.dim_out[1], layer.dim_in[1]))
+ print_tnet_matrix(layer.ltp.trans)
+ print_tnet_matrix(layer.bp.trans)
+ elseif layer_type == "nerv.SigmoidLayer" then
+ f:write(string.format("<Sigmoid> %d %d\n", layer.dim_out[1], layer.dim_in[1]))
+ elseif layer_type == "nerv.SoftmaxCELayer" then
+ f:write(string.format("<Softmax> %d %d\n", layer.dim_in[1], layer.dim_in[1]))
+ else
+ nerv.error("unknown layer type %s", layer_type)
+ end
+end
+f:write("</Nnet>\n")
+f:close()