summaryrefslogtreecommitdiff
path: root/kaldi_decode/src/asr_propagator.lua
blob: ab18d6dcfb40b67524d8f66eb9848eca1a8ce77d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
print = function(...) io.write(table.concat({...}, "\t")) end
io.output('/dev/null')
-- path and cpath are correctly set by `path.sh`
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
    local c = cls._default_context
    cls.print_profile = function () c:print_profile() end
    cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.MMatrix)

function build_propagator(ifname, feature)
    -- FIXME: this is still a hack
    local trainer = nerv.Trainer
    ----
    local param_repo = nerv.ParamRepo()
    param_repo:import(ifname, gconf)
    local layer_repo = trainer.make_layer_repo(nil, param_repo)
    local network = trainer.get_decode_network(nil, layer_repo)
    local input_order = trainer.get_decode_input_order(nil)
    local input_name = gconf.decode_input_name or "main_scp"
    local readers = trainer.make_decode_readers(nil, feature)
    -- nerv.info("prepare")
    local buffer = nerv.SeqBuffer(gconf, {
                                        buffer_size = gconf.buffer_size,
                                        batch_size = gconf.batch_size,
                                        chunk_size = gconf.chunk_size,
                                        randomize = gconf.randomize,
                                        readers = readers,
                                })

    network = nerv.Network("nt", gconf, {network = network})
    network:init(gconf.batch_size, gconf.chunk_size)

    local prev_data = buffer:get_data() or nerv.error("no data in buffer")
    local terminate = false
    local input_pos = nil
    for i, v in ipairs(input_order) do
        if v == input_name then
            input_pos = i
        end
    end
    if input_pos == nil then
        nerv.error("input name %s not found in the input order list", input_name)
    end
    local batch_propagator = function()
        if terminate then
            return "", nil
        end
        network:epoch_init()
        local accu_output = {}
        local utt_id = readers[input_pos].reader.key
        if utt_id == nil then
            nerv.error("no key found.")
        end
        while true do
            local d
            if prev_data then
                d = prev_data
                prev_data = nil
            else
                d = buffer:get_data()
                if d == nil then
                    terminate = true
                    break
                elseif #d.new_seq > 0 then
                    prev_data = d -- the first data of the next utterance
                    break
                end
            end

            local input = {}
            local output = {{}}
            for i, id in ipairs(input_order) do
                if d.data[id] == nil then
                    nerv.error("input data %s not found", id)
                end
                table.insert(input, d.data[id])
                for i = 1, gconf.chunk_size do
                    table.insert(output[1], gconf.mmat_type(gconf.batch_size, network.dim_out[1]))
                end
            end
            --nerv.info("input num: %d\nmat: %s\n", #input[1], input[1][1])
            --nerv.info("output num: %d\nmat: %s\n", #output[1], output[1][1])
            network:mini_batch_init({seq_length = d.seq_length,
                                    new_seq = d.new_seq,
                                    do_train = false,
                                    input = input,
                                    output = output,
                                    err_input = {},
                                    err_output = {}})
            network:propagate()
            for i, v in ipairs(output[1]) do
                --nerv.info(gconf.mask[i])
                if gconf.mask[i][0][0] > 0 then -- is not a hole
                    table.insert(accu_output, v)
                    --nerv.info("input: %s\noutput: %s\n", input[1][i], output[1][i])
                end
            end
        end
        local utt_matrix = gconf.mmat_type(#accu_output, accu_output[1]:ncol())
        for i, v in ipairs(accu_output) do
            utt_matrix:copy_from(v, 0, v:nrow(), i - 1)
        end
        --nerv.info(utt_matrix)
        collectgarbage("collect")
        nerv.info("propagated %d features of an utterance", utt_matrix:nrow())
        return utt_id, utt_matrix
    end

    return batch_propagator
end

function init(config, feature)
    dofile(config)
    gconf.mmat_type = nerv.MMatrixFloat
    gconf.use_cpu = true -- use CPU to decode
    gconf.batch_size = 1
    propagator = build_propagator(gconf.decode_param, feature)
end

function feed()
    local utt, mat = propagator()
    return utt, mat
end