diff options
Diffstat (limited to 'nerv/io')
-rw-r--r-- | nerv/io/frm_buffer.lua (renamed from nerv/io/sgd_buffer.lua) | 14 | ||||
-rw-r--r-- | nerv/io/init.lua | 2 | ||||
-rw-r--r-- | nerv/io/seq_buffer.lua | 7 |
3 files changed, 14 insertions, 9 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/frm_buffer.lua index d78f6d1..9761f16 100644 --- a/nerv/io/sgd_buffer.lua +++ b/nerv/io/frm_buffer.lua @@ -1,6 +1,6 @@ -local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer") +local FrmBuffer = nerv.class("nerv.FrmBuffer", "nerv.DataBuffer") -function SGDBuffer:__init(global_conf, buffer_conf) +function FrmBuffer:__init(global_conf, buffer_conf) self.gconf = global_conf self.batch_size = buffer_conf.batch_size self.buffer_size = math.floor(buffer_conf.buffer_size / @@ -57,7 +57,7 @@ function SGDBuffer:__init(global_conf, buffer_conf) end end -function SGDBuffer:saturate() +function FrmBuffer:saturate() local buffer_size = self.buffer_size self.head = 0 self.tail = buffer_size @@ -116,7 +116,7 @@ function SGDBuffer:saturate() return self.tail >= self.batch_size end -function SGDBuffer:get_data() +function FrmBuffer:get_data() local batch_size = self.batch_size if self.head >= self.tail then -- buffer is empty local t = os.clock() @@ -132,7 +132,9 @@ function SGDBuffer:get_data() return nil -- the remaining data cannot build a batch end actual_batch_size = math.min(batch_size, self.tail - self.head) - local res = {} + local res = {seq_length = table.vector(gconf.batch_size, 1), + new_seq = {}, + data = {}} for i, reader in ipairs(self.readers) do for id, buff in pairs(reader.buffs) do local batch = self.output_mat_type(actual_batch_size, buff.width) @@ -141,7 +143,7 @@ function SGDBuffer:get_data() else self.copy_from(batch, buff.data, self.head, self.head + actual_batch_size) end - res[id] = batch + res.data[id] = {batch} end end self.head = self.head + actual_batch_size diff --git a/nerv/io/init.lua b/nerv/io/init.lua index c36d850..d3ba27c 100644 --- a/nerv/io/init.lua +++ b/nerv/io/init.lua @@ -56,5 +56,5 @@ function DataBuffer:get_data() nerv.error_method_not_implemented() end -nerv.include('sgd_buffer.lua') +nerv.include('frm_buffer.lua') nerv.include('seq_buffer.lua') diff --git a/nerv/io/seq_buffer.lua b/nerv/io/seq_buffer.lua index ad1b3f7..029e7b8 100644 --- a/nerv/io/seq_buffer.lua +++ b/nerv/io/seq_buffer.lua @@ -5,7 +5,10 @@ function SeqBuffer:__init(global_conf, buffer_conf) self.batch_size = buffer_conf.batch_size self.chunk_size = buffer_conf.chunk_size - self.readers = buffer_conf.readers + self.readers = {} + for _, v in ipairs(buffer_conf.readers) do + table.insert(self.readers, v.reader) + end self.nn_act_default = buffer_conf.nn_act_default if self.nn_act_default == nil then self.nn_act_default = 0 @@ -29,7 +32,7 @@ function SeqBuffer:new_mini_batch() end function SeqBuffer:saturate(batch) - if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then + if self.queue[self.head] ~= nil and self.queue[self.head].seq_length[batch] ~= 0 then return true end local data = {} |