diff options
author | txh18 <[email protected]> | 2015-11-15 17:44:03 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-11-15 17:44:03 +0800 |
commit | 5760914d95059777c5e475f3c42d1b32983235a3 (patch) | |
tree | 4f372585470be7738af346be79192430cbd08615 | |
parent | aa3ede28f5848d6461a74add0dbf8dace807d8d8 (diff) | |
parent | 7dfd1aca7117c0a15cf2377741d0a18b212cf2cb (diff) |
merge lr schedule change
Merge branch 'txh18/rnnlm' of github.com:Nerv-SJTU/nerv into txh18/rnnlm
-rw-r--r-- | nerv/examples/lmptb/tnn_ptb_main.lua | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua index ef248d5..c875274 100644 --- a/nerv/examples/lmptb/tnn_ptb_main.lua +++ b/nerv/examples/lmptb/tnn_ptb_main.lua @@ -163,15 +163,16 @@ test_fn = data_dir .. '/ptb.test.txt.adds' vocab_fn = data_dir .. '/vocab' global_conf = { - lrate = 1, wcost = 1e-6, momentum = 0, + lrate = 1, wcost = 1e-5, momentum = 0, cumat_type = nerv.CuMatrixFloat, mmat_type = nerv.MMatrixFloat, nn_act_default = 0, - hidden_size = 300, + hidden_size = 400, chunk_size = 15, batch_size = 10, - max_iter = 30, + max_iter = 35, + decay_iter = 16, param_random = function() return (math.random() / 5 - 0.1) end, train_fn = train_fn, @@ -222,7 +223,7 @@ test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/so vocab_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text' global_conf = { - lrate = 1, wcost = 1e-6, momentum = 0, + lrate = 1, wcost = 1e-5, momentum = 0, cumat_type = nerv.CuMatrixFloat, mmat_type = nerv.CuMatrixFloat, nn_act_default = 0, @@ -328,7 +329,7 @@ for iter = start_iter, global_conf.max_iter, 1 do result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update! ppl_rec[iter].valid = result:ppl_all("rnn") ppl_rec[iter].lr = global_conf.lrate - if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then + if ((ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) and iter > global_conf.decay_iter) then global_conf.lrate = (global_conf.lrate * 0.6) end if (ppl_rec[iter].valid < ppl_last) then @@ -337,10 +338,10 @@ for iter = start_iter, global_conf.max_iter, 1 do else printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre) os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter)) - if (lr_half == true) then - printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre) - break - end + --if (lr_half == true) then + -- printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre) + -- break + --end end if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then lr_half = true |