local TNN = nerv.class("nerv.TNN")
local function parse_id(str)
--used to parse layerid[portid],time
local id, port, time, _
_, _, id, port, time = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%][,]*([0-9]*)")
if id == nil or port == nil then
_, _, id, port, time = string.find(str, "(.+)%[([0-9]+)%][,]*([0-9]*)")
if not (id == " " or id == "") then
nerv.error("wrong format of connection id")
end
end
--print(str, id, port, time)
port = tonumber(port)
if (time == nil) then
time = 0
else
time = tonumber(time)
end
--now time don't need to be parsed
return id, port
end
local function discover(id, layers, layer_repo)
local ref = layers[id]
if id == " " or id == "" then
return nil
end
if ref == nil then
local layer = layer_repo:get_layer(id)
local dim_in, dim_out = layer:get_dim()
ref = {
layer = layer,
id = layer.id,
inputs_m = {}, --storage for computation, inputs_m[time][port]
inputs_b = {}, --inputs_g[time][port], whether this input can been computed
inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port]
outputs_m = {},
outputs_b = {},
err_inputs_m = {},
err_inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation
err_inputs_b = {},
err_outputs_m = {},
err_outputs_b = {},
i_conns_p = {}, --list of inputing connections
o_conns_p = {}, --list of outputing connections
dim_in = dim_in, --list of dimensions of ports
dim_out = dim_out,
}
layers[id] = ref
end
return ref
end
nerv.TNN.FC = {} --flag const
nerv.TNN.FC.SEQ_START = 4
nerv.TNN.FC.SEQ_END = 8
nerv.TNN.FC.HAS_INPUT = 1
nerv.TNN.FC.HAS_LABEL = 2
nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label
function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, extend_t, global_conf, st_c, p_c, t_c)
--Return a table of matrix storage from time (1-extend_t)..(chunk_size+extend_t)
if (type(st) ~= "table") then
nerv.error("st should be a table")
end
for i = 1 - extend_t - 2, chunk_size + extend_t + 2 do --intentionally allocated more time
if (st[i] == nil) then
st[i] = {}
end
st[i][p] = global_conf.cumat_type(batch_size, dim)
st[i][p]:fill(0)
if (st_c ~= nil) then
if (st_c[i + t_c] == nil) then
st_c[i + t_c] = {}
end
st_c[i + t_c][p_c] = st[i][p]
end
end
collectgarbage("collect") --free the old one to save memory
end
function TNN:out_of_feedrange(t) --out of chunk, or no input, for the current feed
if (t < 1 or t > self.chunk_size) then
return true
end
if (self.feeds_now.flagsPack_now[t] == 0 or self.feeds_now.flagsPack_now[t] == nil) then
return true
end
return false
end
function TNN:__init(id, global_conf, layer_conf)
self.clip_t = layer_conf.clip_t
if self.clip_t == nil then
self.clip_t = 0
end
if self.clip_t > 0 then
nerv.info("tnn(%s) will clip gradient across time with %f...", id, self.clip_t)
end
self.extend_t = layer_conf.extend_t --TNN will allocate storage of time for 1-extend_t .. chunk_size+extend_t
if self.extend_t == nil then
self.extend_t = 5
end
nerv.info("tnn(%s) will extend storage beyond MB border for time steps %d...", id, self.extend_t)
local layers = {}
local inputs_p = {} --map:port of the TNN to layer ref and port
local outputs_p = {}
local dim_in = layer_conf.dim_in
local dim_out = layer_conf.dim_out
local parsed_conns = {}
local _
for id, _ in pairs(layer_conf.sub_layers.layers) do --caution: with this line, some layer not connected will be included
discover(id, layers, layer_conf.sub_layers)
end
for _, ll in pairs(layer_conf.connections) do
local id_from, port_from = parse_id(ll[1])
local id_to, port_to = parse_id(ll[2])
local time_to = ll[3]
print(id_from, id_to, time_to)
local ref_from = discover(id_from, layers, layer_conf.sub_layers)
local ref_to = discover(id_to, layers, layer_conf.sub_layers)
if (id_from == " ") then
if (dim_in[port_from] ~= ref_to.dim_in[port_to] or time_to ~= 0) then
nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
end
inputs_p[port_from] = {["ref"] = ref_to, ["port"] = port_to}
ref_to.inputs_m[port_to] = {} --just a place holder
elseif (id_to == "") then
if (dim_out[port_to] ~= ref_from.dim_out[port_from] or time_to ~= 0) then
nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
end
outputs_p[port_to] = {["ref"] = ref_from, ["port"] = port_from}
ref_from.outputs_m[port_from] = {} --just a place holder
else
local conn_now = {
["src"] = {["ref"] = ref_from, ["port"] = port_from},
["dst"] = {["ref"] = ref_to, ["port"] = port_to},
["time"] = time_to
}
if (ref_to.dim_in[port_to] ~= ref_from.dim_out[port_from]) then
nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
end
table.insert(parsed_conns, conn_now)
ref_to.i_conns_p[conn_now.dst.port] = conn_now
ref_from.o_conns_p[conn_now.src.port] = conn_now
end
end
for id, ref in pairs(layers) do
print(id, "#dim_in:", #ref.dim_in, "#dim_out:", #ref.dim_out, "#i_conns_p:", #ref.i_conns_p, "#o_conns_p", #ref.o_conns_p)
end
self.layers = layers
self.inputs_p = inputs_p
self.outputs_p = outputs_p
self.id = id
self.dim_in = dim_in
self.dim_out = dim_out
self.parsed_conns = parsed_conns
self.gconf = global_conf
end
function TNN:init(batch_size, chunk_size)
self.batch_size = batch_size
self.chunk_size = chunk_size
for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN
local _, output_dim
local ref_from, port_from, ref_to, port_to, time
ref_from, port_from = conn.src.ref, conn.src.port
ref_to, port_to = conn.dst.ref, conn.dst.port
time = conn.time
local dim = ref_from.dim_out[port_from]
if (dim == 0) then
nerv.error("layer %s has a zero dim port", ref_from.layer.id)
end
nerv.info("TNN initing storage %s->%s", ref_from.layer.id, ref_to.layer.id)
ref_to.inputs_matbak_p[port_to] = self.gconf.cumat_type(batch_size, dim)
self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.inputs_m, port_to, time)
ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim)
self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.err_outputs_m, port_to, time)
end
self.outputs_m = {}
self.err_inputs_m = {}
for i = 1, #self.dim_out do --Init storage for output ports
local ref = self.outputs_p[i].ref
local p = self.outputs_p[i].port
self.make_initial_store(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.extend_t, self.gconf, self.outputs_m, i, 0)
self.make_initial_store(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.extend_t, self.gconf, self.err_inputs_m, i, 0)
end
self.inputs_m = {}
self.err_outputs_m = {}
for i = 1, #self.dim_in do --Init storage for input ports
local ref = self.inputs_p[i].ref
local p = self.inputs_p[i].port
self.make_initial_store(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.extend_t, self.gconf, self.inputs_m, i, 0)
self.make_initial_store(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.extend_t, self.gconf, self.err_outputs_m, i, 0)
end
for id, ref in pairs(self.layers) do --Calling init for child layers
for i = 1, #ref.dim_in do
if (ref.inputs_m[i] == nil or ref.err_outputs_m[i] == nil) then
nerv.error("dangling input port %d of layer %s", i, id)
end
end
for i = 1, #ref.dim_out do
if (ref.outputs_m[i] == nil or ref.err_inputs_m[i] == nil) then
nerv.error("dangling output port %d of layer %s", i, id)
end
end
-- initialize sub layers
nerv.info("TNN initing sub-layer %s", ref.id)
ref.layer:init(batch_size, chunk_size)
collectgarbage("collect")
end
local flags_now = {}
local flagsPack_now = {}
for i = 1, chunk_size do
flags_now[i] = {}
flagsPack_now[i] = 0
end
self.feeds_now = {} --feeds is for the reader to fill
self.feeds_now.inputs_m = self.inputs_m
self.feeds_now.flags_now = flags_now
self.feeds_now.flagsPack_now = flagsPack_now
self:flush_all()
end
--[[
function DAGLayer:batch_resize(batch_size)
self.gconf.batch_size = batch_size
for i, conn in ipairs(self.parsed_conn) do
local _, output_dim
local ref_from, port_from, ref_to, port_to
ref_from, port_from = unpack(conn[1])
ref_to, port_to = unpack(conn[2])
_, output_dim = ref_from.layer:get_dim()
if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then
local mid = self.gconf.cumat_type(batch_size, output_dim[port_from])
local err_mid = mid:create()
ref_from.outputs[port_from] = mid
ref_to.inputs[port_to] = mid
ref_from.err_inputs[port_from] = err_mid
ref_to.err_outputs[port_to] = err_mid
end
end
for id, ref in pairs(self.layers) do
ref.layer:batch_resize(batch_size)
end
collectgarbage("collect")
end
]]--
function TNN:flush_all() --flush all history and activation
local _, ref
for _, ref in pairs(self.layers) do
for i = 1, #ref.dim_in do
for t = 1 - self.extend_t, self.chunk_size + self.extend_t do
ref.inputs_m[t][i]:fill(self.gconf.nn_act_default)
if (ref.inputs_b[t] == nil) then
ref.inputs_b[t] = {}
end
ref.inputs_b[t][i] = false
ref.err_outputs_m[t][i]:fill(0)
if (ref.err_outputs_b[t] == nil) then
ref.err_outputs_b[t] = {}
end
ref.err_outputs_b[t][i] = false
end
end
for i = 1, #ref.dim_out do
for t = 1 - self.extend_t, self.chunk_size + self.extend_t do
ref.outputs_m[t][i]:fill(self.gconf.nn_act_default)
if (ref.outputs_b[t] == nil) then
ref.outputs_b[t] = {}
end
ref.outputs_b[t][i] = false
ref.err_inputs_m[t][i]:fill(0)
if (ref.err_inputs_b[t] == nil) then
ref.err_inputs_b[t] = {}
end
ref.err_inputs_b[t][i] = false
end
end
end
end
--reader: some reader
--Returns: bool, whether has new feed
--Returns: feeds, a table that will be filled with the reader's feeds
function TNN:getfeed_from_reader(reader)
local feeds_now = self.feeds_now
local got_new = reader:get_batch(feeds_now)
return got_new, feeds_now
end
function TNN:move_right_to_nextmb(list_t) --move output history activations of 1..chunk_size to 1-chunk_size..0
if list_t == nil then
list_t = {}
for i = self.extend_t, 1, -1 do
list_t[i] = 1 - i
end
end
for i = 1, #list_t do
t = list_t[i]
if t < 1 - self.extend_t or t > 0 then
nerv.error("MB move range error")
end
for id, ref in pairs(self.layers) do
for p = 1, #ref.dim_out do
ref.outputs_m[t][p]:copy_fromd(ref.outputs_m[t + self.chunk_size][p])
end
end
end
end
function TNN:net_propagate() --propagate according to feeds_now
for t = 1, self.chunk_size, 1 do
for id, ref in pairs(self.layers) do
for p = 1, #ref.dim_out do
ref.outputs_b[t][p] = false
end
for p = 1, #ref.dim_in do
ref.inputs_b[t][p] = false
end
end
end
local feeds_now = self.feeds_now
for t = 1, self.chunk_size do --some layer maybe do not have inputs from time 1..chunk_size
for id, ref in pairs(self.layers) do
if #ref.dim_in > 0 then --some layer is just there(only to save some parameter)
self:propagate_dfs(ref, t)
end
end
end
for t = 1, self.chunk_size do
if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
for i = 1, #self.dim_in do
local ref = self.inputs_p[i].ref
local p = self.inputs_p[i].port
ref.inputs_b[t][p] = true
self:propagate_dfs(ref, t)
end
end
end
local flag_out = true
for t = 1, self.chunk_size do --check whether every output has been computed
if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then
for i = 1, #self.dim_out do
local ref = self.outputs_p[i].ref
if (ref.outputs_b[t][1] ~= true) then
flag_out = false
break
end
end
end
end
if (flag_out == false) then
nerv.error("some thing wrong, some labeled output is not propagated")
end
end
--ref: the TNN_ref of a layer
--t: the current time to propagate
function TNN:propagate_dfs(ref, t)
if (self:out_of_feedrange(t)) then
return
end
if (ref.outputs_b[t][1] == true) then --already propagated, 1 is just a random port
return
end
--print("debug dfs", ref.layer.id, t)
local flag = true --whether have all inputs
for _, conn in pairs(ref.i_conns_p) do
local p = conn.dst.port
if (not (ref.inputs_b[t][p] or self:out_of_feedrange(t - conn.time))) then
flag = false
break
end
end
if (flag == false) then
return
end
--ok, do propagate
--print("debug ok, propagating");
--The MB moving will cause bordering history to be changed, so it is more wise to flush the input activation
if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border history
for i = 1, self.batch_size do
local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
if (seq_start > 0 or seq_end > 0) then
for p, conn in pairs(ref.i_conns_p) do
if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to default
ref.inputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default)
end
end
end
end
end
self.gconf.timer:tic("tnn_actual_layer_propagate")
ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate!
self.gconf.timer:toc("tnn_actual_layer_propagate")
--[[
if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history
for i = 1, self.batch_size do
local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
if (seq_start > 0 or seq_end > 0) then
for p, conn in pairs(ref.o_conns_p) do
if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then
ref.outputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default)
end
end
end
end
end
]]--
--set input flag for future layers
for i = 1, #ref.dim_out do
if (ref.outputs_b[t][i] == true) then
nerv.error("this time's outputs_b should be false")
end
ref.outputs_b[t][i] = true
end
--try dfs for further layers
for _, conn in pairs(ref.o_conns_p) do
--print("debug dfs-searching", conn.dst.ref.layer.id)
conn.dst.ref.inputs_b[t + conn.time][conn.dst.port] = true
self:propagate_dfs(conn.dst.ref, t + conn.time)
end
end
--do_update: bool, whether we are doing back-propagate or updating the parameters
function TNN:net_backpropagate(do_update) --propagate according to feeds_now
if do_update == nil then
nerv.error("do_update should not be nil")
end
for t = 1, self.chunk_size, 1 do
for id, ref in pairs(self.layers) do
for p = 1, #ref.dim_out do
ref.err_inputs_b[t][p] = false
end
for p = 1, #ref.dim_in do
ref.err_outputs_b[t][p] = false
end
end
end
local feeds_now = self.feeds_now
for t = 1, self.chunk_size do --some layer maybe do not have outputs from time 1..chunk_size
for id, ref in pairs(self.layers) do
if #ref.dim_out > 0 then --some layer is just there(only to save some parameter)
self:backpropagate_dfs(ref, t, do_update)
end
end
end
for t = 1, self.chunk_size do
if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then
for i = 1, #self.dim_out do
local ref = self.outputs_p[i].ref
local p = self.outputs_p[i].port
ref.err_inputs_b[t][p] = true
self:backpropagate_dfs(ref, t, do_update)
end
end
end
local flag_out = true
for t = 1, self.chunk_size do --check whether every output has been computed
if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0 then
for i = 1, #self.dim_in do
local ref = self.inputs_p[i].ref
if ref.err_outputs_b[t][1] ~= true then
flag_out = false
break
end
end
end
end
if (flag_out == false) then
nerv.error("some thing wrong, some input is not back_propagated")
end
end
--ref: the TNN_ref of a layer
--t: the current time to propagate
function TNN:backpropagate_dfs(ref, t, do_update)
if do_update == nil then
nerv.error("got a nil do_update")
end
if self:out_of_feedrange(t) then
return
end
if ref.err_outputs_b[t][1] == true then --already back_propagated, 1 is just a random port
return
end
--print("debug dfs", ref.layer.id, t)
local flag = true --whether have all inputs
for _, conn in pairs(ref.o_conns_p) do
local p = conn.src.port
if (not (ref.err_inputs_b[t][p] or self:out_of_feedrange(t + conn.time))) then
flag = false
break
end
end
if (flag == false) then
return
end
--ok, do back_propagate
--print("debug ok, back-propagating(or updating)")
if (do_update == false) then
self.gconf.timer:tic("tnn_actual_layer_backpropagate")
ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
self.gconf.timer:toc("tnn_actual_layer_backpropagate")
if self.clip_t > 0 then
for _, conn in pairs(ref.i_conns_p) do
local p = conn.dst.port --port for ref
if conn.time ~= 0 then
--print("debug clip_t tnn", ref.id, "port:", p, "clip:", self.clip_t)
ref.err_outputs_m[t][p]:clip(-self.clip_t, self.clip_t)
end
end
end
else
--print(ref.err_inputs_m[t][1])
self.gconf.timer:tic("tnn_actual_layer_update")
ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
self.gconf.timer:toc("tnn_actual_layer_update")
end
if (do_update == false and bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border errors
for i = 1, self.batch_size do
local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
if (seq_start > 0 or seq_end > 0) then
for p, conn in pairs(ref.i_conns_p) do
if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to zero
ref.err_outputs_m[t][p][i - 1]:fill(0)
end
end
end
end
end
for i = 1, #ref.dim_in do
if (ref.err_outputs_b[t][i] == true) then
nerv.error("this time's outputs_b should be false")
end
ref.err_outputs_b[t][i] = true
end
--try dfs for further layers
for _, conn in pairs(ref.i_conns_p) do
--print("debug dfs-searching", conn.src.ref.layer.id)
conn.src.ref.err_inputs_b[t - conn.time][conn.src.port] = true
self:backpropagate_dfs(conn.src.ref, t - conn.time, do_update)
end
end
--Return: nerv.ParamRepo
function TNN:get_params()
local param_repos = {}
for id, ref in pairs(self.layers) do
table.insert(param_repos, ref.layer:get_params())
end
return nerv.ParamRepo.merge(param_repos)
end