aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--class.lua250
-rw-r--r--class_example.lua24
-rw-r--r--nerv.lua1
4 files changed, 276 insertions, 1 deletions
diff --git a/Makefile b/Makefile
index 27f3380..fa323a1 100644
--- a/Makefile
+++ b/Makefile
@@ -1,7 +1,7 @@
.PHONY: all clean luajit
OBJS := oop_example.o nerv.o luaT.o common.o matrix/mmatrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o
LIBS := libnerv.so
-LUA_LIBS := matrix/init.lua nerv.lua
+LUA_LIBS := matrix/init.lua nerv.lua class.lua
INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK
CUDA_BASE := /usr/local/cuda-6.5
CUDA_INCLUDE := -I $(CUDA_BASE)/include/
diff --git a/class.lua b/class.lua
new file mode 100644
index 0000000..d260c31
--- /dev/null
+++ b/class.lua
@@ -0,0 +1,250 @@
+--- Provides a reuseable and convenient framework for creating classes in Lua.
+-- Two possible notations:
+--
+-- B = class(A)
+-- class.B(A)
+--
+-- The latter form creates a named class within the current environment. Note
+-- that this implicitly brings in `pl.utils` as a dependency.
+--
+-- See the Guide for further @{01-introduction.md.Simplifying_Object_Oriented_Programming_in_Lua|discussion}
+-- @module pl.class
+
+local error, getmetatable, io, pairs, rawget, rawset, setmetatable, tostring, type =
+ _G.error, _G.getmetatable, _G.io, _G.pairs, _G.rawget, _G.rawset, _G.setmetatable, _G.tostring, _G.type
+local compat
+
+-- this trickery is necessary to prevent the inheritance of 'super' and
+-- the resulting recursive call problems.
+local function call_ctor (c,obj,...)
+ -- nice alias for the base class ctor
+ local base = rawget(c,'_base')
+ if base then
+ local parent_ctor = rawget(base,'_init')
+ while not parent_ctor do
+ base = rawget(base,'_base')
+ if not base then break end
+ parent_ctor = rawget(base,'_init')
+ end
+ if parent_ctor then
+ rawset(obj,'super',function(obj,...)
+ call_ctor(base,obj,...)
+ end)
+ end
+ end
+ local res = c._init(obj,...)
+ rawset(obj,'super',nil)
+ return res
+end
+
+--- initializes an __instance__ upon creation.
+-- @function class:_init
+-- @param ... parameters passed to the constructor
+-- @usage local Cat = class()
+-- function Cat:_init(name)
+-- --self:super(name) -- call the ancestor initializer if needed
+-- self.name = name
+-- end
+--
+-- local pussycat = Cat("pussycat")
+-- print(pussycat.name) --> pussycat
+
+--- checks whether an __instance__ is derived from some class.
+-- Works the other way around as `class_of`.
+-- @function instance:is_a
+-- @param some_class class to check against
+-- @return `true` if `instance` is derived from `some_class`
+-- @usage local pussycat = Lion() -- assuming Lion derives from Cat
+-- if pussycat:is_a(Cat) then
+-- -- it's true
+-- end
+local function is_a(self,klass)
+ local m = getmetatable(self)
+ if not m then return false end --*can't be an object!
+ while m do
+ if m == klass then return true end
+ m = rawget(m,'_base')
+ end
+ return false
+end
+
+--- checks whether an __instance__ is derived from some class.
+-- Works the other way around as `is_a`.
+-- @function some_class:class_of
+-- @param some_instance instance to check against
+-- @return `true` if `some_instance` is derived from `some_class`
+-- @usage local pussycat = Lion() -- assuming Lion derives from Cat
+-- if Cat:class_of(pussycat) then
+-- -- it's true
+-- end
+local function class_of(klass,obj)
+ if type(klass) ~= 'table' or not rawget(klass,'is_a') then return false end
+ return klass.is_a(obj,klass)
+end
+
+--- cast an object to another class.
+-- It is not clever (or safe!) so use carefully.
+-- @param some_instance the object to be changed
+-- @function some_class:cast
+local function cast (klass, obj)
+ return setmetatable(obj,klass)
+end
+
+
+local function _class_tostring (obj)
+ local mt = obj._class
+ local name = rawget(mt,'_name')
+ setmetatable(obj,nil)
+ local str = tostring(obj)
+ setmetatable(obj,mt)
+ if name then str = name ..str:gsub('table','') end
+ return str
+end
+
+local function tupdate(td,ts,dont_override)
+ for k,v in pairs(ts) do
+ if not dont_override or td[k] == nil then
+ td[k] = v
+ end
+ end
+end
+
+local function _class(base,c_arg,c)
+ -- the class `c` will be the metatable for all its objects,
+ -- and they will look up their methods in it.
+ local mt = {} -- a metatable for the class to support __call and _handler
+ -- can define class by passing it a plain table of methods
+ local plain = type(base) == 'table' and not getmetatable(base)
+ if plain then
+ c = base
+ base = c._base
+ else
+ c = c or {}
+ end
+
+ if type(base) == 'table' then
+ -- our new class is a shallow copy of the base class!
+ -- but be careful not to wipe out any methods we have been given at this point!
+ tupdate(c,base,plain)
+ c._base = base
+ -- inherit the 'not found' handler, if present
+ if rawget(c,'_handler') then mt.__index = c._handler end
+ elseif base ~= nil then
+ error("must derive from a table type",3)
+ end
+
+ c.__index = c
+ setmetatable(c,mt)
+ if not plain then
+ c._init = nil
+ end
+
+ if base and rawget(base,'_class_init') then
+ base._class_init(c,c_arg)
+ end
+
+ -- expose a ctor which can be called by <classname>(<args>)
+ mt.__call = function(class_tbl,...)
+ local obj
+ if rawget(c,'_create') then obj = c._create(...) end
+ if not obj then obj = {} end
+ setmetatable(obj,c)
+
+ if rawget(c,'_init') then -- explicit constructor
+ local res = call_ctor(c,obj,...)
+ if res then -- _if_ a ctor returns a value, it becomes the object...
+ obj = res
+ setmetatable(obj,c)
+ end
+ elseif base and rawget(base,'_init') then -- default constructor
+ -- make sure that any stuff from the base class is initialized!
+ call_ctor(base,obj,...)
+ end
+
+ if base and rawget(base,'_post_init') then
+ base._post_init(obj)
+ end
+
+ if not rawget(c,'__tostring') then
+ c.__tostring = _class_tostring
+ end
+ return obj
+ end
+ -- Call Class.catch to set a handler for methods/properties not found in the class!
+ c.catch = function(self, handler)
+ if type(self) == "function" then
+ -- called using . instead of :
+ handler = self
+ end
+ c._handler = handler
+ mt.__index = handler
+ end
+ c.is_a = is_a
+ c.class_of = class_of
+ c.cast = cast
+ c._class = c
+
+ return c
+end
+
+--- create a new class, derived from a given base class.
+-- Supporting two class creation syntaxes:
+-- either `Name = class(base)` or `class.Name(base)`.
+-- The first form returns the class directly and does not set its `_name`.
+-- The second form creates a variable `Name` in the current environment set
+-- to the class, and also sets `_name`.
+-- @function class
+-- @param base optional base class
+-- @param c_arg optional parameter to class constructor
+-- @param c optional table to be used as class
+local class
+class = setmetatable({},{
+ __call = function(fun,...)
+ return _class(...)
+ end,
+ __index = function(tbl,key)
+ if key == 'class' then
+ io.stderr:write('require("pl.class").class is deprecated. Use require("pl.class")\n')
+ return class
+ end
+ compat = compat or require 'pl.compat'
+ local env = compat.getfenv(2)
+ return function(...)
+ local c = _class(...)
+ c._name = key
+ rawset(env,key,c)
+ return c
+ end
+ end
+})
+
+class.properties = class()
+
+function class.properties._class_init(klass)
+ klass.__index = function(t,key)
+ -- normal class lookup!
+ local v = klass[key]
+ if v then return v end
+ -- is it a getter?
+ v = rawget(klass,'get_'..key)
+ if v then
+ return v(t)
+ end
+ -- is it a field?
+ return rawget(t,'_'..key)
+ end
+ klass.__newindex = function (t,key,value)
+ -- if there's a setter, use that, otherwise directly set table
+ local p = 'set_'..key
+ local setter = klass[p]
+ if setter then
+ setter(t,value)
+ else
+ rawset(t,key,value)
+ end
+ end
+end
+
+
+return class
+
diff --git a/class_example.lua b/class_example.lua
new file mode 100644
index 0000000..ab69b70
--- /dev/null
+++ b/class_example.lua
@@ -0,0 +1,24 @@
+A = nerv.class()
+function A:_init(x)
+ self.x = x
+end
+function A:f()
+ return self.x
+end
+
+function A:g()
+ return self.x + 1
+end
+
+B = nerv.class(A)
+
+function B:f()
+ return self.x * self.x
+end
+
+a = A(3)
+b = B(3)
+print(a:f())
+print(b:f())
+print(b:g())
+
diff --git a/nerv.lua b/nerv.lua
index 5b53306..de2e701 100644
--- a/nerv.lua
+++ b/nerv.lua
@@ -1,2 +1,3 @@
require 'libnerv'
require 'matrix.init'
+nerv.class = require 'class'