aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortxh18 <[email protected]>2016-02-05 22:46:53 +0800
committertxh18 <[email protected]>2016-02-05 22:46:53 +0800
commit7a421571300417dba0d5c703d9a460ad19aeef14 (patch)
tree67062f09266ecef0580434435f4388749be3fc4c
parent3d7a2be2d8ac3083617df2b7194921971f0ac94e (diff)
enhanced m-test/lm_sample
-rw-r--r--nerv/examples/lmptb/m-tests/lm_sampler_test.lua45
1 files changed, 30 insertions, 15 deletions
diff --git a/nerv/examples/lmptb/m-tests/lm_sampler_test.lua b/nerv/examples/lmptb/m-tests/lm_sampler_test.lua
index 42a5787..0313d77 100644
--- a/nerv/examples/lmptb/m-tests/lm_sampler_test.lua
+++ b/nerv/examples/lmptb/m-tests/lm_sampler_test.lua
@@ -362,15 +362,12 @@ global_conf = {
end
-lr_half = false --can not be local, to be set by loadstring
-start_iter = -1
-start_lr = nil
-ppl_last = 100000
commands_str = "sampling" --"train:test"
commands = {}
-test_iter = -1
---for testout(question)
-q_file = "/home/slhome/txh18/workspace/ptb/questionGen/gen/ptb.test.txt.q10rs1_Msss.adds"
+test_iter = -1 --obselete
+random_seed = 1
+sample_num = 10
+out_fn = nil
if arg[2] ~= nil then
nerv.printf("%s applying arg[2](%s)...\n", global_conf.sche_log_pre, arg[2])
@@ -407,16 +404,16 @@ end
nerv.LMUtil.wait(2)
nerv.printf("%s printing training scheduling options...\n", global_conf.sche_log_pre)
-nerv.printf("lr_half:\t%s\n", tostring(lr_half))
-nerv.printf("start_iter:\t%s\n", tostring(start_iter))
-nerv.printf("ppl_last:\t%s\n", tostring(ppl_last))
nerv.printf("commands_str:\t%s\n", commands_str)
nerv.printf("test_iter:\t%s\n", tostring(test_iter))
+nerv.printf("random_seed:\t%s\n", tostring(random_seed))
+nerv.printf("sample_num:\t%s\n", tostring(sample_num))
+nerv.printf("out_fn:\t%s\n", tostring(out_fn))
nerv.printf("%s printing training scheduling end.\n", global_conf.sche_log_pre)
nerv.LMUtil.wait(2)
------------------printing options end------------------------------
-math.randomseed(1)
+math.randomseed(random_seed)
local vocab = nerv.LMVocab()
global_conf["vocab"] = vocab
@@ -438,15 +435,33 @@ if commands["sampling"] == 1 then
nerv.printf("===SAMPLE===\n")
global_conf.sche_log_pre = "[SCHEDULER SAMPLING]:"
local sampler = prepare_sampler(sm_conf)
- for k = 1, 1 do
- local res = sampler:lm_sample_rnn_dagL(10, {})
+ local out_fh = nil
+ if out_fn ~= nil then
+ out_fh = assert(io.open(out_fn, "w"))
+ nerv.printf("%s outputing samples to file \"%s\"...\n", global_conf.sche_log_pre, out_fn)
+ end
+ for k = 1, sample_num do
+ local res = sampler:lm_sample_rnn_dagL(1, {})
for i = 1, #res do
+ if out_fh == nil then nerv.printf("lm_sampler_output_sample: ") end
for j = 1, #res[i] do
- nerv.printf("%s(%f) ", res[i][j].w, res[i][j].p)
+ if out_fh == nil then
+ nerv.printf("%s %f ", res[i][j].w, res[i][j].p)
+ else
+ out_fh:write(nerv.sprintf("%s %f ", res[i][j].w, res[i][j].p))
+ end
+ end
+ if out_fh == nil then
+ nerv.printf("\n")
+ else
+ out_fh:write(nerv.sprintf("\n"))
end
- nerv.printf("\n")
end
+ if k % 100 == 0 and out_fh ~= nil then nerv.printf("%s %d sample done\n", global_conf.sche_log_pre, k) end
end
+
+ if out_fh ~= nil then out_fh:close() end
+ nerv.printf("%s complete,bye\n", global_conf.sche_log_pre)
--global_conf.dropout_rate = 0
--LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update!
end --if commands["sampling"]