From a54332ce81129e81fbb1d041ec41aa5955868c5e Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 11 Mar 2016 17:33:35 +0800 Subject: adapt asr_trainer.lua to new architecture --- nerv/nn/network.lua | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'nerv/nn/network.lua') diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index 35e11e3..2cb83ce 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -109,12 +109,14 @@ function network:init(batch_size, chunk_size) self.chunk_size = chunk_size self:topsort() - + self:make_initial_store() collectgarbage('collect') +end +function network:epoch_init() for i = 1, #self.layers do - self.layers[i]:init(batch_size, chunk_size) + self.layers[i]:init(self.batch_size, self.chunk_size) end end @@ -123,7 +125,7 @@ function network:topsort() local degree = {} for t = 1, self.chunk_size do degree[t] = {} - for i = 1, #self.layers do + for i = 1, #self.layers do degree[t][i] = 0 end end @@ -154,7 +156,7 @@ function network:topsort() end end end - while l<=r do + while l <= r do local t, i = self.queue[l].chunk, self.queue[l].id l = l + 1 local _, dim_out = self.layers[i]:get_dim() @@ -214,7 +216,7 @@ function network:make_initial_store() end end - -- connect memory and reference + -- connect memory and reference self.input = {} self.output = {} self.err_input = {} @@ -420,7 +422,7 @@ function network:mini_batch_init(info) if self.info.do_train then self:set_err_input(self.info.err_input) self:set_err_output(self.info.err_output) - + -- flush border gradient for t = self.max_length + 1, self.max_length + self.delay do if t > self.chunk_size then -- cgit v1.2.3-70-g09d2