diff options
Diffstat (limited to 'nerv/examples/lmptb/tnn_ptb_main.lua')
-rw-r--r-- | nerv/examples/lmptb/tnn_ptb_main.lua | 19 |
1 files changed, 7 insertions, 12 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua index f68311c..c37b217 100644 --- a/nerv/examples/lmptb/tnn_ptb_main.lua +++ b/nerv/examples/lmptb/tnn_ptb_main.lua @@ -17,7 +17,7 @@ local LMTrainer = nerv.LMTrainer function prepare_parameters(global_conf, iter) printf("%s preparing parameters...\n", global_conf.sche_log_pre) - if (iter == -1) then --first time + if iter == -1 then --first time printf("%s first time, generating parameters...\n", global_conf.sche_log_pre) ltp_ih = nerv.LinearTransParam("ltp_ih", global_conf) ltp_ih.trans = global_conf.cumat_type(global_conf.vocab:size(), global_conf.hidden_size) --index 0 is for zero, others correspond to vocab index(starting from 1) @@ -290,9 +290,11 @@ printf("%s building vocab...\n", global_conf.sche_log_pre) global_conf.vocab:build_file(global_conf.vocab_fn, false) ppl_rec = {} -if (start_iter == -1) then +if start_iter == -1 then prepare_parameters(global_conf, -1) --randomly generate parameters +end +if start_iter == -1 or start_iter == 0 then print("===INITIAL VALIDATION===") local tnn, paramRepo = load_net(global_conf, 0) local result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update! @@ -309,9 +311,6 @@ if (start_iter == -1) then print() end -if (start_iter == 0) then - nerv.error("start_iter should not be zero") -end local final_iter for iter = start_iter, global_conf.max_iter, 1 do final_iter = iter --for final testing @@ -335,21 +334,17 @@ for iter = start_iter, global_conf.max_iter, 1 do if ((ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) and iter > global_conf.decay_iter) then global_conf.lrate = (global_conf.lrate * 0.6) end - if (ppl_rec[iter].valid < ppl_last) then + if ppl_rec[iter].valid < ppl_last then printf("%s PPL improves, saving net to file %s.%d...\n", global_conf.sche_log_pre, global_conf.param_fn, iter) paramRepo:export(global_conf.param_fn .. '.' .. tostring(iter), nil) else printf("%s PPL did not improve, rejected, copying param file of last iter...\n", global_conf.sche_log_pre) os.execute('cp ' .. global_conf.param_fn..'.'..tostring(iter - 1) .. ' ' .. global_conf.param_fn..'.'..tostring(iter)) - --if (lr_half == true) then - -- printf("%s LR is already halfing, end training...\n", global_conf.sche_log_pre) - -- break - --end end - if (ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true) then + if ppl_last / ppl_rec[iter].valid < 1.0003 or lr_half == true then lr_half = true end - if (ppl_rec[iter].valid < ppl_last) then + if ppl_rec[iter].valid < ppl_last then ppl_last = ppl_rec[iter].valid end printf("\n") |