local Reader = nerv.class('nerv.Reader') function Reader:__init(vocab_file, input_file) self:get_vocab(vocab_file) self:get_seq(input_file) self.offset = 1 end function Reader:get_vocab(vocab_file) local f = io.open(vocab_file, 'r') local id = 0 self.vocab = {} while true do local word = f:read() if word == nil then break end self.vocab[word] = id id = id + 1 end self.size = id end function Reader:split(s, t) local ret = {} for x in (s .. t):gmatch('(.-)' .. t) do table.insert(ret, x) end return ret end function Reader:get_seq(input_file) local f = io.open(input_file, 'r') self.seq = {} while true do -- for i = 1, 26 do local seq = f:read() if seq == nil then break end seq = self:split(seq, ' ') local tmp = {} for i = 1, #seq do if seq[i] ~= '' then table.insert(tmp, self.vocab[seq[i]]) end end table.insert(self.seq, tmp) end end function Reader:get_data() if self.offset > #self.seq then return nil end local tmp = self.seq[self.offset] local res = { input = nerv.MMatrixFloat(#tmp - 1, 1), label = nerv.MMatrixFloat(#tmp - 1, 1), } for i = 1, #tmp - 1 do res.input[i - 1][0] = tmp[i] res.label[i - 1][0] = tmp[i + 1] end self.offset = self.offset + 1 return res end