blob: 4dd448ceedb0396a862aba66a2d6d4c1233e0fc0 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
#! /usr/bin/env luajit
require 'nerv'
nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (alpha) ***\n")
nerv.info("automatically initialize a default CuContext...")
nerv.CuMatrix._default_context = nerv.CuContext()
nerv.info("the default CuContext is ok")
nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
local c = cls._default_context
cls.print_profile = function () c:print_profile() end
cls.clear_profile = function () c:clear_profile() end
end
_add_profile_method(nerv.CuMatrix)
_add_profile_method(nerv.MMatrix)
if #arg < 1 then
return
end
local script = arg[1]
local script_arg = {}
for i = 2, #arg do
table.insert(script_arg, arg[i])
end
arg = script_arg
dofile(script)
nerv.CuMatrix.print_profile()
nerv.MMatrix.print_profile()
|