summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/tnn_ptb_main.lua10
1 files changed, 7 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua
index 3e5ab2d..00fc12d 100644
--- a/nerv/examples/lmptb/tnn_ptb_main.lua
+++ b/nerv/examples/lmptb/tnn_ptb_main.lua
@@ -157,7 +157,7 @@ local set = arg[1] --"test"
if (set == "ptb") then
data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata'
-train_fn = data_dir .. '/ptb.train.txt.adds'
+train_fn = data_dir .. '/ptb.valid.txt.adds'
valid_fn = data_dir .. '/ptb.valid.txt.adds'
test_fn = data_dir .. '/ptb.test.txt.adds'
vocab_fn = data_dir .. '/vocab'
@@ -299,12 +299,10 @@ for iter = start_iter, global_conf.max_iter, 1 do
ppl_rec[iter].lr = global_conf.lrate
if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then
global_conf.lrate = (global_conf.lrate * 0.6)
- lr_half = true
end
if (ppl_rec[iter].valid < ppl_last) then
printf("%s PPL improves, saving net to file %s.%d...\n", global_conf.sche_log_pre, global_conf.param_fn, iter)
paramRepo:export(global_conf.param_fn .. '.' .. tostring(iter), nil)
- ppl_last = ppl_rec[iter].valid
else
printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre)
os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter))
@@ -313,6 +311,12 @@ for iter = start_iter, global_conf.max_iter, 1 do
break
end
end
+ if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then
+ lr_half = true
+ end
+ if (ppl_rec[iter].valid < ppl_last) then
+ ppl_last = ppl_rec[iter].valid
+ end
printf("\n")
nerv.LMUtil.wait(2)
end