aboutsummaryrefslogtreecommitdiff
path: root/nerv/tnn/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/tnn/tnn.lua')
-rw-r--r--nerv/tnn/tnn.lua565
1 files changed, 565 insertions, 0 deletions
diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua
new file mode 100644
index 0000000..56c9dc0
--- /dev/null
+++ b/nerv/tnn/tnn.lua
@@ -0,0 +1,565 @@
+local TNN = nerv.class("nerv.TNN")
+
+local function parse_id(str)
+ --used to parse layerid[portid],time
+ local id, port, time, _
+ _, _, id, port, time = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%][,]*([0-9]*)")
+ if id == nil or port == nil then
+ _, _, id, port, time = string.find(str, "(.+)%[([0-9]+)%][,]*([0-9]*)")
+ if not (id == "<input>" or id == "<output>") then
+ nerv.error("wrong format of connection id")
+ end
+ end
+ --print(str, id, port, time)
+ port = tonumber(port)
+ if (time == nil) then
+ time = 0
+ else
+ time = tonumber(time)
+ end
+ --now time don't need to be parsed
+ return id, port
+end
+
+local function discover(id, layers, layer_repo)
+ local ref = layers[id]
+ if id == "<input>" or id == "<output>" then
+ return nil
+ end
+ if ref == nil then
+ local layer = layer_repo:get_layer(id)
+ local dim_in, dim_out = layer:get_dim()
+ ref = {
+ layer = layer,
+ id = layer.id,
+ inputs_m = {}, --storage for computation, inputs_m[time][port]
+ inputs_b = {}, --inputs_g[time][port], whether this input can been computed
+ inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port]
+ outputs_m = {},
+ outputs_b = {},
+ err_inputs_m = {},
+ err_inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation
+ err_inputs_b = {},
+ err_outputs_m = {},
+ err_outputs_b = {},
+ i_conns_p = {}, --list of inputing connections
+ o_conns_p = {}, --list of outputing connections
+ dim_in = dim_in, --list of dimensions of ports
+ dim_out = dim_out,
+ }
+ layers[id] = ref
+ end
+ return ref
+end
+
+nerv.TNN.FC = {} --flag const
+nerv.TNN.FC.SEQ_START = 4
+nerv.TNN.FC.SEQ_END = 8
+nerv.TNN.FC.HAS_INPUT = 1
+nerv.TNN.FC.HAS_LABEL = 2
+nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label
+
+function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c, t_c)
+ --Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size)
+ if (type(st) ~= "table") then
+ nerv.error("st should be a table")
+ end
+ for i = 1 - chunk_size - 1, chunk_size * 2 + 1 do --intentionally allocated more time, should be [1-chunk_size, chunk_size*2]
+ if (st[i] == nil) then
+ st[i] = {}
+ end
+ st[i][p] = global_conf.cumat_type(batch_size, dim)
+ st[i][p]:fill(0)
+ if (st_c ~= nil) then
+ if (st_c[i + t_c] == nil) then
+ st_c[i + t_c] = {}
+ end
+ st_c[i + t_c][p_c] = st[i][p]
+ end
+ end
+end
+
+function TNN:out_of_feedrange(t) --out of chunk, or no input, for the current feed
+ if (t < 1 or t > self.chunk_size) then
+ return true
+ end
+ if (self.feeds_now.flagsPack_now[t] == 0 or self.feeds_now.flagsPack_now[t] == nil) then
+ return true
+ end
+ return false
+end
+
+function TNN:__init(id, global_conf, layer_conf)
+ self.clip_t = layer_conf.clip_t
+ if self.clip_t == nil then
+ self.clip_t = 0
+ end
+ if self.clip_t > 0 then
+ nerv.info("tnn(%s) will clip gradient across time with %f...", id, self.clip_t)
+ end
+ local layers = {}
+ local inputs_p = {} --map:port of the TNN to layer ref and port
+ local outputs_p = {}
+ local dim_in = layer_conf.dim_in
+ local dim_out = layer_conf.dim_out
+ local parsed_conns = {}
+ local _
+
+ for _, ll in pairs(layer_conf.connections) do
+ local id_from, port_from = parse_id(ll[1])
+ local id_to, port_to = parse_id(ll[2])
+ local time_to = ll[3]
+
+ print(id_from, id_to, time_to)
+
+ local ref_from = discover(id_from, layers, layer_conf.sub_layers)
+ local ref_to = discover(id_to, layers, layer_conf.sub_layers)
+
+ if (id_from == "<input>") then
+ if (dim_in[port_from] ~= ref_to.dim_in[port_to] or time_to ~= 0) then
+ nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
+ end
+ inputs_p[port_from] = {["ref"] = ref_to, ["port"] = port_to}
+ ref_to.inputs_m[port_to] = {} --just a place holder
+ elseif (id_to == "<output>") then
+ if (dim_out[port_to] ~= ref_from.dim_out[port_from] or time_to ~= 0) then
+ nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
+ end
+ outputs_p[port_to] = {["ref"] = ref_from, ["port"] = port_from}
+ ref_from.outputs_m[port_from] = {} --just a place holder
+ else
+ local conn_now = {
+ ["src"] = {["ref"] = ref_from, ["port"] = port_from},
+ ["dst"] = {["ref"] = ref_to, ["port"] = port_to},
+ ["time"] = time_to
+ }
+ if (ref_to.dim_in[port_to] ~= ref_from.dim_out[port_from]) then
+ nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
+ end
+ table.insert(parsed_conns, conn_now)
+ ref_to.i_conns_p[conn_now.dst.port] = conn_now
+ ref_from.o_conns_p[conn_now.src.port] = conn_now
+ end
+ end
+
+ for id, ref in pairs(layers) do
+ print(id, "#dim_in:", #ref.dim_in, "#dim_out:", #ref.dim_out, "#i_conns_p:", #ref.i_conns_p, "#o_conns_p", #ref.o_conns_p)
+ end
+
+ self.layers = layers
+ self.inputs_p = inputs_p
+ self.outputs_p = outputs_p
+ self.id = id
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.parsed_conns = parsed_conns
+ self.gconf = global_conf
+end
+
+function TNN:init(batch_size, chunk_size)
+ self.batch_size = batch_size
+ self.chunk_size = chunk_size
+ for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN
+ local _, output_dim
+ local ref_from, port_from, ref_to, port_to, time
+ ref_from, port_from = conn.src.ref, conn.src.port
+ ref_to, port_to = conn.dst.ref, conn.dst.port
+ time = conn.time
+
+ local dim = ref_from.dim_out[port_from]
+ if (dim == 0) then
+ nerv.error("layer %s has a zero dim port", ref_from.layer.id)
+ end
+
+ print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id)
+ ref_to.inputs_matbak_p[port_to] = self.gconf.cumat_type(batch_size, dim)
+ self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to, time)
+ ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim)
+ self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to, time)
+
+ end
+
+ self.outputs_m = {}
+ self.err_inputs_m = {}
+ for i = 1, #self.dim_out do --Init storage for output ports
+ local ref = self.outputs_p[i].ref
+ local p = self.outputs_p[i].port
+ self.make_initial_store(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.outputs_m, i, 0)
+ self.make_initial_store(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.err_inputs_m, i, 0)
+ end
+
+ self.inputs_m = {}
+ self.err_outputs_m = {}
+ for i = 1, #self.dim_in do --Init storage for input ports
+ local ref = self.inputs_p[i].ref
+ local p = self.inputs_p[i].port
+ self.make_initial_store(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.inputs_m, i, 0)
+ self.make_initial_store(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.err_outputs_m, i, 0)
+ end
+
+ for id, ref in pairs(self.layers) do --Calling init for child layers
+ for i = 1, #ref.dim_in do
+ if (ref.inputs_m[i] == nil or ref.err_outputs_m[i] == nil) then
+ nerv.error("dangling input port %d of layer %s", i, id)
+ end
+ end
+ for i = 1, #ref.dim_out do
+ if (ref.outputs_m[i] == nil or ref.err_inputs_m[i] == nil) then
+ nerv.error("dangling output port %d of layer %s", i, id)
+ end
+ end
+ -- initialize sub layers
+ ref.layer:init(batch_size, chunk_size)
+ end
+
+ local flags_now = {}
+ local flagsPack_now = {}
+ for i = 1, chunk_size do
+ flags_now[i] = {}
+ flagsPack_now[i] = 0
+ end
+
+ self.feeds_now = {} --feeds is for the reader to fill
+ self.feeds_now.inputs_m = self.inputs_m
+ self.feeds_now.flags_now = flags_now
+ self.feeds_now.flagsPack_now = flagsPack_now
+
+ self:flush_all()
+end
+
+--[[
+function DAGLayer:batch_resize(batch_size)
+ self.gconf.batch_size = batch_size
+
+ for i, conn in ipairs(self.parsed_conn) do
+ local _, output_dim
+ local ref_from, port_from, ref_to, port_to
+ ref_from, port_from = unpack(conn[1])
+ ref_to, port_to = unpack(conn[2])
+ _, output_dim = ref_from.layer:get_dim()
+
+ if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then
+ local mid = self.gconf.cumat_type(batch_size, output_dim[port_from])
+ local err_mid = mid:create()
+
+ ref_from.outputs[port_from] = mid
+ ref_to.inputs[port_to] = mid
+
+ ref_from.err_inputs[port_from] = err_mid
+ ref_to.err_outputs[port_to] = err_mid
+ end
+ end
+ for id, ref in pairs(self.layers) do
+ ref.layer:batch_resize(batch_size)
+ end
+ collectgarbage("collect")
+end
+]]--
+
+function TNN:flush_all() --flush all history and activation
+ local _, ref
+ for _, ref in pairs(self.layers) do
+ for i = 1, #ref.dim_in do
+ for t = 1 - self.chunk_size, self.chunk_size * 2 do
+ ref.inputs_m[t][i]:fill(self.gconf.nn_act_default)
+ if (ref.inputs_b[t] == nil) then
+ ref.inputs_b[t] = {}
+ end
+ ref.inputs_b[t][i] = false
+ ref.err_outputs_m[t][i]:fill(0)
+ if (ref.err_outputs_b[t] == nil) then
+ ref.err_outputs_b[t] = {}
+ end
+ ref.err_outputs_b[t][i] = false
+ end
+ end
+ for i = 1, #ref.dim_out do
+ for t = 1 - self.chunk_size, self.chunk_size * 2 do
+ ref.outputs_m[t][i]:fill(self.gconf.nn_act_default)
+ if (ref.outputs_b[t] == nil) then
+ ref.outputs_b[t] = {}
+ end
+ ref.outputs_b[t][i] = false
+ ref.err_inputs_m[t][i]:fill(0)
+ if (ref.err_inputs_b[t] == nil) then
+ ref.err_inputs_b[t] = {}
+ end
+ ref.err_inputs_b[t][i] = false
+ end
+ end
+ end
+end
+
+--reader: some reader
+--Returns: bool, whether has new feed
+--Returns: feeds, a table that will be filled with the reader's feeds
+function TNN:getfeed_from_reader(reader)
+ local feeds_now = self.feeds_now
+ local got_new = reader:get_batch(feeds_now)
+ return got_new, feeds_now
+end
+
+function TNN:move_right_to_nextmb(list_t) --move output history activations of 1..chunk_size to 1-chunk_size..0
+ if list_t == nil then
+ list_t = {}
+ for i = 1, self.chunk_size do
+ list_t[i] = i - self.chunk_size
+ end
+ end
+ for i = 1, #list_t do
+ t = list_t[i]
+ if t < 1 - self.chunk_size or t > 0 then
+ nerv.error("MB move range error")
+ end
+ for id, ref in pairs(self.layers) do
+ for p = 1, #ref.dim_out do
+ ref.outputs_m[t][p]:copy_fromd(ref.outputs_m[t + self.chunk_size][p])
+ end
+ end
+ end
+end
+
+function TNN:net_propagate() --propagate according to feeds_now
+ for t = 1, self.chunk_size, 1 do
+ for id, ref in pairs(self.layers) do
+ for p = 1, #ref.dim_out do
+ ref.outputs_b[t][p] = false
+ end
+ for p = 1, #ref.dim_in do
+ ref.inputs_b[t][p] = false
+ end
+ end
+ end
+
+ local feeds_now = self.feeds_now
+ for t = 1, self.chunk_size do
+ if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
+ for i = 1, #self.dim_in do
+ local ref = self.inputs_p[i].ref
+ local p = self.inputs_p[i].port
+ ref.inputs_b[t][p] = true
+ self:propagate_dfs(ref, t)
+ end
+ end
+ end
+
+ local flag_out = true
+ for t = 1, self.chunk_size do --check whether every output has been computed
+ if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then
+ for i = 1, #self.dim_out do
+ local ref = self.outputs_p[i].ref
+ if (ref.outputs_b[t][1] ~= true) then
+ flag_out = false
+ break
+ end
+ end
+ end
+ end
+ if (flag_out == false) then
+ nerv.error("some thing wrong, some labeled output is not propagated")
+ end
+end
+
+--ref: the TNN_ref of a layer
+--t: the current time to propagate
+function TNN:propagate_dfs(ref, t)
+ if (self:out_of_feedrange(t)) then
+ return
+ end
+ if (ref.outputs_b[t][1] == true) then --already propagated, 1 is just a random port
+ return
+ end
+
+ --print("debug dfs", ref.layer.id, t)
+
+ local flag = true --whether have all inputs
+ for _, conn in pairs(ref.i_conns_p) do
+ local p = conn.dst.port
+ if (not (ref.inputs_b[t][p] or self:out_of_feedrange(t - conn.time))) then
+ flag = false
+ break
+ end
+ end
+ if (flag == false) then
+ return
+ end
+
+ --ok, do propagate
+ --print("debug ok, propagating");
+ --The MB moving will cause bordering history to be changed, so it is more wise to flush the input activation
+ if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border history
+ for i = 1, self.batch_size do
+ local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
+ local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
+ if (seq_start > 0 or seq_end > 0) then
+ for p, conn in pairs(ref.i_conns_p) do
+ if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to default
+ ref.inputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default)
+ end
+ end
+ end
+ end
+ end
+ self.gconf.timer:tic("tnn_actual_layer_propagate")
+ ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate!
+ self.gconf.timer:toc("tnn_actual_layer_propagate")
+ --[[
+ if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history
+ for i = 1, self.batch_size do
+ local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
+ local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
+ if (seq_start > 0 or seq_end > 0) then
+ for p, conn in pairs(ref.o_conns_p) do
+ if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then
+ ref.outputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default)
+ end
+ end
+ end
+ end
+ end
+ ]]--
+ --set input flag for future layers
+ for i = 1, #ref.dim_out do
+ if (ref.outputs_b[t][i] == true) then
+ nerv.error("this time's outputs_b should be false")
+ end
+ ref.outputs_b[t][i] = true
+ end
+
+ --try dfs for further layers
+ for _, conn in pairs(ref.o_conns_p) do
+ --print("debug dfs-searching", conn.dst.ref.layer.id)
+ conn.dst.ref.inputs_b[t + conn.time][conn.dst.port] = true
+ self:propagate_dfs(conn.dst.ref, t + conn.time)
+ end
+end
+
+--do_update: bool, whether we are doing back-propagate or updating the parameters
+function TNN:net_backpropagate(do_update) --propagate according to feeds_now
+ if do_update == nil then
+ nerv.error("do_update should not be nil")
+ end
+ for t = 1, self.chunk_size, 1 do
+ for id, ref in pairs(self.layers) do
+ for p = 1, #ref.dim_out do
+ ref.err_inputs_b[t][p] = false
+ end
+ for p = 1, #ref.dim_in do
+ ref.err_outputs_b[t][p] = false
+ end
+ end
+ end
+
+ local feeds_now = self.feeds_now
+ for t = 1, self.chunk_size do
+ if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then
+ for i = 1, #self.dim_out do
+ local ref = self.outputs_p[i].ref
+ local p = self.outputs_p[i].port
+ ref.err_inputs_b[t][p] = true
+ self:backpropagate_dfs(ref, t, do_update)
+ end
+ end
+ end
+
+ local flag_out = true
+ for t = 1, self.chunk_size do --check whether every output has been computed
+ if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0 then
+ for i = 1, #self.dim_in do
+ local ref = self.inputs_p[i].ref
+ if ref.err_outputs_b[t][1] ~= true then
+ flag_out = false
+ break
+ end
+ end
+ end
+ end
+ if (flag_out == false) then
+ nerv.error("some thing wrong, some input is not back_propagated")
+ end
+end
+
+--ref: the TNN_ref of a layer
+--t: the current time to propagate
+function TNN:backpropagate_dfs(ref, t, do_update)
+ if self:out_of_feedrange(t) then
+ return
+ end
+ if ref.err_outputs_b[t][1] == true then --already back_propagated, 1 is just a random port
+ return
+ end
+
+ --print("debug dfs", ref.layer.id, t)
+
+ local flag = true --whether have all inputs
+ for _, conn in pairs(ref.o_conns_p) do
+ local p = conn.src.port
+ if (not (ref.err_inputs_b[t][p] or self:out_of_feedrange(t + conn.time))) then
+ flag = false
+ break
+ end
+ end
+ if (flag == false) then
+ return
+ end
+
+ --ok, do back_propagate
+ --print("debug ok, back-propagating(or updating)")
+ if (do_update == false) then
+ self.gconf.timer:tic("tnn_actual_layer_backpropagate")
+ ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
+ self.gconf.timer:toc("tnn_actual_layer_backpropagate")
+ if self.clip_t > 0 then
+ for _, conn in pairs(ref.i_conns_p) do
+ local p = conn.dst.port --port for ref
+ if conn.time ~= 0 then
+ --print("debug clip_t tnn", ref.id, "port:", p, "clip:", self.clip_t)
+ ref.err_outputs_m[t][p]:clip(-self.clip_t, self.clip_t)
+ end
+ end
+ end
+ else
+ --print(ref.err_inputs_m[t][1])
+ self.gconf.timer:tic("tnn_actual_layer_update")
+ ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
+ self.gconf.timer:toc("tnn_actual_layer_update")
+ end
+
+ if (do_update == false and bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border errors
+ for i = 1, self.batch_size do
+ local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
+ local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
+ if (seq_start > 0 or seq_end > 0) then
+ for p, conn in pairs(ref.i_conns_p) do
+ if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to zero
+ ref.err_outputs_m[t][p][i - 1]:fill(0)
+ end
+ end
+ end
+ end
+ end
+
+ for i = 1, #ref.dim_in do
+ if (ref.err_outputs_b[t][i] == true) then
+ nerv.error("this time's outputs_b should be false")
+ end
+ ref.err_outputs_b[t][i] = true
+ end
+
+ --try dfs for further layers
+ for _, conn in pairs(ref.i_conns_p) do
+ --print("debug dfs-searching", conn.src.ref.layer.id)
+ conn.src.ref.err_inputs_b[t - conn.time][conn.src.port] = true
+ self:backpropagate_dfs(conn.src.ref, t - conn.time, do_update)
+ end
+end
+
+--Return: nerv.ParamRepo
+function TNN:get_params()
+ local param_repos = {}
+ for id, ref in pairs(self.layers) do
+ table.insert(param_repos, ref.layer:get_params())
+ end
+ return nerv.ParamRepo.merge(param_repos)
+end
+