aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/main.lua
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-10-23 20:58:27 +0800
committertxh18 <cloudygooseg@gmail.com>2015-10-23 20:58:27 +0800
commit8e7d9453520840b6a5e269a101ca72b4a7ab36fa (patch)
tree1c4278c0939d8fa07487aac6993f677f0480323f /nerv/examples/lmptb/main.lua
parent1234c026869ab052e898cc2541143fe4a22312b6 (diff)
lmptb can be run after merge from master
Diffstat (limited to 'nerv/examples/lmptb/main.lua')
-rw-r--r--nerv/examples/lmptb/main.lua13
1 files changed, 8 insertions, 5 deletions
diff --git a/nerv/examples/lmptb/main.lua b/nerv/examples/lmptb/main.lua
index 8764998..74ce407 100644
--- a/nerv/examples/lmptb/main.lua
+++ b/nerv/examples/lmptb/main.lua
@@ -220,6 +220,7 @@ function propagateFile(global_conf, dagL, fn, config)
end
if (result["rnn"].cn_w % global_conf.log_w_num == 0) then
printf("%s %d words processed %s.\n", global_conf.sche_log_pre, result["rnn"].cn_w, os.date())
+ printf("\t%s log prob per sample :%f.\n", global_conf.sche_log_pre, result:logp_sample("rnn"));
for key, value in pairs(global_conf.timer.rec) do
printf("\t [global_conf.timer]: time spent on %s:%.5fs\n", key, value)
end
@@ -255,10 +256,11 @@ end
local set = "ptb"
if (set == "ptb") then
- train_fn = "/slfs1/users/txh18/workspace/nerv-project/nerv/nerv/examples/lmptb/PTBdata/ptb.train.txt"
- valid_fn = "/slfs1/users/txh18/workspace/nerv-project/nerv/nerv/examples/lmptb/PTBdata/ptb.valid.txt"
- test_fn = "/slfs1/users/txh18/workspace/nerv-project/nerv/nerv/examples/lmptb/PTBdata/ptb.test.txt"
- work_dir_base = "/slfs1/users/txh18/workspace/nerv-project/lmptb-work"
+ data_dir = "/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata"
+ train_fn = data_dir.."/ptb.train.txt"
+ valid_fn = data_dir.."/ptb.valid.txt"
+ test_fn = data_dir.."/ptb.test.txt"
+ work_dir_base = "/home/slhome/txh18/workspace/nerv/lmptb-work"
global_conf = {
lrate = 1, wcost = 1e-6, momentum = 0,
cumat_type = nerv.CuMatrixFloat,
@@ -275,7 +277,7 @@ if (set == "ptb") then
valid_fn = valid_fn,
test_fn = test_fn,
sche_log_pre = "[SCHEDULER]:",
- log_w_num = 50000, --give a message when log_w_num words have been processed
+ log_w_num = 10000, --give a message when log_w_num words have been processed
timer = nerv.Timer()
}
global_conf.work_dir = work_dir_base.."/h"..global_conf.hidden_size.."bp"..global_conf.bptt.."slr"..global_conf.lrate..os.date("_%bD%dH%H")
@@ -323,6 +325,7 @@ os.execute("mkdir -p "..global_conf.work_dir)
scheduler = " printf(\"===INITIAL VALIDATION===\\n\") \
dagL, paramRepo = load_net(global_conf) \
+ printf(\"===INITIAL VALIDATION===\\n\") \
local result = propagateFile(global_conf, dagL, global_conf.valid_fn, {do_train = false, report_word = false}) \
ppl_rec = {} \
lr_rec = {} \