From b99fe996dedccada79772d0a061d6b47e54899dd Mon Sep 17 00:00:00 2001 From: txh18 Date: Fri, 29 Jan 2016 20:25:09 +0800 Subject: select gpu code from mfy --- nerv/examples/lmptb/grulm_ptb_main.lua | 10 +++++++++- nerv/lib/matrix/cumatrix.c | 6 ++++++ nerv/matrix/cumatrix.c | 9 +++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/nerv/examples/lmptb/grulm_ptb_main.lua b/nerv/examples/lmptb/grulm_ptb_main.lua index 6095b12..838a665 100644 --- a/nerv/examples/lmptb/grulm_ptb_main.lua +++ b/nerv/examples/lmptb/grulm_ptb_main.lua @@ -198,6 +198,7 @@ qdata_dir = root_dir .. '/ptb/questionGen/gen' global_conf = { lrate = 0.15, wcost = 1e-5, momentum = 0, clip_t = 5, cumat_type = nerv.CuMatrixFloat, + select_gpu = 0, mmat_type = nerv.MMatrixFloat, nn_act_default = 0, @@ -359,7 +360,14 @@ commands = nerv.SUtil.parse_commands_set(commands_str) if start_lr ~= nil then global_conf.lrate = start_lr end - + +nerv.printf("detecting gconf.select_gpu...\n") +if global_conf.select_gpu then + nerv.printf("select gpu to %d\n", global_conf.select_gpu) + global_conf.cumat_type.select_gpu(global_conf.select_gpu) + nerv.LMUtil.wait(1) +end + nerv.printf("%s creating work_dir(%s)...\n", global_conf.sche_log_pre, global_conf.work_dir) nerv.LMUtil.wait(2) os.execute("mkdir -p "..global_conf.work_dir) diff --git a/nerv/lib/matrix/cumatrix.c b/nerv/lib/matrix/cumatrix.c index a8ed075..533dade 100644 --- a/nerv/lib/matrix/cumatrix.c +++ b/nerv/lib/matrix/cumatrix.c @@ -7,6 +7,12 @@ static cublasHandle_t cublas_handle; static cudaEvent_t profile_start, profile_stop; static HashMap *profile; +void nerv_cumatrix_select_gpu(int dev, Status *status) { + fprintf(stderr, "** selecting GPU %d\n", dev); + NERV_SET_STATUS(status, NERV_NORMAL, 0); + CUDA_SAFE_SYNC_CALL(cudaSetDevice(dev), status); +} + void nerv_cumatrix_print_profile() { size_t i; fprintf(stderr, "*** [nerv cumatrix profile] **\n"); diff --git a/nerv/matrix/cumatrix.c b/nerv/matrix/cumatrix.c index bf92f92..7f22d68 100644 --- a/nerv/matrix/cumatrix.c +++ b/nerv/matrix/cumatrix.c @@ -8,6 +8,14 @@ static cublasHandle_t cublas_handle; static cudaEvent_t profile_start, profile_stop; static HashMap *profile; +static int select_gpu(lua_State *L) { + Status status; + int dev = luaL_checkinteger(L, 1); + nerv_cumatrix_select_gpu(dev, &status); + NERV_LUA_CHECK_STATUS(L, status); + return 0; +} + static int print_profile(lua_State *L) { nerv_cumatrix_print_profile(); return 0; @@ -21,6 +29,7 @@ static int clear_profile(lua_State *L) { static const luaL_Reg cumatrix_methods[] = { {"print_profile", print_profile}, {"clear_profile", clear_profile}, + {"select_gpu", select_gpu}, {NULL, NULL} }; -- cgit v1.2.3