aboutsummaryrefslogtreecommitdiff
path: root/nerv/init.lua
blob: 6312df17b1471726543fe072fd5bdfd03da28066 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
--- NERV: a Lua-based toolkit for high-performance deep learning.
-- This file contains misc utility functions of NERV and finally initializes
-- NERV by including `init.lua` of other basic modules.
-- @author Ted Yin <ted.sybil@gmail.com>
-- @module nerv

require 'libnerv'

--- Dummy function.
-- Display a friendly error message when user attempts to invoke a
-- non-implemented function.
function nerv.error_method_not_implemented()
    nerv.error("method not implemented");
end

--- Format a string just like `sprintf` in C.
-- @param fmt the format string
-- @param ... args, the data to be formatted
-- @return the formatted string
function nerv.sprintf(fmt, ...)
    return string.format(fmt, ...)
end

--- Print a formatted string to stdout.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.printf(fmt, ...)
    io.write(nerv.sprintf(fmt, ...))
end

--- Raise an global error with the formatted message.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.error(fmt, ...)
    error(nerv.sprintf("[nerv] internal error: " .. fmt .. "\n", ...))
end

--- Print a notification message that begins with "info" and a timestamp.
-- Instead of using `nerv.printf`, normal users should use this to print any
-- notification information.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.info(fmt, ...)
    nerv.printf(
        string.format("(%s)[nerv] info: %s\n",
            os.date("%H:%M:%S %F"), fmt), ...)
end

--- Print a warning message that begins with "warning" and a timestamp.
-- Instead of using `nerv.printf`, normal users should use this to print any
-- warnings.
-- @param fmt the format string
-- @param ... args, the data to be formatted
function nerv.warning(fmt, ...)
    nerv.printf(
        string.format("(%s)[nerv] warning: %s\n",
            os.date("%H:%M:%S %F"), fmt), ...)
end

--- Create a class (Torch-compatible).
-- Use this to create a class in NERV.
-- @param tname the class name
-- @param parenttname the parent class name (from which it inherits)
-- @return the created class
function nerv.class(tname, parenttname)

   local function constructor(...)
      local self = {}
      nerv.setmetatable(self, tname)
      if self.__init then
         self:__init(...)
      end
      return self
   end

   local function factory()
      local self = {}
      nerv.setmetatable(self, tname)
      return self
   end

   local mt = nerv.newmetatable(tname, parenttname, constructor, nil, factory)
   local mpt
   if parenttname then
      mpt = nerv.getmetatable(parenttname)
   end
   return mt, mpt
end

function table.val_to_str(v)
  if "string" == type(v) then
    v = string.gsub(v, "\n", "\\n")
    if string.match(string.gsub(v,"[^'\"]",""), '^"+$') then
      return "'" .. v .. "'"
    end
    return '"' .. string.gsub(v,'"', '\\"') .. '"'
  else
    return "table" == type(v) and table.tostring(v) or
      tostring(v)
  end
end

function table.key_to_str (k)
  if "string" == type(k) and string.match(k, "^[_%a][_%a%d]*$") then
    return k
  else
    return "[" .. table.val_to_str(k) .. "]"
  end
end

--- Get the string representation of a table, which can be executed as a valid
-- piece of Lua code.
-- @param tbl the table
-- @return the string representation which will result in a Lua table entity
-- when evaluated
function table.tostring(tbl)
  local result, done = {}, {}
  for k, v in ipairs(tbl) do
    table.insert(result, table.val_to_str(v))
    done[k] = true
  end
  for k, v in pairs(tbl) do
    if not done[k] then
      table.insert(result,
        table.key_to_str(k) .. "=" .. table.val_to_str(v))
    end
  end
  return "{" .. table.concat(result, ",") .. "}"
end

--- Get the class by name.
-- @param tname the name of the class
-- @return the class entity
function nerv.get_type(tname)
    return assert(loadstring("return " .. tname))()
end

--- Check if the object is of the certain class.
-- @param obj the object ("class instance")
-- @param tname the class name ("type name")
function nerv.is_type(obj, tname)
    local mt0 = nerv.getmetatable(tname)
    local mt = getmetatable(obj)
    while mt do
        if mt == mt0 then
            return true
        end
        mt = getmetatable(mt)
    end
    return false
end

--- Strip last component from file name.
-- @param filename the path to a file
-- @return the path to the containing directory
function nerv.dirname(filename)
    if filename:match(".-/.-") then
        local name = string.gsub(filename, "(.*/)(.*)", "%1")
        return name
    else
        return ''
    end
end

--- Include a script file (chunk) into the current script.
-- An analogy to `#include` in C. Note that the effect is the same as executing
-- `dofile(filename)` at the current line.
-- @param filename the path to a file
-- @return all values returned by the chunk
function nerv.include(filename)
    local caller = debug.getinfo(2, "S").source:sub(2)
    return dofile(nerv.dirname(caller) .. filename)
end

-- the following lines trigger the initialization of basic modules

nerv.include('matrix/init.lua')
nerv.include('io/init.lua')
nerv.include('layer/init.lua')
nerv.include('nn/init.lua')
nerv.include('tnn/init.lua')