aboutsummaryrefslogtreecommitdiff
path: root/nerv/init.lua
blob: 9c1a5c8ffc59385ce30a6d34727e2b7d495be284 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
require 'libnerv'

function nerv.error_method_not_implemented()
    nerv.error("method not implemented");
end

function nerv.sprintf(fmt, ...)
    return string.format(fmt, ...)
end

function nerv.printf(fmt, ...)
    io.write(nerv.sprintf(fmt, ...))
end

function nerv.error(fmt, ...)
    error(nerv.sprintf("[nerv] internal error: " .. fmt .. "\n", ...))
end

function nerv.mesg_with_timestamp(fmt, ...)
    nerv.printf(
        string.format("(%s)[nerv] info: %s\n",
            os.date("%H:%M:%S %F"), fmt), ...)
end

function nerv.info(fmt, ...)
    nerv.printf(
        string.format("(%s)[nerv] info: %s\n",
            os.date("%H:%M:%S %F"), fmt), ...)
end

function nerv.warning(fmt, ...)
    nerv.printf(
        string.format("(%s)[nerv] warning: %s\n",
            os.date("%H:%M:%S %F"), fmt), ...)
end

-- Torch C API wrapper
function nerv.class(tname, parenttname)

   local function constructor(...)
      local self = {}
      nerv.setmetatable(self, tname)
      if self.__init then
         self:__init(...)
      end
      return self
   end

   local function factory()
      local self = {}
      nerv.setmetatable(self, tname)
      return self
   end

   local mt = nerv.newmetatable(tname, parenttname, constructor, nil, factory)
   local mpt
   if parenttname then
      mpt = nerv.getmetatable(parenttname)
   end
   return mt, mpt
end

function table.val_to_str(v)
  if "string" == type(v) then
    v = string.gsub(v, "\n", "\\n")
    if string.match(string.gsub(v,"[^'\"]",""), '^"+$') then
      return "'" .. v .. "'"
    end
    return '"' .. string.gsub(v,'"', '\\"') .. '"'
  else
    return "table" == type(v) and table.tostring(v) or
      tostring(v)
  end
end

function table.key_to_str (k)
  if "string" == type(k) and string.match(k, "^[_%a][_%a%d]*$") then
    return k
  else
    return "[" .. table.val_to_str(k) .. "]"
  end
end

function table.tostring(tbl)
  local result, done = {}, {}
  for k, v in ipairs(tbl) do
    table.insert(result, table.val_to_str(v))
    done[k] = true
  end
  for k, v in pairs(tbl) do
    if not done[k] then
      table.insert(result,
        table.key_to_str(k) .. "=" .. table.val_to_str(v))
    end
  end
  return "{" .. table.concat(result, ",") .. "}"
end

function nerv.get_type(tname)
    return assert(loadstring("return " .. tname))()
end

function nerv.is_type(obj, tname)
    local mt0 = nerv.getmetatable(tname)
    local mt = getmetatable(obj)
    while mt do
        if mt == mt0 then
            return true
        end
        mt = getmetatable(mt)
    end
    return false
end

function nerv.dirname(filename)
    if filename:match(".-/.-") then
        local name = string.gsub(filename, "(.*/)(.*)", "%1")
        return name
    else
        return ''
    end
end

function nerv.include(filename)
    local caller = debug.getinfo(2, "S").source:sub(2)
    dofile(nerv.dirname(caller) .. filename)
end

nerv.include('matrix/init.lua')
nerv.include('io/init.lua')
nerv.include('layer/init.lua')
nerv.include('nn/init.lua')