diff options
author | Qi Liu <[email protected]> | 2016-03-11 20:11:00 +0800 |
---|---|---|
committer | Qi Liu <[email protected]> | 2016-03-11 20:11:00 +0800 |
commit | e2a9af061db485d4388902d738c9d8be3f94ab34 (patch) | |
tree | 468d6c6afa0801f6a6bf794b3674f8814b8827f7 /lua/reader.lua | |
parent | 2f46a5e2b37a054f482f76f4ac3d26b144cf988f (diff) |
add recipe and fix bugs
Diffstat (limited to 'lua/reader.lua')
-rw-r--r-- | lua/reader.lua | 113 |
1 files changed, 0 insertions, 113 deletions
diff --git a/lua/reader.lua b/lua/reader.lua deleted file mode 100644 index d2624d3..0000000 --- a/lua/reader.lua +++ /dev/null @@ -1,113 +0,0 @@ -local Reader = nerv.class('nerv.Reader') - -function Reader:__init(vocab_file, input_file) - self:get_vocab(vocab_file) - self:get_seq(input_file) -end - -function Reader:get_vocab(vocab_file) - local f = io.open(vocab_file, 'r') - local id = 0 - self.vocab = {} - while true do - local word = f:read() - if word == nil then - break - end - self.vocab[word] = id - id = id + 1 - end - self.size = id -end - -function Reader:split(s, t) - local ret = {} - for x in (s .. t):gmatch('(.-)' .. t) do - table.insert(ret, x) - end - return ret -end - -function Reader:get_seq(input_file) - local f = io.open(input_file, 'r') - self.seq = {} - while true do - local seq = f:read() - if seq == nil then - break - end - seq = self:split(seq, ' ') - local tmp = {} - for i = 1, #seq do - if seq[i] ~= '' then - table.insert(tmp, self.vocab[seq[i]]) - end - end - table.insert(self.seq, tmp) - end -end - -function Reader:get_in_out(id, pos) - return self.seq[id][pos], self.seq[id][pos + 1], pos + 1 == #self.seq[id] -end - -function Reader:get_all_batch(global_conf) - local data = {} - local pos = {} - local offset = 1 - for i = 1, global_conf.batch_size do - pos[i] = nil - end - while true do - --for i = 1, 100 do - local input = {} - local output = {} - for i = 1, global_conf.chunk_size do - input[i] = global_conf.mmat_type(global_conf.batch_size, 1) - input[i]:fill(global_conf.nn_act_default) - output[i] = global_conf.mmat_type(global_conf.batch_size, 1) - output[i]:fill(global_conf.nn_act_default) - end - local seq_start = {} - local seq_end = {} - local seq_len = {} - for i = 1, global_conf.batch_size do - seq_start[i] = false - seq_end[i] = false - seq_len[i] = 0 - end - local has_new = false - for i = 1, global_conf.batch_size do - if pos[i] == nil then - if offset < #self.seq then - seq_start[i] = true - pos[i] = {offset, 1} - offset = offset + 1 - end - end - if pos[i] ~= nil then - has_new = true - for j = 1, global_conf.chunk_size do - local final - input[j][i-1][0], output[j][i-1][0], final = self:get_in_out(pos[i][1], pos[i][2]) - seq_len[i] = j - if final then - seq_end[i] = true - pos[i] = nil - break - end - pos[i][2] = pos[i][2] + 1 - end - end - end - if not has_new then - break - end - for i = 1, global_conf.chunk_size do - input[i] = global_conf.cumat_type.new_from_host(input[i]) - output[i] = global_conf.cumat_type.new_from_host(output[i]) - end - table.insert(data, {input = input, output = output, seq_start = seq_start, seq_end = seq_end, seq_len = seq_len}) - end - return data -end |