aboutsummaryrefslogblamecommitdiff
path: root/nerv/init.lua
blob: 6312df17b1471726543fe072fd5bdfd03da28066 (plain) (tree)
1
2
3
4
5
6
7
8





                                                                            
                 
 


                                                                  
                                            

                                         
 



                                             



                                  


                                            
                              


                                    


                                                     

                                                                      

   




                                                                            
                            
                

                                              

   




                                                                            





                                                 




                                                                    























                                                                              





















                                                                    




                                                                              













                                                          
 


                                     



                                                   


                                                









                                        

   


                                               








                                                             




                                                                               

                                                      
                                                   

   

                                                                  



                               
                            
--- NERV: a Lua-based toolkit for high-performance deep learning.
-- This file contains misc utility functions of NERV and finally initializes
-- NERV by including `init.lua` of other basic modules.
-- @author Ted Yin <ted.sybil@gmail.com>
-- @module nerv

require 'libnerv'

--- Dummy function.
-- Display a friendly error message when user attempts to invoke a
-- non-implemented function.
function nerv.error_method_not_implemented()
    nerv.error("method not implemented");
end

--- Format a string just like `sprintf` in C.
-- @param fmt the format string
-- @param ... args, the data to be formatted
-- @return the formatted string
function nerv.sprintf(fmt, ...)
    return string.format(fmt, ...)
end

--- Print a formatted string to stdout.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.printf(fmt, ...)
    io.write(nerv.sprintf(fmt, ...))
end

--- Raise an global error with the formatted message.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.error(fmt, ...)
    error(nerv.sprintf("[nerv] internal error: " .. fmt .. "\n", ...))
end

--- Print a notification message that begins with "info" and a timestamp.
-- Instead of using `nerv.printf`, normal users should use this to print any
-- notification information.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.info(fmt, ...)
    nerv.printf(
        string.format("(%s)[nerv] info: %s\n",
            os.date("%H:%M:%S %F"), fmt), ...)
end

--- Print a warning message that begins with "warning" and a timestamp.
-- Instead of using `nerv.printf`, normal users should use this to print any
-- warnings.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.warning(fmt, ...)
    nerv.printf(
        string.format("(%s)[nerv] warning: %s\n",
            os.date("%H:%M:%S %F"), fmt), ...)
end

--- Create a class (Torch-compatible).
-- Use this to create a class in NERV.
-- @param tname the class name
-- @param parenttname the parent class name (from which it inherits)
-- @return the created class
function nerv.class(tname, parenttname)

   local function constructor(...)
      local self = {}
      nerv.setmetatable(self, tname)
      if self.__init then
         self:__init(...)
      end
      return self
   end

   local function factory()
      local self = {}
      nerv.setmetatable(self, tname)
      return self
   end

   local mt = nerv.newmetatable(tname, parenttname, constructor, nil, factory)
   local mpt
   if parenttname then
      mpt = nerv.getmetatable(parenttname)
   end
   return mt, mpt
end

function table.val_to_str(v)
  if "string" == type(v) then
    v = string.gsub(v, "\n", "\\n")
    if string.match(string.gsub(v,"[^'\"]",""), '^"+$') then
      return "'" .. v .. "'"
    end
    return '"' .. string.gsub(v,'"', '\\"') .. '"'
  else
    return "table" == type(v) and table.tostring(v) or
      tostring(v)
  end
end

function table.key_to_str (k)
  if "string" == type(k) and string.match(k, "^[_%a][_%a%d]*$") then
    return k
  else
    return "[" .. table.val_to_str(k) .. "]"
  end
end

--- Get the string representation of a table, which can be executed as a valid
-- piece of Lua code.
-- @param tbl the table
-- @return the string representation which will result in a Lua table entity
-- when evaluated
function table.tostring(tbl)
  local result, done = {}, {}
  for k, v in ipairs(tbl) do
    table.insert(result, table.val_to_str(v))
    done[k] = true
  end
  for k, v in pairs(tbl) do
    if not done[k] then
      table.insert(result,
        table.key_to_str(k) .. "=" .. table.val_to_str(v))
    end
  end
  return "{" .. table.concat(result, ",") .. "}"
end

--- Get the class by name.
-- @param tname the name of the class
-- @return the class entity
function nerv.get_type(tname)
    return assert(loadstring("return " .. tname))()
end

--- Check if the object is of the certain class.
-- @param obj the object ("class instance")
-- @param tname the class name ("type name")
function nerv.is_type(obj, tname)
    local mt0 = nerv.getmetatable(tname)
    local mt = getmetatable(obj)
    while mt do
        if mt == mt0 then
            return true
        end
        mt = getmetatable(mt)
    end
    return false
end

--- Strip last component from file name.
-- @param filename the path to a file
-- @return the path to the containing directory
function nerv.dirname(filename)
    if filename:match(".-/.-") then
        local name = string.gsub(filename, "(.*/)(.*)", "%1")
        return name
    else
        return ''
    end
end

--- Include a script file (chunk) into the current script.
-- An analogy to `#include` in C. Note that the effect is the same as executing
-- `dofile(filename)` at the current line.
-- @param filename the path to a file
-- @return all values returned by the chunk
function nerv.include(filename)
    local caller = debug.getinfo(2, "S").source:sub(2)
    return dofile(nerv.dirname(caller) .. filename)
end

-- the following lines trigger the initialization of basic modules

nerv.include('matrix/init.lua')
nerv.include('io/init.lua')
nerv.include('layer/init.lua')
nerv.include('nn/init.lua')
nerv.include('tnn/init.lua')