Skip to content

Commit 0fee464

Browse files
committed
fix: de-specialize 3 arg mul!
1 parent 89fe3d4 commit 0fee464

File tree

3 files changed

+52
-18
lines changed

3 files changed

+52
-18
lines changed

src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDic
105105
end
106106

107107
function create_result(
108-
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol},
108+
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char},
109109
path,
110110
result_stores,
111111
)

src/Overlay.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,24 +119,28 @@ end
119119
# LinearAlgebra.jl overloads
120120
## `_mul!` goes through too many layers of abstractions and we aren't able to overload
121121
## without specializing on every possible combination of types
122-
@reactant_overlay @noinline function LinearAlgebra.mul!(
123-
C::AbstractVector, A::AbstractMatrix, B::AbstractVector, α::Number, β::Number
122+
for (cT, aT, bT) in (
123+
(:AbstractVector, :AbstractMatrix, :AbstractVector),
124+
(:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat),
124125
)
125-
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
126-
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
127-
else
128-
LinearAlgebra._mul!(C, A, B, α, β)
129-
end
130-
return C
131-
end
126+
@eval begin
127+
@reactant_overlay @noinline function LinearAlgebra.mul!(
128+
C::$cT, A::$aT, B::$bT, α::Number, β::Number
129+
)
130+
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
131+
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
132+
else
133+
LinearAlgebra._mul!(C, A, B, α, β)
134+
end
135+
return C
136+
end
132137

133-
@reactant_overlay @noinline function LinearAlgebra.mul!(
134-
C::AbstractMatrix, A::AbstractMatrix, B::AbstractVecOrMat, α::Number, β::Number
135-
)
136-
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
137-
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
138-
else
139-
LinearAlgebra._mul!(C, A, B, α, β)
138+
# Needed mostly for 1.10 where 3-arg mul is often specialized
139+
@reactant_overlay @noinline function LinearAlgebra.mul!(
140+
C::$cT, A::$aT, B::$bT
141+
)
142+
call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false)
143+
return C
144+
end
140145
end
141-
return C
142146
end

test/wrapped_arrays.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,33 @@ end
172172
@test all(iszero, y_res)
173173
end
174174
end
175+
176+
function lower_triangular_write(x)
177+
y = LowerTriangular(copy(x))
178+
@. y *= 2
179+
return y
180+
end
181+
182+
function upper_triangular_write(x)
183+
y = UpperTriangular(copy(x))
184+
@. y *= 2
185+
return y
186+
end
187+
188+
function tridiagonal_write(x)
189+
y = Tridiagonal(copy(x))
190+
@. y *= 2
191+
return y
192+
end
193+
194+
@testset "Broadcasted Multiply and Alloate" begin
195+
@testset "$(aType)" for (aType, fn) in [
196+
("LowerTriangular", lower_triangular_write),
197+
("UpperTriangular", upper_triangular_write),
198+
("Tridiagonal", tridiagonal_write),
199+
]
200+
x = rand(4, 4)
201+
x_ra = Reactant.to_rarray(x)
202+
@test @jit(fn(x_ra)) fn(x)
203+
end
204+
end

0 commit comments

Comments
 (0)