aboutsummaryrefslogtreecommitdiff
path: root/lua/reader.lua
diff options
context:
space:
mode:
authorQi Liu <[email protected]>2016-03-11 20:11:00 +0800
committerQi Liu <[email protected]>2016-03-11 20:11:00 +0800
commite2a9af061db485d4388902d738c9d8be3f94ab34 (patch)
tree468d6c6afa0801f6a6bf794b3674f8814b8827f7 /lua/reader.lua
parent2f46a5e2b37a054f482f76f4ac3d26b144cf988f (diff)
add recipe and fix bugs
Diffstat (limited to 'lua/reader.lua')
-rw-r--r--lua/reader.lua113
1 files changed, 0 insertions, 113 deletions
diff --git a/lua/reader.lua b/lua/reader.lua
deleted file mode 100644
index d2624d3..0000000
--- a/lua/reader.lua
+++ /dev/null
@@ -1,113 +0,0 @@
-local Reader = nerv.class('nerv.Reader')
-
-function Reader:__init(vocab_file, input_file)
- self:get_vocab(vocab_file)
- self:get_seq(input_file)
-end
-
-function Reader:get_vocab(vocab_file)
- local f = io.open(vocab_file, 'r')
- local id = 0
- self.vocab = {}
- while true do
- local word = f:read()
- if word == nil then
- break
- end
- self.vocab[word] = id
- id = id + 1
- end
- self.size = id
-end
-
-function Reader:split(s, t)
- local ret = {}
- for x in (s .. t):gmatch('(.-)' .. t) do
- table.insert(ret, x)
- end
- return ret
-end
-
-function Reader:get_seq(input_file)
- local f = io.open(input_file, 'r')
- self.seq = {}
- while true do
- local seq = f:read()
- if seq == nil then
- break
- end
- seq = self:split(seq, ' ')
- local tmp = {}
- for i = 1, #seq do
- if seq[i] ~= '' then
- table.insert(tmp, self.vocab[seq[i]])
- end
- end
- table.insert(self.seq, tmp)
- end
-end
-
-function Reader:get_in_out(id, pos)
- return self.seq[id][pos], self.seq[id][pos + 1], pos + 1 == #self.seq[id]
-end
-
-function Reader:get_all_batch(global_conf)
- local data = {}
- local pos = {}
- local offset = 1
- for i = 1, global_conf.batch_size do
- pos[i] = nil
- end
- while true do
- --for i = 1, 100 do
- local input = {}
- local output = {}
- for i = 1, global_conf.chunk_size do
- input[i] = global_conf.mmat_type(global_conf.batch_size, 1)
- input[i]:fill(global_conf.nn_act_default)
- output[i] = global_conf.mmat_type(global_conf.batch_size, 1)
- output[i]:fill(global_conf.nn_act_default)
- end
- local seq_start = {}
- local seq_end = {}
- local seq_len = {}
- for i = 1, global_conf.batch_size do
- seq_start[i] = false
- seq_end[i] = false
- seq_len[i] = 0
- end
- local has_new = false
- for i = 1, global_conf.batch_size do
- if pos[i] == nil then
- if offset < #self.seq then
- seq_start[i] = true
- pos[i] = {offset, 1}
- offset = offset + 1
- end
- end
- if pos[i] ~= nil then
- has_new = true
- for j = 1, global_conf.chunk_size do
- local final
- input[j][i-1][0], output[j][i-1][0], final = self:get_in_out(pos[i][1], pos[i][2])
- seq_len[i] = j
- if final then
- seq_end[i] = true
- pos[i] = nil
- break
- end
- pos[i][2] = pos[i][2] + 1
- end
- end
- end
- if not has_new then
- break
- end
- for i = 1, global_conf.chunk_size do
- input[i] = global_conf.cumat_type.new_from_host(input[i])
- output[i] = global_conf.cumat_type.new_from_host(output[i])
- end
- table.insert(data, {input = input, output = output, seq_start = seq_start, seq_end = seq_end, seq_len = seq_len})
- end
- return data
-end