Skip to content

Commit f47a4f4

Browse files
tkfKristofferC
authored andcommitted
Fix method ambiguities in SparseArrays (#30120)
* Remove unused struct CapturedScalars * Fix method ambiguities in SparseArrays * Fix HigherOrderFns._copy(f) (cherry picked from commit f10530e)
1 parent 7352037 commit f47a4f4

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

stdlib/SparseArrays/src/higherorderfns.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ function _copy(f, args...)
973973
parevalf, passedargstup = capturescalars(f, args)
974974
return _copy(parevalf, passedargstup...)
975975
end
976+
_copy(f) = throw(MethodError(_copy, (f,))) # avoid method ambiguity
976977

977978
function _shapecheckbc(f, args...)
978979
_aresameshape(args...) ? _noshapecheck_map(f, args...) : _diffshape_broadcast(f, args...)
@@ -1006,10 +1007,6 @@ end
10061007
_copyto!(parevalf, dest, passedsrcargstup...)
10071008
end
10081009

1009-
struct CapturedScalars{F, Args, Order}
1010-
args::Args
1011-
end
1012-
10131010
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
10141011
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
10151012
# evaluated f) and a reduced argument tuple (passedargstup) containing only the sparse
@@ -1024,9 +1021,13 @@ end
10241021
# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates
10251022
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} =
10261023
capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
1024+
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Ref{Type{S}}, Vararg{Any}}) where {T, S} =
1025+
# This definition is identical to the one above and necessary only for
1026+
# avoiding method ambiguity.
1027+
capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
10271028
@inline capturescalars(f, mixedargs::Tuple{SparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} =
10281029
capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
1029-
@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
1030+
@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{<:Any,0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
10301031
capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
10311032

10321033
nonscalararg(::SparseVecOrMat) = true
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
using Test, SparseArrays
4+
@test detect_ambiguities(SparseArrays; imported=true, recursive=true) == []

stdlib/SparseArrays/test/higherorderfns.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,4 +632,28 @@ end
632632
@test minimum(sparse([1, 2], [1, 2], ones(Int32, 2)), dims = 1) isa Matrix
633633
end
634634

635+
@testset "Issue #30118" begin
636+
@test ((_, x) -> x).(Int, spzeros(3)) == spzeros(3)
637+
@test ((_, _, x) -> x).(Int, Int, spzeros(3)) == spzeros(3)
638+
@test ((_, _, _, x) -> x).(Int, Int, Int, spzeros(3)) == spzeros(3)
639+
@test_broken ((_, _, _, _, x) -> x).(Int, Int, Int, Int, spzeros(3)) == spzeros(3)
640+
end
641+
642+
using SparseArrays.HigherOrderFns: SparseVecStyle
643+
644+
@testset "Issue #30120: method ambiguity" begin
645+
# HigherOrderFns._copy(f) was ambiguous. It may be impossible to
646+
# invoke this from dot notation and it is an error anyway. But
647+
# when someone invokes it by accident, we want it to produce a
648+
# meaningful error.
649+
err = try
650+
copy(Broadcast.Broadcasted{SparseVecStyle}(rand, ()))
651+
catch err
652+
err
653+
end
654+
@test err isa MethodError
655+
@test !occursin("is ambiguous", sprint(showerror, err))
656+
@test occursin("no method matching _copy(::typeof(rand))", sprint(showerror, err))
657+
end
658+
635659
end # module

stdlib/SparseArrays/test/sparse.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,4 +2341,19 @@ end
23412341
@test m2.module == SparseArrays
23422342
end
23432343

2344+
@testset "sprandn with type $T" for T in (Float64, Float32, Float16, ComplexF64, ComplexF32, ComplexF16)
2345+
@test sprandn(T, 5, 5, 0.5) isa AbstractSparseMatrix{T}
2346+
end
2347+
@testset "sprandn with invalid type $T" for T in (AbstractFloat, BigFloat, Complex)
2348+
@test_throws MethodError sprandn(T, 5, 5, 0.5)
2349+
end
2350+
2351+
@testset "method ambiguity" begin
2352+
# Ambiguity test is run inside a clean process.
2353+
# https://github.com/JuliaLang/julia/issues/28804
2354+
script = joinpath(@__DIR__, "ambiguous_exec.jl")
2355+
cmd = `$(Base.julia_cmd()) --startup-file=no $script`
2356+
@test success(pipeline(cmd; stdout=stdout, stderr=stderr))
2357+
end
2358+
23442359
end # module

0 commit comments

Comments
 (0)