diff options
author | txh18 <[email protected]> | 2015-11-27 22:24:22 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-11-27 22:24:22 +0800 |
commit | b4207a46686e899b797e70f0ace352107bbc0d54 (patch) | |
tree | 90e0c82328d9b2c6b25995eb96ea7b5828d90c18 /nerv/examples/lmptb/tnn | |
parent | f0ac603cbfc5bbb95dad885d35822f0f747b0ab2 (diff) |
added clip_t for tnn
Diffstat (limited to 'nerv/examples/lmptb/tnn')
-rw-r--r-- | nerv/examples/lmptb/tnn/tnn.lua | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/nerv/examples/lmptb/tnn/tnn.lua b/nerv/examples/lmptb/tnn/tnn.lua index c87f963..db6cdd7 100644 --- a/nerv/examples/lmptb/tnn/tnn.lua +++ b/nerv/examples/lmptb/tnn/tnn.lua @@ -31,6 +31,7 @@ local function discover(id, layers, layer_repo) local dim_in, dim_out = layer:get_dim() ref = { layer = layer, + id = layer.id, inputs_m = {}, --storage for computation, inputs_m[time][port] inputs_b = {}, --inputs_g[time][port], whether this input can been computed inputs_matbak_p = {}, --which is a back-up space to handle some cross-border computation, inputs_p_matbak[port] @@ -89,6 +90,10 @@ function TNN:out_of_feedrange(t) --out of chunk, or no input, for the current fe end function TNN:__init(id, global_conf, layer_conf) + self.clip_t = layer_conf.clip_t + if self.clip_t > 0 then + nerv.info("tnn(%s) will clip gradient across time with %f...", id, self.clip_t) + end local layers = {} local inputs_p = {} --map:port of the TNN to layer ref and port local outputs_p = {} @@ -429,7 +434,7 @@ end --do_update: bool, whether we are doing back-propagate or updating the parameters function TNN:net_backpropagate(do_update) --propagate according to feeds_now - if (do_update == nil) then + if do_update == nil then nerv.error("do_update should not be nil") end for t = 1, self.chunk_size, 1 do @@ -445,7 +450,7 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now local feeds_now = self.feeds_now for t = 1, self.chunk_size do - if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0) then + if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_LABEL) > 0 then for i = 1, #self.dim_out do local ref = self.outputs_p[i].ref local p = self.outputs_p[i].port @@ -457,10 +462,10 @@ function TNN:net_backpropagate(do_update) --propagate according to feeds_now local flag_out = true for t = 1, self.chunk_size do --check whether every output has been computed - if (bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0) then + if bit.band(feeds_now.flagsPack_now[t], nerv.TNN.FC.HAS_INPUT) > 0 then for i = 1, #self.dim_in do local ref = self.inputs_p[i].ref - if (ref.err_outputs_b[t][1] ~= true) then + if ref.err_outputs_b[t][1] ~= true then flag_out = false break end @@ -475,10 +480,10 @@ end --ref: the TNN_ref of a layer --t: the current time to propagate function TNN:backpropagate_dfs(ref, t, do_update) - if (self:out_of_feedrange(t)) then + if self:out_of_feedrange(t) then return end - if (ref.err_outputs_b[t][1] == true) then --already back_propagated, 1 is just a random port + if ref.err_outputs_b[t][1] == true then --already back_propagated, 1 is just a random port return end @@ -501,7 +506,16 @@ function TNN:backpropagate_dfs(ref, t, do_update) if (do_update == false) then self.gconf.timer:tic("tnn_actual_layer_backpropagate") ref.layer:back_propagate(ref.err_inputs_m[t], ref.err_outputs_m[t], ref.inputs_m[t], ref.outputs_m[t], t) - self.gconf.timer:toc("tnn_actual_layer_backpropagate") + self.gconf.timer:toc("tnn_actual_layer_backpropagate") + if self.clip_t > 0 then + for _, conn in pairs(ref.i_conns_p) do + local p = conn.dst.port --port for ref + if conn.time ~= 0 then + --print("debug clip_t tnn", ref.id, "port:", p, "clip:", self.clip_t) + ref.err_outputs_m[t][p]:clip(-self.clip_t, self.clip_t) + end + end + end else --print(ref.err_inputs_m[t][1]) self.gconf.timer:tic("tnn_actual_layer_update") |