aboutsummaryrefslogtreecommitdiff
path: root/lua/reader.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua/reader.lua')
-rw-r--r--lua/reader.lua112
1 files changed, 112 insertions, 0 deletions
diff --git a/lua/reader.lua b/lua/reader.lua
new file mode 100644
index 0000000..2e51a9c
--- /dev/null
+++ b/lua/reader.lua
@@ -0,0 +1,112 @@
+local Reader = nerv.class('nerv.Reader')
+
+function Reader:__init(vocab_file, input_file)
+ self:get_vocab(vocab_file)
+ self:get_seq(input_file)
+end
+
+function Reader:get_vocab(vocab_file)
+ local f = io.open(vocab_file, 'r')
+ local id = 0
+ self.vocab = {}
+ while true do
+ local word = f:read()
+ if word == nil then
+ break
+ end
+ self.vocab[word] = id
+ id = id + 1
+ end
+ self.size = id
+end
+
+function Reader:split(s, t)
+ local ret = {}
+ for x in (s .. t):gmatch('(.-)' .. t) do
+ table.insert(ret, x)
+ end
+ return ret
+end
+
+function Reader:get_seq(input_file)
+ local f = io.open(input_file, 'r')
+ self.seq = {}
+ while true do
+ local seq = f:read()
+ if seq == nil then
+ break
+ end
+ seq = self:split(seq, ' ')
+ local tmp = {}
+ for i = 1, #seq do
+ if seq[i] ~= '' then
+ table.insert(tmp, self.vocab[seq[i]])
+ end
+ end
+ table.insert(self.seq, tmp)
+ end
+end
+
+function Reader:get_in_out(id, pos)
+ return self.seq[id][pos], self.seq[id][pos + 1], pos + 1 == #self.seq[id]
+end
+
+function Reader:get_all_batch(global_conf)
+ local data = {}
+ local pos = {}
+ local offset = 1
+ for i = 1, global_conf.batch_size do
+ pos[i] = nil
+ end
+ while true do
+ local input = {}
+ local output = {}
+ for i = 1, global_conf.chunk_size do
+ input[i] = global_conf.mmat_type(global_conf.batch_size, 1)
+ input[i]:fill(global_conf.nn_act_default)
+ output[i] = global_conf.mmat_type(global_conf.batch_size, 1)
+ output[i]:fill(global_conf.nn_act_default)
+ end
+ local seq_start = {}
+ local seq_end = {}
+ local seq_len = {}
+ for i = 1, global_conf.batch_size do
+ seq_start[i] = false
+ seq_end[i] = false
+ seq_len[i] = 0
+ end
+ local has_new = false
+ for i = 1, global_conf.batch_size do
+ if pos[i] == nil then
+ if offset < #self.seq then
+ seq_start[i] = true
+ pos[i] = {offset, 1}
+ offset = offset + 1
+ end
+ end
+ if pos[i] ~= nil then
+ has_new = true
+ for j = 1, global_conf.chunk_size do
+ local final
+ input[j][i-1][0], output[j][i-1][0], final = self:get_in_out(pos[i][1], pos[i][2])
+ seq_len[i] = j
+ if final then
+ seq_end[i] = true
+ pos[i] = nil
+ break
+ end
+ pos[i][2] = pos[i][2] + 1
+ end
+ end
+ end
+ if not has_new then
+ break
+ end
+ for i = 1, global_conf.chunk_size do
+ input[i] = global_conf.cumat_type.new_from_host(input[i])
+ output[i] = global_conf.cumat_type.new_from_host(output[i])
+ end
+ table.insert(data, {input = input, output = output, seq_start = seq_start, seq_end = seq_end, seq_len = seq_len})
+ end
+ return data
+end