aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/tnn_ptb_main.lua21
1 files changed, 11 insertions, 10 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua
index 00fc12d..e9631ba 100644
--- a/nerv/examples/lmptb/tnn_ptb_main.lua
+++ b/nerv/examples/lmptb/tnn_ptb_main.lua
@@ -157,21 +157,22 @@ local set = arg[1] --"test"
if (set == "ptb") then
data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata'
-train_fn = data_dir .. '/ptb.valid.txt.adds'
+train_fn = data_dir .. '/ptb.train.txt.adds'
valid_fn = data_dir .. '/ptb.valid.txt.adds'
test_fn = data_dir .. '/ptb.test.txt.adds'
vocab_fn = data_dir .. '/vocab'
global_conf = {
- lrate = 1, wcost = 1e-6, momentum = 0,
+ lrate = 1, wcost = 1e-5, momentum = 0,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
- hidden_size = 300,
+ hidden_size = 400,
chunk_size = 15,
batch_size = 10,
- max_iter = 30,
+ max_iter = 35,
+ decay_iter = 16,
param_random = function() return (math.random() / 5 - 0.1) end,
train_fn = train_fn,
@@ -191,7 +192,7 @@ test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/so
vocab_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
global_conf = {
- lrate = 1, wcost = 1e-6, momentum = 0,
+ lrate = 1, wcost = 1e-5, momentum = 0,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.CuMatrixFloat,
nn_act_default = 0,
@@ -297,7 +298,7 @@ for iter = start_iter, global_conf.max_iter, 1 do
result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update!
ppl_rec[iter].valid = result:ppl_all("rnn")
ppl_rec[iter].lr = global_conf.lrate
- if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then
+ if ((ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) and iter > global_conf.decay_iter) then
global_conf.lrate = (global_conf.lrate * 0.6)
end
if (ppl_rec[iter].valid < ppl_last) then
@@ -306,10 +307,10 @@ for iter = start_iter, global_conf.max_iter, 1 do
else
printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre)
os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter))
- if (lr_half == true) then
- printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre)
- break
- end
+ --if (lr_half == true) then
+ -- printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre)
+ -- break
+ --end
end
if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then
lr_half = true