aboutsummaryrefslogtreecommitdiff
path: root/nerv/nn/layer_dag.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/nn/layer_dag.lua')
-rw-r--r--nerv/nn/layer_dag.lua146
1 files changed, 91 insertions, 55 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua
index 6ad7ae9..6896878 100644
--- a/nerv/nn/layer_dag.lua
+++ b/nerv/nn/layer_dag.lua
@@ -2,7 +2,7 @@ local DAGLayer = nerv.class("nerv.DAGLayer", "nerv.Layer")
local function parse_id(str)
local id, port, _
- _, _, id, port = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%]")
+ _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]")
if id == nil or port == nil then
_, _, id, port = string.find(str, "(.+)%[([0-9]+)%]")
if not (id == "<input>" or id == "<output>") then
@@ -38,6 +38,12 @@ local function discover(id, layers, layer_repo)
return ref
end
+local function touch_list_by_idx(list, idx)
+ if list[idx] == nil then
+ list[idx] = {}
+ end
+end
+
function DAGLayer:__init(id, global_conf, layer_conf)
local layers = {}
local inputs = {}
@@ -51,11 +57,17 @@ function DAGLayer:__init(id, global_conf, layer_conf)
local ref_from = discover(id_from, layers, layer_conf.sub_layers)
local ref_to = discover(id_to, layers, layer_conf.sub_layers)
local input_dim, output_dim, _
- if ref_from and ref_from.outputs[port_from] ~= nil then
- nerv.error("%s has already been attached", from)
+ if ref_from then
+ touch_list_by_idx(ref_from.outputs, 1)
+ if ref_from.outputs[1][port_from] ~= nil then
+ nerv.error("%s has already been attached", from)
+ end
end
- if ref_to and ref_to.inputs[port_to] ~= nil then
- nerv.error("%s has already been attached", to)
+ if ref_to then
+ touch_list_by_idx(ref_to.inputs, 1)
+ if ref_to.inputs[1][port_to] ~= nil then
+ nerv.error("%s has already been attached", to)
+ end
end
if id_from == "<input>" then
input_dim, _ = ref_to.layer:get_dim()
@@ -63,14 +75,14 @@ function DAGLayer:__init(id, global_conf, layer_conf)
nerv.error("mismatching data dimension between %s and %s", from, to)
end
inputs[port_from] = {ref_to, port_to}
- ref_to.inputs[port_to] = inputs -- just a place holder
+ ref_to.inputs[1][port_to] = inputs -- just a place holder
elseif id_to == "<output>" then
_, output_dim = ref_from.layer:get_dim()
if output_dim[port_from] ~= dim_out[port_to] then
nerv.error("mismatching data dimension between %s and %s", from, to)
end
outputs[port_to] = {ref_from, port_from}
- ref_from.outputs[port_from] = outputs -- just a place holder
+ ref_from.outputs[1][port_from] = outputs -- just a place holder
else
_, output_dim = ref_from.layer:get_dim()
input_dim, _ = ref_to.layer:get_dim()
@@ -104,7 +116,7 @@ function DAGLayer:__init(id, global_conf, layer_conf)
cur.visited = true
l = l + 1
for _, nl in pairs(cur.next_layers) do
- nl.in_deg = nl.in_deg - 1
+ nl.in_deg = nl.in_deg - 1
if nl.in_deg == 0 then
table.insert(queue, nl)
r = r + 1
@@ -138,7 +150,10 @@ function DAGLayer:__init(id, global_conf, layer_conf)
end
end
-function DAGLayer:init(batch_size)
+function DAGLayer:init(batch_size, chunk_size)
+ if chunk_size == nil then
+ chunk_size = 1
+ end
for i, conn in ipairs(self.parsed_conn) do
local _, output_dim
local ref_from, port_from, ref_to, port_to
@@ -149,28 +164,35 @@ function DAGLayer:init(batch_size)
if output_dim[port_from] > 0 then
dim = output_dim[port_from]
end
- local mid = self.mat_type(batch_size, dim)
- local err_mid = mid:create()
- ref_from.outputs[port_from] = mid
- ref_to.inputs[port_to] = mid
+ for t = 1, chunk_size do
+ local mid = self.mat_type(batch_size, dim)
+ local err_mid = mid:create()
+ touch_list_by_idx(ref_to.inputs, t)
+ touch_list_by_idx(ref_from.outputs, t)
+ touch_list_by_idx(ref_from.err_inputs, t)
+ touch_list_by_idx(ref_to.err_outputs, t)
+
+ ref_from.outputs[t][port_from] = mid
+ ref_to.inputs[t][port_to] = mid
- ref_from.err_inputs[port_from] = err_mid
- ref_to.err_outputs[port_to] = err_mid
+ ref_from.err_inputs[t][port_from] = err_mid
+ ref_to.err_outputs[t][port_to] = err_mid
+ end
end
for id, ref in pairs(self.layers) do
for i = 1, ref.input_len do
- if ref.inputs[i] == nil then
+ if ref.inputs[1][i] == nil then
nerv.error("dangling input port %d of layer %s", i, id)
end
end
for i = 1, ref.output_len do
- if ref.outputs[i] == nil then
+ if ref.outputs[1][i] == nil then
nerv.error("dangling output port %d of layer %s", i, id)
end
end
-- initialize sub layers
- ref.layer:init(batch_size)
+ ref.layer:init(batch_size, chunk_size)
end
for i = 1, #self.dim_in do
if self.inputs[i] == nil then
@@ -184,8 +206,10 @@ function DAGLayer:init(batch_size)
end
end
-function DAGLayer:batch_resize(batch_size)
- self.gconf.batch_size = batch_size
+function DAGLayer:batch_resize(batch_size, chunk_size)
+ if chunk_size == nil then
+ chunk_size = 1
+ end
for i, conn in ipairs(self.parsed_conn) do
local _, output_dim
@@ -194,93 +218,105 @@ function DAGLayer:batch_resize(batch_size)
ref_to, port_to = unpack(conn[2])
_, output_dim = ref_from.layer:get_dim()
- if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then
- local mid = self.mat_type(batch_size, output_dim[port_from])
- local err_mid = mid:create()
+ if ref_from.outputs[1][port_from]:nrow() ~= batch_size
+ and output_dim[port_from] > 0 then
+ for t = 1, chunk_size do
+ local mid = self.mat_type(batch_size, output_dim[port_from])
+ local err_mid = mid:create()
- ref_from.outputs[port_from] = mid
- ref_to.inputs[port_to] = mid
+ ref_from.outputs[t][port_from] = mid
+ ref_to.inputs[t][port_to] = mid
- ref_from.err_inputs[port_from] = err_mid
- ref_to.err_outputs[port_to] = err_mid
+ ref_from.err_inputs[t][port_from] = err_mid
+ ref_to.err_outputs[t][port_to] = err_mid
+ end
end
end
for id, ref in pairs(self.layers) do
- ref.layer:batch_resize(batch_size)
+ ref.layer:batch_resize(batch_size, chunk_size)
end
collectgarbage("collect")
end
-function DAGLayer:set_inputs(input)
+function DAGLayer:set_inputs(input, t)
for i = 1, #self.dim_in do
if input[i] == nil then
nerv.error("some input is not provided");
end
local layer = self.inputs[i][1]
local port = self.inputs[i][2]
- layer.inputs[port] = input[i]
+ touch_list_by_idx(layer.inputs, t)
+ layer.inputs[t][port] = input[i]
end
end
-function DAGLayer:set_outputs(output)
+function DAGLayer:set_outputs(output, t)
for i = 1, #self.dim_out do
if output[i] == nil then
nerv.error("some output is not provided");
end
local layer = self.outputs[i][1]
local port = self.outputs[i][2]
- layer.outputs[port] = output[i]
+ touch_list_by_idx(layer.outputs, t)
+ layer.outputs[t][port] = output[i]
end
end
-function DAGLayer:set_err_inputs(bp_err)
+function DAGLayer:set_err_inputs(bp_err, t)
for i = 1, #self.dim_out do
local layer = self.outputs[i][1]
local port = self.outputs[i][2]
- layer.err_inputs[port] = bp_err[i]
+ touch_list_by_idx(layer.err_inputs, t)
+ layer.err_inputs[t][port] = bp_err[i]
end
end
-function DAGLayer:set_err_outputs(next_bp_err)
+function DAGLayer:set_err_outputs(next_bp_err, t)
for i = 1, #self.dim_in do
local layer = self.inputs[i][1]
local port = self.inputs[i][2]
- layer.err_outputs[port] = next_bp_err[i]
+ touch_list_by_idx(layer.err_outputs, t)
+ layer.err_outputs[t][port] = next_bp_err[i]
end
end
-function DAGLayer:update(bp_err, input, output)
- self:set_err_inputs(bp_err)
- self:set_inputs(input)
- self:set_outputs(output)
- -- print("update")
+function DAGLayer:update(bp_err, input, output, t)
+ if t == nil then
+ t = 1
+ end
+ self:set_err_inputs(bp_err, t)
+ self:set_inputs(input, t)
+ self:set_outputs(output, t)
for id, ref in pairs(self.queue) do
- -- print(ref.layer.id)
- ref.layer:update(ref.err_inputs, ref.inputs, ref.outputs)
+ ref.layer:update(ref.err_inputs[t], ref.inputs[t], ref.outputs[t], t)
end
end
-function DAGLayer:propagate(input, output)
- self:set_inputs(input)
- self:set_outputs(output)
+function DAGLayer:propagate(input, output, t)
+ if t == nil then
+ t = 1
+ end
+ self:set_inputs(input, t)
+ self:set_outputs(output, t)
local ret = false
for i = 1, #self.queue do
local ref = self.queue[i]
- -- print(ref.layer.id)
- ret = ref.layer:propagate(ref.inputs, ref.outputs)
+ ret = ref.layer:propagate(ref.inputs[t], ref.outputs[t], t)
end
return ret
end
-function DAGLayer:back_propagate(bp_err, next_bp_err, input, output)
- self:set_err_outputs(next_bp_err)
- self:set_err_inputs(bp_err)
- self:set_inputs(input)
- self:set_outputs(output)
+function DAGLayer:back_propagate(bp_err, next_bp_err, input, output, t)
+ if t == nil then
+ t = 1
+ end
+ self:set_err_outputs(next_bp_err, t)
+ self:set_err_inputs(bp_err, t)
+ self:set_inputs(input, t)
+ self:set_outputs(output, t)
for i = #self.queue, 1, -1 do
local ref = self.queue[i]
- -- print(ref.layer.id)
- ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs)
+ ref.layer:back_propagate(ref.err_inputs[t], ref.err_outputs[t], ref.inputs[t], ref.outputs[t], t)
end
end