local Reader = nerv.class('nerv.Reader')
function Reader:__init(vocab_file, input_file)
self:get_vocab(vocab_file)
self:get_seq(input_file)
end
function Reader:get_vocab(vocab_file)
local f = io.open(vocab_file, 'r')
local id = 0
self.vocab = {}
while true do
local word = f:read()
if word == nil then
break
end
self.vocab[word] = id
id = id + 1
end
self.size = id
end
function Reader:split(s, t)
local ret = {}
for x in (s .. t):gmatch('(.-)' .. t) do
table.insert(ret, x)
end
return ret
end
function Reader:get_seq(input_file)
local f = io.open(input_file, 'r')
self.seq = {}
while true do
local seq = f:read()
if seq == nil then
break
end
seq = self:split(seq, ' ')
local tmp = {}
for i = 1, #seq do
if seq[i] ~= '' then
table.insert(tmp, self.vocab[seq[i]])
end
end
table.insert(self.seq, tmp)
end
end
function Reader:get_in_out(id, pos)
return self.seq[id][pos], self.seq[id][pos + 1], pos + 1 == #self.seq[id]
end
function Reader:get_all_batch(global_conf)
local data = {}
local pos = {}
local offset = 1
for i = 1, global_conf.batch_size do
pos[i] = nil
end
while true do
local input = {}
local output = {}
for i = 1, global_conf.chunk_size do
input[i] = global_conf.mmat_type(global_conf.batch_size, 1)
input[i]:fill(global_conf.nn_act_default)
output[i] = global_conf.mmat_type(global_conf.batch_size, 1)
output[i]:fill(global_conf.nn_act_default)
end
local seq_start = {}
local seq_end = {}
local seq_len = {}
for i = 1, global_conf.batch_size do
seq_start[i] = false
seq_end[i] = false
seq_len[i] = 0
end
local has_new = false
for i = 1, global_conf.batch_size do
if pos[i] == nil then
if offset < #self.seq then
seq_start[i] = true
pos[i] = {offset, 1}
offset = offset + 1
end
end
if pos[i] ~= nil then
has_new = true
for j = 1, global_conf.chunk_size do
local final
input[j][i-1][0], output[j][i-1][0], final = self:get_in_out(pos[i][1], pos[i][2])
seq_len[i] = j
if final then
seq_end[i] = true
pos[i] = nil
break
end
pos[i][2] = pos[i][2] + 1
end
end
end
if not has_new then
break
end
for i = 1, global_conf.chunk_size do
input[i] = global_conf.cumat_type.new_from_host(input[i])
output[i] = global_conf.cumat_type.new_from_host(output[i])
end
table.insert(data, {input = input, output = output, seq_start = seq_start, seq_end = seq_end, seq_len = seq_len})
end
return data
end