summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2015-11-13 15:17:07 +0800
committertxh18 <[email protected]>2015-11-13 15:17:07 +0800
commitee81548f79c496f4e2f3e12325150bb96ecb432e (patch)
tree025cc547058b79c630c0bb39eb29621016c5d07a
parentd5111ef959c2871bfd8977889cf632acce2660d5 (diff)
saving param file for every iter
-rw-r--r--nerv/examples/lmptb/tnn_ptb_main.lua56
1 files changed, 35 insertions, 21 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua
index 5cc92c4..a59a44b 100644
--- a/nerv/examples/lmptb/tnn_ptb_main.lua
+++ b/nerv/examples/lmptb/tnn_ptb_main.lua
@@ -14,10 +14,11 @@ local LMTrainer = nerv.LMTrainer
--global_conf: table
--first_time: bool
--Returns: a ParamRepo
-function prepare_parameters(global_conf, first_time)
+function prepare_parameters(global_conf, iter)
printf("%s preparing parameters...\n", global_conf.sche_log_pre)
- if (first_time) then
+ if (iter == -1) then --first time
+ printf("%s first time, generating parameters...\n", global_conf.sche_log_pre)
ltp_ih = nerv.LinearTransParam("ltp_ih", global_conf)
ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size(), global_conf.hidden_size) --index 0 is for zero, others correspond to vocab index(starting from 1)
ltp_ih.trans:generate(global_conf.param_random)
@@ -38,17 +39,20 @@ function prepare_parameters(global_conf, first_time)
bp_o.trans = global_conf.cumat_type(1, global_conf.vocab:size())
bp_o.trans:generate(global_conf.param_random)
- local f = nerv.ChunkFile(global_conf.param_fn, 'w')
+ local f = nerv.ChunkFile(global_conf.param_fn .. '.0', 'w')
f:write_chunk(ltp_ih)
f:write_chunk(ltp_hh)
f:write_chunk(ltp_ho)
f:write_chunk(bp_h)
f:write_chunk(bp_o)
f:close()
+
+ return nil
end
+ printf("%s loading parameter from file %s...\n", global_conf.sche_log_pre, global_conf.param_fn .. '.' .. tostring(iter))
local paramRepo = nerv.ParamRepo()
- paramRepo:import({global_conf.param_fn}, nil, global_conf)
+ paramRepo:import({global_conf.param_fn .. '.' .. tostring(iter)}, nil, global_conf)
printf("%s preparing parameters end.\n", global_conf.sche_log_pre)
@@ -139,8 +143,8 @@ function prepare_tnn(global_conf, layerRepo)
return tnn
end
-function load_net(global_conf)
- local paramRepo = prepare_parameters(global_conf, false)
+function load_net(global_conf, next_iter)
+ local paramRepo = prepare_parameters(global_conf, next_iter)
local layerRepo = prepare_layers(global_conf, paramRepo)
local tnn = prepare_tnn(global_conf, layerRepo)
return tnn, paramRepo
@@ -152,9 +156,10 @@ local set = arg[1] --"test"
if (set == "ptb") then
data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata'
-train_fn = data_dir .. '/ptb.train.txt.adds'
+train_fn = data_dir .. '/ptb.valid.txt.adds'
valid_fn = data_dir .. '/ptb.valid.txt.adds'
test_fn = data_dir .. '/ptb.test.txt.adds'
+vocab_fn = data_dir .. '/vocab'
global_conf = {
lrate = 1, wcost = 1e-6, momentum = 0,
@@ -171,6 +176,7 @@ global_conf = {
train_fn = train_fn,
valid_fn = valid_fn,
test_fn = test_fn,
+ vocab_fn = vocab_fn,
sche_log_pre = "[SCHEDULER]:",
log_w_num = 40000, --give a message when log_w_num words have been processed
timer = nerv.Timer()
@@ -181,6 +187,7 @@ else
valid_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
train_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
+vocab_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text'
global_conf = {
lrate = 1, wcost = 1e-6, momentum = 0,
@@ -197,6 +204,7 @@ global_conf = {
train_fn = train_fn,
valid_fn = valid_fn,
test_fn = test_fn,
+ vocab_fn = vocab_fn,
sche_log_pre = "[SCHEDULER]:",
log_w_num = 10, --give a message when log_w_num words have been processed
timer = nerv.Timer()
@@ -221,28 +229,30 @@ os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf)
local vocab = nerv.LMVocab()
global_conf["vocab"] = vocab
-global_conf.vocab:build_file(global_conf.train_fn, false)
+printf("%s building vocab...\n", global_conf.sche_log_pre)
+global_conf.vocab:build_file(global_conf.vocab_fn, false)
-prepare_parameters(global_conf, true) --randomly generate parameters
+prepare_parameters(global_conf, -1) --randomly generate parameters
print("===INITIAL VALIDATION===")
-local tnn, paramRepo = load_net(global_conf)
+local tnn, paramRepo = load_net(global_conf, 0)
local result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update!
nerv.LMUtil.wait(3)
ppl_rec = {}
-lr_rec = {}
ppl_rec[0] = {}
ppl_rec[0].valid = result:ppl_all("rnn")
ppl_last = ppl_rec[0].valid
ppl_rec[0].train = 0
ppl_rec[0].test = 0
-lr_rec[0] = 0
+ppl_rec[0].lr = 0
print()
local lr_half = false
+local final_iter
for iter = 1, global_conf.max_iter, 1 do
- tnn, paramRepo = load_net(global_conf)
- printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate)
+ final_iter = iter
global_conf.sche_log_pre = "[SCHEDULER ITER"..iter.." LR"..global_conf.lrate.."]:"
+ tnn, paramRepo = load_net(global_conf, iter - 1)
+ printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate)
result = LMTrainer.lm_process_file(global_conf, global_conf.train_fn_shuf, tnn, true) --true update!
ppl_rec[iter] = {}
ppl_rec[iter].train = result:ppl_all("rnn")
@@ -256,29 +266,33 @@ for iter = 1, global_conf.max_iter, 1 do
printf("===VALIDATION %d===\n", iter)
result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update!
ppl_rec[iter].valid = result:ppl_all("rnn")
- lr_rec[iter] = global_conf.lrate
+ ppl_rec[iter].lr = global_conf.lrate
if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then
global_conf.lrate = (global_conf.lrate * 0.6)
lr_half = true
end
if (ppl_rec[iter].valid < ppl_last) then
- printf("%s saving net to file %s...\n", global_conf.sche_log_pre, global_conf.param_fn)
- paramRepo:export(global_conf.param_fn, nil)
+ printf("%s PPL improves, saving net to file %s.%d...\n", global_conf.sche_log_pre, global_conf.param_fn, iter)
+ paramRepo:export(global_conf.param_fn .. '.' .. tostring(iter), nil)
ppl_last = ppl_rec[iter].valid
else
- printf("%s PPL did not improve, rejected...\n", global_conf.sche_log_pre)
- if (lr_halg == true) then
+ printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre)
+ os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter))
+ if (lr_half == true) then
printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre)
+ break
end
end
printf("\n")
nerv.LMUtil.wait(2)
end
printf("===VALIDATION PPL record===\n")
-for i = 0, #ppl_rec do printf("<ITER%d LR%.5f train:%.3f valid:%.3f test:%.3f> \n", i, lr_rec[i], ppl_rec[i].train, ppl_rec[i].valid, ppl_rec[i].test) end
+for i, _ in pairs(ppl_rec) do
+ printf("<ITER%d LR%.5f train:%.3f valid:%.3f test:%.3f> \n", i, ppl_rec[i].lr, ppl_rec[i].train, ppl_rec[i].valid, ppl_rec[i].test)
+end
printf("\n")
printf("===FINAL TEST===\n")
global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:"
-tnn, paramRepo = load_net(global_conf)
+tnn, paramRepo = load_net(global_conf, final_iter)
LMTrainer.lm_process_file(global_conf, global_conf.test_fn, tnn, false) --false update!