aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua5
1 files changed, 4 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 185bc6d..a203cc6 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -17,11 +17,14 @@ function nerv.BiasParam:update_by_gradient(gradient)
end
--Returns: LMResult
-function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
+function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train)
local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
reader:open_file(fn)
local result = nerv.LMResult(global_conf, global_conf.vocab)
result:init("rnn")
+ if global_conf.dropout_rate ~= nil then
+ nerv.info("LMTrainer.lm_process_file_rnn: dropout_rate is %f", global_conf.dropout_rate)
+ end
global_conf.timer:flush()
tnn:flush_all() --caution: will also flush the inputs from the reader!