diff options
author | Determinant <ted.sybil@gmail.com> | 2016-03-15 15:46:05 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2016-03-15 15:46:05 +0800 |
commit | 07fc1e2794027d44c255e1062c4491346b101a08 (patch) | |
tree | 8e7217b9c5e9570b94af5aaad3f94d1a37cfe40b /nerv/io | |
parent | a5a4d2735b595fc9fadc9c7e91198786d3c0e078 (diff) | |
parent | e15307f071813e2eb56f7f83229b91141961325a (diff) |
Merge branch 'master' of github.com:liuq901/nerv into liuq901-master
Diffstat (limited to 'nerv/io')
-rw-r--r-- | nerv/io/seq_buffer.lua | 105 |
1 files changed, 105 insertions, 0 deletions
diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua index e69de29..ad1b3f7 100644 --- a/nerv/io/seq_buffer.lua +++ b/nerv/io/seq_buffer.lua @@ -0,0 +1,105 @@ +local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer') + +function SeqBuffer:__init(global_conf, buffer_conf) + self.gconf = global_conf + + self.batch_size = buffer_conf.batch_size + self.chunk_size = buffer_conf.chunk_size + self.readers = buffer_conf.readers + self.nn_act_default = buffer_conf.nn_act_default + if self.nn_act_default == nil then + self.nn_act_default = 0 + end + + self.mat_type = self.gconf.mmat_type + self.queue = {} + self.head = 1 + self.tail = 0 +end + +function SeqBuffer:new_mini_batch() + local res = {} + res.data = {} + res.new_seq = {} + res.seq_length = {} + for i = 1, self.batch_size do + res.seq_length[i] = 0 + end + return res +end + +function SeqBuffer:saturate(batch) + if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then + return true + end + local data = {} + local drow = nil + for i = 1, #self.readers do + local tmp = self.readers[i]:get_data() + if tmp == nil then + return false + end + for id, d in pairs(tmp) do + if drow == nil then + drow = d:nrow() + elseif d:nrow() ~= drow then + nerv.error('readers provides with inconsistent rows of data') + end + data[id] = d + end + end + local offset = 0 + local head = self.head + while offset < drow do + local last = math.min(offset + self.chunk_size, drow) + if head > self.tail then + self.tail = self.tail + 1 + self.queue[self.tail] = self:new_mini_batch() + end + self.queue[head].seq_length[batch] = last - offset + if offset == 0 then + table.insert(self.queue[head].new_seq, batch) + end + local mini_batch = self.queue[head].data + for id, d in pairs(data) do + if mini_batch[id] == nil then + mini_batch[id] = {} + end + local tmp = mini_batch[id] + for i = offset + 1, last do + local chunk = i - offset + if tmp[chunk] == nil then + tmp[chunk] = self.mat_type(self.batch_size, d:ncol()) + tmp[chunk]:fill(self.nn_act_default) + end + tmp[chunk]:copy_from(d, i - 1, i, batch - 1) + end + end + head = head + 1 + offset = last + end + return true +end + +function SeqBuffer:get_data() + local has_data = false + for i = 1, self.batch_size do + if self:saturate(i) then + has_data = true + end + end + if not has_data then + return nil + end + local res = self.queue[self.head] + self.queue[self.head] = nil + self.head = self.head + 1 + if not self.gconf.use_cpu then + for id, d in pairs(res.data) do + for i = 1, #d do + d[i] = self.gconf.cumat_type.new_from_host(d[i]) + end + end + end + return res +end |