summaryrefslogtreecommitdiff
path: root/io
diff options
context:
space:
mode:
Diffstat (limited to 'io')
-rw-r--r--io/chunk_file.c4
-rw-r--r--io/sgd_buffer.lua14
2 files changed, 15 insertions, 3 deletions
diff --git a/io/chunk_file.c b/io/chunk_file.c
index ce346c5..4e987b7 100644
--- a/io/chunk_file.c
+++ b/io/chunk_file.c
@@ -268,7 +268,7 @@ int nerv_chunk_file_handle_destroy(lua_State *L) {
return 0;
}
-static int nerv_chunk_destroy(lua_State *L) {
+static int nerv_chunk_info_destroy(lua_State *L) {
ChunkInfo *pci = luaT_checkudata(L, 1, nerv_chunk_info_tname);
free(pci);
return 0;
@@ -298,7 +298,7 @@ void nerv_chunk_file_init(lua_State *L) {
luaT_newmetatable(L, nerv_chunk_file_handle_tname, NULL,
NULL, nerv_chunk_file_handle_destroy, NULL);
luaT_newmetatable(L, nerv_chunk_info_tname, NULL,
- NULL, nerv_chunk_destroy, NULL);
+ NULL, nerv_chunk_info_destroy, NULL);
luaT_newmetatable(L, nerv_chunk_data_tname, NULL,
NULL, nerv_chunk_data_destroy, NULL);
}
diff --git a/io/sgd_buffer.lua b/io/sgd_buffer.lua
index dadcf67..bf72744 100644
--- a/io/sgd_buffer.lua
+++ b/io/sgd_buffer.lua
@@ -4,6 +4,10 @@ function SGDBuffer:__init(global_conf, buffer_conf)
self.gconf = global_conf
self.buffer_size = math.floor(buffer_conf.buffer_size /
global_conf.batch_size) * global_conf.batch_size
+ self.randomize = buffer_conf.randomize
+ if self.randomize == nil then
+ self.randomize = false
+ end
self.head = 0
self.tail = 0
self.readers = {}
@@ -35,7 +39,9 @@ function SGDBuffer:saturate()
nerv.error("buffer size is too small to contain leftovers")
end
buff.data:copy_from(buff.leftover, 0, lrow)
+ buff.leftover = nil
end
+ nerv.utils.printf("leftover: %d\n", lrow)
reader.tail = lrow
reader.has_leftover = false
end
@@ -73,6 +79,8 @@ function SGDBuffer:saturate()
end
self.tail = math.min(self.tail, reader.tail)
end
+ self.rand_map = nerv.MMatrixInt.perm_gen(self.tail) -- generate shuffled index
+ collectgarbage("collect")
return self.tail >= self.gconf.batch_size
end
@@ -90,7 +98,11 @@ function SGDBuffer:get_data()
for i, reader in ipairs(self.readers) do
for id, buff in pairs(reader.buffs) do
local batch = self.gconf.cumat_type(batch_size, buff.width)
- batch:copy_fromh(buff.data, self.head, self.head + batch_size)
+ if self.randomize then
+ batch:copy_rows_fromh_by_idx(buff.data, self.rand_map, self.head)
+ else
+ batch:copy_fromh(buff.data, self.head, self.head + batch_size)
+ end
res[id] = batch
end
end