diff options
Diffstat (limited to 'nerv/init.lua')
-rw-r--r-- | nerv/init.lua | 237 |
1 files changed, 210 insertions, 27 deletions
diff --git a/nerv/init.lua b/nerv/init.lua index 6312df1..ff944b8 100644 --- a/nerv/init.lua +++ b/nerv/init.lua @@ -13,6 +13,10 @@ function nerv.error_method_not_implemented() nerv.error("method not implemented"); end +function nerv.set_logfile(filename) + nerv._logfile = io.open(filename, "w") +end + --- Format a string just like `sprintf` in C. -- @param fmt the format string -- @param ... args, the data to be formatted @@ -25,7 +29,13 @@ end -- @param fmt the format string -- @param ... args, the data to be formatted function nerv.printf(fmt, ...) - io.write(nerv.sprintf(fmt, ...)) + local line = nerv.sprintf(fmt, ...) + io.stderr:write(line) + -- duplicate the all output to the log file, if set + if nerv._logfile then + nerv._logfile:write(line) + nerv._logfile:flush() + end end --- Raise an global error with the formatted message. @@ -88,24 +98,27 @@ function nerv.class(tname, parenttname) end function table.val_to_str(v) - if "string" == type(v) then - v = string.gsub(v, "\n", "\\n") - if string.match(string.gsub(v,"[^'\"]",""), '^"+$') then - return "'" .. v .. "'" + if "string" == type(v) then + v = string.gsub(v, "\n", "\\n") + if string.match(string.gsub(v,"[^'\"]",""), '^"+$') then + return "'" .. v .. "'" + end + return '"' .. string.gsub(v,'"', '\\"') .. '"' + else + return "table" == type(v) and table.tostring(v) or + (("number" == type(v) or + "string" == type(v) or + "boolean" == type(v)) and tostring(v)) or + nil -- failed to serialize end - return '"' .. string.gsub(v,'"', '\\"') .. '"' - else - return "table" == type(v) and table.tostring(v) or - tostring(v) - end end function table.key_to_str (k) - if "string" == type(k) and string.match(k, "^[_%a][_%a%d]*$") then - return k - else - return "[" .. table.val_to_str(k) .. "]" - end + if "string" == type(k) and string.match(k, "^[_%a][_%a%d]*$") then + return k + else + return "[" .. table.val_to_str(k) .. "]" + end end --- Get the string representation of a table, which can be executed as a valid @@ -114,18 +127,18 @@ end -- @return the string representation which will result in a Lua table entity -- when evaluated function table.tostring(tbl) - local result, done = {}, {} - for k, v in ipairs(tbl) do - table.insert(result, table.val_to_str(v)) - done[k] = true - end - for k, v in pairs(tbl) do - if not done[k] then - table.insert(result, - table.key_to_str(k) .. "=" .. table.val_to_str(v)) + local result, done = {}, {} + for k, v in ipairs(tbl) do + table.insert(result, table.val_to_str(v)) + done[k] = true end - end - return "{" .. table.concat(result, ",") .. "}" + for k, v in pairs(tbl) do + if not done[k] then + table.insert(result, + table.key_to_str(k) .. "=" .. table.val_to_str(v)) + end + end + return "{" .. table.concat(result, ",") .. "}" end --- Get the class by name. @@ -172,10 +185,180 @@ function nerv.include(filename) return dofile(nerv.dirname(caller) .. filename) end +--- Parse the command-line options and arguments +-- @param argv the argrument list to parsed +-- @param options The specification of options, should be a list of tables, +-- each one for exactly one available option, say `v`, with `v[1]`, `v[2]`, +-- `v[3]` indicating the full name of the option, the short form of the option +-- (when it is a boolean option) and the type of the value controlled by the +-- option. `default` and `desc` keys can also be specified to set the default +-- value and description of the option. +-- +-- An example of specification: +-- {{"aaa", "a", "boolean", default = false, desc = "an option called aaa"}, +-- {"bbb", "b", "boolean", default = true, desc = "bbb is set to be true if --bbb=no does not present"}, +-- {"ccc", nil, "int", default = 0, desc = "ccc expects an integeral value"}}` +-- +-- @return args, opts The non-option arguments and parsed options. `opts` is +-- again a list of tables, each of which corresponds to one table in parameter +-- `options`. The parsed value could be accessed by `opts["aaa"].val` (which is +-- `true` if "--aaa" or "-a" is specified). +function nerv.parse_args(argv, options, unordered) + local is_opt_exp = "^[-](.*)$" + local sim_opt_exp = "^[-]([a-z]+)$" + local opt_exp = "^[-][-]([^=]+)$" + local opt_with_val_exp = "^[-][-]([^=]+)=([^=]+)$" + local opts = {} + local sopts = {} + local args = {} + local arg_start = false + local function err() + nerv.error("invalid format of option specification") + end + for _, v in ipairs(options) do + if type(v) ~= "table" or + (v[1] == nil and v[2] == nil) or + v[3] == nil then + err() + end + local opt_full = v[1] + local opt_short = v[2] + local opt_type = v[3] + local opt_meta = {type = opt_type, + desc = v.desc or "", + val = v.default} + if opt_short ~= nil then + if type(opt_short) ~= "string" or #opt_short ~= 1 then err() end + if opt_type ~= "boolean" then + nerv.error("only boolean option could have short form") + end + sopts[opt_short] = opt_meta + end + if opt_full ~= nil then + if type(opt_full) ~= "string" then err() end + opts[opt_full] = opt_meta + end + end + for _, token in ipairs(argv) do + if ((not arg_start) or unordered) and token:match(is_opt_exp) then + local k = token:match(sim_opt_exp) + if k then + for c in k:gmatch"." do + if sopts[c] then + sopts[c].val = true + else + nerv.error("invalid option -%s", c) + end + end + else + local k = token:match(opt_exp) + if k then + if opts[k] == nil then + nerv.error("invalid option %s", token) + end + if opts[k].type ~= "boolean" then + nerv.error("invalid option --%s: " .. + "a %s value needs to be specified", + k, opts[k].type) + else + opts[k].val = true + end + else + local k, v = token:match(opt_with_val_exp) + if k then + if opts[k] == nil then + nerv.error("invalid option %s", token) + end + if opts[k].type == "boolean" then + if v == "yes" then + opts[k].val = true + elseif v == "no" then + opts[k].val = false + else + nerv.error("boolean value should be \"yes\" or \"no\"") + end + elseif opts[k].type == "int" then + local t = tonumber(v) + opts[k].val = t + if t == nil or math.floor(t) ~= t then + nerv.error("int value is expected") + end + elseif opts[k].type == "number" then + local t = tonumber(v) + opts[k].val = t + if t == nil then + nerv.error("numeric value is expected") + end + elseif opts[k].type == "string" then + opts[k].val = v + else + nerv.error("unrecognized type %s", opts[k].type) + end + else + nerv.error("unrecognized option %s", token) + end + end + end + else + table.insert(args, token) + arg_start = true + end + end + return args, opts +end + +--- Print usage information of the command-line options +-- @param options the list of options used in `parse_args` +function nerv.print_usage(options) + local full_maxlen = 0 + local type_maxlen = 0 + local default_maxlen = 0 + for _, v in ipairs(options) do + local opt_full = v[1] + local opt_short = v[2] + local opt_type = v[3] + full_maxlen = math.max(full_maxlen, #opt_full or 0) + type_maxlen = math.max(full_maxlen, #opt_type or 0) + default_maxlen = math.max(full_maxlen, #tostring(v.default) or 0) + end + local function pattern_gen() + return string.format("\t%%-%ds\t%%-2s\t%%-%ds\t%%-%ds\t%%s\n", + full_maxlen, type_maxlen, default_maxlen) + end + nerv.printf("\n") + nerv.printf(pattern_gen(), "Option", "Abbr.", "Type", "Default", "Desc.") + for _, v in ipairs(options) do + local opt_full = v[1] + local opt_short = v[2] + local opt_type = v[3] + nerv.printf(pattern_gen(), + (opt_full and '--' .. opt_full) or "", + (opt_short and '-' .. opt_short) or "", + opt_type, + (v.default ~= nil and tostring(v.default)) or "", + v.desc or "") + end + nerv.printf("\n") +end + +function table.extend(tbl1, tbl2) + for _, v in ipairs(tbl2) do + table.insert(tbl1, v) + end +end + +function table.vector(len, fill) + local v = {} + fill = fill or 0 + for i = 1, len do + table.insert(v, fill) + end + return v +end + -- the following lines trigger the initialization of basic modules nerv.include('matrix/init.lua') nerv.include('io/init.lua') nerv.include('layer/init.lua') nerv.include('nn/init.lua') -nerv.include('tnn/init.lua') |