diff options
author | txh18 <[email protected]> | 2015-11-13 15:17:07 +0800 |
---|---|---|
committer | txh18 <[email protected]> | 2015-11-13 15:17:07 +0800 |
commit | ee81548f79c496f4e2f3e12325150bb96ecb432e (patch) | |
tree | 025cc547058b79c630c0bb39eb29621016c5d07a | |
parent | d5111ef959c2871bfd8977889cf632acce2660d5 (diff) |
saving param file for every iter
-rw-r--r-- | nerv/examples/lmptb/tnn_ptb_main.lua | 56 |
1 files changed, 35 insertions, 21 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua index 5cc92c4..a59a44b 100644 --- a/nerv/examples/lmptb/tnn_ptb_main.lua +++ b/nerv/examples/lmptb/tnn_ptb_main.lua @@ -14,10 +14,11 @@ local LMTrainer = nerv.LMTrainer --global_conf: table --first_time: bool --Returns: a ParamRepo -function prepare_parameters(global_conf, first_time) +function prepare_parameters(global_conf, iter) printf("%s preparing parameters...\n", global_conf.sche_log_pre) - if (first_time) then + if (iter == -1) then --first time + printf("%s first time, generating parameters...\n", global_conf.sche_log_pre) ltp_ih = nerv.LinearTransParam("ltp_ih", global_conf) ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size(), global_conf.hidden_size) --index 0 is for zero, others correspond to vocab index(starting from 1) ltp_ih.trans:generate(global_conf.param_random) @@ -38,17 +39,20 @@ function prepare_parameters(global_conf, first_time) bp_o.trans = global_conf.cumat_type(1, global_conf.vocab:size()) bp_o.trans:generate(global_conf.param_random) - local f = nerv.ChunkFile(global_conf.param_fn, 'w') + local f = nerv.ChunkFile(global_conf.param_fn .. '.0', 'w') f:write_chunk(ltp_ih) f:write_chunk(ltp_hh) f:write_chunk(ltp_ho) f:write_chunk(bp_h) f:write_chunk(bp_o) f:close() + + return nil end + printf("%s loading parameter from file %s...\n", global_conf.sche_log_pre, global_conf.param_fn .. '.' .. tostring(iter)) local paramRepo = nerv.ParamRepo() - paramRepo:import({global_conf.param_fn}, nil, global_conf) + paramRepo:import({global_conf.param_fn .. '.' .. tostring(iter)}, nil, global_conf) printf("%s preparing parameters end.\n", global_conf.sche_log_pre) @@ -139,8 +143,8 @@ function prepare_tnn(global_conf, layerRepo) return tnn end -function load_net(global_conf) - local paramRepo = prepare_parameters(global_conf, false) +function load_net(global_conf, next_iter) + local paramRepo = prepare_parameters(global_conf, next_iter) local layerRepo = prepare_layers(global_conf, paramRepo) local tnn = prepare_tnn(global_conf, layerRepo) return tnn, paramRepo @@ -152,9 +156,10 @@ local set = arg[1] --"test" if (set == "ptb") then data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata' -train_fn = data_dir .. '/ptb.train.txt.adds' +train_fn = data_dir .. '/ptb.valid.txt.adds' valid_fn = data_dir .. '/ptb.valid.txt.adds' test_fn = data_dir .. '/ptb.test.txt.adds' +vocab_fn = data_dir .. '/vocab' global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, @@ -171,6 +176,7 @@ global_conf = { train_fn = train_fn, valid_fn = valid_fn, test_fn = test_fn, + vocab_fn = vocab_fn, sche_log_pre = "[SCHEDULER]:", log_w_num = 40000, --give a message when log_w_num words have been processed timer = nerv.Timer() @@ -181,6 +187,7 @@ else valid_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text' train_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text' test_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text' +vocab_fn = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/m-tests/some-text' global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, @@ -197,6 +204,7 @@ global_conf = { train_fn = train_fn, valid_fn = valid_fn, test_fn = test_fn, + vocab_fn = vocab_fn, sche_log_pre = "[SCHEDULER]:", log_w_num = 10, --give a message when log_w_num words have been processed timer = nerv.Timer() @@ -221,28 +229,30 @@ os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf) local vocab = nerv.LMVocab() global_conf["vocab"] = vocab -global_conf.vocab:build_file(global_conf.train_fn, false) +printf("%s building vocab...\n", global_conf.sche_log_pre) +global_conf.vocab:build_file(global_conf.vocab_fn, false) -prepare_parameters(global_conf, true) --randomly generate parameters +prepare_parameters(global_conf, -1) --randomly generate parameters print("===INITIAL VALIDATION===") -local tnn, paramRepo = load_net(global_conf) +local tnn, paramRepo = load_net(global_conf, 0) local result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update! nerv.LMUtil.wait(3) ppl_rec = {} -lr_rec = {} ppl_rec[0] = {} ppl_rec[0].valid = result:ppl_all("rnn") ppl_last = ppl_rec[0].valid ppl_rec[0].train = 0 ppl_rec[0].test = 0 -lr_rec[0] = 0 +ppl_rec[0].lr = 0 print() local lr_half = false +local final_iter for iter = 1, global_conf.max_iter, 1 do - tnn, paramRepo = load_net(global_conf) - printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate) + final_iter = iter global_conf.sche_log_pre = "[SCHEDULER ITER"..iter.." LR"..global_conf.lrate.."]:" + tnn, paramRepo = load_net(global_conf, iter - 1) + printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate) result = LMTrainer.lm_process_file(global_conf, global_conf.train_fn_shuf, tnn, true) --true update! ppl_rec[iter] = {} ppl_rec[iter].train = result:ppl_all("rnn") @@ -256,29 +266,33 @@ for iter = 1, global_conf.max_iter, 1 do printf("===VALIDATION %d===\n", iter) result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update! ppl_rec[iter].valid = result:ppl_all("rnn") - lr_rec[iter] = global_conf.lrate + ppl_rec[iter].lr = global_conf.lrate if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then global_conf.lrate = (global_conf.lrate * 0.6) lr_half = true end if (ppl_rec[iter].valid < ppl_last) then - printf("%s saving net to file %s...\n", global_conf.sche_log_pre, global_conf.param_fn) - paramRepo:export(global_conf.param_fn, nil) + printf("%s PPL improves, saving net to file %s.%d...\n", global_conf.sche_log_pre, global_conf.param_fn, iter) + paramRepo:export(global_conf.param_fn .. '.' .. tostring(iter), nil) ppl_last = ppl_rec[iter].valid else - printf("%s PPL did not improve, rejected...\n", global_conf.sche_log_pre) - if (lr_halg == true) then + printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre) + os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter)) + if (lr_half == true) then printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre) + break end end printf("\n") nerv.LMUtil.wait(2) end printf("===VALIDATION PPL record===\n") -for i = 0, #ppl_rec do printf("<ITER%d LR%.5f train:%.3f valid:%.3f test:%.3f> \n", i, lr_rec[i], ppl_rec[i].train, ppl_rec[i].valid, ppl_rec[i].test) end +for i, _ in pairs(ppl_rec) do + printf("<ITER%d LR%.5f train:%.3f valid:%.3f test:%.3f> \n", i, ppl_rec[i].lr, ppl_rec[i].train, ppl_rec[i].valid, ppl_rec[i].test) +end printf("\n") printf("===FINAL TEST===\n") global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:" -tnn, paramRepo = load_net(global_conf) +tnn, paramRepo = load_net(global_conf, final_iter) LMTrainer.lm_process_file(global_conf, global_conf.test_fn, tnn, false) --false update! |