aboutsummaryrefslogtreecommitdiff
path: root/nerv/io
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/io')
-rw-r--r--nerv/io/frm_buffer.lua (renamed from nerv/io/sgd_buffer.lua)14
-rw-r--r--nerv/io/init.lua2
-rw-r--r--nerv/io/seq_buffer.lua108
3 files changed, 117 insertions, 7 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/frm_buffer.lua
index d78f6d1..9761f16 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/frm_buffer.lua
@@ -1,6 +1,6 @@
-local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
+local FrmBuffer = nerv.class("nerv.FrmBuffer", "nerv.DataBuffer")
-function SGDBuffer:__init(global_conf, buffer_conf)
+function FrmBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.batch_size = buffer_conf.batch_size
self.buffer_size = math.floor(buffer_conf.buffer_size /
@@ -57,7 +57,7 @@ function SGDBuffer:__init(global_conf, buffer_conf)
end
end
-function SGDBuffer:saturate()
+function FrmBuffer:saturate()
local buffer_size = self.buffer_size
self.head = 0
self.tail = buffer_size
@@ -116,7 +116,7 @@ function SGDBuffer:saturate()
return self.tail >= self.batch_size
end
-function SGDBuffer:get_data()
+function FrmBuffer:get_data()
local batch_size = self.batch_size
if self.head >= self.tail then -- buffer is empty
local t = os.clock()
@@ -132,7 +132,9 @@ function SGDBuffer:get_data()
return nil -- the remaining data cannot build a batch
end
actual_batch_size = math.min(batch_size, self.tail - self.head)
- local res = {}
+ local res = {seq_length = table.vector(gconf.batch_size, 1),
+ new_seq = {},
+ data = {}}
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
local batch = self.output_mat_type(actual_batch_size, buff.width)
@@ -141,7 +143,7 @@ function SGDBuffer:get_data()
else
self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
end
- res[id] = batch
+ res.data[id] = {batch}
end
end
self.head = self.head + actual_batch_size
diff --git a/nerv/io/init.lua b/nerv/io/init.lua
index c36d850..d3ba27c 100644
--- a/nerv/io/init.lua
+++ b/nerv/io/init.lua
@@ -56,5 +56,5 @@ function DataBuffer:get_data()
nerv.error_method_not_implemented()
end
-nerv.include('sgd_buffer.lua')
+nerv.include('frm_buffer.lua')
nerv.include('seq_buffer.lua')
diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua
index e69de29..029e7b8 100644
--- a/nerv/io/seq_buffer.lua
+++ b/nerv/io/seq_buffer.lua
@@ -0,0 +1,108 @@
+local SeqBuffer = nerv.class('nerv.SeqBuffer', 'nerv.DataBuffer')
+
+function SeqBuffer:__init(global_conf, buffer_conf)
+ self.gconf = global_conf
+
+ self.batch_size = buffer_conf.batch_size
+ self.chunk_size = buffer_conf.chunk_size
+ self.readers = {}
+ for _, v in ipairs(buffer_conf.readers) do
+ table.insert(self.readers, v.reader)
+ end
+ self.nn_act_default = buffer_conf.nn_act_default
+ if self.nn_act_default == nil then
+ self.nn_act_default = 0
+ end
+
+ self.mat_type = self.gconf.mmat_type
+ self.queue = {}
+ self.head = 1
+ self.tail = 0
+end
+
+function SeqBuffer:new_mini_batch()
+ local res = {}
+ res.data = {}
+ res.new_seq = {}
+ res.seq_length = {}
+ for i = 1, self.batch_size do
+ res.seq_length[i] = 0
+ end
+ return res
+end
+
+function SeqBuffer:saturate(batch)
+ if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
+ return true
+ end
+ local data = {}
+ local drow = nil
+ for i = 1, #self.readers do
+ local tmp = self.readers[i]:get_data()
+ if tmp == nil then
+ return false
+ end
+ for id, d in pairs(tmp) do
+ if drow == nil then
+ drow = d:nrow()
+ elseif d:nrow() ~= drow then
+ nerv.error('readers provides with inconsistent rows of data')
+ end
+ data[id] = d
+ end
+ end
+ local offset = 0
+ local head = self.head
+ while offset < drow do
+ local last = math.min(offset + self.chunk_size, drow)
+ if head > self.tail then
+ self.tail = self.tail + 1
+ self.queue[self.tail] = self:new_mini_batch()
+ end
+ self.queue[head].seq_length[batch] = last - offset
+ if offset == 0 then
+ table.insert(self.queue[head].new_seq, batch)
+ end
+ local mini_batch = self.queue[head].data
+ for id, d in pairs(data) do
+ if mini_batch[id] == nil then
+ mini_batch[id] = {}
+ end
+ local tmp = mini_batch[id]
+ for i = offset + 1, last do
+ local chunk = i - offset
+ if tmp[chunk] == nil then
+ tmp[chunk] = self.mat_type(self.batch_size, d:ncol())
+ tmp[chunk]:fill(self.nn_act_default)
+ end
+ tmp[chunk]:copy_from(d, i - 1, i, batch - 1)
+ end
+ end
+ head = head + 1
+ offset = last
+ end
+ return true
+end
+
+function SeqBuffer:get_data()
+ local has_data = false
+ for i = 1, self.batch_size do
+ if self:saturate(i) then
+ has_data = true
+ end
+ end
+ if not has_data then
+ return nil
+ end
+ local res = self.queue[self.head]
+ self.queue[self.head] = nil
+ self.head = self.head + 1
+ if not self.gconf.use_cpu then
+ for id, d in pairs(res.data) do
+ for i = 1, #d do
+ d[i] = self.gconf.cumat_type.new_from_host(d[i])
+ end
+ end
+ end
+ return res
+end