aboutsummaryrefslogblamecommitdiff
path: root/nerv/nerv
blob: 1b32a4e31eab3af24fd196c310aa06c67df1514b (plain) (tree)
1
2
3
4
5
6
7

                      


                                                                                                                                           
                                       
 




                                                            
                                                                                            




                                         






                                                                
 
                               
                                      
                                                                
                                                        



                                                                             


                         
   





                                                           






                                    
#! /usr/bin/env luajit
require 'nerv'
local options = {{"help", "h", "boolean", default = false, desc = "print this help message"},
                 {"use-cpu", "c", "boolean", default = false, desc = "use CPU by default (instead of gpu by default)"},
                 {"select-gpu", nil, "int", default = -1, desc = "select the GPU for computation, fallback to auto mode if not specified"}}
econf = {} -- environment configuration

local function print_help()
    nerv.printf("Usage: <nerv_prog> [options] script.lua\n")
    nerv.print_usage(options)
end

nerv.printf("*** NERV: A Lua-based toolkit for high-performance deep learning (beta) ***\n")
arg, opts = nerv.parse_args(arg, options)
if #arg < 1 or opts["help"].val then
    print_help()
    return
end

-- only for backward compatibilty, will be removed in the future
local function _add_profile_method(cls)
    local c = cls._default_context
    cls.print_profile = function () c:print_profile() end
    cls.clear_profile = function () c:clear_profile() end
end

if not opts["use-cpu"].val then
    local dev = opts["select-gpu"].val
    nerv.info("automatically initialize a default CuContext...")
    nerv.CuMatrix._default_context = nerv.CuContext(dev)
    nerv.info("the default CuContext is ok")
    _add_profile_method(nerv.CuMatrix)
    nerv.CuMatrix.select_gpu =
            function (dev) nerv.CuMatrix._default_context:select_gpu(dev) end
    econf.use_cpu = false
else
    econf.use_cpu = true
end

nerv.info("automatically initialize a default MContext...")
nerv.MMatrix._default_context = nerv.MContext()
nerv.info("the default MContext is ok")
_add_profile_method(nerv.MMatrix)

local script = arg[1]
local script_arg = {}
for i = 2, #arg do
    table.insert(script_arg, arg[i])
end
arg = script_arg
dofile(script)