1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
print = function(...) io.write(table.concat({...}, "\t")) end
io.output('/dev/null')
-- path and cpath are correctly set by `path.sh`
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
local c = cls._default_context
cls.print_profile = function () c:print_profile() end
cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.MMatrix)
function build_propagator(ifname, feature)
-- FIXME: this is still a hack
local trainer = nerv.Trainer
----
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, gconf)
local layer_repo = trainer.make_layer_repo(nil, param_repo)
local network = trainer.get_decode_network(nil, layer_repo)
local input_order = trainer.get_decode_input_order(nil)
local input_name = gconf.decode_input_name or "main_scp"
local readers = trainer.make_decode_readers(nil, feature)
-- nerv.info("prepare")
local buffer = nerv.SeqBuffer(gconf, {
buffer_size = gconf.buffer_size,
batch_size = gconf.batch_size,
chunk_size = gconf.chunk_size,
randomize = gconf.randomize,
readers = readers,
})
network = nerv.Network("nt", gconf, {network = network})
network:init(gconf.batch_size, gconf.chunk_size)
local prev_data = buffer:get_data() or nerv.error("no data in buffer")
local terminate = false
local input_pos = nil
for i, v in ipairs(input_order) do
if v == input_name then
input_pos = i
end
end
if input_pos == nil then
nerv.error("input name %s not found in the input order list", input_name)
end
local batch_propagator = function()
if terminate then
return "", nil
end
network:epoch_init()
local accu_output = {}
local utt_id = readers[input_pos].reader.key
if utt_id == nil then
nerv.error("no key found.")
end
while true do
local d
if prev_data then
d = prev_data
prev_data = nil
else
d = buffer:get_data()
if d == nil then
terminate = true
break
elseif #d.new_seq > 0 then
prev_data = d -- the first data of the next utterance
break
end
end
local input = {}
local output = {{}}
for i, id in ipairs(input_order) do
if d.data[id] == nil then
nerv.error("input data %s not found", id)
end
table.insert(input, d.data[id])
for i = 1, gconf.chunk_size do
table.insert(output[1], gconf.mmat_type(gconf.batch_size, network.dim_out[1]))
end
end
--nerv.info("input num: %d\nmat: %s\n", #input[1], input[1][1])
--nerv.info("output num: %d\nmat: %s\n", #output[1], output[1][1])
network:mini_batch_init({seq_length = d.seq_length,
new_seq = d.new_seq,
do_train = false,
input = input,
output = output,
err_input = {},
err_output = {}})
network:propagate()
for i, v in ipairs(output[1]) do
--nerv.info(gconf.mask[i])
if gconf.mask[i][0][0] > 0 then -- is not a hole
table.insert(accu_output, v)
--nerv.info("input: %s\noutput: %s\n", input[1][i], output[1][i])
end
end
end
local utt_matrix = gconf.mmat_type(#accu_output, accu_output[1]:ncol())
for i, v in ipairs(accu_output) do
utt_matrix:copy_from(v, 0, v:nrow(), i - 1)
end
--nerv.info(utt_matrix)
collectgarbage("collect")
nerv.info("propagated %d features of an utterance", utt_matrix:nrow())
return utt_id, utt_matrix
end
return batch_propagator
end
function init(config, feature)
dofile(config)
gconf.mmat_type = nerv.MMatrixFloat
gconf.use_cpu = true -- use CPU to decode
gconf.batch_size = 1
propagator = build_propagator(gconf.decode_param, feature)
end
function feed()
local utt, mat = propagator()
return utt, mat
end
|