summaryrefslogtreecommitdiff
path: root/kaldi_decode/src/nerv4decode.lua
blob: 898b5a82a98b5cfa0b7dd19995af4a129306040f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
print = function(...) io.write(table.concat({...}, "\t")) end
io.output('/dev/null')
-- path and cpath are correctly set by `path.sh`
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
    local c = cls._default_context
    cls.print_profile = function () c:print_profile() end
    cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.MMatrix)

function build_trainer(ifname, feature)
    local param_repo = nerv.ParamRepo()
    param_repo:import(ifname, nil, gconf)
    local layer_repo = make_layer_repo(param_repo)
    local network = get_decode_network(layer_repo)
    local global_transf = get_global_transf(layer_repo)
    local input_order = get_input_order()
    local readers = make_readers(feature, layer_repo)
    network:init(1)

    local iterative_trainer = function()
        local data = nil
        for ri = 1, #readers, 1 do
            data = readers[ri].reader:get_data()
            if data ~= nil then
                break
            end
        end

        if data == nil then
            return "", nil
        end

        local input = {}
        for i, e in ipairs(input_order) do
            local id = e.id
            if data[id] == nil then
                nerv.error("input data %s not found", id)
            end
            local transformed
            if e.global_transf then
                transformed = nerv.speech_utils.global_transf(data[id],
                global_transf,
                gconf.frm_ext or 0, 0,
                gconf)
            else
                transformed = data[id]
            end
            table.insert(input, transformed)
        end
        local output = {nerv.MMatrixFloat(input[1]:nrow(), network.dim_out[1])}
        network:batch_resize(input[1]:nrow())
        network:propagate(input, output)
        
        local utt = data["key"]
        if utt == nil then
            nerv.error("no key found.")
        end

        local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol())
        output[1]:copy_toh(mat)

        collectgarbage("collect")
        return utt, mat
    end

    return iterative_trainer
end

function init(config, feature)
    dofile(config)
    gconf.use_cpu = true -- use CPU to decode
    trainer = build_trainer(gconf.decode_param, feature)
end

function feed()
    local utt, mat = trainer()
    return utt, mat
end