aboutsummaryrefslogtreecommitdiff
path: root/nerv/nerv
blob: 4c20ec7804b31bdf84eb58e37e57644827b3ad3a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#! /usr/bin/env luajit
require 'nerv'
local options = {{"help", "h", "boolean", default = false, desc = "print this help message"},
                 {"use-cpu", "c", "boolean", default = false, desc = "use CPU by default (instead of gpu by default)"},
                 {"select-gpu", nil, "int", default = -1, desc = "select the GPU for computation, fallback to auto mode if not specified"}}
econf = {} -- environment configuration

local function print_help()
    nerv.printf("Usage: <nerv_prog> [options] script.lua\n")
    nerv.print_usage(options)
end

nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
arg, opts = nerv.parse_args(arg, options)
if #arg < 1 or opts["help"].val then
    print_help()
    return
end

-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
    local c = cls._default_context
    cls.print_profile = function () c:print_profile() end
    cls.clear_profile = function () c:clear_profile() end
end

if not opts["use-cpu"].val then
    local dev = opts["select-gpu"].val
    nerv.info("automatically initialize a default CuContext...")
    nerv.CuMatrix._default_context = nerv.CuContext(dev)
    nerv.info("the default CuContext is ok")
    _add_profile_method(nerv.CuMatrix)
    nerv.CuMatrix.select_gpu =
            function (dev) nerv.CuMatrix._default_context:select_gpu(dev) end
    econf.use_cpu = false
else
    econf.use_cpu = true
end

nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
_add_profile_method(nerv.MMatrix)

local script = arg[1]
local script_arg = {}
for i = 2, #arg do
    table.insert(script_arg, arg[i])
end
arg = script_arg
dofile(script)