summaryrefslogblamecommitdiff
path: root/htk_io/tools/nerv_to_tnet.lua
blob: 17ff3db31bf197000f296f471f6310421bad6864 (plain) (tree)
1
2
3
4
5
6
                                                                              

              
                                   
                             
                                       






















































                                                                                              
-- usage: nerv nerv_to_tnet.lua config_file nerv_param_input tnet_param_output

dofile(arg[1])
gconf.mmat_type = nerv.MMatrixFloat
param_repo = nerv.ParamRepo()
param_repo:import({arg[2]}, nil, gconf)
layer_repo = make_layer_repo(param_repo)
f = assert(io.open(arg[3], "w"))

function print_tnet_matrix(cumat)
    local strs = {}
    collectgarbage()
    if cumat:nrow() == 1 then
        local mat = nerv.MMatrixFloat(1, cumat:ncol())
        cumat:copy_toh(mat)
        table.insert(strs, string.format("v %d\n", mat:ncol()))
        for j = 0, mat:ncol() - 1 do
            table.insert(strs, string.format("%.8f ", mat[0][j]))
        end
        table.insert(strs, "\n")
        f:write(table.concat(strs))
    else
        cumat = cumat:trans()
        local mat = nerv.MMatrixFloat(cumat:nrow(), cumat:ncol())
        cumat:copy_toh(mat)
        table.insert(strs, string.format("m %d %d\n", mat:nrow(), mat:ncol()))
        for i = 0, mat:nrow() - 1 do
            local row = mat[i]
            for j = 0, mat:ncol() - 1 do
                table.insert(strs, string.format("%.8f ", row[j]))
            end
            table.insert(strs, "\n")
            f:write(table.concat(strs))
            strs = {}
        end
    end
end
local lnames = {"affine0", "sigmoid0",
                "affine1", "sigmoid1",
                "affine2", "sigmoid2",
                "affine3", "sigmoid3",
                "affine4", "sigmoid4",
                "affine5", "sigmoid5",
                "affine6", "sigmoid6",
                "affine7", "ce_crit"}
for i, name in ipairs(lnames) do
    local layer = layer_repo:get_layer(name)
    local layer_type = layer.__typename
    if layer_type == "nerv.AffineLayer" then
        f:write(string.format("<biasedlinearity> %d %d\n", layer.dim_out[1], layer.dim_in[1]))
        print_tnet_matrix(layer.ltp.trans)
        print_tnet_matrix(layer.bp.trans)
    elseif layer_type == "nerv.SigmoidLayer" then
        f:write(string.format("<sigmoid> %d %d\n", layer.dim_out[1], layer.dim_in[1]))
    elseif layer_type == "nerv.SoftmaxCELayer" then
        f:write(string.format("<softmax> %d %d\n", layer.dim_in[1], layer.dim_in[1]))
    else
        nerv.error("unknown layer type %s", layer_type)
    end
end
f:close()