aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn')
-rw-r--r--nerv/nn/init.lua2
-rw-r--r--nerv/nn/layer_dag.lua352
-rw-r--r--nerv/nn/network.lua500
3 files changed, 501 insertions, 353 deletions
diff --git a/nerv/nn/init.lua b/nerv/nn/init.lua
index cbaf52b..1037d05 100644
--- a/nerv/nn/init.lua
+++ b/nerv/nn/init.lua
@@ -1,3 +1,3 @@
nerv.include('layer_repo.lua')
nerv.include('param_repo.lua')
-nerv.include('layer_dag.lua')
+nerv.include('network.lua')
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua
deleted file mode 100644
index f999752..0000000
--- a/nerv/nn/layer_dag.lua
+++ /dev/null
@@ -1,352 +0,0 @@
-local DAGLayer = nerv.class("nerv.DAGLayer", "nerv.Layer")
-
-local function parse_id(str)
- local id, port, _
- _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
- if id == nil or port == nil then
- _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
- if not (id == "<input>" or id == "<output>") then
- nerv.error("wrong format of connection id")
- end
- end
- port = tonumber(port)
- return id, port
-end
-
-local function discover(id, layers, layer_repo)
- local ref = layers[id]
- if id == "<input>" or id == "<output>" then
- return nil
- end
- if ref == nil then
- local layer = layer_repo:get_layer(id)
- local dim_in, dim_out = layer:get_dim()
- ref = {
- layer = layer,
- inputs = {},
- outputs = {},
- err_inputs = {},
- err_outputs = {},
- next_layers = {},
- input_len = #dim_in,
- output_len = #dim_out,
- in_deg = 0,
- visited = false
- }
- layers[id] = ref
- end
- return ref
-end
-
-local function touch_list_by_idx(list, idx)
- if list[idx] == nil then
- list[idx] = {}
- end
-end
-
-function DAGLayer:__init(id, global_conf, layer_conf)
- local layers = {}
- local inputs = {}
- local outputs = {}
- local dim_in = layer_conf.dim_in
- local dim_out = layer_conf.dim_out
- local parsed_conn = {}
- for from, to in pairs(layer_conf.connections) do
- local id_from, port_from = parse_id(from)
- local id_to, port_to = parse_id(to)
- local ref_from = discover(id_from, layers, layer_conf.sub_layers)
- local ref_to = discover(id_to, layers, layer_conf.sub_layers)
- local input_dim, output_dim, _
- if ref_from then
- touch_list_by_idx(ref_from.outputs, 1)
- if ref_from.outputs[1][port_from] ~= nil then
- nerv.error("%s has already been attached", from)
- end
- end
- if ref_to then
- touch_list_by_idx(ref_to.inputs, 1)
- if ref_to.inputs[1][port_to] ~= nil then
- nerv.error("%s has already been attached", to)
- end
- end
- if id_from == "<input>" then
- input_dim, _ = ref_to.layer:get_dim()
- if dim_in[port_from] ~= input_dim[port_to] then
- nerv.error("mismatching data dimension between %s and %s", from, to)
- end
- inputs[port_from] = {ref_to, port_to}
- ref_to.inputs[1][port_to] = inputs -- just a place holder
- elseif id_to == "<output>" then
- _, output_dim = ref_from.layer:get_dim()
- if output_dim[port_from] ~= dim_out[port_to] then
- nerv.error("mismatching data dimension between %s and %s", from, to)
- end
- outputs[port_to] = {ref_from, port_from}
- ref_from.outputs[1][port_from] = outputs -- just a place holder
- else
- _, output_dim = ref_from.layer:get_dim()
- input_dim, _ = ref_to.layer:get_dim()
- if output_dim[port_from] ~= input_dim[port_to] then
- nerv.error("mismatching data dimension between %s and %s", from, to)
- end
-
- table.insert(parsed_conn,
- {{ref_from, port_from}, {ref_to, port_to}})
- table.insert(ref_from.next_layers, ref_to) -- add edge
- ref_to.in_deg = ref_to.in_deg + 1 -- increase the in-degree of the target layer
- end
- end
-
- -- topology sort
- local queue = {}
- local l = 1
- local r = 1
- for id, ref in pairs(layers) do
- if ref.in_deg == 0 then
- table.insert(queue, ref)
- nerv.info("adding source layer: %s", id)
- r = r + 1
- end
- end
- if l == r then
- nerv.error("loop detected")
- end
- while l < r do
- local cur = queue[l]
- cur.visited = true
- l = l + 1
- for _, nl in pairs(cur.next_layers) do
- nl.in_deg = nl.in_deg - 1
- if nl.in_deg == 0 then
- table.insert(queue, nl)
- r = r + 1
- end
- end
- end
- for i = 1, #queue do
- nerv.info("enqueued layer: %s %s", queue[i].layer, queue[i].layer.id)
- end
-
- for id, ref in pairs(layers) do
- -- check wether the graph is connected
- if ref.visited == false then
- nerv.warning("layer %s is ignored", id)
- end
- end
-
- nerv.Layer.__init(self, id, global_conf, layer_conf)
- self.layers = layers
- self.inputs = inputs
- self.outputs = outputs
- self.parsed_conn = parsed_conn
- self.queue = queue
-end
-
-function DAGLayer:bind_params()
- -- do nothing (instead of rebinding params for each layer)
-end
-
-function DAGLayer:init(batch_size, chunk_size)
- if chunk_size == nil then
- chunk_size = 1
- end
- for i, conn in ipairs(self.parsed_conn) do
- local _, output_dim
- local ref_from, port_from, ref_to, port_to
- ref_from, port_from = unpack(conn[1])
- ref_to, port_to = unpack(conn[2])
- _, output_dim = ref_from.layer:get_dim()
- local dim = 1
- if output_dim[port_from] > 0 then
- dim = output_dim[port_from]
- end
-
- for t = 1, chunk_size do
- local mid = self.mat_type(batch_size, dim)
- local err_mid = mid:create()
- touch_list_by_idx(ref_to.inputs, t)
- touch_list_by_idx(ref_from.outputs, t)
- touch_list_by_idx(ref_from.err_inputs, t)
- touch_list_by_idx(ref_to.err_outputs, t)
-
- ref_from.outputs[t][port_from] = mid
- ref_to.inputs[t][port_to] = mid
-
- ref_from.err_inputs[t][port_from] = err_mid
- ref_to.err_outputs[t][port_to] = err_mid
- end
- end
- for id, ref in pairs(self.layers) do
- for i = 1, ref.input_len do
- if ref.inputs[1][i] == nil then
- nerv.error("dangling input port %d of layer %s", i, id)
- end
- end
- for i = 1, ref.output_len do
- if ref.outputs[1][i] == nil then
- nerv.error("dangling output port %d of layer %s", i, id)
- end
- end
- -- initialize sub layers
- ref.layer:init(batch_size, chunk_size)
- end
- for i = 1, #self.dim_in do
- if self.inputs[i] == nil then
- nerv.error("dangling port %d of layer <input>", i)
- end
- end
- for i = 1, #self.dim_out do
- if self.outputs[i] == nil then
- nerv.error("dangling port %d of layer <output>", i)
- end
- end
-end
-
-function DAGLayer:batch_resize(batch_size, chunk_size)
- if chunk_size == nil then
- chunk_size = 1
- end
-
- for i, conn in ipairs(self.parsed_conn) do
- local _, output_dim
- local ref_from, port_from, ref_to, port_to
- ref_from, port_from = unpack(conn[1])
- ref_to, port_to = unpack(conn[2])
- _, output_dim = ref_from.layer:get_dim()
-
- if ref_from.outputs[1][port_from]:nrow() ~= batch_size
- and output_dim[port_from] > 0 then
- for t = 1, chunk_size do
- local mid = self.mat_type(batch_size, output_dim[port_from])
- local err_mid = mid:create()
-
- ref_from.outputs[t][port_from] = mid
- ref_to.inputs[t][port_to] = mid
-
- ref_from.err_inputs[t][port_from] = err_mid
- ref_to.err_outputs[t][port_to] = err_mid
- end
- end
- end
- for id, ref in pairs(self.layers) do
- ref.layer:batch_resize(batch_size, chunk_size)
- end
- collectgarbage("collect")
-end
-
-function DAGLayer:set_inputs(input, t)
- for i = 1, #self.dim_in do
- if input[i] == nil then
- nerv.error("some input is not provided");
- end
- local layer = self.inputs[i][1]
- local port = self.inputs[i][2]
- touch_list_by_idx(layer.inputs, t)
- layer.inputs[t][port] = input[i]
- end
-end
-
-function DAGLayer:set_outputs(output, t)
- for i = 1, #self.dim_out do
- if output[i] == nil then
- nerv.error("some output is not provided");
- end
- local layer = self.outputs[i][1]
- local port = self.outputs[i][2]
- touch_list_by_idx(layer.outputs, t)
- layer.outputs[t][port] = output[i]
- end
-end
-
-function DAGLayer:set_err_inputs(bp_err, t)
- for i = 1, #self.dim_out do
- local layer = self.outputs[i][1]
- local port = self.outputs[i][2]
- touch_list_by_idx(layer.err_inputs, t)
- layer.err_inputs[t][port] = bp_err[i]
- end
-end
-
-function DAGLayer:set_err_outputs(next_bp_err, t)
- for i = 1, #self.dim_in do
- local layer = self.inputs[i][1]
- local port = self.inputs[i][2]
- touch_list_by_idx(layer.err_outputs, t)
- layer.err_outputs[t][port] = next_bp_err[i]
- end
-end
-
-function DAGLayer:update(bp_err, input, output, t)
- if t == nil then
- t = 1
- end
- self:set_err_inputs(bp_err, t)
- self:set_inputs(input, t)
- self:set_outputs(output, t)
- for id, ref in pairs(self.queue) do
- ref.layer:update(ref.err_inputs[t], ref.inputs[t], ref.outputs[t], t)
- end
-end
-
-function DAGLayer:propagate(input, output, t)
- if t == nil then
- t = 1
- end
- self:set_inputs(input, t)
- self:set_outputs(output, t)
- local ret = false
- for i = 1, #self.queue do
- local ref = self.queue[i]
- ret = ref.layer:propagate(ref.inputs[t], ref.outputs[t], t)
- end
- return ret
-end
-
-function DAGLayer:back_propagate(bp_err, next_bp_err, input, output, t)
- if t == nil then
- t = 1
- end
- self:set_err_outputs(next_bp_err, t)
- self:set_err_inputs(bp_err, t)
- self:set_inputs(input, t)
- self:set_outputs(output, t)
- for i = #self.queue, 1, -1 do
- local ref = self.queue[i]
- ref.layer:back_propagate(ref.err_inputs[t], ref.err_outputs[t], ref.inputs[t], ref.outputs[t], t)
- end
-end
-
-function DAGLayer:get_params()
- local param_repos = {}
- for id, ref in pairs(self.queue) do
- table.insert(param_repos, ref.layer:get_params())
- end
- return nerv.ParamRepo.merge(param_repos, self.loc_type)
-end
-
-DAGLayer.PORT_TYPES = {
- INPUT = {},
- OUTPUT = {},
- ERR_INPUT = {},
- ERR_OUTPUT = {}
-}
-
-function DAGLayer:get_intermediate(id, port_type)
- if id == "<input>" or id == "<output>" then
- nerv.error("an actual real layer id is expected")
- end
- local layer = self.layers[id]
- if layer == nil then
- nerv.error("layer id %s not found", id)
- end
- if port_type == DAGLayer.PORT_TYPES.INPUT then
- return layer.inputs
- elseif port_type == DAGLayer.PORT_TYPES.OUTPUT then
- return layer.outputs
- elseif port_type == DAGLayer.PORT_TYPES.ERR_INPUT then
- return layer.err_inputs
- elseif port_type == DAGLayer.PORT_TYPES.ERR_OUTPUT then
- return layer.err_outputs
- end
- nerv.error("unrecognized port type")
-end
diff --git a/nerv/nn/network.lua b/nerv/nn/network.lua
new file mode 100644
index 0000000..2cb83ce
--- /dev/null
+++ b/nerv/nn/network.lua
@@ -0,0 +1,500 @@
+local network = nerv.class('nerv.Network')
+
+function network:__init(id, global_conf, network_conf)
+ self.id = id
+ self.network = network_conf.network
+ self.dim_in = self.network.dim_in
+ self.dim_out = self.network.dim_out
+ self.gconf = global_conf
+ if self.gconf.use_cpu then
+ self.mat_type = self.gconf.mmat_type
+ else
+ self.mat_type = self.gconf.cumat_type
+ end
+ self.clip = network_conf.clip
+ self.nn_act_default = network_conf.nn_act_default
+ if self.nn_act_default == nil then
+ self.nn_act_default = 0
+ end
+ self.layers = {}
+ self.input_conn = {}
+ self.output_conn = {}
+ self.socket = self:compile(self.network)
+ for i = 1, #self.dim_in do
+ local edge = self.socket.inputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if self.input_conn[id][port] ~= nil then
+ nerv.error('duplicate edge')
+ end
+ self.input_conn[id][port] = {0, i, time}
+ end
+ for i = 1, #self.dim_out do
+ local edge = self.socket.outputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if self.output_conn[id][port] ~= nil then
+ nerv.error('duplicate edge')
+ end
+ self.output_conn[id][port] = {0, i, time}
+ end
+ self.delay = 0
+ for i = 1, #self.layers do
+ local dim_in, _ = self.layers[i]:get_dim()
+ for j = 1, #dim_in do
+ local time = self.input_conn[i][j][3]
+ if math.abs(time) > self.delay then
+ self.delay = math.abs(time)
+ end
+ end
+ end
+end
+
+function network:compile(layer)
+ local socket = {inputs = {}, outputs = {}}
+ if not nerv.is_type(layer, 'nerv.GraphLayer') then
+ table.insert(self.layers, layer)
+ local id = #self.layers
+ self.input_conn[id] = {}
+ self.output_conn[id] = {}
+ local dim_in, dim_out = layer:get_dim()
+ for i = 1, #dim_in do
+ socket.inputs[i] = {id, i, 0}
+ end
+ for i = 1, #dim_out do
+ socket.outputs[i] = {id, i, 0}
+ end
+ else
+ local sublayer_socket = {}
+ for id, sublayer in pairs(layer.layers) do
+ if id ~= '<input>' then
+ sublayer_socket[sublayer.id] = self:compile(sublayer.layer)
+ end
+ end
+ for _, edge in pairs(layer.connections) do
+ -- id = 0 means <input> or <output>
+ local id_from, port_from = edge[1], edge[2]
+ local id_to, port_to = edge[3], edge[4]
+ local time = edge[5]
+ if id_from == 0 then
+ if socket.inputs[port_from] ~= nil then
+ nerv.error('duplicate input socket')
+ end
+ local input = sublayer_socket[id_to].inputs[port_to]
+ local id, port, t = input[1], input[2], input[3] + time
+ socket.inputs[port_from] = {id, port, t}
+ else
+ local output = sublayer_socket[id_from].outputs[port_from]
+ local id, port, t = output[1], output[2], output[3] + time
+ if id_to == 0 then
+ if socket.outputs[port_to] ~= nil then
+ nerv.error('duplicate output socket')
+ end
+ socket.outputs[port_to] = {id, port, t}
+ else
+ local input = sublayer_socket[id_to].inputs[port_to]
+ local id1, port1, t1 = input[1], input[2], input[3]
+ if self.input_conn[id1][port1] ~= nil or self.output_conn[id][port] ~= nil then
+ nerv.error('duplicate edge')
+ end
+ self.input_conn[id1][port1] = {id, port, t + t1}
+ self.output_conn[id][port] = {id1, port1, t + t1}
+ end
+ end
+ end
+ end
+ return socket
+end
+
+function network:init(batch_size, chunk_size)
+ self.batch_size = batch_size
+ self.chunk_size = chunk_size
+
+ self:topsort()
+
+ self:make_initial_store()
+ collectgarbage('collect')
+end
+
+function network:epoch_init()
+ for i = 1, #self.layers do
+ self.layers[i]:init(self.batch_size, self.chunk_size)
+ end
+end
+
+function network:topsort()
+ nerv.info('network topology sort')
+ local degree = {}
+ for t = 1, self.chunk_size do
+ degree[t] = {}
+ for i = 1, #self.layers do
+ degree[t][i] = 0
+ end
+ end
+
+ for t = 1, self.chunk_size do
+ for i = 1, #self.layers do
+ local _, dim_out = self.layers[i]:get_dim()
+ for j = 1, #dim_out do
+ if self.output_conn[i][j] ~= nil then
+ local edge = self.output_conn[i][j]
+ local id, time = edge[1], edge[3] + t
+ if time >= 1 and time <= self.chunk_size and id ~= 0 then
+ degree[time][id] = degree[time][id] + 1
+ end
+ end
+ end
+ end
+ end
+
+ self.queue = {}
+ local l = 1
+ local r = 0
+ for t = 1, self.chunk_size do
+ for i = 1, #self.layers do
+ if degree[t][i] == 0 then
+ r = r + 1
+ self.queue[r] = {chunk = t, id = i}
+ end
+ end
+ end
+ while l <= r do
+ local t, i = self.queue[l].chunk, self.queue[l].id
+ l = l + 1
+ local _, dim_out = self.layers[i]:get_dim()
+ for j = 1, #dim_out do
+ if self.output_conn[i][j] ~= nil then
+ local edge = self.output_conn[i][j]
+ local id, time = edge[1], edge[3] + t
+ if time >= 1 and time <= self.chunk_size and id ~= 0 then
+ degree[time][id] = degree[time][id] - 1
+ if degree[time][id] == 0 then
+ r = r + 1
+ self.queue[r] = {chunk = time, id = id}
+ end
+ end
+ end
+ end
+ end
+
+ if r ~= self.chunk_size * #self.layers then
+ nerv.error('loop detected')
+ end
+end
+
+function network:make_initial_store()
+ nerv.info('network initing storage')
+
+ -- allocate memory
+ local memory = {}
+ local err_memory = {}
+ for t = 1 - self.delay, self.chunk_size + self.delay do
+ memory[t] = {}
+ err_memory[t] = {}
+ for i = 1, #self.layers do
+ memory[t][i] = {}
+ err_memory[t][i] = {}
+ local dim_in, dim_out = self.layers[i]:get_dim()
+ for j = 1, #dim_in do
+ err_memory[t][i][j] = self.mat_type(self.batch_size, dim_in[j])
+ err_memory[t][i][j]:fill(0)
+ end
+ for j = 1, #dim_out do
+ memory[t][i][j] = self.mat_type(self.batch_size, dim_out[j])
+ memory[t][i][j]:fill(self.nn_act_default)
+ end
+ end
+ -- memory[t][0] stores network input
+ memory[t][0] = {}
+ for j = 1, #self.dim_in do
+ memory[t][0][j] = self.mat_type(self.batch_size, self.dim_in[j])
+ memory[t][0][j]:fill(self.nn_act_default)
+ end
+ -- err_memory[t][0] stores network err_input
+ err_memory[t][0] = {}
+ for j = 1, #self.dim_out do
+ err_memory[t][0][j] = self.mat_type(self.batch_size, self.dim_out[j])
+ err_memory[t][0][j]:fill(0)
+ end
+ end
+
+ -- connect memory and reference
+ self.input = {}
+ self.output = {}
+ self.err_input = {}
+ self.err_output = {}
+ for t = 1, self.chunk_size do
+ self.input[t] = {}
+ self.output[t] = {}
+ self.err_input[t] = {}
+ self.err_output[t] = {}
+ for i = 1, #self.layers do
+ self.input[t][i] = {}
+ self.output[t][i] = {}
+ self.err_input[t][i] = {}
+ self.err_output[t][i] = {}
+ local dim_in, dim_out = self.layers[i]:get_dim()
+ for j = 1, #dim_in do
+ local edge = self.input_conn[i][j]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if id ~= 0 or t - time < 1 or t - time > self.chunk_size then
+ self.input[t][i][j] = memory[t - time][id][port]
+ end
+ if id ~= 0 then
+ self.err_output[t][i][j] = err_memory[t][i][j]
+ end
+ end
+ for j = 1, #dim_out do
+ local edge = self.output_conn[i][j]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if id ~= 0 then
+ self.output[t][i][j] = memory[t][i][j]
+ end
+ if id ~= 0 or t + time < 1 or t + time > self.chunk_size then
+ self.err_input[t][i][j] = err_memory[t + time][id][port]
+ end
+ end
+ end
+ end
+
+ -- check dangling reference
+ for t = 1, self.chunk_size do
+ for i = 1, #self.dim_in do
+ local edge = self.socket.inputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t + time >= 1 and t + time <= self.chunk_size then
+ if self.input[t + time][id][port] ~= nil then
+ nerv.error('input reference not nil')
+ end
+ self.input[t + time][id][port] = true -- just a place holder
+ if self.err_output[t + time][id][port] ~= nil then
+ nerv.error('err_output reference not nil')
+ end
+ self.err_output[t + time][id][port] = true -- just a place holder
+ end
+ end
+ for i = 1, #self.dim_out do
+ local edge = self.socket.outputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t - time >= 1 and t - time <= self.chunk_size then
+ if self.output[t - time][id][port] ~= nil then
+ nerv.error('output reference not nil')
+ end
+ self.output[t - time][id][port] = true -- just a place holder
+ if self.err_input[t - time][id][port] ~= nil then
+ nerv.error('err_output reference not nil')
+ end
+ self.err_input[t - time][id][port] = true -- just a place holder
+ end
+ end
+ end
+ for t = 1, self.chunk_size do
+ for i = 1, #self.layers do
+ local dim_in, dim_out = self.layers[i]:get_dim()
+ for j = 1, #dim_in do
+ if self.input[t][i][j] == nil then
+ nerv.error('input reference dangling')
+ end
+ if self.err_output[t][i][j] == nil then
+ nerv.error('err_output reference dangling')
+ end
+ end
+ for j = 1, #dim_out do
+ if self.output[t][i][j] == nil then
+ nerv.error('output reference dangling')
+ end
+ if self.err_input[t][i][j] == nil then
+ nerv.error('err_input reference dangling')
+ end
+ end
+ end
+ end
+
+ -- allocate reference for legacy of previous mini-batch
+ self.legacy = {}
+ for t = 1 - self.delay, 0 do
+ self.legacy[t] = {}
+ for i = 1, #self.layers do
+ self.legacy[t][i] = {}
+ local _, dim_out = self.layers[i]:get_dim()
+ for j = 1, #dim_out do
+ self.legacy[t][i][j] = memory[t][i][j]
+ end
+ end
+ end
+end
+
+function network:set_input(input)
+ for t = 1, self.chunk_size do
+ for i = 1, #self.dim_in do
+ local edge = self.socket.inputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t + time >= 1 and t + time <= self.chunk_size then
+ self.input[t + time][id][port] = input[t][i]
+ end
+ end
+ end
+end
+
+function network:set_output(output)
+ for t = 1, self.chunk_size do
+ for i = 1, #self.dim_out do
+ local edge = self.socket.outputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t - time >= 1 and t - time <= self.chunk_size then
+ self.output[t - time][id][port] = output[t][i]
+ end
+ end
+ end
+end
+
+function network:set_err_input(err_input)
+ for t = 1, self.chunk_size do
+ for i = 1, #self.dim_out do
+ local edge = self.socket.outputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t - time >= 1 and t - time <= self.chunk_size then
+ self.err_input[t - time][id][port] = err_input[t][i]
+ end
+ end
+ end
+end
+
+function network:set_err_output(err_output)
+ for t = 1, self.chunk_size do
+ for i = 1, #self.dim_in do
+ local edge = self.socket.inputs[i]
+ local id, port, time = edge[1], edge[2], edge[3]
+ if t + time >= 1 and t + time <= self.chunk_size then
+ self.err_output[t + time][id][port] = err_output[t][i]
+ end
+ end
+ end
+end
+
+--[[
+ [info] is a table that contains information of current mini-batch. These fields must be contained:
+ [input], [output] : matrix array which stores the network input and output
+ [seq_length] : a table contains the length of every sequences
+ [new_seq]: a table contains the batch number of new sequences
+ [do_train]: a bool value indicates do train or not
+ if [do_train] is true, these fileds also must be contained:
+ [err_input], [err_output] : matrix array which stores the network err_input and err_output
+--]]
+function network:mini_batch_init(info)
+ self.info = info
+ self:set_input(self.info.input)
+ self:set_output(self.info.output)
+
+ -- calculate border
+ self.max_length = 0
+ self.border = {}
+ for i = 1, self.chunk_size do
+ self.border[i] = {}
+ end
+ for i = 1, self.batch_size do
+ if self.info.seq_length[i] > self.max_length then
+ self.max_length = self.info.seq_length[i]
+ end
+ for t = 1, self.delay do
+ local chunk = self.info.seq_length[i] + t
+ if chunk > self.chunk_size then
+ break
+ end
+ table.insert(self.border[chunk], i)
+ end
+ end
+
+ -- copy legacy
+ for t = 1 - self.delay, 0 do
+ for i = 1, #self.layers do
+ local _, dim_out = self.layers[i]:get_dim()
+ for j = 1, #dim_out do
+ if t + self.chunk_size >= 1 and self.output_conn[i][j][1] ~= 0 then
+ self.legacy[t][i][j]:copy_from(self.output[t + self.chunk_size][i][j])
+ end
+ for k = 1, #self.info.new_seq do
+ local batch = self.info.new_seq[k]
+ self.legacy[t][i][j][batch - 1]:fill(self.nn_act_default)
+ end
+ end
+ end
+ end
+
+ if self.info.do_train then
+ self:set_err_input(self.info.err_input)
+ self:set_err_output(self.info.err_output)
+
+ -- flush border gradient
+ for t = self.max_length + 1, self.max_length + self.delay do
+ if t > self.chunk_size then
+ break
+ end
+ for i = 1, #self.layers do
+ local dim_in, _ = self.layers[i]:get_dim()
+ for j = 1, #dim_in do
+ self.err_output[t][i][j]:fill(0)
+ end
+ end
+ end
+ end
+end
+
+function network:propagate()
+ for i = 1, #self.queue do
+ local t, id = self.queue[i].chunk, self.queue[i].id
+ if t <= self.max_length then
+ self.layers[id]:propagate(self.input[t][id], self.output[t][id], t)
+ end
+ -- flush border activation
+ for j = 1, #self.border[t] do
+ local batch = self.border[t][j]
+ local _, dim_out = self.layers[id]:get_dim()
+ for k = 1, #dim_out do
+ self.output[t][id][k][batch - 1]:fill(self.nn_act_default)
+ end
+ end
+ end
+end
+
+function network:back_propagate()
+ for i = #self.queue, 1, -1 do
+ local t, id = self.queue[i].chunk, self.queue[i].id
+ if t <= self.max_length then
+ -- flush border gradient
+ for j = 1, #self.border[t] do
+ local batch = self.border[t][j]
+ local _, dim_out = self.layers[id]:get_dim()
+ for k = 1, #dim_out do
+ self.err_input[t][id][k][batch - 1]:fill(0)
+ end
+ end
+ self.layers[id]:back_propagate(self.err_input[t][id], self.err_output[t][id], self.input[t][id], self.output[t][id], t)
+ if self.clip ~= nil then
+ local dim_in, _ = self.layers[id]:get_dim()
+ for j = 1, #dim_in do
+ self.err_output[t][id][j]:clip(-self.clip, self.clip)
+ end
+ end
+ end
+ end
+end
+
+function network:update()
+ for i = 1, #self.queue do
+ local t, id = self.queue[i].chunk, self.queue[i].id
+ if t <= self.max_length then
+ self.layers[id]:update(self.err_input[t][id], self.input[t][id], self.output[t][id], t)
+ end
+ end
+end
+
+function network:set_attr(name, value)
+ self.network:set_attr(name, value)
+end
+
+function network:get_sublayer(id)
+ return self.network:get_sublayer(id)
+end
+
+function network:get_params()
+ return self.network:get_params()
+end