diff options
Diffstat (limited to 'nerv/examples/lmptb/grulm_ptb_main.lua')
-rw-r--r-- | nerv/examples/lmptb/grulm_ptb_main.lua | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/grulm_ptb_main.lua b/nerv/examples/lmptb/grulm_ptb_main.lua index ef5d7f9..4a3f39f 100644 --- a/nerv/examples/lmptb/grulm_ptb_main.lua +++ b/nerv/examples/lmptb/grulm_ptb_main.lua @@ -198,6 +198,7 @@ qdata_dir = root_dir .. '/ptb/questionGen/gen' global_conf = { lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 5, cumat_type = nerv.CuMatrixFloat, + select_gpu = 0, mmat_type = nerv.MMatrixFloat, nn_act_default = 0, @@ -259,7 +260,7 @@ global_conf = { elseif (set == "twitter") then data_dir = root_dir .. '/twitter_new/DATA' -train_fn = data_dir .. '/twitter.choose2.adds' +train_fn = data_dir .. '/twitter.choose.adds' valid_fn = data_dir .. '/twitter.valid.adds' test_fn = data_dir .. '/comm.test.choose-ppl.adds' vocab_fn = data_dir .. '/twitter.choose.train.vocab' @@ -359,7 +360,14 @@ commands = nerv.SUtil.parse_commands_set(commands_str) if start_lr ~= nil then global_conf.lrate = start_lr end - + +nerv.printf("detecting gconf.select_gpu...\n") +if global_conf.select_gpu then + nerv.printf("select gpu to %d\n", global_conf.select_gpu) + global_conf.cumat_type.select_gpu(global_conf.select_gpu) + nerv.LMUtil.wait(1) +end + nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir) nerv.LMUtil.wait(2) os.execute("mkdir -p "..global_conf.work_dir) @@ -388,10 +396,10 @@ nerv.LMUtil.wait(2) math.randomseed(1) -local vocab = nerv.LMVocab() +local vocab = nerv.LMVocab(global_conf) global_conf["vocab"] = vocab nerv.printf("%s building vocab...\n", global_conf.sche_log_pre) -global_conf.vocab:build_file(global_conf.vocab_fn, false) +global_conf.vocab:build_file(global_conf.vocab_fn) ppl_rec = {} local final_iter = -1 |