1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
package.path="/home/slhome/ymz09/.luarocks/share/lua/5.1/?.lua;/home/slhome/ymz09/.luarocks/share/lua/5.1/?/init.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?.lua;/slfs6/users/ymz09/nerv-project/nerv/install/share/lua/5.1/?/init.lua;"..package.path;
package.cpath="/home/slhome/ymz09/.luarocks/lib/lua/5.1/?.so;/slfs6/users/ymz09/nerv-project/nerv/install/lib/lua/5.1/?.so;"..package.cpath;
local k,l,_=pcall(require,"luarocks.loader") _=k and l.add_context("nerv","scm-1")
require 'nerv'
function build_trainer(ifname, feature)
local param_repo = nerv.ParamRepo()
param_repo:import(ifname, nil, gconf)
local layer_repo = make_layer_repo(param_repo)
local network = get_decode_network(layer_repo)
local global_transf = get_global_transf(layer_repo)
local input_order = get_input_order()
local readers = make_readers(feature, layer_repo)
network:init(1)
local iterative_trainer = function()
local data = nil
for ri = 1, #readers, 1 do
data = readers[ri].reader:get_data()
if data ~= nil then
break
end
end
if data == nil then
return "", nil
end
local input = {}
for i, e in ipairs(input_order) do
local id = e.id
if data[id] == nil then
nerv.error("input data %s not found", id)
end
local transformed
if e.global_transf then
local batch = gconf.cumat_type(data[id]:nrow(), data[id]:ncol())
batch:copy_fromh(data[id])
transformed = nerv.speech_utils.global_transf(batch,
global_transf,
gconf.frm_ext or 0, 0,
gconf)
else
transformed = data[id]
end
table.insert(input, transformed)
end
local output = {nerv.CuMatrixFloat(input[1]:nrow(), network.dim_out[1])}
network:batch_resize(input[1]:nrow())
network:propagate(input, output)
local utt = data["key"]
if utt == nil then
nerv.error("no key found.")
end
local mat = nerv.MMatrixFloat(output[1]:nrow(), output[1]:ncol())
output[1]:copy_toh(mat)
collectgarbage("collect")
return utt, mat
end
return iterative_trainer
end
function init(config, feature)
local tmp = io.write
io.write = function(...)
end
dofile(config)
trainer = build_trainer(gconf.decode_param, feature)
io.write = tmp
end
function feed()
local utt, mat = trainer()
return utt, mat
end
|