diff options
author | txh18 <cloudygooseg@gmail.com> | 2016-01-19 21:30:26 +0800 |
---|---|---|
committer | txh18 <cloudygooseg@gmail.com> | 2016-01-19 21:30:26 +0800 |
commit | 37dec2610c92d03813c4e91ed58791ab60da6646 (patch) | |
tree | 61b9bc1d043883bb5d85dcb86cfb621396d75c41 /nerv/examples/lmptb/sample_grulm_ptb_main.lua | |
parent | ae051c0195374ddfab6a0e693d2c2cfa34f24e99 (diff) |
big fixes about lm_sampler
Diffstat (limited to 'nerv/examples/lmptb/sample_grulm_ptb_main.lua')
-rw-r--r-- | nerv/examples/lmptb/sample_grulm_ptb_main.lua | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/nerv/examples/lmptb/sample_grulm_ptb_main.lua b/nerv/examples/lmptb/sample_grulm_ptb_main.lua index b6351bd..9a13d36 100644 --- a/nerv/examples/lmptb/sample_grulm_ptb_main.lua +++ b/nerv/examples/lmptb/sample_grulm_ptb_main.lua @@ -178,7 +178,7 @@ function prepare_dagL(global_conf, layerRepo) dagL:init(global_conf.batch_size) nerv.printf("%s Initing DAGL end.\n", global_conf.sche_log_pre) - return tnn + return dagL end function load_net_tnn(global_conf, fn) @@ -191,8 +191,8 @@ end function load_net_dagL(global_conf, fn) prepare_parameters(global_conf, fn) local layerRepo = prepare_layers(global_conf) - local tnn = prepare_dagL(global_conf, layerRepo) - return tnn + local dagL = prepare_dagL(global_conf, layerRepo) + return dagL end local train_fn, valid_fn, test_fn @@ -424,6 +424,15 @@ if commands["sampling"] == 1 then local dagL = load_net_dagL(global_conf, global_conf.fn_to_sample) local sampler = nerv.LMSampler(global_conf) sampler:load_dagL(dagL) + for k = 1, 5 do + local res = sampler:lm_sample_rnn_dagL(10, {}) + for i = 1, #res do + for j = 1, #res[i] do + nerv.printf("%s ", res[i][j].w) + end + nerv.printf("\n") + end + end --global_conf.dropout_rate = 0 --LMTrainer.lm_process_file_rnn(global_conf, global_conf.test_fn, tnn, false) --false update! end --if commands["sampling"] |