aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/rnn/tnn.lua
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-11-03 22:56:41 +0800
committertxh18 <cloudygooseg@gmail.com>2015-11-03 22:56:41 +0800
commit26db912e38c3446961831d17be6b4508ec508bca (patch)
treefbcf471b7dc2d9921ab15dcf986316874dd35640 /nerv/examples/lmptb/rnn/tnn.lua
parentd18122af2f57b8dd81db49385484f0e51d167a23 (diff)
working on TNN:net_propagate
Diffstat (limited to 'nerv/examples/lmptb/rnn/tnn.lua')
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua93
1 files changed, 71 insertions, 22 deletions
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index 3f192b5..8037918 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -32,10 +32,14 @@ local function discover(id, layers, layer_repo)
local dim_in, dim_out = layer:get_dim()
ref = {
layer = layer,
- inputs_m = {}, --storage for computation, inputs_m[port][time]
+ inputs_m = {}, --storage for computation, inputs_m[time][port]
+ inputs_b = {}, --inputs_g[time][port], whether this input can been computed
outputs_m = {},
+ outputs_b = {},
err_inputs_m = {},
+ err_inputs_b = {},
err_outputs_m = {},
+ err_outputs_b = {},
conns_i = {}, --list of inputing connections
conns_o = {}, --list of outputing connections
dim_in = dim_in, --list of dimensions of ports
@@ -53,7 +57,7 @@ nerv.TNN.FC.HAS_INPUT = 1
nerv.TNN.FC.HAS_LABEL = 2
nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label
-function DAGLayer.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c)
+function TNN.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c)
--Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size)
if (type(st) ~= "table") then
nerv.error("st should be a table")
@@ -73,7 +77,7 @@ function DAGLayer.makeInitialStore(st, p, dim, batch_size, chunk_size, global_co
end
end
-function DAGLayer:__init(id, global_conf, layer_conf)
+function TNN:__init(id, global_conf, layer_conf)
local layers = {}
local inputs_p = {} --map:port of the TDAGLayer to layer ref and port
local outputs_p = {}
@@ -129,7 +133,9 @@ function DAGLayer:__init(id, global_conf, layer_conf)
self.gconf = global_conf
end
-function DAGLayer:init(batch_size, chunk_size)
+function TNN:init(batch_size, chunk_size)
+ self.batch_size = batch_size
+ self.chunk_size = chunk_size
for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN
local _, output_dim
local ref_from, port_from, ref_to, port_to
@@ -141,8 +147,10 @@ function DAGLayer:init(batch_size, chunk_size)
end
print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id)
- self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.inputs_m, port_to)
- self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.err_outputs_m, port_to)
+ ref_to.inputs_p_matbak[port_to] = self.gconf.cumat_type(batch_size, dim)
+ self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to)
+ ref_from.err_inputs_p_matbak[port_from] = self.gconf.cumat_type(batch_size, dim)
+ self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to)
end
@@ -180,13 +188,18 @@ function DAGLayer:init(batch_size, chunk_size)
end
local flags_now = {}
+ local flagsPack_now = {}
for i = 1, chunk_size do
flags_now[i] = {}
+ flagsPack_now[i] = 0
end
self.feeds_now = {} --feeds is for the reader to fill
self.feeds_now.inputs_m = self.inputs_m
self.feeds_now.flags_now = flags_now
+ self.feeds_now.flagsPack_now = flagsPack_now
+
+ self:flush_all()
end
--[[
@@ -218,13 +231,61 @@ function DAGLayer:batch_resize(batch_size)
end
]]--
+function TNN:flush_all() --flush all history and activation
+ local _, ref
+ for _, ref in pairs(self.layers) do
+ for i = 1, #ref.dim_in do
+ for t = 1 - self.chunk_size, self.chunk_size * 2 do
+ ref.inputs_m[t][i]:fill(self.gconf.nn_act_default)
+ if (ref.inputs_b[t] == nil) then
+ ref.inputs_b[t] = {}
+ end
+ ref.inputs_b[t][i] = false
+ ref.err_outputs_m[t][i]:fill(0)
+ if (ref.err_outputs_b[t] == nil) then
+ ref.err_outputs_b[t] = {}
+ end
+ ref.err_outputs_b[t][i] = false
+ end
+ end
+ for i = 1, #ref.dim_out do
+ for t = 1 - self.chunk_size, self.chunk_size * 2 do
+ ref.outputs_m[t][i]:fill(self.gconf.nn_act_default)
+ if (ref.outputs_b[t] == nil) then
+ ref.outputs_b[t] = {}
+ end
+ ref.outputs_b[t][i] = false
+ ref.err_inputs_m[t][i]:fill(0)
+ if (ref.err_inputs_b[t] == nil) then
+ ref.err_inputs_b[t] = {}
+ end
+ ref.err_inputs_b[t][i] = false
+ end
+ end
+ end
+end
+
--reader: some reader
--Returns: bool, whether has new feed
--Returns: feeds, a table that will be filled with the reader's feeds
-function DAGLayer:getFeedFromReader(reader)
- local feeds = self.feeds_now
- local got_new = reader:get_batch(feeds)
- return got_new, feeds
+function TNN:getFeedFromReader(reader)
+ local feeds_now = self.feeds_now
+ local got_new = reader:get_batch(feeds_now)
+ return got_new, feeds_now
+end
+
+function TNN:net_propagate() --propagate according to feeds_now
+ local feeds_now = self.feeds_now
+ for t = 1, chunk_size do
+ if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
+ for i = 1, #self.dim_in do
+ local ref = inputs_p[i].ref
+ local p = inputs_p[i].port
+ ref.inputs_b[t][p] = true
+ end
+ --TODO
+ end
+ end
end
function DAGLayer:update(bp_err, input, output)
@@ -238,18 +299,6 @@ function DAGLayer:update(bp_err, input, output)
end
end
-function DAGLayer:propagate(input, output)
- self:set_inputs(input)
- self:set_outputs(output)
- local ret = false
- for i = 1, #self.queue do
- local ref = self.queue[i]
- -- print(ref.layer.id)
- ret = ref.layer:propagate(ref.inputs, ref.outputs)
- end
- return ret
-end
-
function DAGLayer:back_propagate(bp_err, next_bp_err, input, output)
self:set_err_outputs(next_bp_err)
self:set_err_inputs(bp_err)