local Reader = nerv.class('nerv.Reader') function Reader:__init(vocab_file, input_file) self:get_vocab(vocab_file) self:get_seq(input_file) end function Reader:get_vocab(vocab_file) local f = io.open(vocab_file, 'r') local id = 0 self.vocab = {} while true do local word = f:read() if word == nil then break end self.vocab[word] = id id = id + 1 end self.size = id end function Reader:split(s, t) local ret = {} for x in (s .. t):gmatch('(.-)' .. t) do table.insert(ret, x) end return ret end function Reader:get_seq(input_file) local f = io.open(input_file, 'r') self.seq = {} while true do local seq = f:read() if seq == nil then break end seq = self:split(seq, ' ') local tmp = {} for i = 1, #seq do if seq[i] ~= '' then table.insert(tmp, self.vocab[seq[i]]) end end table.insert(self.seq, tmp) end end function Reader:get_in_out(id, pos) return self.seq[id][pos], self.seq[id][pos + 1], pos + 1 == #self.seq[id] end function Reader:get_all_batch(global_conf) local data = {} local pos = {} local offset = 1 for i = 1, global_conf.batch_size do pos[i] = nil end while true do -- for i = 1, 26 do local input = {} local output = {} for i = 1, global_conf.chunk_size do input[i] = global_conf.mmat_type(global_conf.batch_size, 1) input[i]:fill(global_conf.nn_act_default) output[i] = global_conf.mmat_type(global_conf.batch_size, 1) output[i]:fill(global_conf.nn_act_default) end local seq_start = {} local seq_end = {} local seq_len = {} for i = 1, global_conf.batch_size do seq_start[i] = false seq_end[i] = false seq_len[i] = 0 end local has_new = false for i = 1, global_conf.batch_size do if pos[i] == nil then if offset < #self.seq then seq_start[i] = true pos[i] = {offset, 1} offset = offset + 1 end end if pos[i] ~= nil then has_new = true for j = 1, global_conf.chunk_size do local final input[j][i-1][0], output[j][i-1][0], final = self:get_in_out(pos[i][1], pos[i][2]) seq_len[i] = j if final then seq_end[i] = true pos[i] = nil break end pos[i][2] = pos[i][2] + 1 end end end if not has_new then break end for i = 1, global_conf.chunk_size do input[i] = global_conf.cumat_type.new_from_host(input[i]) output[i] = global_conf.cumat_type.new_from_host(output[i]) end table.insert(data, {input = input, output = output, seq_start = seq_start, seq_end = seq_end, seq_len = seq_len}) end return data end