aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/network.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/network.lua')
-rw-r--r--nerv/nn/network.lua13
1 files changed, 9 insertions, 4 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index 5a6abb6..bf69ccc 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -1,5 +1,5 @@
--- Implements the concept of computable but opaque networks built ("compiled")
--- from nested layers
+-- from nested layers.
-- @author Qi Liu <liuq901@163.com>
-- @author Ted Yin <ted.sybil@gmail.com>
@@ -190,7 +190,7 @@ end
--- Initialize the network for training.
-- To be called before all the epochs, will resolve the structure of the
--- network and allocate the memory for storing temporary values
+-- network and allocate the memory for storing temporary values.
-- @param batch_size The size of a batch matrix
-- @param chunk_size The size of a BPTT chunk
function network:init(batch_size, chunk_size)
@@ -211,7 +211,8 @@ function network:init(batch_size, chunk_size)
end
--- Initialize the internal state of the network for the new epoch.
--- To be called before each new epoch
+-- To be called before each new epoch.
+
function network:epoch_init()
self.timestamp = 0
for i = 1, #self.layers do
@@ -503,7 +504,7 @@ function network:set_err_output(err_output)
end
--- Initialize the internal state of the network for the new mini-batch (a BPTT chunk).
--- To be called before each propagation/back-propagation
+-- To be called before each propagation/back-propagation.
-- @param info a table containing information needed for the current mini-batch computation. The following fields must be supplied:
--
-- * `input`: an array containing `chunk_size` number of row-major batch
@@ -526,6 +527,7 @@ end
-- * `err_output`: an array with the same structure as `input`. Although we
-- are mostly not interested in its values, just allocate this to unify
-- the computation and ease the implementation
+
function network:mini_batch_init(info)
self.info = info
self:set_input(self.info.input)
@@ -634,6 +636,7 @@ function network:mini_batch_init(info)
end
--- Perform a propagation.
+
function network:propagate()
for i = 1, #self.queue do
local t, id = self.queue[i].chunk, self.queue[i].id
@@ -652,6 +655,7 @@ function network:propagate()
end
--- Perform a backward propagation to calculate gradients used for update.
+
function network:back_propagate()
for i = #self.queue, 1, -1 do
local t, id = self.queue[i].chunk, self.queue[i].id
@@ -677,6 +681,7 @@ function network:back_propagate()
end
--- Update the parameters bound to each layer.
+
function network:update()
for i = 1, #self.layers do
self.layers[i]:update()