Skip to content

Commit c8e4c1c

Browse files
committed
fix: matrix multiplication of wrapper types
1 parent b60ca6c commit c8e4c1c

File tree

4 files changed

+57
-12
lines changed

4 files changed

+57
-12
lines changed

src/Overlay.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,28 @@ for randfun in (:rand, :randn, :randexp)
115115
# end
116116
end
117117
end
118+
119+
# LinearAlgebra.jl overloads
120+
## `_mul!` goes through too many layers of abstractions and we aren't able to overload
121+
## without specializing on every possible combination of types
122+
@reactant_overlay @noinline function LinearAlgebra.mul!(
123+
C::AbstractVector, A::AbstractMatrix, B::AbstractVector, α::Number, β::Number
124+
)
125+
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
126+
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
127+
else
128+
LinearAlgebra._mul!(C, A, B, α, β)
129+
end
130+
return C
131+
end
132+
133+
@reactant_overlay @noinline function LinearAlgebra.mul!(
134+
C::AbstractMatrix, A::AbstractMatrix, B::AbstractVecOrMat, α::Number, β::Number
135+
)
136+
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
137+
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
138+
else
139+
LinearAlgebra._mul!(C, A, B, α, β)
140+
end
141+
return C
142+
end

src/Reactant.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ const WrappedTracedRArray{T,N} = WrappedArray{
119119
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
120120
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
121121
const AnyTracedRMatrix{T} = Union{
122-
AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}}
122+
AnyTracedRArray{T,2},
123+
LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}},
124+
LinearAlgebra.Tridiagonal{TracedRNumber{T},TracedRArray{T,1}},
123125
}
124126
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
125127

src/stdlibs/LinearAlgebra.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ function TracedUtils.materialize_traced_array(
3434
return diagm(parent(x))
3535
end
3636

37-
function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T}
37+
function TracedUtils.materialize_traced_array(
38+
x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}
39+
) where {T}
3840
return diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
3941
end
4042

@@ -162,7 +164,7 @@ function TracedUtils.set_mlir_data!(
162164
end
163165

164166
# Core functions
165-
function LinearAlgebra.mul!(
167+
function overloaded_mul!(
166168
@nospecialize(C::TracedRArray{T,1}),
167169
@nospecialize(A::AnyTracedRMatrix),
168170
@nospecialize(B::AnyTracedRVector),
@@ -176,7 +178,7 @@ function LinearAlgebra.mul!(
176178
return C
177179
end
178180

179-
function LinearAlgebra.mul!(
181+
function overloaded_mul!(
180182
@nospecialize(C::TracedRArray{T,2}),
181183
@nospecialize(A::AnyTracedRMatrix),
182184
@nospecialize(B::AnyTracedRVector),
@@ -187,7 +189,7 @@ function LinearAlgebra.mul!(
187189
return C
188190
end
189191

190-
function LinearAlgebra.mul!(
192+
function overloaded_mul!(
191193
@nospecialize(C::TracedRArray{T,2}),
192194
@nospecialize(A::AnyTracedRMatrix),
193195
@nospecialize(B::AnyTracedRMatrix),

test/integration/linear_algebra.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearAlgebra, Reactant
1+
using LinearAlgebra, Reactant, Test
22

33
function muladd2(A, x, b)
44
C = similar(A, promote_type(eltype(A), eltype(b)), size(A, 1), size(x, 2))
@@ -143,13 +143,29 @@ end
143143
@test @jit(diagm(1 => x_ra1, 1 => x_ra2)) diagm(1 => x1, 1 => x2)
144144
end
145145

146-
# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be
147-
# optimized
146+
# TODO: Currently <Wrapper Type>(x) * x goes down the generic matmul path but it should
147+
# clearly be optimized
148148
mul_diagonal(x) = Diagonal(x) * x
149-
150-
@testset "mul_diagonal" begin
151-
x = rand(4)
149+
mul_tridiagonal(x) = Tridiagonal(x) * x
150+
mul_unit_lower_triangular(x) = UnitLowerTriangular(x) * x
151+
mul_unit_upper_triangular(x) = UnitUpperTriangular(x) * x
152+
mul_lower_triangular(x) = LowerTriangular(x) * x
153+
mul_upper_triangular(x) = UpperTriangular(x) * x
154+
mul_symmetric(x) = Symmetric(x) * x
155+
156+
@testset "Wrapper Types Matrix Multiplication" begin
157+
x = rand(4, 4)
152158
x_ra = Reactant.to_rarray(x)
153159

154-
@test @jit(mul_diagonal(x_ra)) mul_diagonal(x)
160+
@testset "$(wrapper_type)" for (wrapper_type, fn) in [
161+
(Diagonal, mul_diagonal),
162+
(Tridiagonal, mul_tridiagonal),
163+
(UnitLowerTriangular, mul_unit_lower_triangular),
164+
(UnitUpperTriangular, mul_unit_upper_triangular),
165+
(LowerTriangular, mul_lower_triangular),
166+
(UpperTriangular, mul_upper_triangular),
167+
(Symmetric, mul_symmetric),
168+
]
169+
@test @jit(fn(x_ra)) fn(x)
170+
end
155171
end

0 commit comments

Comments
 (0)