aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/grulm_ptb_main.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/grulm_ptb_main.lua')
-rw-r--r--nerv/examples/lmptb/grulm_ptb_main.lua10
1 files changed, 9 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/grulm_ptb_main.lua b/nerv/examples/lmptb/grulm_ptb_main.lua
index 6095b12..838a665 100644
--- a/nerv/examples/lmptb/grulm_ptb_main.lua
+++ b/nerv/examples/lmptb/grulm_ptb_main.lua
@@ -198,6 +198,7 @@ qdata_dir = root_dir .. '/ptb/questionGen/gen'
global_conf = {
lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 5,
cumat_type = nerv.CuMatrixFloat,
+ select_gpu = 0,
mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
@@ -359,7 +360,14 @@ commands = nerv.SUtil.parse_commands_set(commands_str)
if start_lr ~= nil then
global_conf.lrate = start_lr
end
-
+
+nerv.printf("detecting gconf.select_gpu...\n")
+if global_conf.select_gpu then
+ nerv.printf("select gpu to %d\n", global_conf.select_gpu)
+ global_conf.cumat_type.select_gpu(global_conf.select_gpu)
+ nerv.LMUtil.wait(1)
+end
+
nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir)
nerv.LMUtil.wait(2)
os.execute("mkdir -p "..global_conf.work_dir)