diff options
author | txh18 <cloudygooseg@gmail.com> | 2016-01-19 19:28:42 +0800 |
---|---|---|
committer | txh18 <cloudygooseg@gmail.com> | 2016-01-19 19:28:42 +0800 |
commit | ae051c0195374ddfab6a0e693d2c2cfa34f24e99 (patch) | |
tree | d287fb6cfaa787342fe449da33769074e24a45f5 /nerv/examples/lmptb/sample_grulm_ptb_main.lua | |
parent | acc2b72b08aa7fe16b85fb47b0308e3da5adae66 (diff) |
added lm_sampler, todo:test it
Diffstat (limited to 'nerv/examples/lmptb/sample_grulm_ptb_main.lua')
-rw-r--r-- | nerv/examples/lmptb/sample_grulm_ptb_main.lua | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/nerv/examples/lmptb/sample_grulm_ptb_main.lua b/nerv/examples/lmptb/sample_grulm_ptb_main.lua index 86209da..b6351bd 100644 --- a/nerv/examples/lmptb/sample_grulm_ptb_main.lua +++ b/nerv/examples/lmptb/sample_grulm_ptb_main.lua @@ -5,6 +5,7 @@ require 'lmptb.layer.init' --require 'tnn.init' require 'lmptb.lmseqreader' require 'lm_trainer' +require 'lm_sampler' --[[global function rename]]-- --local printf = nerv.printf @@ -134,7 +135,7 @@ function prepare_tnn(global_conf, layerRepo) end function prepare_dagL(global_conf, layerRepo) - nerv.printf("%s Generate and initing TNN ...\n", global_conf.sche_log_pre) + nerv.printf("%s Generate and initing dagL ...\n", global_conf.sche_log_pre) --input: input_w, input_w, ... input_w_now, last_activation local connections_t = { @@ -376,15 +377,12 @@ commands = nerv.SUtil.parse_commands_set(commands_str) if start_lr ~= nil then global_conf.lrate = start_lr end - -nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir) -nerv.LMUtil.wait(2) -os.execute("mkdir -p "..global_conf.work_dir) -os.execute("cp " .. global_conf.train_fn .. " " .. global_conf.train_fn_shuf) +--[[ --redirecting log outputs! nerv.SUtil.log_redirect(global_conf.log_fn) nerv.LMUtil.wait(2) +]]-- ----------------printing options--------------------------------- nerv.printf("%s printing global_conf...\n", global_conf.sche_log_pre) @@ -424,6 +422,8 @@ if commands["sampling"] == 1 then nerv.printf("===SAMPLE===\n") global_conf.sche_log_pre = "[SCHEDULER SAMPLING]:" local dagL = load_net_dagL(global_conf, global_conf.fn_to_sample) + local sampler = nerv.LMSampler(global_conf) + sampler:load_dagL(dagL) --global_conf.dropout_rate = 0 --LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update! end --if commands["sampling"] |