diff options
-rw-r--r-- | nerv/examples/lmptb/tnn_ptb_main.lua | 64 |
1 files changed, 46 insertions, 18 deletions
diff --git a/nerv/examples/lmptb/tnn_ptb_main.lua b/nerv/examples/lmptb/tnn_ptb_main.lua index a59a44b..f978247 100644 --- a/nerv/examples/lmptb/tnn_ptb_main.lua +++ b/nerv/examples/lmptb/tnn_ptb_main.lua @@ -150,7 +150,8 @@ function load_net(global_conf, next_iter) return tnn, paramRepo end -local train_fn, valid_fn, test_fn, global_conf +local train_fn, valid_fn, test_fn +global_conf = {} local set = arg[1] --"test" if (set == "ptb") then @@ -217,11 +218,30 @@ global_conf.train_fn_shuf = global_conf.work_dir .. '/train_fn_shuf' global_conf.train_fn_shuf_bak = global_conf.train_fn_shuf .. '_bak' global_conf.param_fn = global_conf.work_dir .. "/params" -printf("%s printing global_conf\n", global_conf.sche_log_pre) +lr_half = false --can not be local, to be set by loadstring +start_iter = -1 +ppl_last = 100000 +if (arg[2] ~= nil) then + printf("%s applying arg[2](%s)...\n", global_conf.sche_log_pre, arg[2]) + loadstring(arg[2])() + nerv.LMUtil.wait(0.5) +else + printf("%s not user setting, all default...\n", global_conf.sche_log_pre) +end + +----------------printing options--------------------------------- +printf("%s printing global_conf...\n", global_conf.sche_log_pre) for id, value in pairs(global_conf) do print(id, value) end nerv.LMUtil.wait(2) +printf("%s printing training scheduling options...\n", global_conf.sche_log_pre) +print("lr_half", lr_half) +print("start_iter", start_iter) +print("ppl_last", ppl_last) +printf("%s printing training scheduling end.\n", global_conf.sche_log_pre) +nerv.LMUtil.wait(2) +------------------printing options end------------------------------ printf("%s creating work_dir...\n", global_conf.sche_log_pre) os.execute("mkdir -p "..global_conf.work_dir) @@ -231,25 +251,33 @@ local vocab = nerv.LMVocab() global_conf["vocab"] = vocab printf("%s building vocab...\n", global_conf.sche_log_pre) global_conf.vocab:build_file(global_conf.vocab_fn, false) +ppl_rec = {} -prepare_parameters(global_conf, -1) --randomly generate parameters +if (start_iter == -1) then + prepare_parameters(global_conf, -1) --randomly generate parameters -print("===INITIAL VALIDATION===") -local tnn, paramRepo = load_net(global_conf, 0) -local result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update! -nerv.LMUtil.wait(3) -ppl_rec = {} -ppl_rec[0] = {} -ppl_rec[0].valid = result:ppl_all("rnn") -ppl_last = ppl_rec[0].valid -ppl_rec[0].train = 0 -ppl_rec[0].test = 0 -ppl_rec[0].lr = 0 -print() -local lr_half = false + print("===INITIAL VALIDATION===") + local tnn, paramRepo = load_net(global_conf, 0) + local result = LMTrainer.lm_process_file(global_conf, global_conf.valid_fn, tnn, false) --false update! + nerv.LMUtil.wait(1) + ppl_rec[0] = {} + ppl_rec[0].valid = result:ppl_all("rnn") + ppl_last = ppl_rec[0].valid + ppl_rec[0].train = 0 + ppl_rec[0].test = 0 + ppl_rec[0].lr = 0 + + start_iter = 1 + + print() +end + +if (start_iter == 0) then + nerv.error("start_iter should not be zero") +end local final_iter -for iter = 1, global_conf.max_iter, 1 do - final_iter = iter +for iter = start_iter, global_conf.max_iter, 1 do + final_iter = iter --for final testing global_conf.sche_log_pre = "[SCHEDULER ITER"..iter.." LR"..global_conf.lrate.."]:" tnn, paramRepo = load_net(global_conf, iter - 1) printf("===ITERATION %d LR %f===\n", iter, global_conf.lrate) |