aboutsummaryrefslogtreecommitdiff
path: root/nerv/io
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/io')
-rw-r--r--nerv/io/frm_buffer.lua (renamed from nerv/io/sgd_buffer.lua)14
-rw-r--r--nerv/io/init.lua2
-rw-r--r--nerv/io/seq_buffer.lua7
3 files changed, 14 insertions, 9 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/frm_buffer.lua
index d78f6d1..9761f16 100644
--- a/nerv/io/sgd_buffer.lua
+++ b/nerv/io/frm_buffer.lua
@@ -1,6 +1,6 @@
-local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
+local FrmBuffer = nerv.class("nerv.FrmBuffer", "nerv.DataBuffer")
-function SGDBuffer:__init(global_conf, buffer_conf)
+function FrmBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.batch_size = buffer_conf.batch_size
self.buffer_size = math.floor(buffer_conf.buffer_size /
@@ -57,7 +57,7 @@ function SGDBuffer:__init(global_conf, buffer_conf)
end
end
-function SGDBuffer:saturate()
+function FrmBuffer:saturate()
local buffer_size = self.buffer_size
self.head = 0
self.tail = buffer_size
@@ -116,7 +116,7 @@ function SGDBuffer:saturate()
return self.tail >= self.batch_size
end
-function SGDBuffer:get_data()
+function FrmBuffer:get_data()
local batch_size = self.batch_size
if self.head >= self.tail then -- buffer is empty
local t = os.clock()
@@ -132,7 +132,9 @@ function SGDBuffer:get_data()
return nil -- the remaining data cannot build a batch
end
actual_batch_size = math.min(batch_size, self.tail - self.head)
- local res = {}
+ local res = {seq_length = table.vector(gconf.batch_size, 1),
+ new_seq = {},
+ data = {}}
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
local batch = self.output_mat_type(actual_batch_size, buff.width)
@@ -141,7 +143,7 @@ function SGDBuffer:get_data()
else
self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size)
end
- res[id] = batch
+ res.data[id] = {batch}
end
end
self.head = self.head + actual_batch_size
diff --git a/nerv/io/init.lua b/nerv/io/init.lua
index c36d850..d3ba27c 100644
--- a/nerv/io/init.lua
+++ b/nerv/io/init.lua
@@ -56,5 +56,5 @@ function DataBuffer:get_data()
nerv.error_method_not_implemented()
end
-nerv.include('sgd_buffer.lua')
+nerv.include('frm_buffer.lua')
nerv.include('seq_buffer.lua')
diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua
index ad1b3f7..029e7b8 100644
--- a/nerv/io/seq_buffer.lua
+++ b/nerv/io/seq_buffer.lua
@@ -5,7 +5,10 @@ function SeqBuffer:__init(global_conf, buffer_conf)
self.batch_size = buffer_conf.batch_size
self.chunk_size = buffer_conf.chunk_size
- self.readers = buffer_conf.readers
+ self.readers = {}
+ for _, v in ipairs(buffer_conf.readers) do
+ table.insert(self.readers, v.reader)
+ end
self.nn_act_default = buffer_conf.nn_act_default
if self.nn_act_default == nil then
self.nn_act_default = 0
@@ -29,7 +32,7 @@ function SeqBuffer:new_mini_batch()
end
function SeqBuffer:saturate(batch)
- if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
+ if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then
return true
end
local data = {}