aboutsummaryrefslogtreecommitdiff
path: root/nerv/matrix
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/matrix')
-rw-r--r--nerv/matrix/generic/matrix.c11
-rw-r--r--nerv/matrix/init.lua25
2 files changed, 22 insertions, 14 deletions
diff --git a/nerv/matrix/generic/matrix.c b/nerv/matrix/generic/matrix.c
index ff89e54..8efe608 100644
--- a/nerv/matrix/generic/matrix.c
+++ b/nerv/matrix/generic/matrix.c
@@ -31,7 +31,7 @@ static int nerv_matrix_(lua_newindex)(lua_State *L) {
if (lua_isnumber(L, 2))
{
int idx = luaL_checkinteger(L, 2);
- if (self->nrow == 1)
+ if (self->dim == 1)
{
if (idx < 0 || idx >= self->ncol)
nerv_error(L, "index must be within range [0, %d)", self->ncol);
@@ -57,7 +57,7 @@ static int nerv_matrix_(lua_index)(lua_State *L) {
if (lua_isnumber(L, 2))
{
int idx = luaL_checkinteger(L, 2);
- if (self->nrow == 1)
+ if (self->dim == 1)
{
if (idx < 0 || idx >= self->ncol)
nerv_error(L, "index must be within range [0, %d)", self->ncol);
@@ -86,6 +86,12 @@ static int nerv_matrix_(lua_ncol)(lua_State *L) {
return 1;
}
+static int nerv_matrix_(lua_dim)(lua_State *L) {
+ Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
+ lua_pushinteger(L, self->dim);
+ return 1;
+}
+
static int nerv_matrix_(lua_nrow)(lua_State *L) {
Matrix *self = luaT_checkudata(L, 1, nerv_matrix_(tname));
lua_pushinteger(L, self->nrow);
@@ -103,6 +109,7 @@ static const luaL_Reg nerv_matrix_(methods)[] = {
{"set_elem", nerv_matrix_(lua_set_elem)},
{"ncol", nerv_matrix_(lua_ncol)},
{"nrow", nerv_matrix_(lua_nrow)},
+ {"dim", nerv_matrix_(lua_dim)},
{"get_dataref_value", nerv_matrix_(lua_get_dataref_value)},
{"__index__", nerv_matrix_(lua_index)},
{"__newindex__", nerv_matrix_(lua_newindex)},
diff --git a/nerv/matrix/init.lua b/nerv/matrix/init.lua
index 1a8925f..f230e9f 100644
--- a/nerv/matrix/init.lua
+++ b/nerv/matrix/init.lua
@@ -1,6 +1,7 @@
function nerv.Matrix:__tostring__()
local ncol = self:ncol()
local nrow = self:nrow()
+ local dim = self:dim()
local strt = {}
local fmt
if self.fmt then
@@ -8,12 +9,7 @@ function nerv.Matrix:__tostring__()
else
fmt = "%.8f "
end
- if nrow == 1 then
- for col = 0, ncol - 1 do
- table.insert(strt, string.format(fmt, self[col]))
- end
- table.insert(strt, "\n")
- else
+ if (dim == 2) then
for row = 0, nrow - 1 do
local rp = self[row]
for col = 0, ncol - 1 do
@@ -21,6 +17,11 @@ function nerv.Matrix:__tostring__()
end
table.insert(strt, "\n")
end
+ else
+ for col = 0, ncol - 1 do
+ table.insert(strt, string.format(fmt, self[col]))
+ end
+ table.insert(strt, "\n")
end
table.insert(strt, string.format(
"[%s %d x %d]", self.__typename, nrow, ncol))
@@ -28,19 +29,19 @@ function nerv.Matrix:__tostring__()
end
-- gen: a function takes take indices of the matrix and return the generated
--- all entrys in the matrix will be assigned by calling gen(i, j)
+-- all entrys in the matrix will be assigned by calling gen(i, j), for a vector, gen(j) will be called.
function nerv.Matrix:generate(gen)
- if (self:nrow() == 1) then
- for j = 0, self:ncol() - 1 do
- self[j] = gen(j)
- end
- else
+ if (self:dim() == 2) then
for i = 0, self:nrow() - 1 do
local row = self[i]
for j = 0, self:ncol() - 1 do
row[j] = gen(i, j)
end
end
+ else
+ for j = 0, self:ncol() - 1 do
+ self[j] = gen(j)
+ end
end
end