--- Contains misc utility functions of NERV and finally initializes NERV by -- including `init.lua` of other basic modules. -- @author Ted Yin -- @module nerv require 'libnerv' --- Display a friendly error message when user attempts to invoke a -- non-implemented function. function nerv.error_method_not_implemented() nerv.error("method not implemented"); end function nerv.set_logfile(filename) nerv._logfile = io.open(filename, "w") end --- Format a string just like `sprintf` in C. -- @param fmt the format string -- @param ... args, the data to be formatted -- @return the formatted string function nerv.sprintf(fmt, ...) return string.format(fmt, ...) end --- Print a formatted string to stdout. -- @param fmt the format string -- @param ... args, the data to be formatted function nerv.printf(fmt, ...) local line = nerv.sprintf(fmt, ...) io.stderr:write(line) -- duplicate the all output to the log file, if set if nerv._logfile then nerv._logfile:write(line) nerv._logfile:flush() end end --- Raise an global error with the formatted message. -- @param fmt the format string -- @param ... args, the data to be formatted function nerv.error(fmt, ...) error(nerv.sprintf("[nerv] internal error: " .. fmt .. "\n", ...)) end --- Print a notification message that begins with "info" and a timestamp. -- Instead of using `nerv.printf`, normal users should use this to print any -- notification information. -- @param fmt the format string -- @param ... args, the data to be formatted function nerv.info(fmt, ...) nerv.printf( string.format("(%s)[nerv] info: %s\n", os.date("%H:%M:%S %F"), fmt), ...) end --- Print a warning message that begins with "warning" and a timestamp. -- Instead of using `nerv.printf`, normal users should use this to print any -- warnings. -- @param fmt the format string -- @param ... args, the data to be formatted function nerv.warning(fmt, ...) nerv.printf( string.format("(%s)[nerv] warning: %s\n", os.date("%H:%M:%S %F"), fmt), ...) end --- Create a class (Torch-compatible). -- Use this to create a class in NERV. -- @param tname the class name -- @param parenttname the parent class name (from which it inherits) -- @return the created class function nerv.class(tname, parenttname) local function constructor(...) local self = {} nerv.setmetatable(self, tname) if self.__init then self:__init(...) end return self end local function factory() local self = {} nerv.setmetatable(self, tname) return self end local mt = nerv.newmetatable(tname, parenttname, constructor, nil, factory) local mpt if parenttname then mpt = nerv.getmetatable(parenttname) end return mt, mpt end function table.val_to_str(v) if "string" == type(v) then v = string.gsub(v, "\n", "\\n") if string.match(string.gsub(v,"[^'\"]",""), '^"+$') then return "'" .. v .. "'" end return '"' .. string.gsub(v,'"', '\\"') .. '"' else return "table" == type(v) and table.tostring(v) or (("number" == type(v) or "string" == type(v) or "boolean" == type(v)) and tostring(v)) or "nil" -- failed to serialize end end function table.key_to_str (k) if "string" == type(k) and string.match(k, "^[_%a][_%a%d]*$") then return k else return "[" .. table.val_to_str(k) .. "]" end end --- Get the string representation of a table, which can be executed as a valid -- piece of Lua code. -- @param tbl the table -- @return the string representation which will result in a Lua table entity -- when evaluated function table.tostring(tbl) local result, done = {}, {} for k, v in ipairs(tbl) do table.insert(result, table.val_to_str(v)) done[k] = true end for k, v in pairs(tbl) do if not done[k] then table.insert(result, table.key_to_str(k) .. "=" .. table.val_to_str(v)) end end return "{" .. table.concat(result, ",") .. "}" end --- Get the class by name. -- @param tname the name of the class -- @return the class entity function nerv.get_type(tname) return assert(loadstring("return " .. tname))() end --- Check if the object is of the certain class. -- @param obj the object ("class instance") -- @param tname the class name ("type name") function nerv.is_type(obj, tname) local mt0 = nerv.getmetatable(tname) local mt = getmetatable(obj) while mt do if mt == mt0 then return true end mt = getmetatable(mt) end return false end --- Strip last component from file name. -- @param filename the path to a file -- @return the path to the containing directory function nerv.dirname(filename) if filename:match(".-/.-") then local name = string.gsub(filename, "(.*/)(.*)", "%1") return name else return '' end end --- Include a script file (chunk) into the current script. -- An analogy to `#include` in C. Note that the effect is the same as executing -- `dofile(filename)` at the current line. -- @param filename the path to a file -- @return all values returned by the chunk function nerv.include(filename) local caller = debug.getinfo(2, "S").source:sub(2) return dofile(nerv.dirname(caller) .. filename) end --- Parse the command-line options and arguments. -- @param argv the argrument list to parsed -- @param options The specification of options, should be a list of tables, -- each one for exactly one available option, say `v`, with `v[1]`, `v[2]`, -- `v[3]` indicating the full name of the option, the short form of the option -- (when it is a boolean option) and the type of the value controlled by the -- option. `default` and `desc` keys can also be specified to set the default -- value and description of the option. -- -- An example of specification: -- -- {{"aaa", "a", "boolean", default = false, desc = "an option called aaa"}, -- {"bbb", "b", "boolean", default = true, desc = "bbb is set to be true if --bbb=no does not present"}, -- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}} -- -- @return args, opts The non-option arguments and parsed options. `opts` is -- again a list of tables, each of which corresponds to one table in parameter -- `options`. The parsed value could be accessed by `opts["aaa"].val` (which is -- `true` if "--aaa" or "-a" is specified). function nerv.parse_args(argv, options, unordered) local is_opt_exp = "^[-](.*)$" local sim_opt_exp = "^[-]([a-z]+)$" local opt_exp = "^[-][-]([^=]+)$" local opt_with_val_exp = "^[-][-]([^=]+)=([^=]+)$" local opts = {} local sopts = {} local args = {} local arg_start = false local function err() nerv.error("invalid format of option specification") end for _, v in ipairs(options) do if type(v) ~= "table" or (v[1] == nil and v[2] == nil) or v[3] == nil then err() end local opt_full = v[1] local opt_short = v[2] local opt_type = v[3] local opt_meta = {type = opt_type, desc = v.desc or "", val = v.default, specified = false} if opt_short ~= nil then if type(opt_short) ~= "string" or #opt_short ~= 1 then err() end if opt_type ~= "boolean" then nerv.error("only boolean option could have short form") end sopts[opt_short] = opt_meta end if opt_full ~= nil then if type(opt_full) ~= "string" then err() end opts[opt_full] = opt_meta end end for _, token in ipairs(argv) do if ((not arg_start) or unordered) and token:match(is_opt_exp) then local k = token:match(sim_opt_exp) if k then for c in k:gmatch"." do if sopts[c] then sopts[c].val = true sopts[c].specified = true else nerv.error("invalid option -%s", c) end end else local k = token:match(opt_exp) if k then if opts[k] == nil then nerv.error("invalid option %s", token) end if opts[k].type ~= "boolean" then nerv.error("invalid option --%s: " .. "a %s value needs to be specified", k, opts[k].type) else opts[k].val = true opts[k].specified = true end else local k, v = token:match(opt_with_val_exp) if k then if opts[k] == nil then nerv.error("invalid option %s", token) end opts[k].specified = true if opts[k].type == "boolean" then if v == "yes" then opts[k].val = true elseif v == "no" then opts[k].val = false else nerv.error("boolean value should be \"yes\" or \"no\"") end elseif opts[k].type == "int" then local t = tonumber(v) opts[k].val = t if t == nil or math.floor(t) ~= t then nerv.error("int value is expected") end elseif opts[k].type == "number" then local t = tonumber(v) opts[k].val = t if t == nil then nerv.error("numeric value is expected") end elseif opts[k].type == "string" then opts[k].val = v else nerv.error("unrecognized type %s", opts[k].type) end else nerv.error("unrecognized option %s", token) end end end else table.insert(args, token) arg_start = true end end return args, opts end --- Print usage information of the command-line options. -- @param options the list of options used in `parse_args` function nerv.print_usage(options) local full_maxlen = 0 local type_maxlen = 0 local default_maxlen = 0 for _, v in ipairs(options) do local opt_full = v[1] local opt_short = v[2] local opt_type = v[3] full_maxlen = math.max(full_maxlen, #opt_full or 0) type_maxlen = math.max(full_maxlen, #opt_type or 0) default_maxlen = math.max(full_maxlen, #tostring(v.default) or 0) end local function pattern_gen() return string.format("\t%%-%ds\t%%-2s\t%%-%ds\t%%-%ds\t%%s\n", full_maxlen, type_maxlen, default_maxlen) end nerv.printf("\n") nerv.printf(pattern_gen(), "Option", "Abbr.", "Type", "Default", "Desc.") for _, v in ipairs(options) do local opt_full = v[1] local opt_short = v[2] local opt_type = v[3] nerv.printf(pattern_gen(), (opt_full and '--' .. opt_full) or "", (opt_short and '-' .. opt_short) or "", opt_type, (v.default ~= nil and tostring(v.default)) or "", v.desc or "") end nerv.printf("\n") end function table.extend(tbl1, tbl2) for _, v in ipairs(tbl2) do table.insert(tbl1, v) end end function table.vector(len, fill) local v = {} fill = fill or 0 for i = 1, len do table.insert(v, fill) end return v end function table.connect(tbl1, tbl2) local res = {} for i = 1, #tbl1 do table.insert(res, tbl1[i]) end for i = 1, #tbl2 do table.insert(res, tbl2[i]) end return res end function table.merge(tbl1, tbl2) local res = {} for k, v in pairs(tbl1) do res[k] = v end for k, v in pairs(tbl2) do res[k] = v end return res end -- the following lines trigger the initialization of basic modules nerv.include('matrix/init.lua') nerv.include('io/init.lua') nerv.include('layer/init.lua') nerv.include('nn/init.lua')