aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua15
1 files changed, 8 insertions, 7 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 44862dc..7dd70e2 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -26,13 +26,15 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
local r, feeds
- r, feeds = tnn:getFeedFromReader(reader)
- if (r == false) then break end
+ r, feeds = tnn:getfeed_from_reader(reader)
+ if r == false then
+ break
+ end
for t = 1, global_conf.chunk_size do
tnn.err_inputs_m[t][1]:fill(1)
for i = 1, global_conf.batch_size do
- if (bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0) then
+ if bit.band(feeds.flags_now[t][i], nerv.TNN.FC.HAS_LABEL) == 0 then
tnn.err_inputs_m[t][1][i - 1][0] = 0
end
end
@@ -50,7 +52,7 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
tnn:net_propagate()
- if (do_train == true) then
+ if do_train == true then
tnn:net_backpropagate(false)
tnn:net_backpropagate(true)
end
@@ -62,12 +64,11 @@ function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
end
end
end
-
- tnn:moveRightToNextMB()
+ tnn:move_right_to_nextmb()
global_conf.timer:toc('most_out_loop_lmprocessfile')
--print log
- if (result["rnn"].cn_w > next_log_wcn) then
+ if result["rnn"].cn_w > next_log_wcn then
next_log_wcn = next_log_wcn + global_conf.log_w_num
printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date())
printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"))