aboutsummaryrefslogtreecommitdiff
path: root/nerv/lib
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/lib')
-rw-r--r--nerv/lib/matrix/generic/matrix.c2
-rw-r--r--nerv/lib/matrix/init.lua7
-rw-r--r--nerv/lib/matrix/matrix.h1
3 files changed, 7 insertions, 3 deletions
diff --git a/nerv/lib/matrix/generic/matrix.c b/nerv/lib/matrix/generic/matrix.c
index a64759e..5dbad48 100644
--- a/nerv/lib/matrix/generic/matrix.c
+++ b/nerv/lib/matrix/generic/matrix.c
@@ -24,6 +24,7 @@ Matrix *nerv_matrix_(create)(long nrow, long ncol, Status *status) {
self->nrow = nrow;
self->ncol = ncol;
self->nmax = self->nrow * self->ncol;
+ self->dim = 2;
MATRIX_DATA_ALLOC(&MATRIX_ELEM_PTR(self), &self->stride,
sizeof(MATRIX_ELEM) * self->ncol, self->nrow,
status);
@@ -47,6 +48,7 @@ Matrix *nerv_matrix_(getrow)(Matrix *self, int row) {
Matrix *prow = (Matrix *)malloc(sizeof(Matrix));
prow->ncol = self->ncol;
prow->nrow = 1;
+ prow->dim = 1;
prow->stride = self->stride;
prow->nmax = prow->ncol;
MATRIX_ELEM_PTR(prow) = MATRIX_ROW_PTR(self, row);
diff --git a/nerv/lib/matrix/init.lua b/nerv/lib/matrix/init.lua
index 1a8925f..89f89d6 100644
--- a/nerv/lib/matrix/init.lua
+++ b/nerv/lib/matrix/init.lua
@@ -1,6 +1,7 @@
function nerv.Matrix:__tostring__()
local ncol = self:ncol()
local nrow = self:nrow()
+ local dim = self:dim()
local strt = {}
local fmt
if self.fmt then
@@ -8,7 +9,7 @@ function nerv.Matrix:__tostring__()
else
fmt = "%.8f "
end
- if nrow == 1 then
+ if dim == 1 then
for col = 0, ncol - 1 do
table.insert(strt, string.format(fmt, self[col]))
end
@@ -28,9 +29,9 @@ function nerv.Matrix:__tostring__()
end
-- gen: a function takes take indices of the matrix and return the generated
--- all entrys in the matrix will be assigned by calling gen(i, j)
+-- all entrys in the matrix will be assigned by calling gen(i, j), if self is a row vector, gen(j) will be called
function nerv.Matrix:generate(gen)
- if (self:nrow() == 1) then
+ if (self:dim() == 1) then
for j = 0, self:ncol() - 1 do
self[j] = gen(j)
end
diff --git a/nerv/lib/matrix/matrix.h b/nerv/lib/matrix/matrix.h
index cbf32c2..67a6e30 100644
--- a/nerv/lib/matrix/matrix.h
+++ b/nerv/lib/matrix/matrix.h
@@ -6,6 +6,7 @@
typedef struct Matrix {
size_t stride; /* size of a row */
long ncol, nrow, nmax; /* dimension of the matrix */
+ int dim; /* dim == 2 for a matrix, dim == 1 for row vector */
union {
float *f;
double *d;