diff options
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | class.lua | 250 | ||||
-rw-r--r-- | class_example.lua | 24 | ||||
-rw-r--r-- | nerv.lua | 1 |
4 files changed, 276 insertions, 1 deletions
@@ -1,7 +1,7 @@ .PHONY: all clean luajit OBJS := oop_example.o nerv.o luaT.o common.o matrix/mmatrix.o matrix/cumatrix.o matrix/init.o matrix/cukernel.o LIBS := libnerv.so -LUA_LIBS := matrix/init.lua nerv.lua +LUA_LIBS := matrix/init.lua nerv.lua class.lua INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK CUDA_BASE := /usr/local/cuda-6.5 CUDA_INCLUDE := -I $(CUDA_BASE)/include/ diff --git a/class.lua b/class.lua new file mode 100644 index 0000000..d260c31 --- /dev/null +++ b/class.lua @@ -0,0 +1,250 @@ +--- Provides a reuseable and convenient framework for creating classes in Lua. +-- Two possible notations: +-- +-- B = class(A) +-- class.B(A) +-- +-- The latter form creates a named class within the current environment. Note +-- that this implicitly brings in `pl.utils` as a dependency. +-- +-- See the Guide for further @{01-introduction.md.Simplifying_Object_Oriented_Programming_in_Lua|discussion} +-- @module pl.class + +local error, getmetatable, io, pairs, rawget, rawset, setmetatable, tostring, type = + _G.error, _G.getmetatable, _G.io, _G.pairs, _G.rawget, _G.rawset, _G.setmetatable, _G.tostring, _G.type +local compat + +-- this trickery is necessary to prevent the inheritance of 'super' and +-- the resulting recursive call problems. +local function call_ctor (c,obj,...) + -- nice alias for the base class ctor + local base = rawget(c,'_base') + if base then + local parent_ctor = rawget(base,'_init') + while not parent_ctor do + base = rawget(base,'_base') + if not base then break end + parent_ctor = rawget(base,'_init') + end + if parent_ctor then + rawset(obj,'super',function(obj,...) + call_ctor(base,obj,...) + end) + end + end + local res = c._init(obj,...) + rawset(obj,'super',nil) + return res +end + +--- initializes an __instance__ upon creation. +-- @function class:_init +-- @param ... parameters passed to the constructor +-- @usage local Cat = class() +-- function Cat:_init(name) +-- --self:super(name) -- call the ancestor initializer if needed +-- self.name = name +-- end +-- +-- local pussycat = Cat("pussycat") +-- print(pussycat.name) --> pussycat + +--- checks whether an __instance__ is derived from some class. +-- Works the other way around as `class_of`. +-- @function instance:is_a +-- @param some_class class to check against +-- @return `true` if `instance` is derived from `some_class` +-- @usage local pussycat = Lion() -- assuming Lion derives from Cat +-- if pussycat:is_a(Cat) then +-- -- it's true +-- end +local function is_a(self,klass) + local m = getmetatable(self) + if not m then return false end --*can't be an object! + while m do + if m == klass then return true end + m = rawget(m,'_base') + end + return false +end + +--- checks whether an __instance__ is derived from some class. +-- Works the other way around as `is_a`. +-- @function some_class:class_of +-- @param some_instance instance to check against +-- @return `true` if `some_instance` is derived from `some_class` +-- @usage local pussycat = Lion() -- assuming Lion derives from Cat +-- if Cat:class_of(pussycat) then +-- -- it's true +-- end +local function class_of(klass,obj) + if type(klass) ~= 'table' or not rawget(klass,'is_a') then return false end + return klass.is_a(obj,klass) +end + +--- cast an object to another class. +-- It is not clever (or safe!) so use carefully. +-- @param some_instance the object to be changed +-- @function some_class:cast +local function cast (klass, obj) + return setmetatable(obj,klass) +end + + +local function _class_tostring (obj) + local mt = obj._class + local name = rawget(mt,'_name') + setmetatable(obj,nil) + local str = tostring(obj) + setmetatable(obj,mt) + if name then str = name ..str:gsub('table','') end + return str +end + +local function tupdate(td,ts,dont_override) + for k,v in pairs(ts) do + if not dont_override or td[k] == nil then + td[k] = v + end + end +end + +local function _class(base,c_arg,c) + -- the class `c` will be the metatable for all its objects, + -- and they will look up their methods in it. + local mt = {} -- a metatable for the class to support __call and _handler + -- can define class by passing it a plain table of methods + local plain = type(base) == 'table' and not getmetatable(base) + if plain then + c = base + base = c._base + else + c = c or {} + end + + if type(base) == 'table' then + -- our new class is a shallow copy of the base class! + -- but be careful not to wipe out any methods we have been given at this point! + tupdate(c,base,plain) + c._base = base + -- inherit the 'not found' handler, if present + if rawget(c,'_handler') then mt.__index = c._handler end + elseif base ~= nil then + error("must derive from a table type",3) + end + + c.__index = c + setmetatable(c,mt) + if not plain then + c._init = nil + end + + if base and rawget(base,'_class_init') then + base._class_init(c,c_arg) + end + + -- expose a ctor which can be called by <classname>(<args>) + mt.__call = function(class_tbl,...) + local obj + if rawget(c,'_create') then obj = c._create(...) end + if not obj then obj = {} end + setmetatable(obj,c) + + if rawget(c,'_init') then -- explicit constructor + local res = call_ctor(c,obj,...) + if res then -- _if_ a ctor returns a value, it becomes the object... + obj = res + setmetatable(obj,c) + end + elseif base and rawget(base,'_init') then -- default constructor + -- make sure that any stuff from the base class is initialized! + call_ctor(base,obj,...) + end + + if base and rawget(base,'_post_init') then + base._post_init(obj) + end + + if not rawget(c,'__tostring') then + c.__tostring = _class_tostring + end + return obj + end + -- Call Class.catch to set a handler for methods/properties not found in the class! + c.catch = function(self, handler) + if type(self) == "function" then + -- called using . instead of : + handler = self + end + c._handler = handler + mt.__index = handler + end + c.is_a = is_a + c.class_of = class_of + c.cast = cast + c._class = c + + return c +end + +--- create a new class, derived from a given base class. +-- Supporting two class creation syntaxes: +-- either `Name = class(base)` or `class.Name(base)`. +-- The first form returns the class directly and does not set its `_name`. +-- The second form creates a variable `Name` in the current environment set +-- to the class, and also sets `_name`. +-- @function class +-- @param base optional base class +-- @param c_arg optional parameter to class constructor +-- @param c optional table to be used as class +local class +class = setmetatable({},{ + __call = function(fun,...) + return _class(...) + end, + __index = function(tbl,key) + if key == 'class' then + io.stderr:write('require("pl.class").class is deprecated. Use require("pl.class")\n') + return class + end + compat = compat or require 'pl.compat' + local env = compat.getfenv(2) + return function(...) + local c = _class(...) + c._name = key + rawset(env,key,c) + return c + end + end +}) + +class.properties = class() + +function class.properties._class_init(klass) + klass.__index = function(t,key) + -- normal class lookup! + local v = klass[key] + if v then return v end + -- is it a getter? + v = rawget(klass,'get_'..key) + if v then + return v(t) + end + -- is it a field? + return rawget(t,'_'..key) + end + klass.__newindex = function (t,key,value) + -- if there's a setter, use that, otherwise directly set table + local p = 'set_'..key + local setter = klass[p] + if setter then + setter(t,value) + else + rawset(t,key,value) + end + end +end + + +return class + diff --git a/class_example.lua b/class_example.lua new file mode 100644 index 0000000..ab69b70 --- /dev/null +++ b/class_example.lua @@ -0,0 +1,24 @@ +A = nerv.class() +function A:_init(x) + self.x = x +end +function A:f() + return self.x +end + +function A:g() + return self.x + 1 +end + +B = nerv.class(A) + +function B:f() + return self.x * self.x +end + +a = A(3) +b = B(3) +print(a:f()) +print(b:f()) +print(b:g()) + @@ -1,2 +1,3 @@ require 'libnerv' require 'matrix.init' +nerv.class = require 'class' |