aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/lm_trainer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/lm_trainer.lua')
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua9
1 files changed, 8 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 62d8b50..185bc6d 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -2,13 +2,20 @@ require 'lmptb.lmvocab'
require 'lmptb.lmfeeder'
require 'lmptb.lmutil'
require 'lmptb.layer.init'
-require 'rnn.init'
+require 'tnn.init'
require 'lmptb.lmseqreader'
local LMTrainer = nerv.class('nerv.LMTrainer')
local printf = nerv.printf
+--The bias param update in nerv don't have wcost added
+function nerv.BiasParam:update_by_gradient(gradient)
+ local gconf = self.gconf
+ local l2 = 1 - gconf.lrate * gconf.wcost
+ self:_update_by_gradient(gradient, l2, l2)
+end
+
--Returns: LMResult
function LMTrainer.lm_process_file(global_conf, fn, tnn, do_train)
local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)