local TNN = nerv.class("nerv.TNN") local function parse_id(str) --used to parse layerid[portid],time local id, port, time, _ _, _, id, port, time = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%][,]*([0-9]*)") if id == nil or port == nil then _, _, id, port, time = string.find(str, "(.+)%[([0-9]+)%][,]*([0-9]*)") if not (id == "" or id == "") then nerv.error("wrong format of connection id") end end --print(str, id, port, time) port = tonumber(port) if (time == nil) then time = 0 else time = tonumber(time) end --now time don't need to be parsed return id, port end local function discover(id, layers, layer_repo) local ref = layers[id] if id == "" or id == "" then return nil end if ref == nil then local layer = layer_repo:get_layer(id) local dim_in, dim_out = layer:get_dim() ref = { layer = layer, id = layer.id, inputs_m = {}, --storage for computation, inputs_m[time][port] inputs_b = {}, --inputs_g[time][port], whether this input can been computed inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port] outputs_m = {}, outputs_b = {}, err_inputs_m = {}, err_inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation err_inputs_b = {}, err_outputs_m = {}, err_outputs_b = {}, i_conns_p = {}, --list of inputing connections o_conns_p = {}, --list of outputing connections dim_in = dim_in, --list of dimensions of ports dim_out = dim_out, } layers[id] = ref end return ref end nerv.TNN.FC = {} --flag const nerv.TNN.FC.SEQ_START = 4 nerv.TNN.FC.SEQ_END = 8 nerv.TNN.FC.HAS_INPUT = 1 nerv.TNN.FC.HAS_LABEL = 2 nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, extend_t, global_conf, st_c, p_c, t_c) --Return a table of matrix storage from time (1-extend_t)..(chunk_size+extend_t) if (type(st) ~= "table") then nerv.error("st should be a table") end for i = 1 - extend_t - 2, chunk_size + extend_t + 2 do --intentionally allocated more time if (st[i] == nil) then st[i] = {} end st[i][p] = global_conf.cumat_type(batch_size, dim) st[i][p]:fill(0) if (st_c ~= nil) then if (st_c[i + t_c] == nil) then st_c[i + t_c] = {} end st_c[i + t_c][p_c] = st[i][p] end end collectgarbage("collect") --free the old one to save memory end function TNN:out_of_feedrange(t) --out of chunk, or no input, for the current feed if (t < 1 or t > self.chunk_size) then return true end if (self.feeds_now.flagsPack_now[t] == 0 or self.feeds_now.flagsPack_now[t] == nil) then return true end return false end function TNN:__init(id, global_conf, layer_conf) self.clip_t = layer_conf.clip_t if self.clip_t == nil then self.clip_t = 0 end if self.clip_t > 0 then nerv.info("tnn(%s) will clip gradient across time with %f...", id, self.clip_t) end self.extend_t = layer_conf.extend_t --TNN will allocate storage of time for 1-extend_t .. chunk_size+extend_t if self.extend_t == nil then self.extend_t = 5 end nerv.info("tnn(%s) will extend storage beyond MB border for time steps %d...", id, self.extend_t) local layers = {} local inputs_p = {} --map:port of the TNN to layer ref and port local outputs_p = {} local dim_in = layer_conf.dim_in local dim_out = layer_conf.dim_out local parsed_conns = {} local _ for id, _ in pairs(layer_conf.sub_layers.layers) do --caution: with this line, some layer not connected will be included discover(id, layers, layer_conf.sub_layers) end for _, ll in pairs(layer_conf.connections) do local id_from, port_from = parse_id(ll[1]) local id_to, port_to = parse_id(ll[2]) local time_to = ll[3] print(id_from, id_to, time_to) local ref_from = discover(id_from, layers, layer_conf.sub_layers) local ref_to = discover(id_to, layers, layer_conf.sub_layers) if (id_from == "") then if (dim_in[port_from] ~= ref_to.dim_in[port_to] or time_to ~= 0) then nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) end inputs_p[port_from] = {["ref"] = ref_to, ["port"] = port_to} ref_to.inputs_m[port_to] = {} --just a place holder elseif (id_to == "") then if (dim_out[port_to] ~= ref_from.dim_out[port_from] or time_to ~= 0) then nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) end outputs_p[port_to] = {["ref"] = ref_from, ["port"] = port_from} ref_from.outputs_m[port_from] = {} --just a place holder else local conn_now = { ["src"] = {["ref"] = ref_from, ["port"] = port_from}, ["dst"] = {["ref"] = ref_to, ["port"] = port_to}, ["time"] = time_to } if (ref_to.dim_in[port_to] ~= ref_from.dim_out[port_from]) then nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) end table.insert(parsed_conns, conn_now) ref_to.i_conns_p[conn_now.dst.port] = conn_now ref_from.o_conns_p[conn_now.src.port] = conn_now end end for id, ref in pairs(layers) do print(id, "#dim_in:", #ref.dim_in, "#dim_out:", #ref.dim_out, "#i_conns_p:", #ref.i_conns_p, "#o_conns_p", #ref.o_conns_p) end self.layers = layers self.inputs_p = inputs_p self.outputs_p = outputs_p self.id = id self.dim_in = dim_in self.dim_out = dim_out self.parsed_conns = parsed_conns self.gconf = global_conf end function TNN:init(batch_size, chunk_size) self.batch_size = batch_size self.chunk_size = chunk_size for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN local _, output_dim local ref_from, port_from, ref_to, port_to, time ref_from, port_from = conn.src.ref, conn.src.port ref_to, port_to = conn.dst.ref, conn.dst.port time = conn.time local dim = ref_from.dim_out[port_from] if (dim == 0) then nerv.error("layer %s has a zero dim port", ref_from.layer.id) end nerv.info("TNN initing storage %s->%s", ref_from.layer.id, ref_to.layer.id) ref_to.inputs_matbak_p[port_to] = self.gconf.cumat_type(batch_size, dim) self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.inputs_m, port_to, time) ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim) self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.err_outputs_m, port_to, time) end self.outputs_m = {} self.err_inputs_m = {} for i = 1, #self.dim_out do --Init storage for output ports local ref = self.outputs_p[i].ref local p = self.outputs_p[i].port self.make_initial_store(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.extend_t, self.gconf, self.outputs_m, i, 0) self.make_initial_store(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.extend_t, self.gconf, self.err_inputs_m, i, 0) end self.inputs_m = {} self.err_outputs_m = {} for i = 1, #self.dim_in do --Init storage for input ports local ref = self.inputs_p[i].ref local p = self.inputs_p[i].port self.make_initial_store(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.extend_t, self.gconf, self.inputs_m, i, 0) self.make_initial_store(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.extend_t, self.gconf, self.err_outputs_m, i, 0) end for id, ref in pairs(self.layers) do --Calling init for child layers for i = 1, #ref.dim_in do if (ref.inputs_m[i] == nil or ref.err_outputs_m[i] == nil) then nerv.error("dangling input port %d of layer %s", i, id) end end for i = 1, #ref.dim_out do if (ref.outputs_m[i] == nil or ref.err_inputs_m[i] == nil) then nerv.error("dangling output port %d of layer %s", i, id) end end -- initialize sub layers nerv.info("TNN initing sub-layer %s", ref.id) ref.layer:init(batch_size, chunk_size) collectgarbage("collect") end local flags_now = {} local flagsPack_now = {} for i = 1, chunk_size do flags_now[i] = {} flagsPack_now[i] = 0 end self.feeds_now = {} --feeds is for the reader to fill self.feeds_now.inputs_m = self.inputs_m self.feeds_now.flags_now = flags_now self.feeds_now.flagsPack_now = flagsPack_now self:flush_all() end --[[ function DAGLayer:batch_resize(batch_size) self.gconf.batch_size = batch_size for i, conn in ipairs(self.parsed_conn) do local _, output_dim local ref_from, port_from, ref_to, port_to ref_from, port_from = unpack(conn[1]) ref_to, port_to = unpack(conn[2]) _, output_dim = ref_from.layer:get_dim() if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then local mid = self.gconf.cumat_type(batch_size, output_dim[port_from]) local err_mid = mid:create() ref_from.outputs[port_from] = mid ref_to.inputs[port_to] = mid ref_from.err_inputs[port_from] = err_mid ref_to.err_outputs[port_to] = err_mid end end for id, ref in pairs(self.layers) do ref.layer:batch_resize(batch_size) end collectgarbage("collect") end ]]-- function TNN:flush_all() --flush all history and activation local _, ref for _, ref in pairs(self.layers) do for i = 1, #ref.dim_in do for t = 1 - self.extend_t, self.chunk_size + self.extend_t do ref.inputs_m[t][i]:fill(self.gconf.nn_act_default) if (ref.inputs_b[t] == nil) then ref.inputs_b[t] = {} end ref.inputs_b[t][i] = false ref.err_outputs_m[t][i]:fill(0) if (ref.err_outputs_b[t] == nil) then ref.err_outputs_b[t] = {} end ref.err_outputs_b[t][i] = false end end for i = 1, #ref.dim_out do for t = 1 - self.extend_t, self.chunk_size + self.extend_t do ref.outputs_m[t][i]:fill(self.gconf.nn_act_default) if (ref.outputs_b[t] == nil) then ref.outputs_b[t] = {} end ref.outputs_b[t][i] = false ref.err_inputs_m[t][i]:fill(0) if (ref.err_inputs_b[t] == nil) then ref.err_inputs_b[t] = {} end ref.err_inputs_b[t][i] = false end end end end --reader: some reader --Returns: bool, whether has new feed --Returns: feeds, a table that will be filled with the reader's feeds function TNN:getfeed_from_reader(reader) local feeds_now = self.feeds_now local got_new = reader:get_batch(feeds_now) return got_new, feeds_now end function TNN:move_right_to_nextmb(list_t) --move output history activations of 1..chunk_size to 1-chunk_size..0 if list_t == nil then list_t = {} for i = self.extend_t, 1, -1 do list_t[i] = 1 - i end end for i = 1, #list_t do t = list_t[i] if t < 1 - self.extend_t or t > 0 then nerv.error("MB move range error") end for id, ref in pairs(self.layers) do for p = 1, #ref.dim_out do ref.outputs_m[t][p]:copy_fromd(ref.outputs_m[t + self.chunk_size][p]) end end end end function TNN:net_propagate() --propagate according to feeds_now for t = 1, self.chunk_size, 1 do for id, ref in pairs(self.layers) do for p = 1, #ref.dim_out do ref.outputs_b[t][p] = false end for p = 1, #ref.dim_in do ref.inputs_b[t][p] = false end end end local feeds_now = self.feeds_now for t = 1, self.chunk_size do --some layer maybe do not have inputs from time 1..chunk_size for id, ref in pairs(self.layers) do self:propagate_dfs(ref, t) end end for t = 1, self.chunk_size do if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then for i = 1, #self.dim_in do local ref = self.inputs_p[i].ref local p = self.inputs_p[i].port ref.inputs_b[t][p] = true self:propagate_dfs(ref, t) end end end local flag_out = true for t = 1, self.chunk_size do --check whether every output has been computed if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then for i = 1, #self.dim_out do local ref = self.outputs_p[i].ref if (ref.outputs_b[t][1] ~= true) then flag_out = false break end end end end if (flag_out == false) then nerv.error("some thing wrong, some labeled output is not propagated") end end --ref: the TNN_ref of a layer --t: the current time to propagate function TNN:propagate_dfs(ref, t) if (self:out_of_feedrange(t)) then return end if (ref.outputs_b[t][1] == true) then --already propagated, 1 is just a random port return end --print("debug dfs", ref.layer.id, t) local flag = true --whether have all inputs for _, conn in pairs(ref.i_conns_p) do local p = conn.dst.port if (not (ref.inputs_b[t][p] or self:out_of_feedrange(t - conn.time))) then flag = false break end end if (flag == false) then return end --ok, do propagate --print("debug ok, propagating"); --The MB moving will cause bordering history to be changed, so it is more wise to flush the input activation if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border history for i = 1, self.batch_size do local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START) local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END) if (seq_start > 0 or seq_end > 0) then for p, conn in pairs(ref.i_conns_p) do if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to default ref.inputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default) end end end end end self.gconf.timer:tic("tnn_actual_layer_propagate") ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate! self.gconf.timer:toc("tnn_actual_layer_propagate") --[[ if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history for i = 1, self.batch_size do local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START) local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END) if (seq_start > 0 or seq_end > 0) then for p, conn in pairs(ref.o_conns_p) do if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then ref.outputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default) end end end end end ]]-- --set input flag for future layers for i = 1, #ref.dim_out do if (ref.outputs_b[t][i] == true) then nerv.error("this time's outputs_b should be false") end ref.outputs_b[t][i] = true end --try dfs for further layers for _, conn in pairs(ref.o_conns_p) do --print("debug dfs-searching", conn.dst.ref.layer.id) conn.dst.ref.inputs_b[t + conn.time][conn.dst.port] = true self:propagate_dfs(conn.dst.ref, t + conn.time) end end --do_update: bool, whether we are doing back-propagate or updating the parameters function TNN:net_backpropagate(do_update) --propagate according to feeds_now if do_update == nil then nerv.error("do_update should not be nil") end for t = 1, self.chunk_size, 1 do for id, ref in pairs(self.layers) do for p = 1, #ref.dim_out do ref.err_inputs_b[t][p] = false end for p = 1, #ref.dim_in do ref.err_outputs_b[t][p] = false end end end local feeds_now = self.feeds_now for t = 1, self.chunk_size do --some layer maybe do not have outputs from time 1..chunk_size for id, ref in pairs(self.layers) do self:backpropagate_dfs(ref, t, do_update) end end for t = 1, self.chunk_size do if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then for i = 1, #self.dim_out do local ref = self.outputs_p[i].ref local p = self.outputs_p[i].port ref.err_inputs_b[t][p] = true self:backpropagate_dfs(ref, t, do_update) end end end local flag_out = true for t = 1, self.chunk_size do --check whether every output has been computed if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0 then for i = 1, #self.dim_in do local ref = self.inputs_p[i].ref if ref.err_outputs_b[t][1] ~= true then flag_out = false break end end end end if (flag_out == false) then nerv.error("some thing wrong, some input is not back_propagated") end end --ref: the TNN_ref of a layer --t: the current time to propagate function TNN:backpropagate_dfs(ref, t, do_update) if do_update == nil then nerv.error("got a nil do_update") end if self:out_of_feedrange(t) then return end if ref.err_outputs_b[t][1] == true then --already back_propagated, 1 is just a random port return end --print("debug dfs", ref.layer.id, t) local flag = true --whether have all inputs for _, conn in pairs(ref.o_conns_p) do local p = conn.src.port if (not (ref.err_inputs_b[t][p] or self:out_of_feedrange(t + conn.time))) then flag = false break end end if (flag == false) then return end --ok, do back_propagate --print("debug ok, back-propagating(or updating)") if (do_update == false) then self.gconf.timer:tic("tnn_actual_layer_backpropagate") ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t) self.gconf.timer:toc("tnn_actual_layer_backpropagate") if self.clip_t > 0 then for _, conn in pairs(ref.i_conns_p) do local p = conn.dst.port --port for ref if conn.time ~= 0 then --print("debug clip_t tnn", ref.id, "port:", p, "clip:", self.clip_t) ref.err_outputs_m[t][p]:clip(-self.clip_t, self.clip_t) end end end else --print(ref.err_inputs_m[t][1]) self.gconf.timer:tic("tnn_actual_layer_update") ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t) self.gconf.timer:toc("tnn_actual_layer_update") end if (do_update == false and bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border errors for i = 1, self.batch_size do local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START) local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END) if (seq_start > 0 or seq_end > 0) then for p, conn in pairs(ref.i_conns_p) do if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to zero ref.err_outputs_m[t][p][i - 1]:fill(0) end end end end end for i = 1, #ref.dim_in do if (ref.err_outputs_b[t][i] == true) then nerv.error("this time's outputs_b should be false") end ref.err_outputs_b[t][i] = true end --try dfs for further layers for _, conn in pairs(ref.i_conns_p) do --print("debug dfs-searching", conn.src.ref.layer.id) conn.src.ref.err_inputs_b[t - conn.time][conn.src.port] = true self:backpropagate_dfs(conn.src.ref, t - conn.time, do_update) end end --Return: nerv.ParamRepo function TNN:get_params() local param_repos = {} for id, ref in pairs(self.layers) do table.insert(param_repos, ref.layer:get_params()) end return nerv.ParamRepo.merge(param_repos) end