summaryrefslogtreecommitdiff
path: root/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'matrix')
-rw-r--r--matrix/cumatrix.c15
-rw-r--r--matrix/generic/matrix.c18
-rw-r--r--matrix/init.c3
-rw-r--r--matrix/init.lua38
-rw-r--r--matrix/matrix.c5
-rw-r--r--matrix/matrix.lua17
6 files changed, 68 insertions, 28 deletions
diff --git a/matrix/cumatrix.c b/matrix/cumatrix.c
new file mode 100644
index 0000000..87ad57b
--- /dev/null
+++ b/matrix/cumatrix.c
@@ -0,0 +1,15 @@
+#define MATRIX_DATA_FREE(ptr) free(ptr)
+#define MATRIX_DATA_ALLOC(size) malloc(size)
+#define MATRIX_DATA_STRIDE(ncol) (sizeof(float) * (ncol))
+#define MATRIX_GENERIC
+#define nerv_float_matrix_(NAME) nerv_float_matrix_cuda_ ## NAME
+#include "generic/matrix.c"
+
+const char *nerv_float_matrix_(tname) = "nerv.FloatCuMatrix";
+int nerv_float_matrix_(get_elem)(lua_State *L) {
+ return nerv_error_method_not_implemented(L);
+}
+
+int nerv_float_matrix_(set_elem)(lua_State *L) {
+ return nerv_error_method_not_implemented(L);
+}
diff --git a/matrix/generic/matrix.c b/matrix/generic/matrix.c
index 3bb12ef..29919d8 100644
--- a/matrix/generic/matrix.c
+++ b/matrix/generic/matrix.c
@@ -3,7 +3,7 @@
#include "matrix.h"
extern const char *nerv_matrix_tname;
-const char *nerv_float_matrix_tname = "nerv.FloatMatrix";
+extern const char *nerv_float_matrix_(tname);
void nerv_float_matrix_(data_free)(Matrix *self) {
if (--(*self->data_ref) == 0)
@@ -24,12 +24,12 @@ int nerv_float_matrix_(new)(lua_State *L) {
self->data_ref = (long *)malloc(sizeof(long));
*self->data_ref = 0;
nerv_float_matrix_(data_retain)(self);
- luaT_pushudata(L, self, nerv_float_matrix_tname);
+ luaT_pushudata(L, self, nerv_float_matrix_(tname));
return 1;
}
int nerv_float_matrix_(destroy)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname);
+ Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
nerv_float_matrix_(data_free)(self);
return 0;
}
@@ -50,7 +50,7 @@ static Matrix *nerv_float_matrix_(getrow)(Matrix *self, int row) {
}
static int nerv_float_matrix_(newindex)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname);
+ Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
if (lua_isnumber(L, 2))
{
int idx = luaL_checkinteger(L, 2);
@@ -74,7 +74,7 @@ static int nerv_float_matrix_(newindex)(lua_State *L) {
static int nerv_float_matrix_(index)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname);
+ Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
if (lua_isnumber(L, 2))
{
int idx = luaL_checkinteger(L, 2);
@@ -88,7 +88,7 @@ static int nerv_float_matrix_(index)(lua_State *L) {
{
if (idx < 0 || idx >= self->nrow)
nerv_error(L, "index must be within range [0, %d)", self->nrow);
- luaT_pushudata(L, nerv_float_matrix_(getrow)(self, idx), nerv_float_matrix_tname);
+ luaT_pushudata(L, nerv_float_matrix_(getrow)(self, idx), nerv_float_matrix_(tname));
}
lua_pushboolean(L, 1);
return 2;
@@ -101,13 +101,13 @@ static int nerv_float_matrix_(index)(lua_State *L) {
}
static int nerv_float_matrix_(ncol)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname);
+ Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
lua_pushinteger(L, self->ncol);
return 1;
}
static int nerv_float_matrix_(nrow)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname);
+ Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
lua_pushinteger(L, self->nrow);
return 1;
}
@@ -124,7 +124,7 @@ static const luaL_Reg nerv_float_matrix_(methods)[] = {
};
void nerv_float_matrix_(init)(lua_State *L) {
- luaT_newmetatable(L, nerv_float_matrix_tname, nerv_matrix_tname,
+ luaT_newmetatable(L, nerv_float_matrix_(tname), nerv_matrix_tname,
nerv_float_matrix_(new), nerv_float_matrix_(destroy), NULL);
luaL_register(L, NULL, nerv_float_matrix_(methods));
lua_pop(L, 1);
diff --git a/matrix/init.c b/matrix/init.c
index e251628..e723f55 100644
--- a/matrix/init.c
+++ b/matrix/init.c
@@ -2,6 +2,8 @@
#include "generic/matrix.h"
const char *nerv_matrix_tname = "nerv.Matrix";
+void nerv_float_matrix_host_init(lua_State *L);
+void nerv_float_matrix_cuda_init(lua_State *L);
static const luaL_Reg matrix_methods[] = {
{"__tostring__", nerv_error_method_not_implemented },
{"__add__", nerv_error_method_not_implemented },
@@ -16,4 +18,5 @@ void nerv_matrix_init(lua_State *L) {
luaL_register(L, NULL, matrix_methods);
lua_pop(L, 1);
nerv_float_matrix_host_init(L);
+ nerv_float_matrix_cuda_init(L);
}
diff --git a/matrix/init.lua b/matrix/init.lua
new file mode 100644
index 0000000..59b8384
--- /dev/null
+++ b/matrix/init.lua
@@ -0,0 +1,38 @@
+function nerv.FloatCuMatrix:__tostring__()
+ local ncol = self:ncol()
+ local nrow = self:nrow()
+ local strt = {}
+
+ if nrow == 1 then
+ for col = 0, ncol - 1 do
+ table.insert(strt, string.format("%f ", self[col]))
+ end
+ table.insert(strt, "\n")
+ else
+ for row = 0, nrow - 1 do
+ local rp = self[row]
+ for col = 0, ncol - 1 do
+ table.insert(strt, string.format("%f ", rp[col]))
+ end
+ table.insert(strt, "\n")
+ end
+ end
+ table.insert(strt, string.format("[Float Matrix %d x %d]", nrow, ncol))
+ return table.concat(strt)
+end
+
+function nerv.FloatMatrix:__tostring__()
+ local ncol = self:ncol()
+ local nrow = self:nrow()
+ local i = 0
+ local strt = {}
+ for row = 0, nrow - 1 do
+ for col = 0, ncol - 1 do
+ table.insert(strt, string.format("%f ", self:get_elem(i)))
+ i = i + 1
+ end
+ table.insert(strt, "\n")
+ end
+ table.insert(strt, string.format("[Float Matrix %d x %d]", nrow, ncol))
+ return table.concat(strt)
+end
diff --git a/matrix/matrix.c b/matrix/matrix.c
index 3a593e5..0e5f75f 100644
--- a/matrix/matrix.c
+++ b/matrix/matrix.c
@@ -5,8 +5,9 @@
#define nerv_float_matrix_(NAME) nerv_float_matrix_host_ ## NAME
#include "generic/matrix.c"
+const char *nerv_float_matrix_(tname) = "nerv.FloatMatrix";
int nerv_float_matrix_(get_elem)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname);
+ Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
int idx = luaL_checkinteger(L, 2);
if (idx < 0 || idx >= self->nmax)
nerv_error(L, "index must be within range [0, %d)", self->nmax);
@@ -15,7 +16,7 @@ int nerv_float_matrix_(get_elem)(lua_State *L) {
}
int nerv_float_matrix_(set_elem)(lua_State *L) {
- Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_tname);
+ Matrix *self = luaT_checkudata(L, 1, nerv_float_matrix_(tname));
int idx = luaL_checkinteger(L, 2);
float v = luaL_checknumber(L, 3);
long upper = self->nrow * self->ncol;
diff --git a/matrix/matrix.lua b/matrix/matrix.lua
deleted file mode 100644
index b9e4876..0000000
--- a/matrix/matrix.lua
+++ /dev/null
@@ -1,17 +0,0 @@
-function nerv.FloatMatrix:__tostring__()
- local ncol = self:ncol()
- local nrow = self:nrow()
- local i = 0
- local strt = {}
- for row = 0, nrow - 1 do
--- local rp = self[row]
- for col = 0, ncol - 1 do
--- table.insert(strt, string.format("%f ", rp[col]))
- table.insert(strt, string.format("%f ", self:get_elem(i)))
- i = i + 1
- end
- table.insert(strt, "\n")
- end
- table.insert(strt, string.format("[Float Matrix %d x %d]", nrow, ncol))
- return table.concat(strt)
-end