Skip to content

Commit f9a6bbd

Browse files
ararslanandreasnoack
authored andcommitted
Check output dimensions for 2x2 and 3x3 in-place matmuls (#19227)
1 parent 4e2e931 commit f9a6bbd

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

base/linalg/matmul.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,9 @@ function matmul2x2{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
656656
end
657657

658658
function matmul2x2!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
659+
if !(size(A) == size(B) == size(C) == (2,2))
660+
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
661+
end
659662
@inbounds begin
660663
if tA == 'T'
661664
A11 = transpose(A[1,1]); A12 = transpose(A[2,1]); A21 = transpose(A[1,2]); A22 = transpose(A[2,2])
@@ -685,6 +688,9 @@ function matmul3x3{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
685688
end
686689

687690
function matmul3x3!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
691+
if !(size(A) == size(B) == size(C) == (3,3))
692+
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
693+
end
688694
@inbounds begin
689695
if tA == 'T'
690696
A11 = transpose(A[1,1]); A12 = transpose(A[2,1]); A13 = transpose(A[3,1])

test/linalg/matmul.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ let
3939
@test Ac_mul_Bc(Ai, Bi) == [-28.25-66im 9.75-58im; -26-89im 21-73im]
4040
@test_throws DimensionMismatch [1 2; 0 0; 0 0] * [1 2]
4141
end
42+
CC = ones(3, 3)
43+
@test_throws DimensionMismatch A_mul_B!(CC, AA, BB)
4244
end
4345
# 3x3
4446
let
@@ -62,6 +64,8 @@ let
6264
@test Ac_mul_Bc(Ai, Bi) == [1+2im 20.75+9im -44.75+42im; 19.5+17.5im -54-36.5im 51-14.5im; 13+7.5im 11.25+31.5im -43.25-14.5im]
6365
@test_throws DimensionMismatch [1 2 3; 0 0 0; 0 0 0] * [1 2 3]
6466
end
67+
CC = ones(4, 4)
68+
@test_throws DimensionMismatch A_mul_B!(CC, AA, BB)
6569
end
6670
# Generic integer matrix multiplication
6771
# Generic AbstractArrays

0 commit comments

Comments
 (0)