summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-05-14 15:01:55 +0800
committerDeterminant <[email protected]>2015-05-14 15:01:55 +0800
commitf48dc493b5b77fd4e4472dd6c78b7542a4884129 (patch)
tree0b7a0f95df28fc100fc1fd252ce1d0215d19150d
parent46ccec6d5ad057476e945afa34981f7e8d732547 (diff)
add basic matrix implementation
-rw-r--r--Makefile16
-rw-r--r--common.c19
-rw-r--r--common.h9
-rw-r--r--matrix.c95
-rw-r--r--matrix.lua15
-rw-r--r--matrix_example.lua7
-rwxr-xr-xnerv2
-rw-r--r--nerv.c4
-rw-r--r--oop_example.c16
-rw-r--r--oop_example.lua2
10 files changed, 172 insertions, 13 deletions
diff --git a/Makefile b/Makefile
index e638b96..55510a7 100644
--- a/Makefile
+++ b/Makefile
@@ -1,18 +1,28 @@
.PHONY: all clean luajit
-OBJS := oop_example.o nerv.o luaT.o
+OBJS := oop_example.o nerv.o luaT.o common.o matrix.o
LIBS := libnerv.so
+LUA_LIBS := matrix.lua
INCLUDE := -I build/luajit-2.0/include/luajit-2.0/ -DLUA_USE_APICHECK
LDFLAGS := -L luajit-2.0/build/lib/ -llua -lm
+CFLAGS :=
OBJ_DIR := build/objs
+LUA_DIR := build/lua
+
OBJS := $(addprefix $(OBJ_DIR)/,$(OBJS))
LIBS := $(addprefix $(OBJ_DIR)/,$(LIBS))
-all: luajit $(OBJ_DIR) $(LIBS)
+LUA_LIBS := $(addprefix $(LUA_DIR)/,$(LUA_LIBS))
+
+all: luajit $(OBJ_DIR) $(LIBS) $(LUA_DIR) $(LUA_LIBS)
luajit:
./build_luajit.sh
$(OBJ_DIR):
-mkdir -p $(OBJ_DIR)
+$(LUA_DIR):
+ -mkdir -p $(LUA_DIR)
$(OBJ_DIR)/%.o: %.c
- gcc -c -o $@ $< $(INCLUDE) -fPIC
+ gcc -c -o $@ $< $(INCLUDE) -fPIC $(CFLAGS)
+$(LUA_DIR)/%.lua: %.lua
+ cp $< $@
$(OBJ_DIR)/luaT.o:
gcc -c -o $@ luaT/luaT.c $(INCLUDE) -fPIC
$(LIBS): $(OBJS)
diff --git a/common.c b/common.c
new file mode 100644
index 0000000..f5521fd
--- /dev/null
+++ b/common.c
@@ -0,0 +1,19 @@
+#ifndef NERV_COMMON_H
+#define NERV_COMMON_H
+#include "common.h"
+#include <stdarg.h>
+int nerv_error(lua_State *L, const char *err_mesg_fmt, ...) {
+ va_list ap;
+ va_start(ap, err_mesg_fmt);
+ lua_pushstring(L, "Nerv internal error: ");
+ lua_pushvfstring(L, err_mesg_fmt, ap);
+ lua_concat(L, 2);
+ lua_error(L);
+ va_end(ap);
+ return 0;
+}
+
+int nerv_error_method_not_implemented(lua_State *L) {
+ return nerv_error(L, "method not implemented");
+}
+#endif
diff --git a/common.h b/common.h
new file mode 100644
index 0000000..b316f20
--- /dev/null
+++ b/common.h
@@ -0,0 +1,9 @@
+#include "lua.h"
+#include "lauxlib.h"
+#include "lualib.h"
+#include "luaT/luaT.h"
+#include <stdio.h>
+#include <stdlib.h>
+
+int nerv_error(lua_State *L, const char *err_mesg_fmt, ...);
+int nerv_error_method_not_implemented(lua_State *L);
diff --git a/matrix.c b/matrix.c
new file mode 100644
index 0000000..9d93dba
--- /dev/null
+++ b/matrix.c
@@ -0,0 +1,95 @@
+#include "common.h"
+
+typedef struct Matrix {
+ long stride; /* size of a row */
+ long ncol, nrow, nmax; /* dimension of the matrix */
+ union {
+ float *f;
+ double *d;
+ } storage; /* pointer to actual storage */
+} Matrix;
+
+const char *float_matrix_tname = "nerv.FloatMatrix";
+const char *matrix_tname = "nerv.Matrix";
+
+int float_matrix_new(lua_State *L) {
+ Matrix *self = (Matrix *)malloc(sizeof(Matrix));
+ self->nrow = luaL_checkinteger(L, 1);
+ self->ncol = luaL_checkinteger(L, 2);
+ self->nmax = self->nrow * self->ncol;
+ self->stride = sizeof(float) * self->nrow;
+ self->storage.f = (float *)malloc(self->stride * self->ncol);
+ luaT_pushudata(L, self, float_matrix_tname);
+ return 1;
+}
+
+int float_matrix_destroy(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, float_matrix_tname);
+ free(self->storage.f);
+ fprintf(stderr, "[debug] destroyted\n");
+ return 0;
+}
+
+int nerv_float_matrix_get_elem(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, float_matrix_tname);
+ int idx = luaL_checkinteger(L, 2);
+ if (idx < 0 || idx >= self->nmax)
+ nerv_error(L, "index must be within range [0, %d)", self->nmax);
+ lua_pushnumber(L, self->storage.f[idx]);
+ return 1;
+}
+
+int nerv_float_matrix_set_elem(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, float_matrix_tname);
+ int idx = luaL_checkinteger(L, 2);
+ float v = luaL_checknumber(L, 3);
+ long upper = self->nrow * self->ncol;
+ if (idx < 0 || idx >= self->nmax)
+ nerv_error(L, "index must be within range [0, %d)", self->nmax);
+ self->storage.f[idx] = v;
+ return 0;
+}
+
+static int nerv_float_matrix_ncol(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, float_matrix_tname);
+ lua_pushinteger(L, self->ncol);
+ return 1;
+}
+
+static int nerv_float_matrix_nrow(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, float_matrix_tname);
+ lua_pushinteger(L, self->nrow);
+ return 1;
+}
+
+
+static const luaL_Reg float_matrix_methods[] = {
+ {"get_elem", nerv_float_matrix_get_elem},
+ {"set_elem", nerv_float_matrix_set_elem},
+ {"ncol", nerv_float_matrix_ncol},
+ {"nrow", nerv_float_matrix_nrow},
+ {NULL, NULL}
+};
+
+void nerv_float_matrix_init(lua_State *L) {
+ luaT_newmetatable(L, float_matrix_tname, matrix_tname,
+ float_matrix_new, float_matrix_destroy, NULL);
+ luaL_register(L, NULL, float_matrix_methods);
+ lua_pop(L, 1);
+}
+
+static const luaL_Reg matrix_methods[] = {
+ {"__tostring__", nerv_error_method_not_implemented },
+ {"__add__", nerv_error_method_not_implemented },
+ {"__sub__", nerv_error_method_not_implemented },
+ {"__mul__", nerv_error_method_not_implemented },
+ {NULL, NULL}
+};
+
+void nerv_matrix_init(lua_State *L) {
+ /* abstract class */
+ luaT_newmetatable(L, matrix_tname, NULL, NULL, NULL, NULL);
+ luaL_register(L, NULL, matrix_methods);
+ lua_pop(L, 1);
+ nerv_float_matrix_init(L);
+}
diff --git a/matrix.lua b/matrix.lua
new file mode 100644
index 0000000..2a70590
--- /dev/null
+++ b/matrix.lua
@@ -0,0 +1,15 @@
+function nerv.FloatMatrix:__tostring__()
+ local ncol = self:ncol()
+ local nrow = self:nrow()
+ local i = 0
+ local res = ""
+ for row = 0, nrow - 1 do
+ for col = 0, ncol - 1 do
+ res = res .. string.format("%f ", self:get_elem(i))
+ i = i + 1
+ end
+ res = res .. "\n"
+ end
+ res = res .. string.format("[Float Matrix %d x %d]", nrow, ncol)
+ return res
+end
diff --git a/matrix_example.lua b/matrix_example.lua
new file mode 100644
index 0000000..1ff129d
--- /dev/null
+++ b/matrix_example.lua
@@ -0,0 +1,7 @@
+t = nerv.FloatMatrix(2, 3)
+print(t:get_elem(1))
+t:set_elem(1, 3.23432);
+print(t:get_elem(1))
+print(t)
+t = nerv.FloatMatrix(10, 20)
+print(t)
diff --git a/nerv b/nerv
index 7b035c6..2eeab76 100755
--- a/nerv
+++ b/nerv
@@ -1,2 +1,2 @@
#!/bin/bash
-exec 'build/luajit-2.0/bin/luajit' -e "package.cpath=\"${PWD}/build/objs/?.so\"" -e "require 'libnerv' " "$@"
+exec 'build/luajit-2.0/bin/luajit' -e "package.cpath=\"${PWD}/build/objs/?.so\"" -e "require 'libnerv'" -e "package.path=\"${PWD}/build/lua/?.lua\"" -e "require 'matrix'" "$@"
diff --git a/nerv.c b/nerv.c
index 2118a3c..42b50b4 100644
--- a/nerv.c
+++ b/nerv.c
@@ -2,12 +2,14 @@
#include "lauxlib.h"
#include "lualib.h"
-extern int nerv_point_init(lua_State *L);
+extern void nerv_point_init(lua_State *L);
+extern void nerv_matrix_init(lua_State *L);
LUALIB_API int luaopen_libnerv(lua_State *L) {
lua_newtable(L);
lua_pushvalue(L, -1);
lua_setfield(L, LUA_GLOBALSINDEX, "nerv");
nerv_point_init(L);
+ nerv_matrix_init(L);
return 1;
}
diff --git a/oop_example.c b/oop_example.c
index fecac14..e9a4ffe 100644
--- a/oop_example.c
+++ b/oop_example.c
@@ -6,32 +6,34 @@
#include <stdio.h>
#include <stdlib.h>
+const char *point_tname = "nerv.Point";
+
typedef struct {
double x, y;
int arr[100];
} Point;
static int point_get_sinx (lua_State *L) {
- Point *p = luaT_checkudata(L, 1, "nerv.point");
+ Point *p = luaT_checkudata(L, 1, point_tname);
lua_pushnumber(L, sin(p->x));
return 1;
}
static int point_set_x (lua_State *L) {
- Point *p = luaT_checkudata(L, 1, "nerv.point");
+ Point *p = luaT_checkudata(L, 1, point_tname);
p->x = luaL_checknumber(L, 2);
return 0;
}
static int point_get_y(lua_State *L) {
- Point *p = luaT_checkudata(L, 1, "nerv.point");
+ Point *p = luaT_checkudata(L, 1, point_tname);
lua_pushnumber(L, sin(p->x));
return 1;
}
static int point_newindex(lua_State *L) {
- Point *p = luaT_checkudata(L, 1, "nerv.point");
+ Point *p = luaT_checkudata(L, 1, point_tname);
if (lua_isnumber(L, 2))
{
int d = luaL_checkinteger(L, 2);
@@ -49,7 +51,7 @@ static int point_newindex(lua_State *L) {
}
static int point_index(lua_State *L) {
- Point *p = luaT_checkudata(L, 1, "nerv.point");
+ Point *p = luaT_checkudata(L, 1, point_tname);
if (lua_isnumber(L, 2))
{
int d = luaL_checkinteger(L, 2);
@@ -69,7 +71,7 @@ int point_new(lua_State *L) {
Point *self = (Point *)malloc(sizeof(Point));
self->x = 0;
self->y = 0;
- luaT_pushudata(L, self, "nerv.point");
+ luaT_pushudata(L, self, point_tname);
return 1;
}
@@ -83,7 +85,7 @@ static const luaL_Reg point[] = {
};
void nerv_point_init(lua_State *L) {
- luaT_newmetatable(L, "nerv.point", NULL, point_new, NULL, NULL);
+ luaT_newmetatable(L, "nerv.Point", NULL, point_new, NULL, NULL);
luaL_register(L, NULL, point);
lua_pop(L, 1);
}
diff --git a/oop_example.lua b/oop_example.lua
index a4e7009..45da36e 100644
--- a/oop_example.lua
+++ b/oop_example.lua
@@ -1,4 +1,4 @@
-a = nerv.point()
+a = nerv.Point()
print(a:get_sinx())
a:set_x(3.14)
print(a:get_sinx())