diff options
-rw-r--r-- | nerv/examples/lmptb/tnn_ptb_main.lua | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua index 00fc12d..e9631ba 100644 --- a/nerv/examples/lmptb/tnn_ptb_main.lua +++ b/nerv/examples/lmptb/tnn_ptb_main.lua @@ -157,21 +157,22 @@ local set = arg[1] --"test" if (set == "ptb") then data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata' -train_fn = data_dir .. '/ptb.valid.txt.adds' +train_fn = data_dir .. '/ptb.train.txt.adds' valid_fn = data_dir .. '/ptb.valid.txt.adds' test_fn = data_dir .. '/ptb.test.txt.adds' vocab_fn = data_dir .. '/vocab' global_conf = { - lrate = 1, wcost = 1e-6, momentum = 0, + lrate = 1, wcost = 1e-5, momentum = 0, cumat_type = nerv.CuMatrixFloat, mmat_type = nerv.MMatrixFloat, nn_act_default = 0, - hidden_size = 300, + hidden_size = 400, chunk_size = 15, batch_size = 10, - max_iter = 30, + max_iter = 35, + decay_iter = 16, param_random = function() return (math.random() / 5 - 0.1) end, train_fn = train_fn, @@ -191,7 +192,7 @@ test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/so vocab_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text' global_conf = { - lrate = 1, wcost = 1e-6, momentum = 0, + lrate = 1, wcost = 1e-5, momentum = 0, cumat_type = nerv.CuMatrixFloat, mmat_type = nerv.CuMatrixFloat, nn_act_default = 0, @@ -297,7 +298,7 @@ for iter = start_iter, global_conf.max_iter, 1 do result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update! ppl_rec[iter].valid = result:ppl_all("rnn") ppl_rec[iter].lr = global_conf.lrate - if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then + if ((ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) and iter > global_conf.decay_iter) then global_conf.lrate = (global_conf.lrate * 0.6) end if (ppl_rec[iter].valid < ppl_last) then @@ -306,10 +307,10 @@ for iter = start_iter, global_conf.max_iter, 1 do else printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre) os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter)) - if (lr_half == true) then - printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre) - break - end + --if (lr_half == true) then + -- printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre) + -- break + --end end if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then lr_half = true |