aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua15
1 files changed, 13 insertions, 2 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 9ef4794..58d5bfc 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -17,8 +17,19 @@ function nerv.BiasParam:update_by_gradient(gradient)
end
--Returns: LMResult
-function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train)
- local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
+function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
+ if p_conf == nil then
+ p_conf = {}
+ end
+ local reader
+ if p_conf.one_sen_report == true then --report log prob one by one sentence
+ if do_train == true then
+ nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
+ end
+ reader = nerv.LMSeqReader(global_conf, 1, global_conf.max_sen_len, global_conf.vocab)
+ else
+ reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
+ end
reader:open_file(fn)
local result = nerv.LMResult(global_conf, global_conf.vocab)
result:init("rnn")