diff options
author | txh18 <[email protected]> | 2015-12-02 20:29:56 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-12-02 20:29:56 +0800 |
commit | 103a4291349c0f55155ca97bd236fc7784d286ff (patch) | |
tree | f9b4c7e021779ba803791148cec6dcea28053e76 /nerv/examples/lmptb/lm_trainer.lua | |
parent | 094fc872d3e62c5f0950ac1747f130e30a08bee8 (diff) |
function name change in LMTrainer
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r-- | nerv/examples/lmptb/lm_trainer.lua | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua index 185bc6d..a203cc6 100644 --- a/nerv/examples/lmptb/lm_trainer.lua +++ b/nerv/examples/lmptb/lm_trainer.lua @@ -17,11 +17,14 @@ function nerv.BiasParam:update_by_gradient(gradient) end --Returns: LMResult -function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train) +function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train) local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab) reader:open_file(fn) local result = nerv.LMResult(global_conf, global_conf.vocab) result:init("rnn") + if global_conf.dropout_rate ~= nil then + nerv.info("LMTrainer.lm_process_file_rnn: dropout_rate is %f", global_conf.dropout_rate) + end global_conf.timer:flush() tnn:flush_all() --caution: will also flush the inputs from the reader! |