aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <cloudygooseg@gmail.com>2015-12-04 17:53:31 +0800
committertxh18 <cloudygooseg@gmail.com>2015-12-04 17:53:31 +0800
commitcce6efcdfbe50a59e260cb5d55ae2c77326dc67c (patch)
tree074e41d6090c71cb6aa88a0ab5919878b3393de2
parent618450eb71817ded45c422f35d8fede2d52a66b2 (diff)
added testout command for lstmlm
-rw-r--r--nerv/examples/lmptb/lm_trainer.lua15
-rw-r--r--nerv/examples/lmptb/lstmlm_ptb_main.lua44
2 files changed, 48 insertions, 11 deletions
diff --git a/nerv/examples/lmptb/lm_trainer.lua b/nerv/examples/lmptb/lm_trainer.lua
index 9ef4794..58d5bfc 100644
--- a/nerv/examples/lmptb/lm_trainer.lua
+++ b/nerv/examples/lmptb/lm_trainer.lua
@@ -17,8 +17,19 @@ function nerv.BiasParam:update_by_gradient(gradient)
end
--Returns: LMResult
-function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train)
- local reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
+function LMTrainer.lm_process_file_rnn(global_conf, fn, tnn, do_train, p_conf)
+ if p_conf == nil then
+ p_conf = {}
+ end
+ local reader
+ if p_conf.one_sen_report == true then --report log prob one by one sentence
+ if do_train == true then
+ nerv.warning("LMTrainer.lm_process_file_rnn: warning, one_sen_report is true while do_train is also true, strange")
+ end
+ reader = nerv.LMSeqReader(global_conf, 1, global_conf.max_sen_len, global_conf.vocab)
+ else
+ reader = nerv.LMSeqReader(global_conf, global_conf.batch_size, global_conf.chunk_size, global_conf.vocab)
+ end
reader:open_file(fn)
local result = nerv.LMResult(global_conf, global_conf.vocab)
result:init("rnn")
diff --git a/nerv/examples/lmptb/lstmlm_ptb_main.lua b/nerv/examples/lmptb/lstmlm_ptb_main.lua
index 681c308..9f02324 100644
--- a/nerv/examples/lmptb/lstmlm_ptb_main.lua
+++ b/nerv/examples/lmptb/lstmlm_ptb_main.lua
@@ -195,20 +195,23 @@ local set = arg[1] --"test"
if (set == "ptb") then
-data_dir = '/home/slhome/txh18/workspace/nerv/nerv/nerv/examples/lmptb/PTBdata'
+root_dir = '/home/slhome/txh18/workspace'
+data_dir = root_dir .. '/ptb/DATA'
train_fn = data_dir .. '/ptb.train.txt.adds'
valid_fn = data_dir .. '/ptb.valid.txt.adds'
test_fn = data_dir .. '/ptb.test.txt.adds'
vocab_fn = data_dir .. '/vocab'
+qdata_dir = root_dir .. '/ptb/questionGen/gen'
+
global_conf = {
- lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 2,
+ lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 5,
cumat_type = nerv.CuMatrixFloat,
mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
- hidden_size = 650,
- layer_num = 2,
+ hidden_size = 300,
+ layer_num = 1,
chunk_size = 15,
batch_size = 20,
max_iter = 45,
@@ -221,10 +224,11 @@ global_conf = {
valid_fn = valid_fn,
test_fn = test_fn,
vocab_fn = vocab_fn,
+ max_sen_len = 90,
sche_log_pre = "[SCHEDULER]:",
log_w_num = 40000, --give a message when log_w_num words have been processed
timer = nerv.Timer(),
- work_dir_base = '/home/slhome/txh18/workspace/nerv/play/ptbEXP/tnn_lstm_test'
+ work_dir_base = '/home/slhome/txh18/workspace/ptb/EXP-nerv/lstmlm_v1.0'
}
elseif (set == "msr_sc") then
@@ -303,6 +307,9 @@ local commands_str = "train:test"
local commands = {}
local test_iter = -1
+--for testout(question)
+local q_file = "ptb.test.txt.q10rs1_Msss.adds"
+
if arg[2] ~= nil then
nerv.printf("%s applying arg[2](%s)...\n", global_conf.sche_log_pre, arg[2])
loadstring(arg[2])()
@@ -311,21 +318,22 @@ else
nerv.printf("%s no user setting, all default...\n", global_conf.sche_log_pre)
end
-global_conf.work_dir = global_conf.work_dir_base .. 'h' .. global_conf.hidden_size .. 'l' .. global_conf.layer_num --.. 'ch' .. global_conf.chunk_size .. 'ba' .. global_conf.batch_size .. 'slr' .. global_conf.lrate .. 'wc' .. global_conf.wcost
+global_conf.work_dir = global_conf.work_dir_base .. 'h' .. global_conf.hidden_size .. 'l' .. global_conf.layer_num .. 'ch' .. global_conf.chunk_size .. 'ba' .. global_conf.batch_size .. 'slr' .. global_conf.lrate .. 'wc' .. global_conf.wcost
global_conf.train_fn_shuf = global_conf.work_dir .. '/train_fn_shuf'
global_conf.train_fn_shuf_bak = global_conf.train_fn_shuf .. '_bak'
global_conf.param_fn = global_conf.work_dir .. "/params"
global_conf.dropout_list = nerv.SUtil.parse_schedule(global_conf.dropout_str)
-global_conf.log_fn = global_conf.work_dir .. '/lstm_tnn_' .. commands_str .. '_log'
+global_conf.log_fn = global_conf.work_dir .. '/log_lstm_tnn_' .. commands_str ..os.date("_TT%X_%m_%d",os.time())
commands = nerv.SUtil.parse_commands_set(commands_str)
-nerv.printf("%s creating work_dir...\n", global_conf.sche_log_pre)
-nerv.LMUtil.wait(1)
+nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir)
+nerv.LMUtil.wait(2)
os.execute("mkdir -p "..global_conf.work_dir)
os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf)
--redirecting log outputs!
nerv.SUtil.log_redirect(global_conf.log_fn)
+nerv.LMUtil.wait(2)
----------------printing options---------------------------------
nerv.printf("%s printing global_conf...\n", global_conf.sche_log_pre)
@@ -441,3 +449,21 @@ if commands["test"] == 1 then
global_conf.dropout_rate = 0
LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update!
end --if commands["test"]
+
+if commands["testout"] == 1 then
+ nerv.printf("===TEST OUT===\n")
+ nerv.printf("q_file:\t%s\n", q_file)
+ local q_fn = qdata_dir .. q_file
+ global_conf.sche_log_pre = "[SCHEDULER FINAL_TEST]:"
+ if final_iter ~= -1 and test_iter == -1 then
+ test_iter = final_iter
+ end
+ if test_iter == -1 then
+ test_iter = "final"
+ end
+ tnn = load_net(global_conf, test_iter)
+ global_conf.dropout_rate = 0
+ LMTrainer.lm_process_file_rnn(global_conf, q_fn, tnn, false) --false update!
+end --if commands["testout"]
+
+