diff options
author | txh18 <cloudygooseg@gmail.com> | 2015-10-23 20:58:27 +0800 |
---|---|---|
committer | txh18 <cloudygooseg@gmail.com> | 2015-10-23 20:58:27 +0800 |
commit | 8e7d9453520840b6a5e269a101ca72b4a7ab36fa (patch) | |
tree | 1c4278c0939d8fa07487aac6993f677f0480323f /nerv/examples/lmptb/main.lua | |
parent | 1234c026869ab052e898cc2541143fe4a22312b6 (diff) |
lmptb can be run after merge from master
Diffstat (limited to 'nerv/examples/lmptb/main.lua')
-rw-r--r-- | nerv/examples/lmptb/main.lua | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/nerv/examples/lmptb/main.lua b/nerv/examples/lmptb/main.lua index 8764998..74ce407 100644 --- a/nerv/examples/lmptb/main.lua +++ b/nerv/examples/lmptb/main.lua @@ -220,6 +220,7 @@ function propagateFile(global_conf, dagL, fn, config) end if (result["rnn"].cn_w % global_conf.log_w_num == 0) then printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date()) + printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn")); for key, value in pairs(global_conf.timer.rec) do printf("\t [global_conf.timer]: time spent on %s:%.5fs\n", key, value) end @@ -255,10 +256,11 @@ end local set = "ptb" if (set == "ptb") then - train_fn = "/slfs1/users/txh18/workspace/nerv-project/nerv/nerv/examples/lmptb/PTBdata/ptb.train.txt" - valid_fn = "/slfs1/users/txh18/workspace/nerv-project/nerv/nerv/examples/lmptb/PTBdata/ptb.valid.txt" - test_fn = "/slfs1/users/txh18/workspace/nerv-project/nerv/nerv/examples/lmptb/PTBdata/ptb.test.txt" - work_dir_base = "/slfs1/users/txh18/workspace/nerv-project/lmptb-work" + data_dir = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata" + train_fn = data_dir.."/ptb.train.txt" + valid_fn = data_dir.."/ptb.valid.txt" + test_fn = data_dir.."/ptb.test.txt" + work_dir_base = "/home/slhome/txh18/workspace/nerv/lmptb-work" global_conf = { lrate = 1, wcost = 1e-6, momentum = 0, cumat_type = nerv.CuMatrixFloat, @@ -275,7 +277,7 @@ if (set == "ptb") then valid_fn = valid_fn, test_fn = test_fn, sche_log_pre = "[SCHEDULER]:", - log_w_num = 50000, --give a message when log_w_num words have been processed + log_w_num = 10000, --give a message when log_w_num words have been processed timer = nerv.Timer() } global_conf.work_dir = work_dir_base.."/h"..global_conf.hidden_size.."bp"..global_conf.bptt.."slr"..global_conf.lrate..os.date("_%bD%dH%H") @@ -323,6 +325,7 @@ os.execute("mkdir -p "..global_conf.work_dir) scheduler = " printf(\"===INITIAL VALIDATION===\\n\") \ dagL, paramRepo = load_net(global_conf) \ + printf(\"===INITIAL VALIDATION===\\n\") \ local result = propagateFile(global_conf, dagL, global_conf.valid_fn, {do_train = false, report_word = false}) \ ppl_rec = {} \ lr_rec = {} \ |