Skip to content

Commit ad5c866

Browse files
committed
Fix type stability of Broadcast.flatten
Fixes #27988.
1 parent 98061ab commit ad5c866

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

base/broadcast.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ some cases.
295295
function flatten(bc::Broadcasted{Style}) where {Style}
296296
isflat(bc) && return bc
297297
# concatenate the nested arguments into {a, b, c, d}
298-
args = cat_nested(x->x.args, bc)
298+
args = cat_nested(bc)
299299
# build a function `makeargs` that takes a "flat" argument list and
300300
# and creates the appropriate input arguments for `f`, e.g.,
301301
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
@@ -318,14 +318,9 @@ _isflat(args::NestedTuple) = false
318318
_isflat(args::Tuple) = _isflat(tail(args))
319319
_isflat(args::Tuple{}) = true
320320

321-
cat_nested(fieldextractor, bc::Broadcasted) = cat_nested(fieldextractor, fieldextractor(bc), ())
322-
323-
cat_nested(fieldextractor, t::Tuple, rest) =
324-
(t[1], cat_nested(fieldextractor, tail(t), rest)...)
325-
cat_nested(fieldextractor, t::Tuple{<:Broadcasted,Vararg{Any}}, rest) =
326-
cat_nested(fieldextractor, cat_nested(fieldextractor, fieldextractor(t[1]), tail(t)), rest)
327-
cat_nested(fieldextractor, t::Tuple{}, tail) = cat_nested(fieldextractor, tail, ())
328-
cat_nested(fieldextractor, t::Tuple{}, tail::Tuple{}) = ()
321+
cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
322+
cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
323+
cat_nested() = ()
329324

330325
make_makeargs(bc::Broadcasted) = make_makeargs(()->(), bc.args)
331326
@inline function make_makeargs(makeargs, t::Tuple)

test/broadcast.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,17 @@ let X = zeros(2, 3)
722722
@test X == [1 1 1; 2 2 2]
723723
end
724724

725+
# issue #27988: inference of Broadcast.flatten
726+
using .Broadcast: Broadcasted
727+
let
728+
bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
729+
@test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5)
730+
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62
731+
bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
732+
@test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5)
733+
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8
734+
end
735+
725736
# Issue #26127: multiple splats in a fused dot-expression
726737
let f(args...) = *(args...)
727738
x, y, z = (1,2), 3, (4, 5)

0 commit comments

Comments
 (0)