summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/tnn_ptb_main.lua19
1 files changed, 10 insertions, 9 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua
index ef248d5..c875274 100644
--- a/nerv/examples/lmptb/tnn_ptb_main.lua
+++ b/nerv/examples/lmptb/tnn_ptb_main.lua
@@ -163,15 +163,16 @@ test_fn = data_dir .. '/ptb.test.txt.adds'
vocab_fn = data_dir .. '/vocab'
global_conf = {
- lrate = 1, wcost = 1e-6, momentum = 0,
+ lrate = 1, wcost = 1e-5, momentum = 0,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
- hidden_size = 300,
+ hidden_size = 400,
chunk_size = 15,
batch_size = 10,
- max_iter = 30,
+ max_iter = 35,
+ decay_iter = 16,
param_random = function() return (math.random() / 5 - 0.1) end,
train_fn = train_fn,
@@ -222,7 +223,7 @@ test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/so
vocab_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
global_conf = {
- lrate = 1, wcost = 1e-6, momentum = 0,
+ lrate = 1, wcost = 1e-5, momentum = 0,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.CuMatrixFloat,
nn_act_default = 0,
@@ -328,7 +329,7 @@ for iter = start_iter, global_conf.max_iter, 1 do
result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update!
ppl_rec[iter].valid = result:ppl_all("rnn")
ppl_rec[iter].lr = global_conf.lrate
- if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then
+ if ((ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) and iter > global_conf.decay_iter) then
global_conf.lrate = (global_conf.lrate * 0.6)
end
if (ppl_rec[iter].valid < ppl_last) then
@@ -337,10 +338,10 @@ for iter = start_iter, global_conf.max_iter, 1 do
else
printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre)
os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter))
- if (lr_half == true) then
- printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre)
- break
- end
+ --if (lr_half == true) then
+ -- printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre)
+ -- break
+ --end
end
if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then
lr_half = true