diff options
-rw-r--r-- | nerv/examples/lmptb/tnn_ptb_main.lua | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua index 3e5ab2d..00fc12d 100644 --- a/nerv/examples/lmptb/tnn_ptb_main.lua +++ b/nerv/examples/lmptb/tnn_ptb_main.lua @@ -157,7 +157,7 @@ local set = arg[1] --"test" if (set == "ptb") then data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata' -train_fn = data_dir .. '/ptb.train.txt.adds' +train_fn = data_dir .. '/ptb.valid.txt.adds' valid_fn = data_dir .. '/ptb.valid.txt.adds' test_fn = data_dir .. '/ptb.test.txt.adds' vocab_fn = data_dir .. '/vocab' @@ -299,12 +299,10 @@ for iter = start_iter, global_conf.max_iter, 1 do ppl_rec[iter].lr = global_conf.lrate if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then global_conf.lrate = (global_conf.lrate * 0.6) - lr_half = true end if (ppl_rec[iter].valid < ppl_last) then printf("%s PPL improves, saving net to file %s.%d...\n", global_conf.sche_log_pre, global_conf.param_fn, iter) paramRepo:export(global_conf.param_fn .. '.' .. tostring(iter), nil) - ppl_last = ppl_rec[iter].valid else printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre) os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter)) @@ -313,6 +311,12 @@ for iter = start_iter, global_conf.max_iter, 1 do break end end + if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then + lr_half = true + end + if (ppl_rec[iter].valid < ppl_last) then + ppl_last = ppl_rec[iter].valid + end printf("\n") nerv.LMUtil.wait(2) end |