summaryrefslogblamecommitdiff
path: root/kaldi_io/tools/nerv_to_kaldi.lua
blob: 804f09ba4d96081241dc822c72384aa8e473f5e1 (plain) (tree)

































































                                                                                                             
-- usage: nerv config_file nerv_param_input tnet_output

dofile(arg[1])
param_repo = nerv.ParamRepo()
param_repo:import({arg[2], gconf.initialized_param[2]}, nil, gconf)
layer_repo = make_layer_repo(param_repo)
f = assert(io.open(arg[3], "w"))

function print_tnet_matrix(cumat)
    local strs = {}
    collectgarbage()
    if cumat:nrow() == 1 then
        local mat = nerv.MMatrixFloat(1, cumat:ncol())
        cumat:copy_toh(mat)
        table.insert(strs, "[ ")
        for j = 0, mat:ncol() - 1 do
            table.insert(strs, string.format("%.8f ", mat[0][j]))
        end
        table.insert(strs, " ]\n")
        f:write(table.concat(strs))
    else
        cumat = cumat:trans()
        local mat = nerv.MMatrixFloat(cumat:nrow(), cumat:ncol())
        cumat:copy_toh(mat)
        table.insert(strs, string.format(" [\n", mat:nrow(), mat:ncol()))
        for i = 0, mat:nrow() - 1 do
            local row = mat[i]
            for j = 0, mat:ncol() - 1 do
                table.insert(strs, string.format("%.8f ", row[j]))
            end
            if i == mat:nrow() - 1 then
                table.insert(strs, " ]\n")
            else
                table.insert(strs, "\n")
            end
            f:write(table.concat(strs))
            strs = {}
        end
    end
end
local lnames = {"affine0", "sigmoid0",
                "affine1", "sigmoid1",
                "affine2", "sigmoid2",
                "affine3", "sigmoid3",
                "affine4", "sigmoid4",
                "affine5", "sigmoid5",
                "affine6", "ce_crit"}
f:write("<Nnet>\n")
for i, name in ipairs(lnames) do
    local layer = layer_repo:get_layer(name)
    local layer_type = layer.__typename
    if layer_type == "nerv.AffineLayer" then
        f:write(string.format("<AffineTransform> %d %d\n<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0",
            layer.dim_out[1], layer.dim_in[1]))
        print_tnet_matrix(layer.ltp.trans)
        print_tnet_matrix(layer.bp.trans)
    elseif layer_type == "nerv.SigmoidLayer" then
        f:write(string.format("<Sigmoid> %d %d\n", layer.dim_out[1], layer.dim_in[1]))
    elseif layer_type == "nerv.SoftmaxCELayer" then
        f:write(string.format("<Softmax> %d %d\n", layer.dim_in[1], layer.dim_in[1]))
    else
        nerv.error("unknown layer type %s", layer_type)
    end
end
f:write("</Nnet>\n")
f:close()