diff options
-rw-r--r-- | nerv/lib/matrix/cumatrix.c | 2 | ||||
-rw-r--r-- | nerv/matrix/cumatrix.c | 2 | ||||
-rw-r--r-- | nerv/nerv | 44 |
3 files changed, 32 insertions, 16 deletions
diff --git a/nerv/lib/matrix/cumatrix.c b/nerv/lib/matrix/cumatrix.c index ff2ea22..537fabb 100644 --- a/nerv/lib/matrix/cumatrix.c +++ b/nerv/lib/matrix/cumatrix.c @@ -78,10 +78,10 @@ void nerv_cuda_context_destroy(CuContext *context, Status *status) { void nerv_cuda_context_select_gpu(CuContext *context, int dev, Status *status) { + CUDA_SAFE_SYNC_CALL(cudaSetDevice(dev), status); free_cuda_handles(context, status); if (status->err_code != NERV_NORMAL) return; - CUDA_SAFE_SYNC_CALL(cudaSetDevice(dev), status); new_cuda_handles(context, status); if (status->err_code != NERV_NORMAL) return; diff --git a/nerv/matrix/cumatrix.c b/nerv/matrix/cumatrix.c index f6a4ed5..7d10895 100644 --- a/nerv/matrix/cumatrix.c +++ b/nerv/matrix/cumatrix.c @@ -9,7 +9,7 @@ const char *nerv_cuda_context_tname = "nerv.CuContext"; int nerv_cuda_context_lua_select_gpu(lua_State *L) { Status status; nerv_cuda_context_select_gpu(luaT_checkudata(L, 1, nerv_cuda_context_tname), - luaL_checkinteger(L, 1), &status); + luaL_checkinteger(L, 2), &status); NERV_LUA_CHECK_STATUS(L, status); return 0; } @@ -1,13 +1,20 @@ #! /usr/bin/env luajit require 'nerv' -nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n") -nerv.info("automatically initialize a default CuContext...") -nerv.CuMatrix._default_context = nerv.CuContext() -nerv.info("the default CuContext is ok") +local options = {{"help", "h", "bool", default = false, desc = "print this help message"}, + {"use-cpu", "c", "bool", default = false, desc = "use CPU by default (instead of gpu by default)"}, + {"select-gpu", nil, "int", default = nil, desc = "select the GPU for computation, fallback to auto mode if not specified"}} -nerv.info("automatically initialize a default MContext...") -nerv.MMatrix._default_context = nerv.MContext() -nerv.info("the default MContext is ok") +local function print_help() + nerv.printf("Usage: <nerv_prog> [options] script.lua\n") + nerv.print_usage(options) +end + +nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n") +arg, opts = nerv.parse_args(arg, options) +if #arg < 1 or opts["help"].val then + print_help() + return +end -- only for backward compatibilty, will be removed in the future local function _add_profile_method(cls) @@ -15,13 +22,24 @@ local function _add_profile_method(cls) cls.print_profile = function () c:print_profile() end cls.clear_profile = function () c:clear_profile() end end -_add_profile_method(nerv.CuMatrix) -_add_profile_method(nerv.MMatrix) - -if #arg < 1 then - return +if not opts["use-cpu"].val then + nerv.info("automatically initialize a default CuContext...") + nerv.CuMatrix._default_context = nerv.CuContext() + nerv.info("the default CuContext is ok") + _add_profile_method(nerv.CuMatrix) + nerv.CuMatrix.select_gpu = + function (dev) nerv.CuMatrix._default_context:select_gpu(dev) end + if opts["select-gpu"].val then + nerv.CuMatrix.select_gpu(opts["select-gpu"].val) + end end + +nerv.info("automatically initialize a default MContext...") +nerv.MMatrix._default_context = nerv.MContext() +nerv.info("the default MContext is ok") +_add_profile_method(nerv.MMatrix) + local script = arg[1] local script_arg = {} for i = 2, #arg do @@ -29,5 +47,3 @@ for i = 2, #arg do end arg = script_arg dofile(script) -nerv.CuMatrix.print_profile() -nerv.MMatrix.print_profile() |