From df737041e4a9f3f55978cc74db9a9cea27fa9fa0 Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 5 Jun 2015 10:58:57 +0800 Subject: add profiling; add ce accurarcy; several other changes --- io/chunk_file.c | 4 ++-- io/sgd_buffer.lua | 14 +++++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) (limited to 'io') diff --git a/io/chunk_file.c b/io/chunk_file.c index ce346c5..4e987b7 100644 --- a/io/chunk_file.c +++ b/io/chunk_file.c @@ -268,7 +268,7 @@ int nerv_chunk_file_handle_destroy(lua_State *L) { return 0; } -static int nerv_chunk_destroy(lua_State *L) { +static int nerv_chunk_info_destroy(lua_State *L) { ChunkInfo *pci = luaT_checkudata(L, 1, nerv_chunk_info_tname); free(pci); return 0; @@ -298,7 +298,7 @@ void nerv_chunk_file_init(lua_State *L) { luaT_newmetatable(L, nerv_chunk_file_handle_tname, NULL, NULL, nerv_chunk_file_handle_destroy, NULL); luaT_newmetatable(L, nerv_chunk_info_tname, NULL, - NULL, nerv_chunk_destroy, NULL); + NULL, nerv_chunk_info_destroy, NULL); luaT_newmetatable(L, nerv_chunk_data_tname, NULL, NULL, nerv_chunk_data_destroy, NULL); } diff --git a/io/sgd_buffer.lua b/io/sgd_buffer.lua index dadcf67..bf72744 100644 --- a/io/sgd_buffer.lua +++ b/io/sgd_buffer.lua @@ -4,6 +4,10 @@ function SGDBuffer:__init(global_conf, buffer_conf) self.gconf = global_conf self.buffer_size = math.floor(buffer_conf.buffer_size / global_conf.batch_size) * global_conf.batch_size + self.randomize = buffer_conf.randomize + if self.randomize == nil then + self.randomize = false + end self.head = 0 self.tail = 0 self.readers = {} @@ -35,7 +39,9 @@ function SGDBuffer:saturate() nerv.error("buffer size is too small to contain leftovers") end buff.data:copy_from(buff.leftover, 0, lrow) + buff.leftover = nil end + nerv.utils.printf("leftover: %d\n", lrow) reader.tail = lrow reader.has_leftover = false end @@ -73,6 +79,8 @@ function SGDBuffer:saturate() end self.tail = math.min(self.tail, reader.tail) end + self.rand_map = nerv.MMatrixInt.perm_gen(self.tail) -- generate shuffled index + collectgarbage("collect") return self.tail >= self.gconf.batch_size end @@ -90,7 +98,11 @@ function SGDBuffer:get_data() for i, reader in ipairs(self.readers) do for id, buff in pairs(reader.buffs) do local batch = self.gconf.cumat_type(batch_size, buff.width) - batch:copy_fromh(buff.data, self.head, self.head + batch_size) + if self.randomize then + batch:copy_rows_fromh_by_idx(buff.data, self.rand_map, self.head) + else + batch:copy_fromh(buff.data, self.head, self.head + batch_size) + end res[id] = batch end end -- cgit v1.2.3