aboutsummaryrefslogtreecommitdiff
path: root/nerv/tnn/tnn.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/tnn/tnn.lua')
-rw-r--r--nerv/tnn/tnn.lua596
1 files changed, 0 insertions, 596 deletions
diff --git a/nerv/tnn/tnn.lua b/nerv/tnn/tnn.lua
deleted file mode 100644
index d527fe6..0000000
--- a/nerv/tnn/tnn.lua
+++ /dev/null
@@ -1,596 +0,0 @@
-local TNN = nerv.class("nerv.TNN")
-
-local function parse_id(str)
- --used to parse layerid[portid],time
- local id, port, time, _
- _, _, id, port, time = string.find(str, "([a-zA-Z0-9_]+)%[([0-9]+)%][,]*([0-9]*)")
- if id == nil or port == nil then
- _, _, id, port, time = string.find(str, "(.+)%[([0-9]+)%][,]*([0-9]*)")
- if not (id == "<input>" or id == "<output>") then
- nerv.error("wrong format of connection id")
- end
- end
- --print(str, id, port, time)
- port = tonumber(port)
- if (time == nil) then
- time = 0
- else
- time = tonumber(time)
- end
- --now time don't need to be parsed
- return id, port
-end
-
-local function discover(id, layers, layer_repo)
- local ref = layers[id]
- if id == "<input>" or id == "<output>" then
- return nil
- end
- if ref == nil then
- local layer = layer_repo:get_layer(id)
- local dim_in, dim_out = layer:get_dim()
- ref = {
- layer = layer,
- id = layer.id,
- inputs_m = {}, --storage for computation, inputs_m[time][port]
- inputs_b = {}, --inputs_g[time][port], whether this input can been computed
- inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port]
- outputs_m = {},
- outputs_b = {},
- err_inputs_m = {},
- err_inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation
- err_inputs_b = {},
- err_outputs_m = {},
- err_outputs_b = {},
- i_conns_p = {}, --list of inputing connections
- o_conns_p = {}, --list of outputing connections
- dim_in = dim_in, --list of dimensions of ports
- dim_out = dim_out,
- }
- layers[id] = ref
- end
- return ref
-end
-
-nerv.TNN.FC = {} --flag const
-nerv.TNN.FC.SEQ_START = 4
-nerv.TNN.FC.SEQ_END = 8
-nerv.TNN.FC.HAS_INPUT = 1
-nerv.TNN.FC.HAS_LABEL = 2
-nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label
-
-function TNN.make_initial_store(st, p, dim, batch_size, chunk_size, extend_t, global_conf, st_c, p_c, t_c)
- --Return a table of matrix storage from time (1-extend_t)..(chunk_size+extend_t)
- if (type(st) ~= "table") then
- nerv.error("st should be a table")
- end
- for i = 1 - extend_t - 2, chunk_size + extend_t + 2 do --intentionally allocated more time
- if (st[i] == nil) then
- st[i] = {}
- end
- st[i][p] = global_conf.cumat_type(batch_size, dim)
- st[i][p]:fill(0)
- if (st_c ~= nil) then
- if (st_c[i + t_c] == nil) then
- st_c[i + t_c] = {}
- end
- st_c[i + t_c][p_c] = st[i][p]
- end
- end
- collectgarbage("collect") --free the old one to save memory
-end
-
-function TNN:out_of_feedrange(t) --out of chunk, or no input, for the current feed
- if (t < 1 or t > self.chunk_size) then
- return true
- end
- if (self.feeds_now.flagsPack_now[t] == 0 or self.feeds_now.flagsPack_now[t] == nil) then
- return true
- end
- return false
-end
-
-function TNN:__init(id, global_conf, layer_conf)
- self.clip_t = layer_conf.clip_t
- if self.clip_t == nil then
- self.clip_t = 0
- end
- if self.clip_t > 0 then
- nerv.info("tnn(%s) will clip gradient across time with %f...", id, self.clip_t)
- end
-
- self.extend_t = layer_conf.extend_t --TNN will allocate storage of time for 1-extend_t .. chunk_size+extend_t
- if self.extend_t == nil then
- self.extend_t = 5
- end
- nerv.info("tnn(%s) will extend storage beyond MB border for time steps %d...", id, self.extend_t)
-
- local layers = {}
- local inputs_p = {} --map:port of the TNN to layer ref and port
- local outputs_p = {}
- local dim_in = layer_conf.dim_in
- local dim_out = layer_conf.dim_out
- local parsed_conns = {}
- local _
-
- for id, _ in pairs(layer_conf.sub_layers.layers) do --caution: with this line, some layer not connected will be included
- discover(id, layers, layer_conf.sub_layers)
- end
-
- for _, ll in pairs(layer_conf.connections) do
- local id_from, port_from = parse_id(ll[1])
- local id_to, port_to = parse_id(ll[2])
- local time_to = ll[3]
-
- print(id_from, id_to, time_to)
-
- local ref_from = discover(id_from, layers, layer_conf.sub_layers)
- local ref_to = discover(id_to, layers, layer_conf.sub_layers)
-
- if (id_from == "<input>") then
- if (dim_in[port_from] ~= ref_to.dim_in[port_to] or time_to ~= 0) then
- nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
- end
- inputs_p[port_from] = {["ref"] = ref_to, ["port"] = port_to}
- ref_to.inputs_m[port_to] = {} --just a place holder
- elseif (id_to == "<output>") then
- if (dim_out[port_to] ~= ref_from.dim_out[port_from] or time_to ~= 0) then
- nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
- end
- outputs_p[port_to] = {["ref"] = ref_from, ["port"] = port_from}
- ref_from.outputs_m[port_from] = {} --just a place holder
- else
- local conn_now = {
- ["src"] = {["ref"] = ref_from, ["port"] = port_from},
- ["dst"] = {["ref"] = ref_to, ["port"] = port_to},
- ["time"] = time_to
- }
- if (ref_to.dim_in[port_to] ~= ref_from.dim_out[port_from]) then
- nerv.error("mismatch dimension or wrong time %s,%s,%d", ll[1], ll[2], ll[3])
- end
- table.insert(parsed_conns, conn_now)
- ref_to.i_conns_p[conn_now.dst.port] = conn_now
- ref_from.o_conns_p[conn_now.src.port] = conn_now
- end
- end
-
- for id, ref in pairs(layers) do
- print(id, "#dim_in:", #ref.dim_in, "#dim_out:", #ref.dim_out, "#i_conns_p:", #ref.i_conns_p, "#o_conns_p", #ref.o_conns_p)
- end
-
- self.layers = layers
- self.inputs_p = inputs_p
- self.outputs_p = outputs_p
- self.id = id
- self.dim_in = dim_in
- self.dim_out = dim_out
- self.parsed_conns = parsed_conns
- self.gconf = global_conf
-end
-
-function TNN:init(batch_size, chunk_size)
- self.batch_size = batch_size
- self.chunk_size = chunk_size
- for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN
- local _, output_dim
- local ref_from, port_from, ref_to, port_to, time
- ref_from, port_from = conn.src.ref, conn.src.port
- ref_to, port_to = conn.dst.ref, conn.dst.port
- time = conn.time
-
- local dim = ref_from.dim_out[port_from]
- if (dim == 0) then
- nerv.error("layer %s has a zero dim port", ref_from.layer.id)
- end
-
- nerv.info("TNN initing storage %s->%s", ref_from.layer.id, ref_to.layer.id)
- ref_to.inputs_matbak_p[port_to] = self.gconf.cumat_type(batch_size, dim)
- self.make_initial_store(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.inputs_m, port_to, time)
- ref_from.err_inputs_matbak_p[port_from] = self.gconf.cumat_type(batch_size, dim)
- self.make_initial_store(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.extend_t, self.gconf, ref_to.err_outputs_m, port_to, time)
- end
-
- self.outputs_m = {}
- self.err_inputs_m = {}
- for i = 1, #self.dim_out do --Init storage for output ports
- local ref = self.outputs_p[i].ref
- local p = self.outputs_p[i].port
- self.make_initial_store(ref.outputs_m, p, self.dim_out[i], batch_size, chunk_size, self.extend_t, self.gconf, self.outputs_m, i, 0)
- self.make_initial_store(ref.err_inputs_m, p, self.dim_out[i], batch_size, chunk_size, self.extend_t, self.gconf, self.err_inputs_m, i, 0)
- end
-
- self.inputs_m = {}
- self.err_outputs_m = {}
- for i = 1, #self.dim_in do --Init storage for input ports
- local ref = self.inputs_p[i].ref
- local p = self.inputs_p[i].port
- self.make_initial_store(ref.inputs_m, p, self.dim_in[i], batch_size, chunk_size, self.extend_t, self.gconf, self.inputs_m, i, 0)
- self.make_initial_store(ref.err_outputs_m, p, self.dim_in[i], batch_size, chunk_size, self.extend_t, self.gconf, self.err_outputs_m, i, 0)
- end
-
- for id, ref in pairs(self.layers) do --Calling init for child layers
- for i = 1, #ref.dim_in do
- if (ref.inputs_m[i] == nil or ref.err_outputs_m[i] == nil) then
- nerv.error("dangling input port %d of layer %s", i, id)
- end
- end
- for i = 1, #ref.dim_out do
- if (ref.outputs_m[i] == nil or ref.err_inputs_m[i] == nil) then
- nerv.error("dangling output port %d of layer %s", i, id)
- end
- end
- -- initialize sub layers
- nerv.info("TNN initing sub-layer %s", ref.id)
- ref.layer:init(batch_size, chunk_size)
- collectgarbage("collect")
- end
-
- local flags_now = {}
- local flagsPack_now = {}
- for i = 1, chunk_size do
- flags_now[i] = {}
- flagsPack_now[i] = 0
- end
-
- self.feeds_now = {} --feeds is for the reader to fill
- self.feeds_now.inputs_m = self.inputs_m
- self.feeds_now.flags_now = flags_now
- self.feeds_now.flagsPack_now = flagsPack_now
-
- self:flush_all()
-end
-
---[[
-function DAGLayer:batch_resize(batch_size)
- self.gconf.batch_size = batch_size
-
- for i, conn in ipairs(self.parsed_conn) do
- local _, output_dim
- local ref_from, port_from, ref_to, port_to
- ref_from, port_from = unpack(conn[1])
- ref_to, port_to = unpack(conn[2])
- _, output_dim = ref_from.layer:get_dim()
-
- if ref_from.outputs[port_from]:nrow() ~= batch_size and output_dim[port_from] > 0 then
- local mid = self.gconf.cumat_type(batch_size, output_dim[port_from])
- local err_mid = mid:create()
-
- ref_from.outputs[port_from] = mid
- ref_to.inputs[port_to] = mid
-
- ref_from.err_inputs[port_from] = err_mid
- ref_to.err_outputs[port_to] = err_mid
- end
- end
- for id, ref in pairs(self.layers) do
- ref.layer:batch_resize(batch_size)
- end
- collectgarbage("collect")
-end
-]]--
-
-function TNN:flush_all() --flush all history and activation
- local _, ref
- for _, ref in pairs(self.layers) do
- for i = 1, #ref.dim_in do
- for t = 1 - self.extend_t, self.chunk_size + self.extend_t do
- ref.inputs_m[t][i]:fill(self.gconf.nn_act_default)
- if (ref.inputs_b[t] == nil) then
- ref.inputs_b[t] = {}
- end
- ref.inputs_b[t][i] = false
- ref.err_outputs_m[t][i]:fill(0)
- if (ref.err_outputs_b[t] == nil) then
- ref.err_outputs_b[t] = {}
- end
- ref.err_outputs_b[t][i] = false
- end
- end
- for i = 1, #ref.dim_out do
- for t = 1 - self.extend_t, self.chunk_size + self.extend_t do
- ref.outputs_m[t][i]:fill(self.gconf.nn_act_default)
- if (ref.outputs_b[t] == nil) then
- ref.outputs_b[t] = {}
- end
- ref.outputs_b[t][i] = false
- ref.err_inputs_m[t][i]:fill(0)
- if (ref.err_inputs_b[t] == nil) then
- ref.err_inputs_b[t] = {}
- end
- ref.err_inputs_b[t][i] = false
- end
- end
- end
-end
-
---reader: some reader
---Returns: bool, whether has new feed
---Returns: feeds, a table that will be filled with the reader's feeds
-function TNN:getfeed_from_reader(reader)
- local feeds_now = self.feeds_now
- local got_new = reader:get_batch(feeds_now)
- return got_new, feeds_now
-end
-
-function TNN:move_right_to_nextmb(list_t) --move output history activations of 1..chunk_size to 1-chunk_size..0
- if list_t == nil then
- list_t = {}
- for i = self.extend_t, 1, -1 do
- list_t[i] = 1 - i
- end
- end
- for i = 1, #list_t do
- t = list_t[i]
- if t < 1 - self.extend_t or t > 0 then
- nerv.error("MB move range error")
- end
- for id, ref in pairs(self.layers) do
- for p = 1, #ref.dim_out do
- ref.outputs_m[t][p]:copy_fromd(ref.outputs_m[t + self.chunk_size][p])
- end
- end
- end
-end
-
-function TNN:net_propagate() --propagate according to feeds_now
- for t = 1, self.chunk_size, 1 do
- for id, ref in pairs(self.layers) do
- for p = 1, #ref.dim_out do
- ref.outputs_b[t][p] = false
- end
- for p = 1, #ref.dim_in do
- ref.inputs_b[t][p] = false
- end
- end
- end
-
- local feeds_now = self.feeds_now
- for t = 1, self.chunk_size do --some layer maybe do not have inputs from time 1..chunk_size
- for id, ref in pairs(self.layers) do
- if #ref.dim_in > 0 then --some layer is just there(only to save some parameter)
- self:propagate_dfs(ref, t)
- end
- end
- end
- for t = 1, self.chunk_size do
- if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
- for i = 1, #self.dim_in do
- local ref = self.inputs_p[i].ref
- local p = self.inputs_p[i].port
- ref.inputs_b[t][p] = true
- self:propagate_dfs(ref, t)
- end
- end
- end
-
- local flag_out = true
- for t = 1, self.chunk_size do --check whether every output has been computed
- if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then
- for i = 1, #self.dim_out do
- local ref = self.outputs_p[i].ref
- if (ref.outputs_b[t][1] ~= true) then
- flag_out = false
- break
- end
- end
- end
- end
-
- if (flag_out == false) then
- nerv.error("some thing wrong, some labeled output is not propagated")
- end
-end
-
---ref: the TNN_ref of a layer
---t: the current time to propagate
-function TNN:propagate_dfs(ref, t)
- if (self:out_of_feedrange(t)) then
- return
- end
- if (ref.outputs_b[t][1] == true) then --already propagated, 1 is just a random port
- return
- end
-
- --print("debug dfs", ref.layer.id, t)
-
- local flag = true --whether have all inputs
- for _, conn in pairs(ref.i_conns_p) do
- local p = conn.dst.port
- if (not (ref.inputs_b[t][p] or self:out_of_feedrange(t - conn.time))) then
- flag = false
- break
- end
- end
- if (flag == false) then
- return
- end
-
- --ok, do propagate
- --print("debug ok, propagating");
- --The MB moving will cause bordering history to be changed, so it is more wise to flush the input activation
- if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border history
- for i = 1, self.batch_size do
- local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
- local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
- if (seq_start > 0 or seq_end > 0) then
- for p, conn in pairs(ref.i_conns_p) do
- if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to default
- ref.inputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default)
- end
- end
- end
- end
- end
- self.gconf.timer:tic("tnn_actual_layer_propagate")
- ref.layer:propagate(ref.inputs_m[t], ref.outputs_m[t], t) --propagate!
- self.gconf.timer:toc("tnn_actual_layer_propagate")
- --[[
- if (bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --restore cross-border history
- for i = 1, self.batch_size do
- local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
- local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
- if (seq_start > 0 or seq_end > 0) then
- for p, conn in pairs(ref.o_conns_p) do
- if ((ref.o_conns_p[p].time > 0 and seq_end > 0) or (ref.o_conns_p[p].time < 0 and seq_start > 0)) then
- ref.outputs_m[t][p][i - 1]:fill(self.gconf.nn_act_default)
- end
- end
- end
- end
- end
- ]]--
- --set input flag for future layers
- for i = 1, #ref.dim_out do
- if (ref.outputs_b[t][i] == true) then
- nerv.error("this time's outputs_b should be false")
- end
- ref.outputs_b[t][i] = true
- end
-
- --try dfs for further layers
- for _, conn in pairs(ref.o_conns_p) do
- --print("debug dfs-searching", conn.dst.ref.layer.id)
- conn.dst.ref.inputs_b[t + conn.time][conn.dst.port] = true
- self:propagate_dfs(conn.dst.ref, t + conn.time)
- end
-end
-
---do_update: bool, whether we are doing back-propagate or updating the parameters
-function TNN:net_backpropagate(do_update) --propagate according to feeds_now
- if do_update == nil then
- nerv.error("do_update should not be nil")
- end
- for t = 1, self.chunk_size, 1 do
- for id, ref in pairs(self.layers) do
- for p = 1, #ref.dim_out do
- ref.err_inputs_b[t][p] = false
- end
- for p = 1, #ref.dim_in do
- ref.err_outputs_b[t][p] = false
- end
- end
- end
-
- local feeds_now = self.feeds_now
- for t = 1, self.chunk_size do --some layer maybe do not have outputs from time 1..chunk_size
- for id, ref in pairs(self.layers) do
- if #ref.dim_out > 0 then --some layer is just there(only to save some parameter)
- self:backpropagate_dfs(ref, t, do_update)
- end
- end
- end
- for t = 1, self.chunk_size do
- if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then
- for i = 1, #self.dim_out do
- local ref = self.outputs_p[i].ref
- local p = self.outputs_p[i].port
- ref.err_inputs_b[t][p] = true
- self:backpropagate_dfs(ref, t, do_update)
- end
- end
- end
-
- local flag_out = true
- for t = 1, self.chunk_size do --check whether every output has been computed
- if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0 then
- for i = 1, #self.dim_in do
- local ref = self.inputs_p[i].ref
- if ref.err_outputs_b[t][1] ~= true then
- flag_out = false
- break
- end
- end
- end
- end
- if (flag_out == false) then
- nerv.error("some thing wrong, some input is not back_propagated")
- end
-end
-
---ref: the TNN_ref of a layer
---t: the current time to propagate
-function TNN:backpropagate_dfs(ref, t, do_update)
- if do_update == nil then
- nerv.error("got a nil do_update")
- end
- if self:out_of_feedrange(t) then
- return
- end
- if ref.err_outputs_b[t][1] == true then --already back_propagated, 1 is just a random port
- return
- end
-
- --print("debug dfs", ref.layer.id, t)
-
- local flag = true --whether have all inputs
- for _, conn in pairs(ref.o_conns_p) do
- local p = conn.src.port
- if (not (ref.err_inputs_b[t][p] or self:out_of_feedrange(t + conn.time))) then
- flag = false
- break
- end
- end
- if (flag == false) then
- return
- end
-
- --ok, do back_propagate
- --print("debug ok, back-propagating(or updating)")
- if (do_update == false) then
- self.gconf.timer:tic("tnn_actual_layer_backpropagate")
- ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
- self.gconf.timer:toc("tnn_actual_layer_backpropagate")
- if self.clip_t > 0 then
- for _, conn in pairs(ref.i_conns_p) do
- local p = conn.dst.port --port for ref
- if conn.time ~= 0 then
- --print("debug clip_t tnn", ref.id, "port:", p, "clip:", self.clip_t)
- ref.err_outputs_m[t][p]:clip(-self.clip_t, self.clip_t)
- end
- end
- end
- else
- --print(ref.err_inputs_m[t][1])
- self.gconf.timer:tic("tnn_actual_layer_update")
- ref.layer:update(ref.err_inputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
- self.gconf.timer:toc("tnn_actual_layer_update")
- end
-
- if (do_update == false and bit.band(self.feeds_now.flagsPack_now[t], bit.bor(nerv.TNN.FC.SEQ_START, nerv.TNN.FC.SEQ_END)) > 0) then --flush cross-border errors
- for i = 1, self.batch_size do
- local seq_start = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_START)
- local seq_end = bit.band(self.feeds_now.flags_now[t][i], nerv.TNN.FC.SEQ_END)
- if (seq_start > 0 or seq_end > 0) then
- for p, conn in pairs(ref.i_conns_p) do
- if ((ref.i_conns_p[p].time > 0 and seq_start > 0) or (ref.i_conns_p[p].time < 0 and seq_end > 0)) then --cross-border, set to zero
- ref.err_outputs_m[t][p][i - 1]:fill(0)
- end
- end
- end
- end
- end
-
- for i = 1, #ref.dim_in do
- if (ref.err_outputs_b[t][i] == true) then
- nerv.error("this time's outputs_b should be false")
- end
- ref.err_outputs_b[t][i] = true
- end
-
- --try dfs for further layers
- for _, conn in pairs(ref.i_conns_p) do
- --print("debug dfs-searching", conn.src.ref.layer.id)
- conn.src.ref.err_inputs_b[t - conn.time][conn.src.port] = true
- self:backpropagate_dfs(conn.src.ref, t - conn.time, do_update)
- end
-end
-
---Return: nerv.ParamRepo
-function TNN:get_params()
- local param_repos = {}
- for id, ref in pairs(self.layers) do
- table.insert(param_repos, ref.layer:get_params())
- end
- return nerv.ParamRepo.merge(param_repos)
-end
-