aboutsummaryrefslogtreecommitdiff
path: root/nerv/nerv
blob: 4dd448ceedb0396a862aba66a2d6d4c1233e0fc0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#! /usr/bin/env luajit
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default CuContext...")
nerv.CuMatrix._default_context = nerv.CuContext()
nerv.info("the default CuContext is ok")

nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")

-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
    local c = cls._default_context
    cls.print_profile = function () c:print_profile() end
    cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.CuMatrix)
_add_profile_method(nerv.MMatrix)


if #arg < 1 then
    return
end
local script = arg[1]
local script_arg = {}
for i = 2, #arg do
    table.insert(script_arg, arg[i])
end
arg = script_arg
dofile(script)
nerv.CuMatrix.print_profile()
nerv.MMatrix.print_profile()