aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/network.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/network.lua')
-rw-r--r--nerv/nn/network.lua14
1 files changed, 8 insertions, 6 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index 35e11e3..2cb83ce 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -109,12 +109,14 @@ function network:init(batch_size, chunk_size)
self.chunk_size = chunk_size
self:topsort()
-
+
self:make_initial_store()
collectgarbage('collect')
+end
+function network:epoch_init()
for i = 1, #self.layers do
- self.layers[i]:init(batch_size, chunk_size)
+ self.layers[i]:init(self.batch_size, self.chunk_size)
end
end
@@ -123,7 +125,7 @@ function network:topsort()
local degree = {}
for t = 1, self.chunk_size do
degree[t] = {}
- for i = 1, #self.layers do
+ for i = 1, #self.layers do
degree[t][i] = 0
end
end
@@ -154,7 +156,7 @@ function network:topsort()
end
end
end
- while l<=r do
+ while l <= r do
local t, i = self.queue[l].chunk, self.queue[l].id
l = l + 1
local _, dim_out = self.layers[i]:get_dim()
@@ -214,7 +216,7 @@ function network:make_initial_store()
end
end
- -- connect memory and reference
+ -- connect memory and reference
self.input = {}
self.output = {}
self.err_input = {}
@@ -420,7 +422,7 @@ function network:mini_batch_init(info)
if self.info.do_train then
self:set_err_input(self.info.err_input)
self:set_err_output(self.info.err_output)
-
+
-- flush border gradient
for t = self.max_length + 1, self.max_length + self.delay do
if t > self.chunk_size then