diff options
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r-- | nerv/examples/lmptb/rnn/tnn.lua | 539 |
1 files changed, 539 insertions, 0 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua new file mode 100644 index 0000000..d6bf42e --- /dev/null +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -0,0 +1,539 @@ +local TNN = nerv.class("nerv.TNN", "nerv.Layer") + +local function parse_id(str) + --used to parse layerid[portid],time + local id, port, time, _ + _, _, id, port, time = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%][,]*([0-9]*)") + if id == nil or port == nil then + _, _, id, port, time = string.find(str, "(.+)%[([0-9]+)%][,]*([0-9]*)") + if not (id == "<input>" or id == "<output>") then + nerv.error("wrong format of connection id") + end + end + --print(str, id, port, time) + port = tonumber(port) + if (time == nil) then + time = 0 + else + time = tonumber(time) + end + --now time don't need to be parsed + return id, port +end + +local function discover(id, layers, layer_repo) + local ref = layers[id] + if id == "<input>" or id == "<output>" then + return nil + end + if ref == nil then + local layer = layer_repo:get_layer(id) + local dim_in, dim_out = layer:get_dim() + ref = { + layer = layer, + inputs_m = {}, --storage for computation, inputs_m[time][port] + inputs_b = {}, --inputs_g[time][port], whether this input can been computed + inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port] + outputs_m = {}, + outputs_b = {}, + err_inputs_m = {}, + err_inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation + err_inputs_b = {}, + err_outputs_m = {}, + err_outputs_b = {}, + i_conns_p = {}, --list of inputing connections + o_conns_p = {}, --list of outputing connections + dim_in = dim_in, --list of dimensions of ports + dim_out = dim_out, + } + layers[id] = ref + end + return ref +end + +nerv.TNN.FC = {} --flag const +nerv.TNN.FC.SEQ_START = 4 +nerv.TNN.FC.SEQ_END = 8 +nerv.TNN.FC.HAS_INPUT = 1 +nerv.TNN.FC.HAS_LABEL = 2 +nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label + +function TNN.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c, t_c) + --Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size) + if (type(st) ~= "table") then + nerv.error("st should be a table") + end + for i = 1 - chunk_size - 1, chunk_size * 2 + 1 do --intentionally allocated more time, should be [1-chunk_size, chunk_size*2] + if (st[i] == nil) then + st[i] = {} + end + st[i][p] = global_conf.cumat_type(batch_size, dim) + st[i][p]:fill(0) + if (st_c ~= nil) then + if (st_c[i + t_c] == nil) then + st_c[i + t_c] = {} + end + st_c[i + t_c][p_c] = st[i][p] + end + end +end + +function TNN:outOfFeedRange(t) --out of chunk, or no input, for the current feed + if (t < 1 or t > self.chunk_size) then + return true + end + if (self.feeds_now.flagsPack_now[t] == 0 or self.feeds_now.flagsPack_now[t] == nil) then + return true + end + return false +end + +function TNN:__init(id, global_conf, layer_conf) + local layers = {} + local inputs_p = {} --map:port of the TNN to layer ref and port + local outputs_p = {} + local dim_in = layer_conf.dim_in + local dim_out = layer_conf.dim_out + local parsed_conns = {} + local _ + + for _, ll in pairs(layer_conf.connections) do + local id_from, port_from = parse_id(ll[1]) + local id_to, port_to = parse_id(ll[2]) + local time_to = ll[3] + + print(id_from, id_to, time_to) + + local ref_from = discover(id_from, layers, layer_conf.sub_layers) + local ref_to = discover(id_to, layers, layer_conf.sub_layers) + + if (id_from == "<input>") then + if (dim_in[port_from] ~= ref_to.dim_in[port_to] or time_to ~= 0) then + nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) + end + inputs_p[port_from] = {["ref"] = ref_to, ["port"] = port_to} + ref_to.inputs_m[port_to] = {} --just a place holder + elseif (id_to == "<output>") then + if (dim_out[port_to] ~= ref_from.dim_out[port_from] or time_to ~= 0) then + nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) + end + outputs_p[port_to] = {["ref"] = ref_from, ["port"] = port_from} + ref_from.outputs_m[port_from] = {} --just a place holder + else + local conn_now = { + ["src"] = {["ref"] = ref_from, ["port"] = port_from}, + ["dst"] = {["ref"] = ref_to, ["port"] = port_to}, + ["time"] = time_to + } + if (ref_to.dim_in[port_to] ~= ref_from.dim_out[port_from]) then + nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3]) + end + table.insert(parsed_conns, conn_now) + ref_to.i_conns_p[conn_now.dst.port] = conn_now + ref_from.o_conns_p[conn_now.src.port] = conn_now + end + end + + for id, ref in pairs(layers) do + print(id, "#dim_in:", #ref.dim_in, "#dim_out:", #ref.dim_out, "#i_conns_p:", #ref.i_conns_p, "#o_conns_p", #ref.o_conns_p) + end + + self.layers = layers + self.inputs_p = inputs_p + self.outputs_p = outputs_p + self.id = id + self.dim_in = dim_in + self.dim_out = dim_out + self.parsed_conns = parsed_conns + self.gconf = global_conf +end + +function TNN:init(batch_size, chunk_size) + self.batch_size = batch_size + self.chunk_size = chunk_size + for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN + local _, output_dim + local ref_from, port_from, ref_to, port_to, time + ref_from, port_from = conn.src.ref, conn.src.port + ref_to, port_to = conn.dst.ref, conn.dst.port + time = conn.time + + local dim = ref_from.dim_out[port_from] + if (dim == 0) then + nerv.error("layer %s has a zero dim port", ref_from.layer.id) + end + + print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id) + ref_to.inputs_matbak_p[port_to] = self.gconf.cumat_type(batch_size, dim) + self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to, time) + ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim) + self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to, time) + + end + + self.outputs_m = {} + self.err_inputs_m = {} + for i = 1, #self.dim_out do --Init storage for output ports + local ref = self.outputs_p[i].ref + local p = self.outputs_p[i].port + self.makeInitialStore(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.outputs_m, i, 0) + self.makeInitialStore(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.gconf, self.err_inputs_m, i, 0) + end + + self.inputs_m = {} + self.err_outputs_m = {} + for i = 1, #self.dim_in do --Init storage for input ports + local ref = self.inputs_p[i].ref + local p = self.inputs_p[i].port + self.makeInitialStore(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.inputs_m, i, 0) + self.makeInitialStore(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.gconf, self.err_outputs_m, i, 0) + end + + for id, ref in pairs(self.layers) do --Calling init for child layers + for i = 1, #ref.dim_in do + if (ref.inputs_m[i] == nil or ref.err_outputs_m[i] == nil) then + nerv.error("dangling input port %d of layer %s", i, id) + end + end + for i = 1, #ref.dim_out do + if (ref.outputs_m[i] == nil or ref.err_inputs_m[i] == nil) then + nerv.error("dangling output port %d of layer %s", i, id) + end + end + -- initialize sub layers + ref.layer:init(batch_size, chunk_size) + end + + local flags_now = {} + local flagsPack_now = {} + for i = 1, chunk_size do + flags_now[i] = {} + flagsPack_now[i] = 0 + end + + self.feeds_now = {} --feeds is for the reader to fill + self.feeds_now.inputs_m = self.inputs_m + self.feeds_now.flags_now = flags_now + self.feeds_now.flagsPack_now = flagsPack_now + + self:flush_all() +end + +--[[ +function DAGLayer:batch_resize(batch_size) + self.gconf.batch_size = batch_size + + for i, conn in ipairs(self.parsed_conn) do + local _, output_dim + local ref_from, port_from, ref_to, port_to + ref_from, port_from = unpack(conn[1]) + ref_to, port_to = unpack(conn[2]) + _, output_dim = ref_from.layer:get_dim() + + if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then + local mid = self.gconf.cumat_type(batch_size, output_dim[port_from]) + local err_mid = mid:create() + + ref_from.outputs[port_from] = mid + ref_to.inputs[port_to] = mid + + ref_from.err_inputs[port_from] = err_mid + ref_to.err_outputs[port_to] = err_mid + end + end + for id, ref in pairs(self.layers) do + ref.layer:batch_resize(batch_size) + end + collectgarbage("collect") +end +]]-- + +function TNN:flush_all() --flush all history and activation + local _, ref + for _, ref in pairs(self.layers) do + for i = 1, #ref.dim_in do + for t = 1 - self.chunk_size, self.chunk_size * 2 do + ref.inputs_m[t][i]:fill(self.gconf.nn_act_default) + if (ref.inputs_b[t] == nil) then + ref.inputs_b[t] = {} + end + ref.inputs_b[t][i] = false + ref.err_outputs_m[t][i]:fill(0) + if (ref.err_outputs_b[t] == nil) then + ref.err_outputs_b[t] = {} + end + ref.err_outputs_b[t][i] = false + end + end + for i = 1, #ref.dim_out do + for t = 1 - self.chunk_size, self.chunk_size * 2 do + ref.outputs_m[t][i]:fill(self.gconf.nn_act_default) + if (ref.outputs_b[t] == nil) then + ref.outputs_b[t] = {} + end + ref.outputs_b[t][i] = false + ref.err_inputs_m[t][i]:fill(0) + if (ref.err_inputs_b[t] == nil) then + ref.err_inputs_b[t] = {} + end + ref.err_inputs_b[t][i] = false + end + end + end +end + +--reader: some reader +--Returns: bool, whether has new feed +--Returns: feeds, a table that will be filled with the reader's feeds +function TNN:getFeedFromReader(reader) + local feeds_now = self.feeds_now + local got_new = reader:get_batch(feeds_now) + return got_new, feeds_now +end + +function TNN:moveRightToNextMB() --move output history activations of 1..chunk_size to 1-chunk_size..0 + for t = 1, self.chunk_size, 1 do + for id, ref in pairs(self.layers) do + for p = 1, #ref.dim_out do + ref.outputs_m[t - self.chunk_size][p]:copy_fromd(ref.outputs_m[t][p]) + end + end + end +end + +function TNN:net_propagate() --propagate according to feeds_now + for t = 1, self.chunk_size, 1 do + for id, ref in pairs(self.layers) do + for p = 1, #ref.dim_out do + ref.outputs_b[t][p] = false + end + for p = 1, #ref.dim_in do + ref.inputs_b[t][p] = false + end + end + end + + local feeds_now = self.feeds_now + for t = 1, self.chunk_size do + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then + for i = 1, #self.dim_in do + local ref = self.inputs_p[i].ref + local p = self.inputs_p[i].port + ref.inputs_b[t][p] = true + self:propagate_dfs(ref, t) + end + end + end + + local flag_out = true + for t = 1, self.chunk_size do --check whether every output has been computed + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then + for i = 1, #self.dim_out do + local ref = self.outputs_p[i].ref + if (ref.outputs_b[t][1] ~= true) then + flag_out = false + break + end + end + end + end + if (flag_out == false) then + nerv.error("some thing wrong, some labeled output is not propagated") + end +end + +--ref: the TNN_ref of a layer +--t: the current time to propagate +function TNN:propagate_dfs(ref, t) + if (self:outOfFeedRange(t)) then + return + end + if (ref.outputs_b[t][1] == true) then --already propagated, 1 is just a random port + return + end + + --print("debug dfs", ref.layer.id, t) + + local flag = true --whether have all inputs + for _, conn in pairs(ref.i_conns_p) do + local p = conn.dst.port + if (not (ref.inputs_b[t][p] or self:outOfFeedRange(t - conn.time))) then + flag = false + break + end + end + if (flag == false) then + return + end + + --ok, do propagate + --print("debug ok, propagating"); + --[[ + if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border history + for i = 1, self.batch_size do + local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START) + local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END) + if (seq_start > 0 or seq_end > 0) then + for p, conn in pairs(ref.i_conns_p) do + if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to default + ref.inputs_matbak_p[p][i - 1]:copy_fromd(ref.inputs_m[t][p][i - 1]) + ref.inputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default) + end + end + end + end + end + ]]-- + self.gconf.timer:tic("tnn_actual_layer_propagate") + ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate! + self.gconf.timer:toc("tnn_actual_layer_propagate") + + if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history + for i = 1, self.batch_size do + local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START) + local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END) + if (seq_start > 0 or seq_end > 0) then + for p, conn in pairs(ref.o_conns_p) do + if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then + ref.outputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default) + end + end + end + end + end + --set input flag for future layers + for i = 1, #ref.dim_out do + if (ref.outputs_b[t][i] == true) then + nerv.error("this time's outputs_b should be false") + end + ref.outputs_b[t][i] = true + end + + --try dfs for further layers + for _, conn in pairs(ref.o_conns_p) do + --print("debug dfs-searching", conn.dst.ref.layer.id) + conn.dst.ref.inputs_b[t + conn.time][conn.dst.port] = true + self:propagate_dfs(conn.dst.ref, t + conn.time) + end +end + +--do_update: bool, whether we are doing back-propagate or updating the parameters +function TNN:net_backpropagate(do_update) --propagate according to feeds_now + if (do_update == nil) then + nerv.error("do_update should not be nil") + end + for t = 1, self.chunk_size, 1 do + for id, ref in pairs(self.layers) do + for p = 1, #ref.dim_out do + ref.err_inputs_b[t][p] = false + end + for p = 1, #ref.dim_in do + ref.err_outputs_b[t][p] = false + end + end + end + + local feeds_now = self.feeds_now + for t = 1, self.chunk_size do + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then + for i = 1, #self.dim_out do + local ref = self.outputs_p[i].ref + local p = self.outputs_p[i].port + ref.err_inputs_b[t][p] = true + self:backpropagate_dfs(ref, t, do_update) + end + end + end + + local flag_out = true + for t = 1, self.chunk_size do --check whether every output has been computed + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then + for i = 1, #self.dim_in do + local ref = self.inputs_p[i].ref + if (ref.err_outputs_b[t][1] ~= true) then + flag_out = false + break + end + end + end + end + if (flag_out == false) then + nerv.error("some thing wrong, some input is not back_propagated") + end +end + +--ref: the TNN_ref of a layer +--t: the current time to propagate +function TNN:backpropagate_dfs(ref, t, do_update) + if (self:outOfFeedRange(t)) then + return + end + if (ref.err_outputs_b[t][1] == true) then --already back_propagated, 1 is just a random port + return + end + + --print("debug dfs", ref.layer.id, t) + + local flag = true --whether have all inputs + for _, conn in pairs(ref.o_conns_p) do + local p = conn.src.port + if (not (ref.err_inputs_b[t][p] or self:outOfFeedRange(t + conn.time))) then + flag = false + break + end + end + if (flag == false) then + return + end + + --ok, do back_propagate + --print("debug ok, back-propagating(or updating)") + if (do_update == false) then + self.gconf.timer:tic("tnn_actual_layer_backpropagate") + ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t) + self.gconf.timer:toc("tnn_actual_layer_backpropagate") + else + --print(ref.err_inputs_m[t][1]) + self.gconf.timer:tic("tnn_actual_layer_update") + ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t) + self.gconf.timer:toc("tnn_actual_layer_update") + end + + if (do_update == false and bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border errors + for i = 1, self.batch_size do + local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START) + local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END) + if (seq_start > 0 or seq_end > 0) then + for p, conn in pairs(ref.i_conns_p) do + if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to zero + ref.err_outputs_m[t][p][i - 1]:fill(0) + end + end + end + end + end + + for i = 1, #ref.dim_in do + if (ref.err_outputs_b[t][i] == true) then + nerv.error("this time's outputs_b should be false") + end + ref.err_outputs_b[t][i] = true + end + + --try dfs for further layers + for _, conn in pairs(ref.i_conns_p) do + --print("debug dfs-searching", conn.src.ref.layer.id) + conn.src.ref.err_inputs_b[t - conn.time][conn.src.port] = true + self:backpropagate_dfs(conn.src.ref, t - conn.time, do_update) + end +end + +--Return: nerv.ParamRepo +function TNN:get_params() + local param_repos = {} + for id, ref in pairs(self.queue) do + table.insert(param_repos, ref.layer:get_params()) + end + return nerv.ParamRepo.merge(param_repos) +end + |