aboutsummaryrefslogtreecommitdiff
path: root/nerv/lib/matrix/generic/cumatrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'nerv/lib/matrix/generic/cumatrix.c')
-rw-r--r--nerv/lib/matrix/generic/cumatrix.c2
1 files changed, 1 insertions, 1 deletions
diff --git a/nerv/lib/matrix/generic/cumatrix.c b/nerv/lib/matrix/generic/cumatrix.c
index 7582725..bf93b77 100644
--- a/nerv/lib/matrix/generic/cumatrix.c
+++ b/nerv/lib/matrix/generic/cumatrix.c
@@ -41,7 +41,7 @@ void nerv_matrix_(mul)(Matrix *c, const Matrix *a, const Matrix *b,
int bm = b->nrow, bn = b->ncol;
if (ta == CUBLAS_OP_T) SWAP(am, an);
if (tb == CUBLAS_OP_T) SWAP(bm, bn);
- if (an != bm)
+ if (an != bm || (am != c->nrow && bn != c->ncol))
NERV_EXIT_STATUS(status, MAT_WRONG_MULT_DIM, 0);
/* Because matrix in Nerv is row-major, here b comes first */
PROFILE_START