aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/lstmlm_ptb_main.lua15
-rw-r--r--nerv/examples/lmptb/tnn/tnn.lua28
2 files changed, 29 insertions, 14 deletions
diff --git a/nerv/examples/lmptb/lstmlm_ptb_main.lua b/nerv/examples/lmptb/lstmlm_ptb_main.lua
index 42b541f..69f26f5 100644
--- a/nerv/examples/lmptb/lstmlm_ptb_main.lua
+++ b/nerv/examples/lmptb/lstmlm_ptb_main.lua
@@ -153,8 +153,9 @@ function prepare_tnn(global_conf, layerRepo)
end
]]--
- local tnn = nerv.TNN("TNN", global_conf, {["dim_in"] = {1, global_conf.vocab:size()}, ["dim_out"] = {1}, ["sub_layers"] = layerRepo,
- ["connections"] = connections_t,
+ local tnn = nerv.TNN("TNN", global_conf, {["dim_in"] = {1, global_conf.vocab:size()},
+ ["dim_out"] = {1}, ["sub_layers"] = layerRepo,
+ ["connections"] = connections_t, ["clip_t"] = global_conf.clip_t,
})
tnn:init(global_conf.batch_size, global_conf.chunk_size)
@@ -183,12 +184,12 @@ test_fn = data_dir .. '/ptb.test.txt.adds'
vocab_fn = data_dir .. '/vocab'
global_conf = {
- lrate = 1, wcost = 1e-6, momentum = 0,
+ lrate = 0.001, wcost = 1e-6, momentum = 0, clip_t = 0.01,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
- hidden_size = 400, --set to 400 for a stable good test PPL
+ hidden_size = 200, --set to 400 for a stable good test PPL
chunk_size = 15,
batch_size = 10,
max_iter = 35,
@@ -200,9 +201,9 @@ global_conf = {
test_fn = test_fn,
vocab_fn = vocab_fn,
sche_log_pre = "[SCHEDULER]:",
- log_w_num = 40000, --give a message when log_w_num words have been processed
+ log_w_num = 400, --give a message when log_w_num words have been processed
timer = nerv.Timer(),
- work_dir_base = '/home/slhome/txh18/workspace/nerv/play/ptbEXP/tnn_test'
+ work_dir_base = '/home/slhome/txh18/workspace/nerv/play/ptbEXP/tnn_lstm_test'
}
elseif (set == "msr_sc") then
@@ -278,7 +279,7 @@ else
printf("%s no user setting, all default...\n", global_conf.sche_log_pre)
end
-global_conf.work_dir = global_conf.work_dir_base .. 'h' .. global_conf.hidden_size .. 'ch' .. global_conf.chunk_size .. 'ba' .. global_conf.batch_size .. 'slr' .. global_conf.lrate .. 'wc' .. global_conf.wcost
+global_conf.work_dir = global_conf.work_dir_base .. 'h' .. global_conf.hidden_size --.. 'ch' .. global_conf.chunk_size .. 'ba' .. global_conf.batch_size .. 'slr' .. global_conf.lrate .. 'wc' .. global_conf.wcost
global_conf.train_fn_shuf = global_conf.work_dir .. '/train_fn_shuf'
global_conf.train_fn_shuf_bak = global_conf.train_fn_shuf .. '_bak'
global_conf.param_fn = global_conf.work_dir .. "/params"
diff --git a/nerv/examples/lmptb/tnn/tnn.lua b/nerv/examples/lmptb/tnn/tnn.lua
index c87f963..db6cdd7 100644
--- a/nerv/examples/lmptb/tnn/tnn.lua
+++ b/nerv/examples/lmptb/tnn/tnn.lua
@@ -31,6 +31,7 @@ local function discover(id, layers, layer_repo)
local dim_in, dim_out = layer:get_dim()
ref = {
layer = layer,
+ id = layer.id,
inputs_m = {}, --storage for computation, inputs_m[time][port]
inputs_b = {}, --inputs_g[time][port], whether this input can been computed
inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port]
@@ -89,6 +90,10 @@ function TNN:out_of_feedrange(t) --out of chunk, or no input, for the current fe
end
function TNN:__init(id, global_conf, layer_conf)
+ self.clip_t = layer_conf.clip_t
+ if self.clip_t > 0 then
+ nerv.info("tnn(%s) will clip gradient across time with %f...", id, self.clip_t)
+ end
local layers = {}
local inputs_p = {} --map:port of the TNN to layer ref and port
local outputs_p = {}
@@ -429,7 +434,7 @@ end
--do_update: bool, whether we are doing back-propagate or updating the parameters
function TNN:net_backpropagate(do_update) --propagate according to feeds_now
- if (do_update == nil) then
+ if do_update == nil then
nerv.error("do_update should not be nil")
end
for t = 1, self.chunk_size, 1 do
@@ -445,7 +450,7 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now
local feeds_now = self.feeds_now
for t = 1, self.chunk_size do
- if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then
+ if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then
for i = 1, #self.dim_out do
local ref = self.outputs_p[i].ref
local p = self.outputs_p[i].port
@@ -457,10 +462,10 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now
local flag_out = true
for t = 1, self.chunk_size do --check whether every output has been computed
- if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then
+ if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0 then
for i = 1, #self.dim_in do
local ref = self.inputs_p[i].ref
- if (ref.err_outputs_b[t][1] ~= true) then
+ if ref.err_outputs_b[t][1] ~= true then
flag_out = false
break
end
@@ -475,10 +480,10 @@ end
--ref: the TNN_ref of a layer
--t: the current time to propagate
function TNN:backpropagate_dfs(ref, t, do_update)
- if (self:out_of_feedrange(t)) then
+ if self:out_of_feedrange(t) then
return
end
- if (ref.err_outputs_b[t][1] == true) then --already back_propagated, 1 is just a random port
+ if ref.err_outputs_b[t][1] == true then --already back_propagated, 1 is just a random port
return
end
@@ -501,7 +506,16 @@ function TNN:backpropagate_dfs(ref, t, do_update)
if (do_update == false) then
self.gconf.timer:tic("tnn_actual_layer_backpropagate")
ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t)
- self.gconf.timer:toc("tnn_actual_layer_backpropagate")
+ self.gconf.timer:toc("tnn_actual_layer_backpropagate")
+ if self.clip_t > 0 then
+ for _, conn in pairs(ref.i_conns_p) do
+ local p = conn.dst.port --port for ref
+ if conn.time ~= 0 then
+ --print("debug clip_t tnn", ref.id, "port:", p, "clip:", self.clip_t)
+ ref.err_outputs_m[t][p]:clip(-self.clip_t, self.clip_t)
+ end
+ end
+ end
else
--print(ref.err_inputs_m[t][1])
self.gconf.timer:tic("tnn_actual_layer_update")