summaryrefslogtreecommitdiff
path: root/kaldi_decode/src/nerv4decode.lua
blob: b2ff344796ab553960f3fa456b2decb0235fe82d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package.path="/home/slhome/ymz09/.luarocks/share/lua/5.1/?.lua;/home/slhome/ymz09/.luarocks/share/lua/5.1/?/init.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?/init.lua;"..package.path; 
package.cpath="/home/slhome/ymz09/.luarocks/lib/lua/5.1/?.so;/slfs6/users/ymz09/nerv-project/nerv/install/lib/lua/5.1/?.so;"..package.cpath;
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'

function build_trainer(ifname, feature)
    local param_repo = nerv.ParamRepo()
    param_repo:import(ifname, nil, gconf)
    local layer_repo = make_layer_repo(param_repo)
    local network = get_decode_network(layer_repo)
    local global_transf = get_global_transf(layer_repo)
    local input_order = get_input_order()
    local readers = make_readers(feature, layer_repo)
    network:init(1)

    local iterative_trainer = function()
        local data = nil
        for ri = 1, #readers, 1 do
            data = readers[ri].reader:get_data()
            if data ~= nil then
                break
            end
        end

        if data == nil then
            return "", nil
        end

        local input = {}
        for i, e in ipairs(input_order) do
            local id = e.id
            if data[id] == nil then
                nerv.error("input data %s not found", id)
            end
            local transformed
            if e.global_transf then
                local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol())
                batch:copy_fromh(data[id])
                transformed = nerv.speech_utils.global_transf(batch,
                global_transf,
                gconf.frm_ext or 0, 0,
                gconf)
            else
                transformed = data[id]
            end
            table.insert(input, transformed)
        end
        local output = {nerv.CuMatrixFloat(input[1]:nrow(), network.dim_out[1])}
        network:batch_resize(input[1]:nrow())
        network:propagate(input, output)
        
        local utt = data["key"]
        if utt == nil then
            nerv.error("no key found.")
        end

        local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol())
        output[1]:copy_toh(mat)

        collectgarbage("collect")
        return utt, mat
    end

    return iterative_trainer
end

function init(config, feature)
    local tmp = io.write
    io.write = function(...)
    end
    dofile(config)
    trainer = build_trainer(gconf.decode_param, feature)
    io.write = tmp
end

function feed()
    local utt, mat = trainer()
    return utt, mat
end