aboutsummaryrefslogblamecommitdiff
path: root/nerv/init.lua
blob: ff944b8e5f30a0d393a877318339471ef97756f0 (plain) (tree)
1
2
3
4
5
6
7
8





                                                                            
                 
 


                                                                  
                                            

                                         
 



                                          



                                             



                                  


                                            
                              






                                                       

   


                                                     

                                                                      

   




                                                                            
                            
                

                                              

   




                                                                            


                                                 
                                              

   




                                                                    























                                                                              

                            











                                                                
       


                             




                                                                      

   




                                                                              
                            



                                                 
       






                                                              
   
 


                                     



                                                   


                                                









                                        

   


                                               








                                                             




                                                                               

                                                      
                                                   

   
                                                








                                                                              

                                                                                                        





                                                                               
                                                  






                                                      
                           








                                                            


                              



                                                
                                                                            
                                         









                                                                       
                                                                          














                                                              
                                                     











                                                                       
                                                         




                                                   
                                                                                       






                                                                   
                                                            


                                                 
                                                                       












                                                                            
                            




                     



























                                                                             
                                                                     




                                 




                                 
 








                                

                                                                  



                               
--- NERV: a Lua-based toolkit for high-performance deep learning.
-- This file contains misc utility functions of NERV and finally initializes
-- NERV by including `init.lua` of other basic modules.
-- @author Ted Yin <[email protected]>
-- @module nerv

require 'libnerv'

--- Dummy function.
-- Display a friendly error message when user attempts to invoke a
-- non-implemented function.
function nerv.error_method_not_implemented()
    nerv.error("method not implemented");
end

function nerv.set_logfile(filename)
    nerv._logfile = io.open(filename, "w")
end

--- Format a string just like `sprintf` in C.
-- @param fmt the format string
-- @param ... args, the data to be formatted
-- @return the formatted string
function nerv.sprintf(fmt, ...)
    return string.format(fmt, ...)
end

--- Print a formatted string to stdout.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.printf(fmt, ...)
    local line = nerv.sprintf(fmt, ...)
    io.stderr:write(line)
    -- duplicate the all output to the log file, if set
    if nerv._logfile then
        nerv._logfile:write(line)
        nerv._logfile:flush()
    end
end

--- Raise an global error with the formatted message.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.error(fmt, ...)
    error(nerv.sprintf("[nerv] internal error: " .. fmt .. "\n", ...))
end

--- Print a notification message that begins with "info" and a timestamp.
-- Instead of using `nerv.printf`, normal users should use this to print any
-- notification information.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.info(fmt, ...)
    nerv.printf(
        string.format("(%s)[nerv] info: %s\n",
            os.date("%H:%M:%S %F"), fmt), ...)
end

--- Print a warning message that begins with "warning" and a timestamp.
-- Instead of using `nerv.printf`, normal users should use this to print any
-- warnings.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.warning(fmt, ...)
    nerv.printf(
        string.format("(%s)[nerv] warning: %s\n",
            os.date("%H:%M:%S %F"), fmt), ...)
end

--- Create a class (Torch-compatible).
-- Use this to create a class in NERV.
-- @param tname the class name
-- @param parenttname the parent class name (from which it inherits)
-- @return the created class
function nerv.class(tname, parenttname)

   local function constructor(...)
      local self = {}
      nerv.setmetatable(self, tname)
      if self.__init then
         self:__init(...)
      end
      return self
   end

   local function factory()
      local self = {}
      nerv.setmetatable(self, tname)
      return self
   end

   local mt = nerv.newmetatable(tname, parenttname, constructor, nil, factory)
   local mpt
   if parenttname then
      mpt = nerv.getmetatable(parenttname)
   end
   return mt, mpt
end

function table.val_to_str(v)
    if "string" == type(v) then
        v = string.gsub(v, "\n", "\\n")
        if string.match(string.gsub(v,"[^'\"]",""), '^"+$') then
            return "'" .. v .. "'"
        end
        return '"' .. string.gsub(v,'"', '\\"') .. '"'
    else
        return "table" == type(v) and table.tostring(v) or
                    (("number" == type(v) or
                    "string" == type(v) or
                    "boolean" == type(v)) and tostring(v)) or
                    nil -- failed to serialize
    end
end

function table.key_to_str (k)
    if "string" == type(k) and string.match(k, "^[_%a][_%a%d]*$") then
        return k
    else
        return "[" .. table.val_to_str(k) .. "]"
    end
end

--- Get the string representation of a table, which can be executed as a valid
-- piece of Lua code.
-- @param tbl the table
-- @return the string representation which will result in a Lua table entity
-- when evaluated
function table.tostring(tbl)
    local result, done = {}, {}
    for k, v in ipairs(tbl) do
        table.insert(result, table.val_to_str(v))
        done[k] = true
    end
    for k, v in pairs(tbl) do
        if not done[k] then
            table.insert(result,
            table.key_to_str(k) .. "=" .. table.val_to_str(v))
        end
    end
    return "{" .. table.concat(result, ",") .. "}"
end

--- Get the class by name.
-- @param tname the name of the class
-- @return the class entity
function nerv.get_type(tname)
    return assert(loadstring("return " .. tname))()
end

--- Check if the object is of the certain class.
-- @param obj the object ("class instance")
-- @param tname the class name ("type name")
function nerv.is_type(obj, tname)
    local mt0 = nerv.getmetatable(tname)
    local mt = getmetatable(obj)
    while mt do
        if mt == mt0 then
            return true
        end
        mt = getmetatable(mt)
    end
    return false
end

--- Strip last component from file name.
-- @param filename the path to a file
-- @return the path to the containing directory
function nerv.dirname(filename)
    if filename:match(".-/.-") then
        local name = string.gsub(filename, "(.*/)(.*)", "%1")
        return name
    else
        return ''
    end
end

--- Include a script file (chunk) into the current script.
-- An analogy to `#include` in C. Note that the effect is the same as executing
-- `dofile(filename)` at the current line.
-- @param filename the path to a file
-- @return all values returned by the chunk
function nerv.include(filename)
    local caller = debug.getinfo(2, "S").source:sub(2)
    return dofile(nerv.dirname(caller) .. filename)
end

--- Parse the command-line options and arguments
-- @param argv the argrument list to parsed
-- @param options The specification of options, should be a list of tables,
-- each one for exactly one available option, say `v`, with `v[1]`, `v[2]`,
-- `v[3]` indicating the full name of the option, the short form of the option
-- (when it is a boolean option) and the type of the value controlled by the
-- option. `default` and `desc` keys can also be specified to set the default
-- value and description of the option.
--
-- An example of specification:
-- {{"aaa", "a", "boolean", default = false, desc = "an option called aaa"},
-- {"bbb", "b", "boolean", default = true, desc = "bbb is set to be true if --bbb=no does not present"},
-- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}}`
--
-- @return args, opts The non-option arguments and parsed options. `opts` is
-- again a list of tables, each of which corresponds to one table in parameter
-- `options`. The parsed value could be accessed by `opts["aaa"].val` (which is
-- `true` if "--aaa" or "-a" is specified).
function nerv.parse_args(argv, options, unordered)
    local is_opt_exp = "^[-](.*)$"
    local sim_opt_exp = "^[-]([a-z]+)$"
    local opt_exp = "^[-][-]([^=]+)$"
    local opt_with_val_exp = "^[-][-]([^=]+)=([^=]+)$"
    local opts = {}
    local sopts = {}
    local args = {}
    local arg_start = false
    local function err()
        nerv.error("invalid format of option specification")
    end
    for _, v in ipairs(options) do
        if type(v) ~= "table" or
            (v[1] == nil and v[2] == nil) or
            v[3] == nil then
            err()
        end
        local opt_full = v[1]
        local opt_short = v[2]
        local opt_type = v[3]
        local opt_meta = {type = opt_type,
                            desc = v.desc or "",
                            val = v.default}
        if opt_short ~= nil then
            if type(opt_short) ~= "string" or #opt_short ~= 1 then err() end
            if opt_type ~= "boolean" then
                nerv.error("only boolean option could have short form")
            end
            sopts[opt_short] = opt_meta
        end
        if opt_full ~= nil then
            if type(opt_full) ~= "string" then err() end
            opts[opt_full] = opt_meta
        end
    end
    for _, token in ipairs(argv) do
        if ((not arg_start) or unordered) and token:match(is_opt_exp) then
            local k = token:match(sim_opt_exp)
            if k then
                for c in k:gmatch"." do
                    if sopts[c] then
                        sopts[c].val = true
                    else
                        nerv.error("invalid option -%s", c)
                    end
                end
            else
                local k = token:match(opt_exp)
                if k then
                    if opts[k] == nil then
                        nerv.error("invalid option %s", token)
                    end
                    if opts[k].type ~= "boolean" then
                        nerv.error("invalid option --%s: " ..
                                    "a %s value needs to be specified",
                                    k, opts[k].type)
                    else
                        opts[k].val = true
                    end
                else
                    local k, v = token:match(opt_with_val_exp)
                    if k then
                        if opts[k] == nil then
                            nerv.error("invalid option %s", token)
                        end
                        if opts[k].type == "boolean" then
                            if v == "yes" then
                                opts[k].val = true
                            elseif v == "no" then
                                opts[k].val = false
                            else
                                nerv.error("boolean value should be \"yes\" or \"no\"")
                            end
                        elseif opts[k].type == "int" then
                            local t = tonumber(v)
                            opts[k].val = t
                            if t == nil or math.floor(t) ~= t then
                                nerv.error("int value is expected")
                            end
                        elseif opts[k].type == "number" then
                            local t = tonumber(v)
                            opts[k].val = t
                            if t == nil then
                                nerv.error("numeric value is expected")
                            end
                        elseif opts[k].type == "string" then
                            opts[k].val = v
                        else
                            nerv.error("unrecognized type %s", opts[k].type)
                        end
                    else
                        nerv.error("unrecognized option %s", token)
                    end
                end
            end
        else
            table.insert(args, token)
            arg_start = true
        end
    end
    return args, opts
end

--- Print usage information of the command-line options
-- @param options the list of options used in `parse_args`
function nerv.print_usage(options)
    local full_maxlen = 0
    local type_maxlen = 0
    local default_maxlen = 0
    for _, v in ipairs(options) do
        local opt_full = v[1]
        local opt_short = v[2]
        local opt_type = v[3]
        full_maxlen = math.max(full_maxlen, #opt_full or 0)
        type_maxlen = math.max(full_maxlen, #opt_type or 0)
        default_maxlen = math.max(full_maxlen, #tostring(v.default) or 0)
    end
    local function pattern_gen()
        return string.format("\t%%-%ds\t%%-2s\t%%-%ds\t%%-%ds\t%%s\n",
                            full_maxlen, type_maxlen, default_maxlen)
    end
    nerv.printf("\n")
    nerv.printf(pattern_gen(), "Option", "Abbr.", "Type", "Default", "Desc.")
    for _, v in ipairs(options) do
        local opt_full = v[1]
        local opt_short = v[2]
        local opt_type = v[3]
        nerv.printf(pattern_gen(),
                    (opt_full and '--' .. opt_full) or "",
                    (opt_short and '-' .. opt_short) or "",
                    opt_type,
                    (v.default ~= nil and tostring(v.default)) or "",
                    v.desc or "")
    end
    nerv.printf("\n")
end

function table.extend(tbl1, tbl2)
    for _, v in ipairs(tbl2) do
        table.insert(tbl1, v)
    end
end

function table.vector(len, fill)
    local v = {}
    fill = fill or 0
    for i = 1, len do
        table.insert(v, fill)
    end
    return v
end

-- the following lines trigger the initialization of basic modules

nerv.include('matrix/init.lua')
nerv.include('io/init.lua')
nerv.include('layer/init.lua')
nerv.include('nn/init.lua')