aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/ptb/reader.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/ptb/reader.lua')
-rw-r--r--nerv/examples/ptb/reader.lua67
1 files changed, 67 insertions, 0 deletions
diff --git a/nerv/examples/ptb/reader.lua b/nerv/examples/ptb/reader.lua
new file mode 100644
index 0000000..70c0c97
--- /dev/null
+++ b/nerv/examples/ptb/reader.lua
@@ -0,0 +1,67 @@
+local Reader = nerv.class('nerv.Reader')
+
+function Reader:__init(vocab_file, input_file)
+ self:get_vocab(vocab_file)
+ self:get_seq(input_file)
+ self.offset = 1
+end
+
+function Reader:get_vocab(vocab_file)
+ local f = io.open(vocab_file, 'r')
+ local id = 0
+ self.vocab = {}
+ while true do
+ local word = f:read()
+ if word == nil then
+ break
+ end
+ self.vocab[word] = id
+ id = id + 1
+ end
+ self.size = id
+end
+
+function Reader:split(s, t)
+ local ret = {}
+ for x in (s .. t):gmatch('(.-)' .. t) do
+ table.insert(ret, x)
+ end
+ return ret
+end
+
+function Reader:get_seq(input_file)
+ local f = io.open(input_file, 'r')
+ self.seq = {}
+ -- while true do
+ for i = 1, 26 do
+ local seq = f:read()
+ if seq == nil then
+ break
+ end
+ seq = self:split(seq, ' ')
+ local tmp = {}
+ for i = 1, #seq do
+ if seq[i] ~= '' then
+ table.insert(tmp, self.vocab[seq[i]])
+ end
+ end
+ table.insert(self.seq, tmp)
+ end
+end
+
+function Reader:get_data()
+ if self.offset > #self.seq then
+ return nil
+ end
+ local tmp = self.seq[self.offset]
+ local res = {
+ input = nerv.MMatrixFloat(#tmp - 1, 1),
+ label = nerv.MMatrixFloat(#tmp - 1, 1),
+ }
+ for i = 1, #tmp - 1 do
+ res.input[i - 1][0] = tmp[i]
+ res.label[i - 1][0] = tmp[i + 1]
+ end
+ self.offset = self.offset + 1
+ return res
+end