diff options
author | Determinant <ted.sybil@gmail.com> | 2016-03-02 18:24:09 +0800 |
---|---|---|
committer | Determinant <ted.sybil@gmail.com> | 2016-03-02 18:24:09 +0800 |
commit | ad704f2623cc9e0a5d702434bfdebc345465ca12 (patch) | |
tree | 898d0688e913efc3ff098ba51e5c1a5488f5771d /nerv/io | |
parent | d3abc6459a776ff7fa3777f4f561bc4f5d5e2075 (diff) |
major changes in asr_trainer.lua; unified settings in `gconf`
Diffstat (limited to 'nerv/io')
-rw-r--r-- | nerv/io/sgd_buffer.lua | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/nerv/io/sgd_buffer.lua b/nerv/io/sgd_buffer.lua index 3cf4f5a..d78f6d1 100644 --- a/nerv/io/sgd_buffer.lua +++ b/nerv/io/sgd_buffer.lua @@ -2,8 +2,9 @@ local SGDBuffer = nerv.class("nerv.SGDBuffer", "nerv.DataBuffer") function SGDBuffer:__init(global_conf, buffer_conf) self.gconf = global_conf + self.batch_size = buffer_conf.batch_size self.buffer_size = math.floor(buffer_conf.buffer_size / - global_conf.batch_size) * global_conf.batch_size + self.batch_size) * self.batch_size self.randomize = buffer_conf.randomize self.consume = buffer_conf.consume local cumat_type = global_conf.cumat_type @@ -112,11 +113,11 @@ function SGDBuffer:saturate() end self.rand_map = self.perm_gen(self.tail) -- generate shuffled index collectgarbage("collect") - return self.tail >= self.gconf.batch_size + return self.tail >= self.batch_size end function SGDBuffer:get_data() - local batch_size = self.gconf.batch_size + local batch_size = self.batch_size if self.head >= self.tail then -- buffer is empty local t = os.clock() if (not self:saturate()) and (not self.consume) then |