diff options
Diffstat (limited to 'nerv/examples/ptb/reader.lua')
-rw-r--r-- | nerv/examples/ptb/reader.lua | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/nerv/examples/ptb/reader.lua b/nerv/examples/ptb/reader.lua new file mode 100644 index 0000000..70c0c97 --- /dev/null +++ b/nerv/examples/ptb/reader.lua @@ -0,0 +1,67 @@ +local Reader = nerv.class('nerv.Reader') + +function Reader:__init(vocab_file, input_file) + self:get_vocab(vocab_file) + self:get_seq(input_file) + self.offset = 1 +end + +function Reader:get_vocab(vocab_file) + local f = io.open(vocab_file, 'r') + local id = 0 + self.vocab = {} + while true do + local word = f:read() + if word == nil then + break + end + self.vocab[word] = id + id = id + 1 + end + self.size = id +end + +function Reader:split(s, t) + local ret = {} + for x in (s .. t):gmatch('(.-)' .. t) do + table.insert(ret, x) + end + return ret +end + +function Reader:get_seq(input_file) + local f = io.open(input_file, 'r') + self.seq = {} + -- while true do + for i = 1, 26 do + local seq = f:read() + if seq == nil then + break + end + seq = self:split(seq, ' ') + local tmp = {} + for i = 1, #seq do + if seq[i] ~= '' then + table.insert(tmp, self.vocab[seq[i]]) + end + end + table.insert(self.seq, tmp) + end +end + +function Reader:get_data() + if self.offset > #self.seq then + return nil + end + local tmp = self.seq[self.offset] + local res = { + input = nerv.MMatrixFloat(#tmp - 1, 1), + label = nerv.MMatrixFloat(#tmp - 1, 1), + } + for i = 1, #tmp - 1 do + res.input[i - 1][0] = tmp[i] + res.label[i - 1][0] = tmp[i + 1] + end + self.offset = self.offset + 1 + return res +end |