diff options
Diffstat (limited to 'nerv/nn/layer_dag.lua')
-rw-r--r-- | nerv/nn/layer_dag.lua | 146 |
1 files changed, 91 insertions, 55 deletions
diff --git a/nerv/nn/layer_dag.lua b/nerv/nn/layer_dag.lua index 6ad7ae9..6896878 100644 --- a/nerv/nn/layer_dag.lua +++ b/nerv/nn/layer_dag.lua @@ -2,7 +2,7 @@ local DAGLayer = nerv.class("nerv.DAGLayer", "nerv.Layer") local function parse_id(str) local id, port, _ - _, _, id, port = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%]") + _, _, id, port = string.find(str, "([a-zA-Z0-9_.]+)%[([0-9]+)%]") if id == nil or port == nil then _, _, id, port = string.find(str, "(.+)%[([0-9]+)%]") if not (id == "<input>" or id == "<output>") then @@ -38,6 +38,12 @@ local function discover(id, layers, layer_repo) return ref end +local function touch_list_by_idx(list, idx) + if list[idx] == nil then + list[idx] = {} + end +end + function DAGLayer:__init(id, global_conf, layer_conf) local layers = {} local inputs = {} @@ -51,11 +57,17 @@ function DAGLayer:__init(id, global_conf, layer_conf) local ref_from = discover(id_from, layers, layer_conf.sub_layers) local ref_to = discover(id_to, layers, layer_conf.sub_layers) local input_dim, output_dim, _ - if ref_from and ref_from.outputs[port_from] ~= nil then - nerv.error("%s has already been attached", from) + if ref_from then + touch_list_by_idx(ref_from.outputs, 1) + if ref_from.outputs[1][port_from] ~= nil then + nerv.error("%s has already been attached", from) + end end - if ref_to and ref_to.inputs[port_to] ~= nil then - nerv.error("%s has already been attached", to) + if ref_to then + touch_list_by_idx(ref_to.inputs, 1) + if ref_to.inputs[1][port_to] ~= nil then + nerv.error("%s has already been attached", to) + end end if id_from == "<input>" then input_dim, _ = ref_to.layer:get_dim() @@ -63,14 +75,14 @@ function DAGLayer:__init(id, global_conf, layer_conf) nerv.error("mismatching data dimension between %s and %s", from, to) end inputs[port_from] = {ref_to, port_to} - ref_to.inputs[port_to] = inputs -- just a place holder + ref_to.inputs[1][port_to] = inputs -- just a place holder elseif id_to == "<output>" then _, output_dim = ref_from.layer:get_dim() if output_dim[port_from] ~= dim_out[port_to] then nerv.error("mismatching data dimension between %s and %s", from, to) end outputs[port_to] = {ref_from, port_from} - ref_from.outputs[port_from] = outputs -- just a place holder + ref_from.outputs[1][port_from] = outputs -- just a place holder else _, output_dim = ref_from.layer:get_dim() input_dim, _ = ref_to.layer:get_dim() @@ -104,7 +116,7 @@ function DAGLayer:__init(id, global_conf, layer_conf) cur.visited = true l = l + 1 for _, nl in pairs(cur.next_layers) do - nl.in_deg = nl.in_deg - 1 + nl.in_deg = nl.in_deg - 1 if nl.in_deg == 0 then table.insert(queue, nl) r = r + 1 @@ -138,7 +150,10 @@ function DAGLayer:__init(id, global_conf, layer_conf) end end -function DAGLayer:init(batch_size) +function DAGLayer:init(batch_size, chunk_size) + if chunk_size == nil then + chunk_size = 1 + end for i, conn in ipairs(self.parsed_conn) do local _, output_dim local ref_from, port_from, ref_to, port_to @@ -149,28 +164,35 @@ function DAGLayer:init(batch_size) if output_dim[port_from] > 0 then dim = output_dim[port_from] end - local mid = self.mat_type(batch_size, dim) - local err_mid = mid:create() - ref_from.outputs[port_from] = mid - ref_to.inputs[port_to] = mid + for t = 1, chunk_size do + local mid = self.mat_type(batch_size, dim) + local err_mid = mid:create() + touch_list_by_idx(ref_to.inputs, t) + touch_list_by_idx(ref_from.outputs, t) + touch_list_by_idx(ref_from.err_inputs, t) + touch_list_by_idx(ref_to.err_outputs, t) + + ref_from.outputs[t][port_from] = mid + ref_to.inputs[t][port_to] = mid - ref_from.err_inputs[port_from] = err_mid - ref_to.err_outputs[port_to] = err_mid + ref_from.err_inputs[t][port_from] = err_mid + ref_to.err_outputs[t][port_to] = err_mid + end end for id, ref in pairs(self.layers) do for i = 1, ref.input_len do - if ref.inputs[i] == nil then + if ref.inputs[1][i] == nil then nerv.error("dangling input port %d of layer %s", i, id) end end for i = 1, ref.output_len do - if ref.outputs[i] == nil then + if ref.outputs[1][i] == nil then nerv.error("dangling output port %d of layer %s", i, id) end end -- initialize sub layers - ref.layer:init(batch_size) + ref.layer:init(batch_size, chunk_size) end for i = 1, #self.dim_in do if self.inputs[i] == nil then @@ -184,8 +206,10 @@ function DAGLayer:init(batch_size) end end -function DAGLayer:batch_resize(batch_size) - self.gconf.batch_size = batch_size +function DAGLayer:batch_resize(batch_size, chunk_size) + if chunk_size == nil then + chunk_size = 1 + end for i, conn in ipairs(self.parsed_conn) do local _, output_dim @@ -194,93 +218,105 @@ function DAGLayer:batch_resize(batch_size) ref_to, port_to = unpack(conn[2]) _, output_dim = ref_from.layer:get_dim() - if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then - local mid = self.mat_type(batch_size, output_dim[port_from]) - local err_mid = mid:create() + if ref_from.outputs[1][port_from]:nrow() ~= batch_size + and output_dim[port_from] > 0 then + for t = 1, chunk_size do + local mid = self.mat_type(batch_size, output_dim[port_from]) + local err_mid = mid:create() - ref_from.outputs[port_from] = mid - ref_to.inputs[port_to] = mid + ref_from.outputs[t][port_from] = mid + ref_to.inputs[t][port_to] = mid - ref_from.err_inputs[port_from] = err_mid - ref_to.err_outputs[port_to] = err_mid + ref_from.err_inputs[t][port_from] = err_mid + ref_to.err_outputs[t][port_to] = err_mid + end end end for id, ref in pairs(self.layers) do - ref.layer:batch_resize(batch_size) + ref.layer:batch_resize(batch_size, chunk_size) end collectgarbage("collect") end -function DAGLayer:set_inputs(input) +function DAGLayer:set_inputs(input, t) for i = 1, #self.dim_in do if input[i] == nil then nerv.error("some input is not provided"); end local layer = self.inputs[i][1] local port = self.inputs[i][2] - layer.inputs[port] = input[i] + touch_list_by_idx(layer.inputs, t) + layer.inputs[t][port] = input[i] end end -function DAGLayer:set_outputs(output) +function DAGLayer:set_outputs(output, t) for i = 1, #self.dim_out do if output[i] == nil then nerv.error("some output is not provided"); end local layer = self.outputs[i][1] local port = self.outputs[i][2] - layer.outputs[port] = output[i] + touch_list_by_idx(layer.outputs, t) + layer.outputs[t][port] = output[i] end end -function DAGLayer:set_err_inputs(bp_err) +function DAGLayer:set_err_inputs(bp_err, t) for i = 1, #self.dim_out do local layer = self.outputs[i][1] local port = self.outputs[i][2] - layer.err_inputs[port] = bp_err[i] + touch_list_by_idx(layer.err_inputs, t) + layer.err_inputs[t][port] = bp_err[i] end end -function DAGLayer:set_err_outputs(next_bp_err) +function DAGLayer:set_err_outputs(next_bp_err, t) for i = 1, #self.dim_in do local layer = self.inputs[i][1] local port = self.inputs[i][2] - layer.err_outputs[port] = next_bp_err[i] + touch_list_by_idx(layer.err_outputs, t) + layer.err_outputs[t][port] = next_bp_err[i] end end -function DAGLayer:update(bp_err, input, output) - self:set_err_inputs(bp_err) - self:set_inputs(input) - self:set_outputs(output) - -- print("update") +function DAGLayer:update(bp_err, input, output, t) + if t == nil then + t = 1 + end + self:set_err_inputs(bp_err, t) + self:set_inputs(input, t) + self:set_outputs(output, t) for id, ref in pairs(self.queue) do - -- print(ref.layer.id) - ref.layer:update(ref.err_inputs, ref.inputs, ref.outputs) + ref.layer:update(ref.err_inputs[t], ref.inputs[t], ref.outputs[t], t) end end -function DAGLayer:propagate(input, output) - self:set_inputs(input) - self:set_outputs(output) +function DAGLayer:propagate(input, output, t) + if t == nil then + t = 1 + end + self:set_inputs(input, t) + self:set_outputs(output, t) local ret = false for i = 1, #self.queue do local ref = self.queue[i] - -- print(ref.layer.id) - ret = ref.layer:propagate(ref.inputs, ref.outputs) + ret = ref.layer:propagate(ref.inputs[t], ref.outputs[t], t) end return ret end -function DAGLayer:back_propagate(bp_err, next_bp_err, input, output) - self:set_err_outputs(next_bp_err) - self:set_err_inputs(bp_err) - self:set_inputs(input) - self:set_outputs(output) +function DAGLayer:back_propagate(bp_err, next_bp_err, input, output, t) + if t == nil then + t = 1 + end + self:set_err_outputs(next_bp_err, t) + self:set_err_inputs(bp_err, t) + self:set_inputs(input, t) + self:set_outputs(output, t) for i = #self.queue, 1, -1 do local ref = self.queue[i] - -- print(ref.layer.id) - ref.layer:back_propagate(ref.err_inputs, ref.err_outputs, ref.inputs, ref.outputs) + ref.layer:back_propagate(ref.err_inputs[t], ref.err_outputs[t], ref.inputs[t], ref.outputs[t], t) end end |