diff options
Diffstat (limited to 'nerv/nerv')
-rw-r--r-- | nerv/nerv | 44 |
1 files changed, 30 insertions, 14 deletions
@@ -1,13 +1,20 @@ #! /usr/bin/env luajit require 'nerv' -nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n") -nerv.info("automatically initialize a default CuContext...") -nerv.CuMatrix._default_context = nerv.CuContext() -nerv.info("the default CuContext is ok") +local options = {{"help", "h", "bool", default = false, desc = "print this help message"}, + {"use-cpu", "c", "bool", default = false, desc = "use CPU by default (instead of gpu by default)"}, + {"select-gpu", nil, "int", default = nil, desc = "select the GPU for computation, fallback to auto mode if not specified"}} -nerv.info("automatically initialize a default MContext...") -nerv.MMatrix._default_context = nerv.MContext() -nerv.info("the default MContext is ok") +local function print_help() + nerv.printf("Usage: <nerv_prog> [options] script.lua\n") + nerv.print_usage(options) +end + +nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n") +arg, opts = nerv.parse_args(arg, options) +if #arg < 1 or opts["help"].val then + print_help() + return +end -- only for backward compatibilty, will be removed in the future local function _add_profile_method(cls) @@ -15,13 +22,24 @@ local function _add_profile_method(cls) cls.print_profile = function () c:print_profile() end cls.clear_profile = function () c:clear_profile() end end -_add_profile_method(nerv.CuMatrix) -_add_profile_method(nerv.MMatrix) - -if #arg < 1 then - return +if not opts["use-cpu"].val then + nerv.info("automatically initialize a default CuContext...") + nerv.CuMatrix._default_context = nerv.CuContext() + nerv.info("the default CuContext is ok") + _add_profile_method(nerv.CuMatrix) + nerv.CuMatrix.select_gpu = + function (dev) nerv.CuMatrix._default_context:select_gpu(dev) end + if opts["select-gpu"].val then + nerv.CuMatrix.select_gpu(opts["select-gpu"].val) + end end + +nerv.info("automatically initialize a default MContext...") +nerv.MMatrix._default_context = nerv.MContext() +nerv.info("the default MContext is ok") +_add_profile_method(nerv.MMatrix) + local script = arg[1] local script_arg = {} for i = 2, #arg do @@ -29,5 +47,3 @@ for i = 2, #arg do end arg = script_arg dofile(script) -nerv.CuMatrix.print_profile() -nerv.MMatrix.print_profile() |