Skip to content

Commit 58d8678

Browse files
authored
Merge pull request #362 from bjodah/matmul-operator
Add support for Python's matrix multiplication operator
2 parents ec4127b + 35027d7 commit 58d8678

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3230,6 +3230,8 @@ cdef class DenseMatrixBase(MatrixBase):
32303230
raise ValueError("sizes don't match.")
32313231
else:
32323232
self.thisptr = new symengine.DenseMatrix(0, 0, v_)
3233+
elif col is not None and (row*col != v_.size()):
3234+
raise ValueError("Number of elements should equal rows*columns.")
32333235
else:
32343236
self.thisptr = new symengine.DenseMatrix(row, v_.size() / row, v_)
32353237

@@ -3271,6 +3273,13 @@ cdef class DenseMatrixBase(MatrixBase):
32713273
else:
32723274
return NotImplemented
32733275

3276+
def __matmul__(a, b):
3277+
a = _sympify(a, False)
3278+
b = _sympify(b, False)
3279+
if (a.ncols() != b.nrows()):
3280+
raise ShapeError("Invalid shapes for matrix multiplication. Got %s %s" % (a.shape, b.shape))
3281+
return a.mul_matrix(b)
3282+
32743283
def __truediv__(a, b):
32753284
return div_matrices(a, b)
32763285

@@ -4889,7 +4898,7 @@ cdef class LambdaDouble(_Lambdify):
48894898
c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype)
48904899
c_out = out
48914900
for idx in range(nbroadcast):
4892-
self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size])
4901+
self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size])
48934902

48944903
cpdef as_scipy_low_level_callable(self):
48954904
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
@@ -5135,7 +5144,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C',
51355144
raise ValueError("Long double not supported on this platform")
51365145
else:
51375146
raise ValueError("Unknown numpy dtype.")
5138-
5147+
51395148
if as_scipy:
51405149
return ret.as_scipy_low_level_callable()
51415150
return ret

symengine/tests/test_matrices.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
HAVE_NUMPY = False
1313

1414

15+
def test_init():
16+
raises(ValueError, lambda: DenseMatrix(2, 1, [0]*4))
17+
18+
1519
def test_get():
1620
A = DenseMatrix([[1, 2], [3, 4]])
1721

@@ -229,6 +233,8 @@ def test_mul_matrix():
229233

230234
assert A.mul_matrix(B) == DenseMatrix(2, 2, [a + b, 0, c + d, 0])
231235
assert A * B == DenseMatrix(2, 2, [a + b, 0, c + d, 0])
236+
assert A @ B == DenseMatrix(2, 2, [a + b, 0, c + d, 0])
237+
assert (A @ DenseMatrix(2, 1, [0]*2)).shape == (2, 1)
232238

233239
C = DenseMatrix(2, 3, [1, 2, 3, 2, 3, 4])
234240
D = DenseMatrix(3, 2, [3, 4, 4, 5, 5, 6])

0 commit comments

Comments
 (0)