aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-11-03 22:56:41 +0800
committertxh18 <cloudygooseg@gmail.com>2015-11-03 22:56:41 +0800
commit26db912e38c3446961831d17be6b4508ec508bca (patch)
treefbcf471b7dc2d9921ab15dcf986316874dd35640
parentd18122af2f57b8dd81db49385484f0e51d167a23 (diff)
working on TNN:net_propagate
-rw-r--r--nerv/examples/lmptb/lmptb/lmseqreader.lua9
-rw-r--r--nerv/examples/lmptb/m-tests/dagl_test.lua26
-rw-r--r--nerv/examples/lmptb/rnn/tnn.lua93
3 files changed, 102 insertions, 26 deletions
diff --git a/nerv/examples/lmptb/lmptb/lmseqreader.lua b/nerv/examples/lmptb/lmptb/lmseqreader.lua
index 006b5cb..6cbd0e9 100644
--- a/nerv/examples/lmptb/lmptb/lmseqreader.lua
+++ b/nerv/examples/lmptb/lmptb/lmseqreader.lua
@@ -83,6 +83,7 @@ function LMReader:get_batch(feeds)
local inputs_m = feeds.inputs_m --port 1 : word_id, port 2 : label
local flags = feeds.flags_now
+ local flagsPack = feeds.flagsPack_now
local got_new = false
for i = 1, self.batch_size, 1 do
@@ -120,6 +121,14 @@ function LMReader:get_batch(feeds)
end
end
end
+
+ for j = 1, self.chunk_size, 1 do
+ flagsPack[j] = 0
+ for i = 1, self.batch_size, 1 do
+ flagsPack[j] = bit.bor(flagsPack[j], flags[j][i])
+ end
+ end
+
if (got_new == false) then
return false
else
diff --git a/nerv/examples/lmptb/m-tests/dagl_test.lua b/nerv/examples/lmptb/m-tests/dagl_test.lua
index a50107d..6bd11c8 100644
--- a/nerv/examples/lmptb/m-tests/dagl_test.lua
+++ b/nerv/examples/lmptb/m-tests/dagl_test.lua
@@ -2,7 +2,8 @@ require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
-require 'rnn.layer_tdag'
+require 'lmptb.lmseqreader'
+require 'rnn.tnn'
--[[global function rename]]--
printf = nerv.printf
@@ -128,13 +129,14 @@ function prepare_dagLayer(global_conf, layerRepo)
return tnn
end
-train_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
-test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
+local train_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
+local test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
-global_conf = {
+local global_conf = {
lrate = 1, wcost = 1e-6, momentum = 0,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.CuMatrixFloat,
+ nn_act_default = 0,
hidden_size = 20,
chunk_size = 5,
@@ -160,3 +162,19 @@ local layerRepo = prepare_layers(global_conf, paramRepo)
local tnn = prepare_dagLayer(global_conf, layerRepo)
tnn:init(global_conf.batch_size, global_conf.chunk_size)
+local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
+reader:open_file(global_conf.train_fn)
+
+local batch_num = 1
+while (1) do
+ local r, feeds
+ r, feeds = tnn:getFeedFromReader(reader)
+ if (r == false) then break end
+ for j = 1, global_conf.chunk_size, 1 do
+ for i = 1, global_conf.batch_size, 1 do
+ printf("%s[L(%s)] ", feeds.inputs_s[j][i], feeds.labels_s[j][i]) --vocab:get_word_str(input[i][j]).id
+ end
+ printf("\n")
+ end
+ printf("\n")
+end
diff --git a/nerv/examples/lmptb/rnn/tnn.lua b/nerv/examples/lmptb/rnn/tnn.lua
index 3f192b5..8037918 100644
--- a/nerv/examples/lmptb/rnn/tnn.lua
+++ b/nerv/examples/lmptb/rnn/tnn.lua
@@ -32,10 +32,14 @@ local function discover(id, layers, layer_repo)
local dim_in, dim_out = layer:get_dim()
ref = {
layer = layer,
- inputs_m = {}, --storage for computation, inputs_m[port][time]
+ inputs_m = {}, --storage for computation, inputs_m[time][port]
+ inputs_b = {}, --inputs_g[time][port], whether this input can been computed
outputs_m = {},
+ outputs_b = {},
err_inputs_m = {},
+ err_inputs_b = {},
err_outputs_m = {},
+ err_outputs_b = {},
conns_i = {}, --list of inputing connections
conns_o = {}, --list of outputing connections
dim_in = dim_in, --list of dimensions of ports
@@ -53,7 +57,7 @@ nerv.TNN.FC.HAS_INPUT = 1
nerv.TNN.FC.HAS_LABEL = 2
nerv.TNN.FC.SEQ_NORM = bit.bor(nerv.TNN.FC.HAS_INPUT, nerv.TNN.FC.HAS_LABEL) --This instance have both input and label
-function DAGLayer.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c)
+function TNN.makeInitialStore(st, p, dim, batch_size, chunk_size, global_conf, st_c, p_c)
--Return a table of matrix storage from time (1-chunk_size)..(2*chunk_size)
if (type(st) ~= "table") then
nerv.error("st should be a table")
@@ -73,7 +77,7 @@ function DAGLayer.makeInitialStore(st, p, dim, batch_size, chunk_size, global_co
end
end
-function DAGLayer:__init(id, global_conf, layer_conf)
+function TNN:__init(id, global_conf, layer_conf)
local layers = {}
local inputs_p = {} --map:port of the TDAGLayer to layer ref and port
local outputs_p = {}
@@ -129,7 +133,9 @@ function DAGLayer:__init(id, global_conf, layer_conf)
self.gconf = global_conf
end
-function DAGLayer:init(batch_size, chunk_size)
+function TNN:init(batch_size, chunk_size)
+ self.batch_size = batch_size
+ self.chunk_size = chunk_size
for i, conn in ipairs(self.parsed_conns) do --init storage for connections inside the NN
local _, output_dim
local ref_from, port_from, ref_to, port_to
@@ -141,8 +147,10 @@ function DAGLayer:init(batch_size, chunk_size)
end
print("TNN initing storage", ref_from.layer.id, "->", ref_to.layer.id)
- self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.inputs_m, port_to)
- self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, global_conf, ref_to.err_outputs_m, port_to)
+ ref_to.inputs_p_matbak[port_to] = self.gconf.cumat_type(batch_size, dim)
+ self.makeInitialStore(ref_from.outputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.inputs_m, port_to)
+ ref_from.err_inputs_p_matbak[port_from] = self.gconf.cumat_type(batch_size, dim)
+ self.makeInitialStore(ref_from.err_inputs_m, port_from, dim, batch_size, chunk_size, self.gconf, ref_to.err_outputs_m, port_to)
end
@@ -180,13 +188,18 @@ function DAGLayer:init(batch_size, chunk_size)
end
local flags_now = {}
+ local flagsPack_now = {}
for i = 1, chunk_size do
flags_now[i] = {}
+ flagsPack_now[i] = 0
end
self.feeds_now = {} --feeds is for the reader to fill
self.feeds_now.inputs_m = self.inputs_m
self.feeds_now.flags_now = flags_now
+ self.feeds_now.flagsPack_now = flagsPack_now
+
+ self:flush_all()
end
--[[
@@ -218,13 +231,61 @@ function DAGLayer:batch_resize(batch_size)
end
]]--
+function TNN:flush_all() --flush all history and activation
+ local _, ref
+ for _, ref in pairs(self.layers) do
+ for i = 1, #ref.dim_in do
+ for t = 1 - self.chunk_size, self.chunk_size * 2 do
+ ref.inputs_m[t][i]:fill(self.gconf.nn_act_default)
+ if (ref.inputs_b[t] == nil) then
+ ref.inputs_b[t] = {}
+ end
+ ref.inputs_b[t][i] = false
+ ref.err_outputs_m[t][i]:fill(0)
+ if (ref.err_outputs_b[t] == nil) then
+ ref.err_outputs_b[t] = {}
+ end
+ ref.err_outputs_b[t][i] = false
+ end
+ end
+ for i = 1, #ref.dim_out do
+ for t = 1 - self.chunk_size, self.chunk_size * 2 do
+ ref.outputs_m[t][i]:fill(self.gconf.nn_act_default)
+ if (ref.outputs_b[t] == nil) then
+ ref.outputs_b[t] = {}
+ end
+ ref.outputs_b[t][i] = false
+ ref.err_inputs_m[t][i]:fill(0)
+ if (ref.err_inputs_b[t] == nil) then
+ ref.err_inputs_b[t] = {}
+ end
+ ref.err_inputs_b[t][i] = false
+ end
+ end
+ end
+end
+
--reader: some reader
--Returns: bool, whether has new feed
--Returns: feeds, a table that will be filled with the reader's feeds
-function DAGLayer:getFeedFromReader(reader)
- local feeds = self.feeds_now
- local got_new = reader:get_batch(feeds)
- return got_new, feeds
+function TNN:getFeedFromReader(reader)
+ local feeds_now = self.feeds_now
+ local got_new = reader:get_batch(feeds_now)
+ return got_new, feeds_now
+end
+
+function TNN:net_propagate() --propagate according to feeds_now
+ local feeds_now = self.feeds_now
+ for t = 1, chunk_size do
+ if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
+ for i = 1, #self.dim_in do
+ local ref = inputs_p[i].ref
+ local p = inputs_p[i].port
+ ref.inputs_b[t][p] = true
+ end
+ --TODO
+ end
+ end
end
function DAGLayer:update(bp_err, input, output)
@@ -238,18 +299,6 @@ function DAGLayer:update(bp_err, input, output)
end
end
-function DAGLayer:propagate(input, output)
- self:set_inputs(input)
- self:set_outputs(output)
- local ret = false
- for i = 1, #self.queue do
- local ref = self.queue[i]
- -- print(ref.layer.id)
- ret = ref.layer:propagate(ref.inputs, ref.outputs)
- end
- return ret
-end
-
function DAGLayer:back_propagate(bp_err, next_bp_err, input, output)
self:set_err_outputs(next_bp_err)
self:set_err_inputs(bp_err)