Skip to content

Commit 71ab90c

Browse files
committed
raise an exception if conflicting shape is given to DenseMatrix at construction
1 parent cc478b2 commit 71ab90c

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3230,6 +3230,8 @@ cdef class DenseMatrixBase(MatrixBase):
32303230
raise ValueError("sizes don't match.")
32313231
else:
32323232
self.thisptr = new symengine.DenseMatrix(0, 0, v_)
3233+
elif col is not None and (row*col != v_.size()):
3234+
raise ValueError("Number of elements should equal rows*columns.")
32333235
else:
32343236
self.thisptr = new symengine.DenseMatrix(row, v_.size() / row, v_)
32353237

symengine/tests/test_matrices.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
HAVE_NUMPY = False
1313

1414

15+
def test_init():
16+
raises(ValueError, DenseMatrix(2, 1, [0]*4))
17+
18+
1519
def test_get():
1620
A = DenseMatrix([[1, 2], [3, 4]])
1721

@@ -230,7 +234,7 @@ def test_mul_matrix():
230234
assert A.mul_matrix(B) == DenseMatrix(2, 2, [a + b, 0, c + d, 0])
231235
assert A * B == DenseMatrix(2, 2, [a + b, 0, c + d, 0])
232236
assert A @ B == DenseMatrix(2, 2, [a + b, 0, c + d, 0])
233-
assert (A @ DenseMatrix(2, 1, [0]*4)).shape == (2, 1)
237+
assert (A @ DenseMatrix(2, 1, [0]*2)).shape == (2, 1)
234238

235239
C = DenseMatrix(2, 3, [1, 2, 3, 2, 3, 4])
236240
D = DenseMatrix(3, 2, [3, 4, 4, 5, 5, 6])

0 commit comments

Comments
 (0)