diff options
Diffstat (limited to 'nerv/nn/network.lua')
-rw-r--r-- | nerv/nn/network.lua | 80 |
1 files changed, 71 insertions, 9 deletions
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua index cf6a4d3..19fa9d3 100644 --- a/nerv/nn/network.lua +++ b/nerv/nn/network.lua @@ -1,5 +1,43 @@ +--- Implements the concept of computable but opaque networks built ("compiled") +-- from nested layers +-- @author Qi Liu <liuq901@163.com> +-- @author Ted Yin <ted.sybil@gmail.com> + +--- The class describing a computable but opaque network built from nested +-- layers. +-- @type nerv.Network + local network = nerv.class('nerv.Network') +--- The constructor. +-- @param id the identifier of the network (currently having no effects) +-- @param global_conf a table describing the computation state and providing +-- with some global settings +-- +-- The following fields in `global_conf` will be used: +-- +-- * `use_cpu`: whether to use CPU for the computation +-- * `mmat_type`: the class used for creating matrices in CPU computation +-- * `cumat_type` (if `use_cpu = false`): the class used for creating matrices +-- in GPU computation +-- +-- The following fields in `global_conf` will be altered: +-- +-- * `mask`: an array of `chunk_size` length containing column binary vectors +-- indicating whether each frame in a *batch matrix* (i.e. one matrix in a BPTT +-- chunk/"mini-batch") contains a valid data (1 indicates data, 0 indicates +-- holes) +-- +-- @param network_conf a table providing with settings dedicated for the +-- network. Available fields includes: +-- +-- * `network`: a `nerv.Layer` instance describing the structure of the network +-- to be compiled +-- * `clip`: a `number` value indicating the cliping threshold (i.e. preserve +-- the values within [-clip, +clip]) +-- * `nn_act_default`: a `number` value indicating the value used for filling +-- "holes" in activation values of a batch matrix (0 by default) + function network:__init(id, global_conf, network_conf) self.id = id self.network = network_conf.network @@ -150,6 +188,11 @@ function network:compile(layer) return socket end +--- Initialize the network for training. +-- To be called before all the epochs, will resolve the structure of the +-- network and allocate the memory for storing temporary values +-- @param batch_size The size of a batch matrix +-- @param chunk_size The size of a BPTT chunk function network:init(batch_size, chunk_size) self.batch_size = batch_size self.chunk_size = chunk_size @@ -167,6 +210,8 @@ function network:init(batch_size, chunk_size) end end +--- Initialize the internal state of the network for the new epoch. +-- To be called before each new epoch function network:epoch_init() self.timestamp = 0 for i = 1, #self.layers do @@ -457,15 +502,29 @@ function network:set_err_output(err_output) end end ---[[ - [info] is a table that contains information of current mini-batch. These fields must be contained: - [input], [output] : matrix array which stores the network input and output - [seq_length] : a table contains the length of every sequences - [new_seq]: a table contains the batch number of new sequences - [do_train]: a bool value indicates do train or not - if [do_train] is true, these fileds also must be contained: - [err_input], [err_output] : matrix array which stores the network err_input and err_output ---]] +--- Initialize the internal state of the network for the new mini-batch (a BPTT chunk). +-- To be called before each propagation/back-propagation +-- @param info a table containing information needed for the current mini-batch computation. The following fields must be supplied: +-- +-- * `input`: an array containing `chunk_size` number of row-major batch +-- matrices with `batch_size` rows +-- * `output`: similar to `input`, but the matrices have different number of +-- columns (depending on the width of the output, which is typically 1 for +-- criteria, i.e. single column indicating the error), used to hold the output of the network +-- * `seq_length` : a table containing the length (number of frames) of each sequence (utterance) +-- * `new_seq`: a table containing the indices of batch matrix rows that are the +-- first frames of a sequence +-- * `do_train`: a bool value indicating whether to update the network +-- +-- If `do_train` is true, two additional fields are required: +-- +-- * `err_input`: an array with the same structure as `output` but containg the initial +-- values for computing errors in back-propagation (when the width of the +-- output is 1, `gconf.mask` is typically used here to ignore the invalid +-- values produced by "holes" in the mini-batch). +-- * `err_output`: an array with the same structure as `input`. Although we +-- are mostly not interested in its values, just allocate this to unify +-- the computation and ease the implementation function network:mini_batch_init(info) self.info = info self:set_input(self.info.input) @@ -573,6 +632,7 @@ function network:mini_batch_init(info) end end +--- Perform a propagation. function network:propagate() for i = 1, #self.queue do local t, id = self.queue[i].chunk, self.queue[i].id @@ -590,6 +650,7 @@ function network:propagate() end end +--- Perform a backward propagation to calculate gradients used for update. function network:back_propagate() for i = #self.queue, 1, -1 do local t, id = self.queue[i].chunk, self.queue[i].id @@ -614,6 +675,7 @@ function network:back_propagate() end end +--- Update the parameters bound to each layer. function network:update() for i = 1, #self.layers do self.layers[i]:update() |