diff options
author | txh18 <[email protected]> | 2015-11-13 19:25:25 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-11-13 19:25:25 +0800 |
commit | fc5c05c70eb09b797c45f2d4913d7b2c5d418874 (patch) | |
tree | f4c48c3120cdc20ba1bdfdf417cf1b5b68323684 | |
parent | 982cb4b9f4a41a31668142ba7f168dc986b969f4 (diff) |
small bug: lr_half
-rw-r--r-- | nerv/examples/lmptb/tnn_ptb_main.lua | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua index 3e5ab2d..00fc12d 100644 --- a/nerv/examples/lmptb/tnn_ptb_main.lua +++ b/nerv/examples/lmptb/tnn_ptb_main.lua @@ -157,7 +157,7 @@ local set = arg[1] --"test" if (set == "ptb") then data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata' -train_fn = data_dir .. '/ptb.train.txt.adds' +train_fn = data_dir .. '/ptb.valid.txt.adds' valid_fn = data_dir .. '/ptb.valid.txt.adds' test_fn = data_dir .. '/ptb.test.txt.adds' vocab_fn = data_dir .. '/vocab' @@ -299,12 +299,10 @@ for iter = start_iter, global_conf.max_iter, 1 do ppl_rec[iter].lr = global_conf.lrate if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then global_conf.lrate = (global_conf.lrate * 0.6) - lr_half = true end if (ppl_rec[iter].valid < ppl_last) then printf("%s PPL improves, saving net to file %s.%d...\n", global_conf.sche_log_pre, global_conf.param_fn, iter) paramRepo:export(global_conf.param_fn .. '.' .. tostring(iter), nil) - ppl_last = ppl_rec[iter].valid else printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre) os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter)) @@ -313,6 +311,12 @@ for iter = start_iter, global_conf.max_iter, 1 do break end end + if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then + lr_half = true + end + if (ppl_rec[iter].valid < ppl_last) then + ppl_last = ppl_rec[iter].valid + end printf("\n") nerv.LMUtil.wait(2) end |