diff options
Diffstat (limited to 'nerv/examples/lmptb/grulm_ptb_main.lua')
-rw-r--r-- | nerv/examples/lmptb/grulm_ptb_main.lua | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/grulm_ptb_main.lua b/nerv/examples/lmptb/grulm_ptb_main.lua index 6095b12..838a665 100644 --- a/nerv/examples/lmptb/grulm_ptb_main.lua +++ b/nerv/examples/lmptb/grulm_ptb_main.lua @@ -198,6 +198,7 @@ qdata_dir = root_dir .. '/ptb/questionGen/gen' global_conf = { lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 5, cumat_type = nerv.CuMatrixFloat, + select_gpu = 0, mmat_type = nerv.MMatrixFloat, nn_act_default = 0, @@ -359,7 +360,14 @@ commands = nerv.SUtil.parse_commands_set(commands_str) if start_lr ~= nil then global_conf.lrate = start_lr end - + +nerv.printf("detecting gconf.select_gpu...\n") +if global_conf.select_gpu then + nerv.printf("select gpu to %d\n", global_conf.select_gpu) + global_conf.cumat_type.select_gpu(global_conf.select_gpu) + nerv.LMUtil.wait(1) +end + nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir) nerv.LMUtil.wait(2) os.execute("mkdir -p "..global_conf.work_dir) |