aboutsummaryrefslogtreecommitdiff
path: root/nerv/examples/lmptb/grulm_ptb_main.lua
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/examples/lmptb/grulm_ptb_main.lua')
-rw-r--r--nerv/examples/lmptb/grulm_ptb_main.lua16
1 files changed, 12 insertions, 4 deletions
diff --git a/nerv/examples/lmptb/grulm_ptb_main.lua b/nerv/examples/lmptb/grulm_ptb_main.lua
index ef5d7f9..4a3f39f 100644
--- a/nerv/examples/lmptb/grulm_ptb_main.lua
+++ b/nerv/examples/lmptb/grulm_ptb_main.lua
@@ -198,6 +198,7 @@ qdata_dir = root_dir .. '/ptb/questionGen/gen'
global_conf = {
lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 5,
cumat_type = nerv.CuMatrixFloat,
+ select_gpu = 0,
mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
@@ -259,7 +260,7 @@ global_conf = {
elseif (set == "twitter") then
data_dir = root_dir .. '/twitter_new/DATA'
-train_fn = data_dir .. '/twitter.choose2.adds'
+train_fn = data_dir .. '/twitter.choose.adds'
valid_fn = data_dir .. '/twitter.valid.adds'
test_fn = data_dir .. '/comm.test.choose-ppl.adds'
vocab_fn = data_dir .. '/twitter.choose.train.vocab'
@@ -359,7 +360,14 @@ commands = nerv.SUtil.parse_commands_set(commands_str)
if start_lr ~= nil then
global_conf.lrate = start_lr
end
-
+
+nerv.printf("detecting gconf.select_gpu...\n")
+if global_conf.select_gpu then
+ nerv.printf("select gpu to %d\n", global_conf.select_gpu)
+ global_conf.cumat_type.select_gpu(global_conf.select_gpu)
+ nerv.LMUtil.wait(1)
+end
+
nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir)
nerv.LMUtil.wait(2)
os.execute("mkdir -p "..global_conf.work_dir)
@@ -388,10 +396,10 @@ nerv.LMUtil.wait(2)
math.randomseed(1)
-local vocab = nerv.LMVocab()
+local vocab = nerv.LMVocab(global_conf)
global_conf["vocab"] = vocab
nerv.printf("%s building vocab...\n", global_conf.sche_log_pre)
-global_conf.vocab:build_file(global_conf.vocab_fn, false)
+global_conf.vocab:build_file(global_conf.vocab_fn)
ppl_rec = {}
local final_iter = -1