aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/network.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/network.lua')
-rw-r--r--nerv/nn/network.lua80
1 files changed, 71 insertions, 9 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
index cf6a4d3..19fa9d3 100644
--- a/nerv/nn/network.lua
+++ b/nerv/nn/network.lua
@@ -1,5 +1,43 @@
+--- Implements the concept of computable but opaque networks built ("compiled")
+-- from nested layers
+-- @author Qi Liu <liuq901@163.com>
+-- @author Ted Yin <ted.sybil@gmail.com>
+
+--- The class describing a computable but opaque network built from nested
+-- layers.
+-- @type nerv.Network
+
local network = nerv.class('nerv.Network')
+--- The constructor.
+-- @param id the identifier of the network (currently having no effects)
+-- @param global_conf a table describing the computation state and providing
+-- with some global settings
+--
+-- The following fields in `global_conf` will be used:
+--
+-- * `use_cpu`: whether to use CPU for the computation
+-- * `mmat_type`: the class used for creating matrices in CPU computation
+-- * `cumat_type` (if `use_cpu = false`): the class used for creating matrices
+-- in GPU computation
+--
+-- The following fields in `global_conf` will be altered:
+--
+-- * `mask`: an array of `chunk_size` length containing column binary vectors
+-- indicating whether each frame in a *batch matrix* (i.e. one matrix in a BPTT
+-- chunk/"mini-batch") contains a valid data (1 indicates data, 0 indicates
+-- holes)
+--
+-- @param network_conf a table providing with settings dedicated for the
+-- network. Available fields includes:
+--
+-- * `network`: a `nerv.Layer` instance describing the structure of the network
+-- to be compiled
+-- * `clip`: a `number` value indicating the cliping threshold (i.e. preserve
+-- the values within [-clip, +clip])
+-- * `nn_act_default`: a `number` value indicating the value used for filling
+-- "holes" in activation values of a batch matrix (0 by default)
+
function network:__init(id, global_conf, network_conf)
self.id = id
self.network = network_conf.network
@@ -150,6 +188,11 @@ function network:compile(layer)
return socket
end
+--- Initialize the network for training.
+-- To be called before all the epochs, will resolve the structure of the
+-- network and allocate the memory for storing temporary values
+-- @param batch_size The size of a batch matrix
+-- @param chunk_size The size of a BPTT chunk
function network:init(batch_size, chunk_size)
self.batch_size = batch_size
self.chunk_size = chunk_size
@@ -167,6 +210,8 @@ function network:init(batch_size, chunk_size)
end
end
+--- Initialize the internal state of the network for the new epoch.
+-- To be called before each new epoch
function network:epoch_init()
self.timestamp = 0
for i = 1, #self.layers do
@@ -457,15 +502,29 @@ function network:set_err_output(err_output)
end
end
---[[
- [info] is a table that contains information of current mini-batch. These fields must be contained:
- [input], [output] : matrix array which stores the network input and output
- [seq_length] : a table contains the length of every sequences
- [new_seq]: a table contains the batch number of new sequences
- [do_train]: a bool value indicates do train or not
- if [do_train] is true, these fileds also must be contained:
- [err_input], [err_output] : matrix array which stores the network err_input and err_output
---]]
+--- Initialize the internal state of the network for the new mini-batch (a BPTT chunk).
+-- To be called before each propagation/back-propagation
+-- @param info a table containing information needed for the current mini-batch computation. The following fields must be supplied:
+--
+-- * `input`: an array containing `chunk_size` number of row-major batch
+-- matrices with `batch_size` rows
+-- * `output`: similar to `input`, but the matrices have different number of
+-- columns (depending on the width of the output, which is typically 1 for
+-- criteria, i.e. single column indicating the error), used to hold the output of the network
+-- * `seq_length` : a table containing the length (number of frames) of each sequence (utterance)
+-- * `new_seq`: a table containing the indices of batch matrix rows that are the
+-- first frames of a sequence
+-- * `do_train`: a bool value indicating whether to update the network
+--
+-- If `do_train` is true, two additional fields are required:
+--
+-- * `err_input`: an array with the same structure as `output` but containg the initial
+-- values for computing errors in back-propagation (when the width of the
+-- output is 1, `gconf.mask` is typically used here to ignore the invalid
+-- values produced by "holes" in the mini-batch).
+-- * `err_output`: an array with the same structure as `input`. Although we
+-- are mostly not interested in its values, just allocate this to unify
+-- the computation and ease the implementation
function network:mini_batch_init(info)
self.info = info
self:set_input(self.info.input)
@@ -573,6 +632,7 @@ function network:mini_batch_init(info)
end
end
+--- Perform a propagation.
function network:propagate()
for i = 1, #self.queue do
local t, id = self.queue[i].chunk, self.queue[i].id
@@ -590,6 +650,7 @@ function network:propagate()
end
end
+--- Perform a backward propagation to calculate gradients used for update.
function network:back_propagate()
for i = #self.queue, 1, -1 do
local t, id = self.queue[i].chunk, self.queue[i].id
@@ -614,6 +675,7 @@ function network:back_propagate()
end
end
+--- Update the parameters bound to each layer.
function network:update()
for i = 1, #self.layers do
self.layers[i]:update()