aboutsummaryrefslogtreecommitdiff
path: root/nerv/lib/matrix/generic/matrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/lib/matrix/generic/matrix.c')
-rw-r--r--nerv/lib/matrix/generic/matrix.c8
1 files changed, 8 insertions, 0 deletions
diff --git a/nerv/lib/matrix/generic/matrix.c b/nerv/lib/matrix/generic/matrix.c
index 4246751..fd5d28f 100644
--- a/nerv/lib/matrix/generic/matrix.c
+++ b/nerv/lib/matrix/generic/matrix.c
@@ -10,6 +10,8 @@ void nerv_matrix_(data_free)(Matrix *self, Status *status) {
{
/* free matrix data */
MATRIX_DATA_FREE(MATRIX_ELEM_PTR(self), status);
+ curandDestroyGenerator(*(self->curand_gen));
+ free(self->curand_gen);
free(self->data_ref);
free(self);
}
@@ -39,6 +41,11 @@ Matrix *nerv_matrix_(create)(long nrow, long ncol, Status *status) {
}
self->data_ref = (long *)malloc(sizeof(long));
*self->data_ref = 0;
+
+ self->curand_gen = (curandGenerator_t*)malloc(sizeof(curandGenerator_t));
+ curandCreateGenerator(self->curand_gen, CURAND_RNG_PSEUDO_DEFAULT);
+ curandSetPseudoRandomGeneratorSeed(*(self->curand_gen), time(NULL));
+
nerv_matrix_(data_retain)(self);
NERV_SET_STATUS(status, NERV_NORMAL, 0);
return self;
@@ -57,6 +64,7 @@ Matrix *nerv_matrix_(getrow)(Matrix *self, int row) {
prow->nmax = prow->ncol;
MATRIX_ELEM_PTR(prow) = MATRIX_ROW_PTR(self, row);
prow->data_ref = self->data_ref;
+ prow->curand_gen = self->curand_gen;
nerv_matrix_(data_retain)(prow);
return prow;
}