summaryrefslogtreecommitdiff
path: root/io
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-06-03 23:00:45 +0800
committerDeterminant <[email protected]>2015-06-03 23:00:45 +0800
commitea6f2990f99dd9ded6a0e74d75a3ec84900a2518 (patch)
tree03b4ea34fa373189bf6b2b017bf54793d5c89f8e /io
parentbb56a806e0636a0b20117b1644701d63e2bfaefb (diff)
demo now works (without random shuffle)
Diffstat (limited to 'io')
-rw-r--r--io/init.lua22
-rw-r--r--io/sgd_buffer.lua99
2 files changed, 121 insertions, 0 deletions
diff --git a/io/init.lua b/io/init.lua
index 4a663a7..9bbd51a 100644
--- a/io/init.lua
+++ b/io/init.lua
@@ -28,3 +28,25 @@ function nerv.ChunkFile:read_chunk(id, global_conf)
chunk:read(self:get_chunkdata(id))
return chunk
end
+
+local DataReader = nerv.class("nerv.DataReader")
+
+function DataReader:__init(global_conf, reader_conf)
+ nerv.error_method_not_implemented()
+end
+
+function DataReader:get_data()
+ nerv.error_method_not_implemented()
+end
+
+local DataBuffer = nerv.class("nerv.DataBuffer")
+
+function DataBuffer:__init(global_conf, buffer_conf)
+ nerv.error_method_not_implemented()
+end
+
+function DataBuffer:get_batch()
+ nerv.error_method_not_implemented()
+end
+
+require 'io.sgd_buffer'
diff --git a/io/sgd_buffer.lua b/io/sgd_buffer.lua
new file mode 100644
index 0000000..dadcf67
--- /dev/null
+++ b/io/sgd_buffer.lua
@@ -0,0 +1,99 @@
+local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer")
+
+function SGDBuffer:__init(global_conf, buffer_conf)
+ self.gconf = global_conf
+ self.buffer_size = math.floor(buffer_conf.buffer_size /
+ global_conf.batch_size) * global_conf.batch_size
+ self.head = 0
+ self.tail = 0
+ self.readers = {}
+ for i, reader_spec in ipairs(buffer_conf.readers) do
+ local buffs = {}
+ for id, width in pairs(reader_spec.data) do
+ buffs[id] = {data = global_conf.mmat_type(self.buffer_size, width),
+ leftover = {},
+ width = width}
+ end
+ table.insert(self.readers, {buffs = buffs,
+ reader = reader_spec.reader,
+ tail = 0,
+ has_leftover = false})
+ end
+end
+
+function SGDBuffer:saturate()
+ local buffer_size = self.buffer_size
+ self.head = 0
+ self.tail = buffer_size
+ for i, reader in ipairs(self.readers) do
+ reader.tail = 0
+ if reader.has_leftover then
+ local lrow
+ for id, buff in pairs(reader.buffs) do
+ lrow = buff.leftover:nrow()
+ if lrow > buffer_size then
+ nerv.error("buffer size is too small to contain leftovers")
+ end
+ buff.data:copy_from(buff.leftover, 0, lrow)
+ end
+ reader.tail = lrow
+ reader.has_leftover = false
+ end
+ while reader.tail < buffer_size do
+ local data = reader.reader:get_data()
+ if data == nil then
+ break
+ end
+ local drow = nil
+ for id, d in pairs(data) do
+ if drow == nil then
+ drow = d:nrow()
+ elseif d:nrow() ~= drow then
+ nerv.error("reader provides with inconsistent rows of data")
+ end
+ end
+ local remain = buffer_size - reader.tail
+ if drow > remain then
+ for id, buff in pairs(reader.buffs) do
+ local d = data[id]
+ if d == nil then
+ nerv.error("reader does not provide data for %s", id)
+ end
+ buff.leftover = self.gconf.mmat_type(drow - remain,
+ buff.width)
+ buff.leftover:copy_from(d, remain, drow)
+ end
+ drow = remain
+ reader.has_leftover = true
+ end
+ for id, buff in pairs(reader.buffs) do
+ buff.data:copy_from(data[id], 0, drow, reader.tail)
+ end
+ reader.tail = reader.tail + drow
+ end
+ self.tail = math.min(self.tail, reader.tail)
+ end
+ return self.tail >= self.gconf.batch_size
+end
+
+function SGDBuffer:get_data()
+ local batch_size = self.gconf.batch_size
+ if self.head >= self.tail then -- buffer is empty
+ if not self:saturate() then
+ return nil -- the remaining data cannot build a batch
+ end
+ end
+ if self.head + batch_size > self.tail then
+ return nil -- the remaining data cannot build a batch
+ end
+ local res = {}
+ for i, reader in ipairs(self.readers) do
+ for id, buff in pairs(reader.buffs) do
+ local batch = self.gconf.cumat_type(batch_size, buff.width)
+ batch:copy_fromh(buff.data, self.head, self.head + batch_size)
+ res[id] = batch
+ end
+ end
+ self.head = self.head + batch_size
+ return res
+end