aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua56
1 files changed, 37 insertions, 19 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 44862dc..9ef4794 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -2,41 +2,55 @@ require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
-require 'rnn.init'
+--require 'tnn.init'
require 'lmptb.lmseqreader'
local LMTrainer = nerv.class('nerv.LMTrainer')
-local printf = nerv.printf
+--local printf = nerv.printf
+
+--The bias param update in nerv don't have wcost added
+function nerv.BiasParam:update_by_gradient(gradient)
+ local gconf = self.gconf
+ local l2 = 1 - gconf.lrate * gconf.wcost
+ self:_update_by_gradient(gradient, l2, l2)
+end
--Returns: LMResult
-function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
+function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train)
local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
reader:open_file(fn)
local result = nerv.LMResult(global_conf, global_conf.vocab)
result:init("rnn")
-
+ if global_conf.dropout_rate ~= nil then
+ nerv.info("LMTrainer.lm_process_file_rnn: dropout_rate is %f", global_conf.dropout_rate)
+ end
+
global_conf.timer:flush()
tnn:flush_all() --caution: will also flush the inputs from the reader!
local next_log_wcn = global_conf.log_w_num
+ local neto_bakm = global_conf.mmat_type(global_conf.batch_size, 1) --space backup matrix for network output
while (1) do
global_conf.timer:tic('most_out_loop_lmprocessfile')
local r, feeds
-
- r, feeds = tnn:getFeedFromReader(reader)
- if (r == false) then break end
+ global_conf.timer:tic('tnn_beforeprocess')
+ r, feeds = tnn:getfeed_from_reader(reader)
+ if r == false then
+ break
+ end
for t = 1, global_conf.chunk_size do
tnn.err_inputs_m[t][1]:fill(1)
for i = 1, global_conf.batch_size do
- if (bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then
+ if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then
tnn.err_inputs_m[t][1][i - 1][0] = 0
end
end
end
+ global_conf.timer:toc('tnn_beforeprocess')
--[[
for j = 1, global_conf.chunk_size, 1 do
@@ -50,29 +64,33 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
tnn:net_propagate()
- if (do_train == true) then
+ if do_train == true then
tnn:net_backpropagate(false)
tnn:net_backpropagate(true)
end
-
+
+ global_conf.timer:tic('tnn_afterprocess')
for t = 1, global_conf.chunk_size, 1 do
+ tnn.outputs_m[t][1]:copy_toh(neto_bakm)
for i = 1, global_conf.batch_size, 1 do
if (feeds.labels_s[t][i] ~= global_conf.vocab.null_token) then
- result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
+ --result:add("rnn", feeds.labels_s[t][i], math.exp(tnn.outputs_m[t][1][i - 1][0]))
+ result:add("rnn", feeds.labels_s[t][i], math.exp(neto_bakm[i - 1][0]))
end
end
end
+ tnn:move_right_to_nextmb({0}) --only copy for time 0
+ global_conf.timer:toc('tnn_afterprocess')
- tnn:moveRightToNextMB()
global_conf.timer:toc('most_out_loop_lmprocessfile')
--print log
- if (result["rnn"].cn_w > next_log_wcn) then
+ if result["rnn"].cn_w > next_log_wcn then
next_log_wcn = next_log_wcn + global_conf.log_w_num
- printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date())
- printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))
+ nerv.printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date())
+ nerv.printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))
for key, value in pairs(global_conf.timer.rec) do
- printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value)
+ nerv.printf("\t [global_conf.timer]: time spent on %s:%.5f clock time\n", key, value)
end
global_conf.timer:flush()
nerv.LMUtil.wait(0.1)
@@ -90,9 +108,9 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
--break --debug
end
- printf("%s Displaying result:\n", global_conf.sche_log_pre)
- printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn"))
- printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn)
+ nerv.printf("%s Displaying result:\n", global_conf.sche_log_pre)
+ nerv.printf("%s %s\n", global_conf.sche_log_pre, result:status("rnn"))
+ nerv.printf("%s Doing on %s end.\n", global_conf.sche_log_pre, fn)
return result
end