Skip to content

Commit 7890971

Browse files
committed
Fix length calculation of broadcast over tuples
1 parent 674e64b commit 7890971

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

base/broadcast.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ end
127127
end
128128

129129
Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
130+
Base.@propagate_inbounds _broadcast_getindex(::Type{Tuple}, A::Tuple{Any}, I) = A[1]
130131
Base.@propagate_inbounds _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
131132
Base.@propagate_inbounds _broadcast_getindex(::ScalarType, A, I) = A
132133
Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
@@ -333,13 +334,25 @@ end
333334
end
334335
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
335336
@inline broadcast_c(f, ::Type{Tuple}, A, Bs...) =
336-
tuplebroadcast(f, first_tuple(A, Bs...), A, Bs...)
337+
tuplebroadcast(f, tuplebroadcast_length(A, Bs...), A, Bs...)
337338
@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} =
338339
ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val(N))
339340
@inline tuplebroadcast(f, ::NTuple{N,Any}, ::Type{T}, As...) where {N,T} =
340341
ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val(N))
341-
first_tuple(A::Tuple, Bs...) = A
342-
@inline first_tuple(A, Bs...) = first_tuple(Bs...)
342+
tuplebroadcast_length(A, B) = (nothing,)
343+
tuplebroadcast_length(A, ::Tuple{}) = ()
344+
tuplebroadcast_length(::Tuple{}, A) = ()
345+
tuplebroadcast_length(::Tuple, ::Tuple{}) = ()
346+
tuplebroadcast_length(::Tuple{}, ::Tuple) = ()
347+
tuplebroadcast_length(::Tuple{Any}, ::Tuple{}) = ()
348+
tuplebroadcast_length(::Tuple{}, ::Tuple{Any}) = ()
349+
tuplebroadcast_length(A::Tuple, ::Tuple{Any}) = A
350+
tuplebroadcast_length(::Tuple{Any}, A::Tuple) = A
351+
tuplebroadcast_length(A::NTuple{N,Any}, ::NTuple{N,Any}...) where {N} = A
352+
tuplebroadcast_length(::Tuple, ::Tuple) =
353+
throw(DimensionMismatch("tuples could not be broadcast to a common size"))
354+
@inline tuplebroadcast_length(A, Bs...) =
355+
tuplebroadcast_length(A, tuplebroadcast_length(Bs...))
343356
tuplebroadcast_getargs(::Tuple{}, k) = ()
344357
@inline tuplebroadcast_getargs(As, k) =
345358
(_broadcast_getindex(first(As), k), tuplebroadcast_getargs(tail(As), k)...)

test/broadcast.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,16 @@ end
519519
Nullable("hello"))
520520
end
521521

522-
# Issue #21291
523-
let t = (0, 1, 2)
524-
o = 1
525-
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
522+
@testset "broadcast resulting in tuples" begin
523+
# Issue #21291
524+
let t = (0, 1, 2)
525+
o = 1
526+
@test @inferred(broadcast(+, t, o)) == (1, 2, 3)
527+
end
528+
529+
# Issue #23647
530+
@test (1, 2, 3) .+ (1,) == (1,) .+ (1, 2, 3) == (2, 3, 4)
531+
@test (1,) .+ () == () .+ (1,) == () .+ () == ()
532+
@test (1, 2) .+ (1, 2) == (2, 4)
533+
@test_throws DimensionMismatch (1, 2) .+ (1, 2, 3)
526534
end

0 commit comments

Comments
 (0)