1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
|
local Reader = nerv.class('nerv.Reader')
function Reader:__init(vocab_file, input_file)
self:get_vocab(vocab_file)
self:get_seq(input_file)
end
function Reader:get_vocab(vocab_file)
local f = io.open(vocab_file, 'r')
local id = 0
self.vocab = {}
while true do
local word = f:read()
if word == nil then
break
end
self.vocab[word] = id
id = id + 1
end
self.size = id
end
function Reader:split(s, t)
local ret = {}
for x in (s .. t):gmatch('(.-)' .. t) do
table.insert(ret, x)
end
return ret
end
function Reader:get_seq(input_file)
local f = io.open(input_file, 'r')
self.seq = {}
while true do
local seq = f:read()
if seq == nil then
break
end
seq = self:split(seq, ' ')
local tmp = {}
for i = 1, #seq do
if seq[i] ~= '' then
table.insert(tmp, self.vocab[seq[i]])
end
end
table.insert(self.seq, tmp)
end
end
function Reader:get_in_out(id, pos)
return self.seq[id][pos], self.seq[id][pos + 1], pos + 1 == #self.seq[id]
end
function Reader:get_all_batch(global_conf)
local data = {}
local pos = {}
local offset = 1
for i = 1, global_conf.batch_size do
pos[i] = nil
end
while true do
--for i = 1, 100 do
local input = {}
local output = {}
for i = 1, global_conf.chunk_size do
input[i] = global_conf.mmat_type(global_conf.batch_size, 1)
input[i]:fill(global_conf.nn_act_default)
output[i] = global_conf.mmat_type(global_conf.batch_size, 1)
output[i]:fill(global_conf.nn_act_default)
end
local seq_start = {}
local seq_end = {}
local seq_len = {}
for i = 1, global_conf.batch_size do
seq_start[i] = false
seq_end[i] = false
seq_len[i] = 0
end
local has_new = false
for i = 1, global_conf.batch_size do
if pos[i] == nil then
if offset < #self.seq then
seq_start[i] = true
pos[i] = {offset, 1}
offset = offset + 1
end
end
if pos[i] ~= nil then
has_new = true
for j = 1, global_conf.chunk_size do
local final
input[j][i-1][0], output[j][i-1][0], final = self:get_in_out(pos[i][1], pos[i][2])
seq_len[i] = j
if final then
seq_end[i] = true
pos[i] = nil
break
end
pos[i][2] = pos[i][2] + 1
end
end
end
if not has_new then
break
end
for i = 1, global_conf.chunk_size do
input[i] = global_conf.cumat_type.new_from_host(input[i])
output[i] = global_conf.cumat_type.new_from_host(output[i])
end
table.insert(data, {input = input, output = output, seq_start = seq_start, seq_end = seq_end, seq_len = seq_len})
end
return data
end
|