diff options
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r-- | nerv/examples/lmptb/rnn/tnn.lua | 93 |
1 files changed, 71 insertions, 22 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua index 3f192b5..8037918 100644 --- a/nerv/examples/lmptb/rnn/tnn.lua +++ b/nerv/examples/lmptb/rnn/tnn.lua @@ -32,10 +32,14 @@ local function discover(id, layers, layer_repo) local dim_in, dim_out = layer:get_dim() ref = { layer = layer, - inputs_m = {}, --storage for computation, inputs_m[port][time] + inputs_m = {}, --storage for computation, inputs_m[time][port] + inputs_b = {}, --inputs_g[time][port], whether this input can been computed outputs_m = {}, + outputs_b = {}, err_inputs_m = {}, + err_inputs_b = {}, err_outputs_m = {}, + err_outputs_b = {}, conns_i = {}, --list of inputing connections conns_o = {}, --list of outputing connections dim_in = dim_in, --list of dimensions of ports @@ -53,7 +57,7 @@ nerv.TNN.FC.HAS_INPUT = 1 nerv.TNN.FC.HAS_LABEL = 2 nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label -function DAGLayer.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c) +function TNN.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c) --Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size) if (type(st) ~= "table") then nerv.error("st should be a table") @@ -73,7 +77,7 @@ function DAGLayer.makeInitialStore(st, p, dim, batch_size, chunk_size, global_co end end -function DAGLayer:__init(id, global_conf, layer_conf) +function TNN:__init(id, global_conf, layer_conf) local layers = {} local inputs_p = {} --map:port of the TDAGLayer to layer ref and port local outputs_p = {} @@ -129,7 +133,9 @@ function DAGLayer:__init(id, global_conf, layer_conf) self.gconf = global_conf end -function DAGLayer:init(batch_size, chunk_size) +function TNN:init(batch_size, chunk_size) + self.batch_size = batch_size + self.chunk_size = chunk_size for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN local _, output_dim local ref_from, port_from, ref_to, port_to @@ -141,8 +147,10 @@ function DAGLayer:init(batch_size, chunk_size) end print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id) - self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.inputs_m, port_to) - self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.err_outputs_m, port_to) + ref_to.inputs_p_matbak[port_to] = self.gconf.cumat_type(batch_size, dim) + self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to) + ref_from.err_inputs_p_matbak[port_from] = self.gconf.cumat_type(batch_size, dim) + self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to) end @@ -180,13 +188,18 @@ function DAGLayer:init(batch_size, chunk_size) end local flags_now = {} + local flagsPack_now = {} for i = 1, chunk_size do flags_now[i] = {} + flagsPack_now[i] = 0 end self.feeds_now = {} --feeds is for the reader to fill self.feeds_now.inputs_m = self.inputs_m self.feeds_now.flags_now = flags_now + self.feeds_now.flagsPack_now = flagsPack_now + + self:flush_all() end --[[ @@ -218,13 +231,61 @@ function DAGLayer:batch_resize(batch_size) end ]]-- +function TNN:flush_all() --flush all history and activation + local _, ref + for _, ref in pairs(self.layers) do + for i = 1, #ref.dim_in do + for t = 1 - self.chunk_size, self.chunk_size * 2 do + ref.inputs_m[t][i]:fill(self.gconf.nn_act_default) + if (ref.inputs_b[t] == nil) then + ref.inputs_b[t] = {} + end + ref.inputs_b[t][i] = false + ref.err_outputs_m[t][i]:fill(0) + if (ref.err_outputs_b[t] == nil) then + ref.err_outputs_b[t] = {} + end + ref.err_outputs_b[t][i] = false + end + end + for i = 1, #ref.dim_out do + for t = 1 - self.chunk_size, self.chunk_size * 2 do + ref.outputs_m[t][i]:fill(self.gconf.nn_act_default) + if (ref.outputs_b[t] == nil) then + ref.outputs_b[t] = {} + end + ref.outputs_b[t][i] = false + ref.err_inputs_m[t][i]:fill(0) + if (ref.err_inputs_b[t] == nil) then + ref.err_inputs_b[t] = {} + end + ref.err_inputs_b[t][i] = false + end + end + end +end + --reader: some reader --Returns: bool, whether has new feed --Returns: feeds, a table that will be filled with the reader's feeds -function DAGLayer:getFeedFromReader(reader) - local feeds = self.feeds_now - local got_new = reader:get_batch(feeds) - return got_new, feeds +function TNN:getFeedFromReader(reader) + local feeds_now = self.feeds_now + local got_new = reader:get_batch(feeds_now) + return got_new, feeds_now +end + +function TNN:net_propagate() --propagate according to feeds_now + local feeds_now = self.feeds_now + for t = 1, chunk_size do + if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then + for i = 1, #self.dim_in do + local ref = inputs_p[i].ref + local p = inputs_p[i].port + ref.inputs_b[t][p] = true + end + --TODO + end + end end function DAGLayer:update(bp_err, input, output) @@ -238,18 +299,6 @@ function DAGLayer:update(bp_err, input, output) end end -function DAGLayer:propagate(input, output) - self:set_inputs(input) - self:set_outputs(output) - local ret = false - for i = 1, #self.queue do - local ref = self.queue[i] - -- print(ref.layer.id) - ret = ref.layer:propagate(ref.inputs, ref.outputs) - end - return ret -end - function DAGLayer:back_propagate(bp_err, next_bp_err, input, output) self:set_err_outputs(next_bp_err) self:set_err_inputs(bp_err) |