Skip to content

Commit 4b26e4a

Browse files
committed
inference: improve associativity of tmerge
and slow down convergence rate of Tuple and UnionAll types
1 parent c57f0fc commit 4b26e4a

File tree

3 files changed

+116
-23
lines changed

3 files changed

+116
-23
lines changed

base/compiler/typelimits.jl

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -327,21 +327,95 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
327327
# XXX: this should never happen
328328
return Any
329329
end
330-
if unionlen(typea) + unionlen(typeb) > MAX_TYPEUNION_LEN
330+
# converge the Tuple part of the Union less quickly
331+
# this is a bit more tricky than other branch, so
332+
# do this first, so that the lattice is over Tuple is associative:
333+
# tmerge(A, tmerge(B, C)) == tmerge(tmerge(A, B), C)
334+
tuplea = typeintersect(typea, Tuple)
335+
tupleb = typeintersect(typeb, Tuple)
336+
if tuplea !== Union{} && tupleb !== Union{}
337+
if isconcretetype(tuplea) && isconcretetype(tupleb)
338+
# if the tuple part is concrete, try just returning the Union
339+
u = Union{typea, typeb}
340+
if unionlen(u) <= MAX_TYPEUNION_LEN
341+
return u
342+
end
343+
end
344+
# otherwise, make a single non-concrete Tuple containing both
345+
tuplejoin = Tuple
346+
if tuplea <: Tuple && tupleb <: Tuple
347+
# converge the Tuple element-wise if they are the same length
348+
# see 4ee2b41552a6bc95465c12ca66146d69b354317b, be59686f7613a2ccfd63491c7b354d0b16a95c05,
349+
if nothing !== tuplelen(tuplea) === tuplelen(tupleb)
350+
tuplejoin = tuplemerge(tuplea, tupleb)
351+
end
352+
# TODO: else, merge them into a single Tuple{Vararg{T}} instead (#22120)?
353+
end
354+
# now rejoin it with the rest of the type union elements
355+
typea = Union{tuplejoin, typea}
356+
end
357+
u = Union{typea, typeb}
358+
if unionlen(u) <= MAX_TYPEUNION_LEN
331359
# don't let type unions get too big
332360
# this sets our convergence rate (e.g. worst-case compiler performance)
333-
namea, nameb = _typename(typea), _typename(typeb)
334-
if namea isa Const && nameb isa Const && namea.val === nameb.val
335-
# If they have the same type name, widen to that instead
336-
# of widening fully (or using a slower convergence like typejoin)
337-
wrapper = (namea.val::Core.TypeName).wrapper
338-
if typea <: wrapper && typeb <: wrapper
339-
# This can happen when a typevar has bounds too wide for its context
340-
return wrapper
361+
return u
362+
end
363+
# see if either of them is a DataType
364+
# which can be used to collapse the Union
365+
# by swapping it for the widest possible version
366+
# of itself (the wrapper)
367+
# this currently violates associativity, which reduces the predictability
368+
# of the result but is not inherently wrong
369+
# TODO: instead, call `uniontypes` on `u`, and swap all of the results for their wrappers
370+
unwrapa = unwrap_unionall(typea)
371+
unwrapb = unwrap_unionall(typeb)
372+
if unwrapa isa DataType
373+
wrapa = unwrapa.name.wrapper
374+
if typea <: wrapa
375+
u = Union{wrapa, typeb}
376+
if unionlen(u) <= MAX_TYPEUNION_LEN
377+
return u
341378
end
342379
end
343-
# TODO: something smarter, like a common supertype?
344-
return Any
380+
elseif unwrapb isa DataType
381+
wrapb = unwrapb.name.wrapper
382+
if typeb <: wrapb
383+
u = Union{typea, wrapb}
384+
if unionlen(u) <= MAX_TYPEUNION_LEN
385+
return u
386+
end
387+
end
388+
end
389+
# finally, just return the widest possible type
390+
return Any
391+
end
392+
393+
# the inverse of switchtupleunion, with limits on max element union size
394+
function tuplemerge(@nospecialize(a), @nospecialize(b))
395+
if isa(a, UnionAll)
396+
return UnionAll(a.var, tuplemerge(a.body, b))
397+
elseif isa(b, UnionAll)
398+
return UnionAll(b.var, tuplemerge(a, b.body))
399+
elseif isa(a, Union)
400+
return tuplemerge(tuplemerge(a.a, a.b), b)
401+
elseif isa(b, Union)
402+
return tuplemerge(a, tuplemerge(b.a, b.b))
403+
end
404+
a = a::DataType
405+
b = b::DataType
406+
ap, bp = a.parameters, b.parameters
407+
lar = length(ap)::Int
408+
lbr = length(bp)::Int
409+
@assert lar === lbr && a.name === b.name === Tuple.name "assertion failure"
410+
p = Vector{Any}(undef, lar)
411+
for i = 1:lar
412+
api = ap[i]
413+
bpi = bp[i]
414+
if unionlen(unwrap_unionall(api)) + unionlen(unwrap_unionall(bpi)) > MAX_TYPEUNION_LEN
415+
p[i] = Any
416+
else
417+
p[i] = Union{api, bpi}
418+
end
345419
end
346-
return Union{typea, typeb}
420+
return Tuple{p...}
347421
end

base/compiler/typeutils.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,23 @@ function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospeci
139139
end
140140
return tunion
141141
end
142+
143+
tuplelen(@nospecialize tpl) = nothing
144+
function tuplelen(tpl::DataType)
145+
l = length(tpl.parameters)::Int
146+
if l > 0
147+
last = unwrap_unionall(tpl.parameters[l])
148+
if isvarargtype(last)
149+
N = last.parameters[2]
150+
N isa Int || return nothing
151+
l += N - 1
152+
end
153+
end
154+
return l
155+
end
156+
tuplelen(tpl::UnionAll) = tuplelen(tpl.body)
157+
function tuplelen(tpl::Union)
158+
la, lb = tuplelen(tpl.a), tuplelen(tpl.b)
159+
la == lb && return la
160+
return nothing
161+
end

base/promotion.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,25 @@ function typejoin(@nospecialize(a), @nospecialize(b))
1818
return b
1919
elseif b <: a
2020
return a
21-
elseif isa(a,UnionAll)
21+
elseif isa(a, UnionAll)
2222
return UnionAll(a.var, typejoin(a.body, b))
23-
elseif isa(b,UnionAll)
23+
elseif isa(b, UnionAll)
2424
return UnionAll(b.var, typejoin(a, b.body))
25-
elseif isa(a,TypeVar)
25+
elseif isa(a, TypeVar)
2626
return typejoin(a.ub, b)
27-
elseif isa(b,TypeVar)
27+
elseif isa(b, TypeVar)
2828
return typejoin(a, b.ub)
29-
elseif isa(a,Union)
30-
a′ = typejoin(a.a, a.b)
31-
return a′ === a ? typejoin(a, b) : typejoin(a′, b)
32-
elseif isa(b,Union)
33-
b′ = typejoin(b.a, b.b)
34-
return b′ === b ? typejoin(a, b) : typejoin(a, b′)
29+
elseif isa(a, Union)
30+
return typejoin(typejoin(a.a, a.b), b)
31+
elseif isa(b, Union)
32+
return typejoin(a, typejoin(b.a, b.b))
3533
elseif a <: Tuple
3634
if !(b <: Tuple)
3735
return Any
3836
end
3937
ap, bp = a.parameters, b.parameters
40-
lar = length(ap)::Int; lbr = length(bp)::Int
38+
lar = length(ap)::Int
39+
lbr = length(bp)::Int
4140
if lar == 0
4241
return Tuple{Vararg{tailjoin(bp, 1)}}
4342
end

0 commit comments

Comments
 (0)