diff options
Diffstat (limited to 'nerv/examples/network_debug/reader.lua')
-rw-r--r-- | nerv/examples/network_debug/reader.lua | 76 |
1 files changed, 15 insertions, 61 deletions
diff --git a/nerv/examples/network_debug/reader.lua b/nerv/examples/network_debug/reader.lua index b10baaf..76a78cf 100644 --- a/nerv/examples/network_debug/reader.lua +++ b/nerv/examples/network_debug/reader.lua @@ -3,6 +3,7 @@ local Reader = nerv.class('nerv.Reader') function Reader:__init(vocab_file, input_file) self:get_vocab(vocab_file) self:get_seq(input_file) + self.offset = 1 end function Reader:get_vocab(vocab_file) @@ -32,6 +33,7 @@ function Reader:get_seq(input_file) local f = io.open(input_file, 'r') self.seq = {} while true do + -- for i = 1, 26 do local seq = f:read() if seq == nil then break @@ -47,67 +49,19 @@ function Reader:get_seq(input_file) end end -function Reader:get_in_out(id, pos) - return self.seq[id][pos], self.seq[id][pos + 1], pos + 1 == #self.seq[id] -end - -function Reader:get_all_batch(global_conf) - local data = {} - local pos = {} - local offset = 1 - for i = 1, global_conf.batch_size do - pos[i] = nil +function Reader:get_data() + if self.offset > #self.seq then + return nil end - while true do - -- for i = 1, 26 do - local input = {} - local output = {} - for i = 1, global_conf.chunk_size do - input[i] = global_conf.mmat_type(global_conf.batch_size, 1) - input[i]:fill(global_conf.nn_act_default) - output[i] = global_conf.mmat_type(global_conf.batch_size, 1) - output[i]:fill(global_conf.nn_act_default) - end - local seq_start = {} - local seq_end = {} - local seq_len = {} - for i = 1, global_conf.batch_size do - seq_start[i] = false - seq_end[i] = false - seq_len[i] = 0 - end - local has_new = false - for i = 1, global_conf.batch_size do - if pos[i] == nil then - if offset < #self.seq then - seq_start[i] = true - pos[i] = {offset, 1} - offset = offset + 1 - end - end - if pos[i] ~= nil then - has_new = true - for j = 1, global_conf.chunk_size do - local final - input[j][i-1][0], output[j][i-1][0], final = self:get_in_out(pos[i][1], pos[i][2]) - seq_len[i] = j - if final then - seq_end[i] = true - pos[i] = nil - break - end - pos[i][2] = pos[i][2] + 1 - end - end - end - if not has_new then - break - end - for i = 1, global_conf.chunk_size do - input[i] = global_conf.cumat_type.new_from_host(input[i]) - output[i] = global_conf.cumat_type.new_from_host(output[i]) - end - table.insert(data, {input = input, output = output, seq_start = seq_start, seq_end = seq_end, seq_len = seq_len}) + local tmp = self.seq[self.offset] + local res = { + input = nerv.MMatrixFloat(#tmp - 1, 1), + label = nerv.MMatrixFloat(#tmp - 1, 1), + } + for i = 1, #tmp - 1 do + res.input[i - 1][0] = tmp[i] + res.label[i - 1][0] = tmp[i + 1] end - return data + self.offset = self.offset + 1 + return res end |