aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/examples/lmptb/grulm_ptb_main.lua10
-rw-r--r--nerv/lib/matrix/cumatrix.c6
-rw-r--r--nerv/matrix/cumatrix.c9
3 files changed, 24 insertions, 1 deletions
diff --git a/nerv/examples/lmptb/grulm_ptb_main.lua b/nerv/examples/lmptb/grulm_ptb_main.lua
index 6095b12..838a665 100644
--- a/nerv/examples/lmptb/grulm_ptb_main.lua
+++ b/nerv/examples/lmptb/grulm_ptb_main.lua
@@ -198,6 +198,7 @@ qdata_dir = root_dir .. '/ptb/questionGen/gen'
global_conf = {
lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 5,
cumat_type = nerv.CuMatrixFloat,
+ select_gpu = 0,
mmat_type = nerv.MMatrixFloat,
nn_act_default = 0,
@@ -359,7 +360,14 @@ commands = nerv.SUtil.parse_commands_set(commands_str)
if start_lr ~= nil then
global_conf.lrate = start_lr
end
-
+
+nerv.printf("detecting gconf.select_gpu...\n")
+if global_conf.select_gpu then
+ nerv.printf("select gpu to %d\n", global_conf.select_gpu)
+ global_conf.cumat_type.select_gpu(global_conf.select_gpu)
+ nerv.LMUtil.wait(1)
+end
+
nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir)
nerv.LMUtil.wait(2)
os.execute("mkdir -p "..global_conf.work_dir)
diff --git a/nerv/lib/matrix/cumatrix.c b/nerv/lib/matrix/cumatrix.c
index a8ed075..533dade 100644
--- a/nerv/lib/matrix/cumatrix.c
+++ b/nerv/lib/matrix/cumatrix.c
@@ -7,6 +7,12 @@ static cublasHandle_t cublas_handle;
static cudaEvent_t profile_start, profile_stop;
static HashMap *profile;
+void nerv_cumatrix_select_gpu(int dev, Status *status) {
+ fprintf(stderr, "** selecting GPU %d\n", dev);
+ NERV_SET_STATUS(status, NERV_NORMAL, 0);
+ CUDA_SAFE_SYNC_CALL(cudaSetDevice(dev), status);
+}
+
void nerv_cumatrix_print_profile() {
size_t i;
fprintf(stderr, "*** [nerv cumatrix profile] **\n");
diff --git a/nerv/matrix/cumatrix.c b/nerv/matrix/cumatrix.c
index bf92f92..7f22d68 100644
--- a/nerv/matrix/cumatrix.c
+++ b/nerv/matrix/cumatrix.c
@@ -8,6 +8,14 @@ static cublasHandle_t cublas_handle;
static cudaEvent_t profile_start, profile_stop;
static HashMap *profile;
+static int select_gpu(lua_State *L) {
+ Status status;
+ int dev = luaL_checkinteger(L, 1);
+ nerv_cumatrix_select_gpu(dev, &status);
+ NERV_LUA_CHECK_STATUS(L, status);
+ return 0;
+}
+
static int print_profile(lua_State *L) {
nerv_cumatrix_print_profile();
return 0;
@@ -21,6 +29,7 @@ static int clear_profile(lua_State *L) {
static const luaL_Reg cumatrix_methods[] = {
{"print_profile", print_profile},
{"clear_profile", clear_profile},
+ {"select_gpu", select_gpu},
{NULL, NULL}
};