Skip to content

Commit cc478b2

Browse files
committed
Add support for Python's matrix multiplication operator
1 parent d025f57 commit cc478b2

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3271,6 +3271,13 @@ cdef class DenseMatrixBase(MatrixBase):
32713271
else:
32723272
return NotImplemented
32733273

3274+
def __matmul__(a, b):
3275+
a = _sympify(a, False)
3276+
b = _sympify(b, False)
3277+
if (a.ncols() != b.nrows()):
3278+
raise ShapeError("Invalid shapes for matrix multiplication. Got %s %s" % (a.shape, b.shape))
3279+
return a.mul_matrix(b)
3280+
32743281
def __truediv__(a, b):
32753282
return div_matrices(a, b)
32763283

@@ -4889,7 +4896,7 @@ cdef class LambdaDouble(_Lambdify):
48894896
c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype)
48904897
c_out = out
48914898
for idx in range(nbroadcast):
4892-
self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size])
4899+
self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size])
48934900

48944901
cpdef as_scipy_low_level_callable(self):
48954902
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
@@ -5135,7 +5142,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C',
51355142
raise ValueError("Long double not supported on this platform")
51365143
else:
51375144
raise ValueError("Unknown numpy dtype.")
5138-
5145+
51395146
if as_scipy:
51405147
return ret.as_scipy_low_level_callable()
51415148
return ret

symengine/tests/test_matrices.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def test_mul_matrix():
229229

230230
assert A.mul_matrix(B) == DenseMatrix(2, 2, [a + b, 0, c + d, 0])
231231
assert A * B == DenseMatrix(2, 2, [a + b, 0, c + d, 0])
232+
assert A @ B == DenseMatrix(2, 2, [a + b, 0, c + d, 0])
233+
assert (A @ DenseMatrix(2, 1, [0]*4)).shape == (2, 1)
232234

233235
C = DenseMatrix(2, 3, [1, 2, 3, 2, 3, 4])
234236
D = DenseMatrix(3, 2, [3, 4, 4, 5, 5, 6])

0 commit comments

Comments
 (0)