diff options
-rw-r--r-- | nerv/examples/lmptb/tnn_ptb_main.lua | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua index ef248d5..c875274 100644 --- a/nerv/examples/lmptb/tnn_ptb_main.lua +++ b/nerv/examples/lmptb/tnn_ptb_main.lua @@ -163,15 +163,16 @@ test_fn = data_dir .. '/ptb.test.txt.adds' vocab_fn = data_dir .. '/vocab' global_conf = { - lrate = 1, wcost = 1e-6, momentum = 0, + lrate = 1, wcost = 1e-5, momentum = 0, cumat_type = nerv.CuMatrixFloat, mmat_type = nerv.MMatrixFloat, nn_act_default = 0, - hidden_size = 300, + hidden_size = 400, chunk_size = 15, batch_size = 10, - max_iter = 30, + max_iter = 35, + decay_iter = 16, param_random = function() return (math.random() / 5 - 0.1) end, train_fn = train_fn, @@ -222,7 +223,7 @@ test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/so vocab_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text' global_conf = { - lrate = 1, wcost = 1e-6, momentum = 0, + lrate = 1, wcost = 1e-5, momentum = 0, cumat_type = nerv.CuMatrixFloat, mmat_type = nerv.CuMatrixFloat, nn_act_default = 0, @@ -328,7 +329,7 @@ for iter = start_iter, global_conf.max_iter, 1 do result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update! ppl_rec[iter].valid = result:ppl_all("rnn") ppl_rec[iter].lr = global_conf.lrate - if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then + if ((ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) and iter > global_conf.decay_iter) then global_conf.lrate = (global_conf.lrate * 0.6) end if (ppl_rec[iter].valid < ppl_last) then @@ -337,10 +338,10 @@ for iter = start_iter, global_conf.max_iter, 1 do else printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre) os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter)) - if (lr_half == true) then - printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre) - break - end + --if (lr_half == true) then + -- printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre) + -- break + --end end if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then lr_half = true |