summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/lib/matrix/cumatrix.c2
-rw-r--r--nerv/matrix/cumatrix.c2
-rw-r--r--nerv/nerv44
3 files changed, 32 insertions, 16 deletions
diff --git a/nerv/lib/matrix/cumatrix.c b/nerv/lib/matrix/cumatrix.c
index ff2ea22..537fabb 100644
--- a/nerv/lib/matrix/cumatrix.c
+++ b/nerv/lib/matrix/cumatrix.c
@@ -78,10 +78,10 @@ void nerv_cuda_context_destroy(CuContext *context, Status *status) {
void nerv_cuda_context_select_gpu(CuContext *context,
int dev, Status *status) {
+ CUDA_SAFE_SYNC_CALL(cudaSetDevice(dev), status);
free_cuda_handles(context, status);
if (status->err_code != NERV_NORMAL)
return;
- CUDA_SAFE_SYNC_CALL(cudaSetDevice(dev), status);
new_cuda_handles(context, status);
if (status->err_code != NERV_NORMAL)
return;
diff --git a/nerv/matrix/cumatrix.c b/nerv/matrix/cumatrix.c
index f6a4ed5..7d10895 100644
--- a/nerv/matrix/cumatrix.c
+++ b/nerv/matrix/cumatrix.c
@@ -9,7 +9,7 @@ const char *nerv_cuda_context_tname = "nerv.CuContext";
int nerv_cuda_context_lua_select_gpu(lua_State *L) {
Status status;
nerv_cuda_context_select_gpu(luaT_checkudata(L, 1, nerv_cuda_context_tname),
- luaL_checkinteger(L, 1), &status);
+ luaL_checkinteger(L, 2), &status);
NERV_LUA_CHECK_STATUS(L, status);
return 0;
}
diff --git a/nerv/nerv b/nerv/nerv
index 4dd448c..9295290 100644
--- a/nerv/nerv
+++ b/nerv/nerv
@@ -1,13 +1,20 @@
#! /usr/bin/env luajit
require 'nerv'
-nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
-nerv.info("automatically initialize a default CuContext...")
-nerv.CuMatrix._default_context = nerv.CuContext()
-nerv.info("the default CuContext is ok")
+local options = {{"help", "h", "bool", default = false, desc = "print this help message"},
+ {"use-cpu", "c", "bool", default = false, desc = "use CPU by default (instead of gpu by default)"},
+ {"select-gpu", nil, "int", default = nil, desc = "select the GPU for computation, fallback to auto mode if not specified"}}
-nerv.info("automatically initialize a default MContext...")
-nerv.MMatrix._default_context = nerv.MContext()
-nerv.info("the default MContext is ok")
+local function print_help()
+ nerv.printf("Usage: <nerv_prog> [options] script.lua\n")
+ nerv.print_usage(options)
+end
+
+nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
+arg, opts = nerv.parse_args(arg, options)
+if #arg < 1 or opts["help"].val then
+ print_help()
+ return
+end
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
@@ -15,13 +22,24 @@ local function _add_profile_method(cls)
cls.print_profile = function () c:print_profile() end
cls.clear_profile = function () c:clear_profile() end
end
-_add_profile_method(nerv.CuMatrix)
-_add_profile_method(nerv.MMatrix)
-
-if #arg < 1 then
- return
+if not opts["use-cpu"].val then
+ nerv.info("automatically initialize a default CuContext...")
+ nerv.CuMatrix._default_context = nerv.CuContext()
+ nerv.info("the default CuContext is ok")
+ _add_profile_method(nerv.CuMatrix)
+ nerv.CuMatrix.select_gpu =
+ function (dev) nerv.CuMatrix._default_context:select_gpu(dev) end
+ if opts["select-gpu"].val then
+ nerv.CuMatrix.select_gpu(opts["select-gpu"].val)
+ end
end
+
+nerv.info("automatically initialize a default MContext...")
+nerv.MMatrix._default_context = nerv.MContext()
+nerv.info("the default MContext is ok")
+_add_profile_method(nerv.MMatrix)
+
local script = arg[1]
local script_arg = {}
for i = 2, #arg do
@@ -29,5 +47,3 @@ for i = 2, #arg do
end
arg = script_arg
dofile(script)
-nerv.CuMatrix.print_profile()
-nerv.MMatrix.print_profile()