Skip to content

Commit 8aaf4db

Browse files
committed
Improve inference for common reflection operations
Make inference be able to infer the type constructor. This is a little tricky, since we don't really have a good way to represent this. It's not quite `Const(TypeVar(:T,lb,ub))`, because a) the bounds may only be accurate up to type equality and b) TypeVar is not `isbits`, so it's not actually egal to the value we'll have at runtime. Additionally Type{T} already has meaning as a partially constructed type (e.g. an unwrapped UnionAll), so using T::Type{T} runs the risk of confusing this with an unwrapped type. Instead, introduce a new Const-like type, `PartialTypeVar`, which carries the type var and also keeps track of whether the bounds were egal or only typequal (we don't take advantage of that yet, but we could in the future). Additionally, improve the inference of `typename` and allow constant folding of field accesses on TypeName and SimpleVectors (to be able to constant fold T.parameters[1]).
1 parent 9554bec commit 8aaf4db

File tree

2 files changed

+121
-14
lines changed

2 files changed

+121
-14
lines changed

base/inference.jl

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ type Conditional
8282
end
8383
end
8484

85+
immutable PartialTypeVar
86+
tv::TypeVar
87+
# N.B.: Currently unused, but would allow turning something back
88+
# into Const, if the bounds are pulled out of this TypeVar
89+
lb_certain::Bool
90+
ub_certain::Bool
91+
PartialTypeVar(tv::TypeVar, lb_certain::Bool, ub_certain::Bool) = new(tv, lb_certain, ub_certain)
92+
end
93+
8594
function rewrap(t::ANY, u::ANY)
8695
isa(t, Const) && return t
8796
isa(t, Conditional) && return t
@@ -678,6 +687,7 @@ function limit_type_depth(t::ANY, d::Int, cov::Bool=true, var::Union{Void,TypeVa
678687
return (cov && !stillcov) ? UnionAll(var, R) : R
679688
end
680689

690+
const DataType_name_fieldindex = fieldindex(DataType, :name)
681691
const DataType_parameters_fieldindex = fieldindex(DataType, :parameters)
682692
const DataType_types_fieldindex = fieldindex(DataType, :types)
683693
const DataType_super_fieldindex = fieldindex(DataType, :super)
@@ -713,7 +723,8 @@ function getfield_tfunc(s00::ANY, name)
713723
if isa(sv, Module) && isa(nv, Symbol)
714724
return abstract_eval_global(sv, nv)
715725
end
716-
if (isa(sv, DataType) || isimmutable(sv)) && isdefined(sv, nv)
726+
if (isa(sv, DataType) || isa(sv, SimpleVector) || isa(sv, TypeName)
727+
|| isimmutable(sv)) && isdefined(sv, nv)
717728
return abstract_eval_constant(getfield(sv, nv))
718729
end
719730
end
@@ -774,7 +785,8 @@ function getfield_tfunc(s00::ANY, name)
774785
sp = nothing
775786
end
776787
if (sp !== nothing &&
777-
(fld == DataType_parameters_fieldindex ||
788+
(fld == DataType_name_fieldindex ||
789+
fld == DataType_parameters_fieldindex ||
778790
fld == DataType_types_fieldindex ||
779791
fld == DataType_super_fieldindex))
780792
return Const(getfield(sp, fld))
@@ -905,15 +917,19 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
905917
return Any
906918
end
907919
uncertain = false
920+
uncertain_typevar = false
908921
tparams = Any[]
909922
outervars = Any[]
910923
for i = 1:largs
911924
ai = args[i]
912925
if isType(ai)
913926
aip1 = ai.parameters[1]
914927
push!(tparams, aip1)
915-
elseif isa(ai, Const) && (isa(ai.val, Type) || valid_tparam(ai.val))
928+
elseif isa(ai, Const) && (isa(ai.val, Type) || isa(ai.val, TypeVar) || valid_tparam(ai.val))
916929
push!(tparams, ai.val)
930+
elseif isa(ai, PartialTypeVar)
931+
uncertain_typevar = true
932+
push!(tparams, ai.tv)
917933
else
918934
# TODO: return `Bottom` for trying to apply a non-UnionAll
919935
uncertain = true
@@ -956,7 +972,8 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
956972
# doesn't match, which could happen if a type estimate is too coarse
957973
return Type{_} where _<:headtype
958974
end
959-
!uncertain && return Const(appl)
975+
!uncertain && !uncertain_typevar && return Const(appl)
976+
!uncertain && return Type{appl}
960977
if isvarargtype(headtype)
961978
return Type
962979
end
@@ -1476,6 +1493,25 @@ function Pair_name()
14761493
return _Pair_name
14771494
end
14781495

1496+
_typename(a) = Union{}
1497+
_typename(a::Vararg) = Any
1498+
_typename(a::TypeVar) = Any
1499+
_typename(a::DataType) = Const(a.name)
1500+
function _typename(a::Union)
1501+
ta = _typename(a.a)
1502+
tb = _typename(a.b)
1503+
ta === tb ? tb : (ta === Any || tb === Any) ? Any : Union{}
1504+
end
1505+
_typename(union::UnionAll) = typename(union.body)
1506+
function typename_static(t)
1507+
# N.B.: typename maps type equivalence classes to a single value
1508+
if isa(t, Const) || isType(t)
1509+
return _typename(isa(t, Const) ? t.val : t.parameters[1])
1510+
else
1511+
return Any
1512+
end
1513+
end
1514+
14791515
function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
14801516
if f === _apply
14811517
length(fargs) > 1 || return Any
@@ -1557,19 +1593,63 @@ function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vect
15571593
end
15581594
end
15591595
return Any
1560-
elseif f === UnionAll
1561-
if length(fargs) == 3 && isa(argtypes[2], Const)
1562-
tv = argtypes[2].val
1563-
if isa(tv, TypeVar)
1596+
elseif f === TypeVar
1597+
lb = Union{}
1598+
ub = Any
1599+
ub_certain = lb_certain = true
1600+
if length(fargs) >= 2 && isa(argtypes[2], Const)
1601+
nv = argtypes[2].val
1602+
ubidx = 3
1603+
if length(fargs) >= 4
1604+
ubidx = 4
15641605
if isa(argtypes[3], Const)
1565-
body = argtypes[3].val
1606+
lb = argtypes[3].val
15661607
elseif isType(argtypes[3])
1567-
body = argtypes[3].parameters[1]
1608+
lb = argtypes[3].parameters[1]
1609+
lb_certain = false
15681610
else
1569-
return Any
1611+
return TypeVar
1612+
end
1613+
end
1614+
if length(fargs) >= ubidx
1615+
if isa(argtypes[ubidx], Const)
1616+
ub = argtypes[ubidx].val
1617+
elseif isType(argtypes[ubidx])
1618+
ub = argtypes[ubidx].parameters[1]
1619+
ub_certain = false
1620+
else
1621+
return TypeVar
15701622
end
1571-
return abstract_eval_constant(UnionAll(tv, body))
15721623
end
1624+
tv = TypeVar(nv, lb, ub)
1625+
return PartialTypeVar(tv, lb_certain, ub_certain)
1626+
end
1627+
return TypeVar
1628+
elseif f === UnionAll
1629+
if length(fargs) == 3
1630+
canconst = true
1631+
if isa(argtypes[3], Const)
1632+
body = argtypes[3].val
1633+
elseif isType(argtypes[3])
1634+
body = argtypes[3].parameters[1]
1635+
canconst = false
1636+
else
1637+
return Any
1638+
end
1639+
if isa(argtypes[2], Const)
1640+
tv = argtypes[2].val
1641+
elseif isa(argtypes[2], PartialTypeVar)
1642+
ptv = argtypes[2]
1643+
tv = ptv.tv
1644+
canconst = false
1645+
else
1646+
return Any
1647+
end
1648+
!isa(tv, TypeVar) && return Any
1649+
(!isa(body, Type) || !isa(body, TypeVar)) && return Any
1650+
theunion = UnionAll(tv, body)
1651+
ret = canconst ? abstract_eval_constant(theunion) : Type{theunion}
1652+
return ret
15731653
end
15741654
return Any
15751655
elseif f === return_type
@@ -1595,7 +1675,16 @@ function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vect
15951675

15961676
if length(argtypes)>2 && argtypes[3] Int
15971677
at2 = widenconst(argtypes[2])
1598-
if (at2 <: Tuple ||
1678+
if at2 <: SimpleVector && istopfunction(tm, f, :getindex)
1679+
if isa(argtypes[2], Const) && isa(argtypes[3], Const)
1680+
svecval = argtypes[2].val
1681+
idx = argtypes[3].val
1682+
if isa(idx, Int) && 1 <= idx <= length(svecval) &
1683+
isassigned(svecval, idx)
1684+
return Const(getindex(svecval, idx))
1685+
end
1686+
end
1687+
elseif (at2 <: Tuple ||
15991688
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
16001689
# allow tuple indexing functions to take advantage of constant
16011690
# index arguments.
@@ -1617,6 +1706,12 @@ function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vect
16171706

16181707
if istopfunction(tm, f, :promote_type) || istopfunction(tm, f, :typejoin)
16191708
return Type
1709+
elseif length(argtypes) == 2 && istopfunction(tm, f, :typename)
1710+
t = argtypes[2]
1711+
if isa(t, Const) || isType(t)
1712+
return typename_static(t)
1713+
end
1714+
return Any
16201715
end
16211716

16221717
if sv.params.inlining
@@ -1905,6 +2000,7 @@ function widenconst(c::Const)
19052000
return typeof(c.val)
19062001
end
19072002
end
2003+
widenconst(c::PartialTypeVar) = TypeVar
19082004
widenconst(t::ANY) = t
19092005

19102006
issubstate(a::VarState, b::VarState) = (a.typ b.typ && a.undef <= b.undef)
@@ -3552,7 +3648,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
35523648
if method.name == :getindex || method.name == :next || method.name == :indexed_next
35533649
if length(atypes) > 2 && atypes[3] Int
35543650
at2 = widenconst(atypes[2])
3555-
if (at2 <: Tuple ||
3651+
if (at2 <: Tuple || at2 <: SimpleVector ||
35563652
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
35573653
force_infer = true
35583654
end

test/inference.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,14 @@ f11015(a::AT11015) = g11015(Base.fieldtype(typeof(a), :f), true)
588588
g11015(::Type{Bool}, ::Bool) = 2.0
589589
@test Int <: Base.return_types(f11015, (AT11015,))[1]
590590
@test f11015(AT11015(true)) === 1
591+
592+
# Inference for some type-level computation
593+
fUnionAll{T}(::Type{T}) = Type{S} where S <: T
594+
@inferred fUnionAll(Real) == Type{T} where T <: Real
595+
@inferred fUnionAll(Rational{T} where T <: AbstractFloat) == Type{T} where T<:(Rational{S} where S <: AbstractFloat)
596+
597+
fComplicatedUnionAll{T}(::Type{T}) = Type{Tuple{S,rand() >= 0.5 ? Int : Float64}} where S <: T
598+
let pub = Base.parameter_upper_bound, x = fComplicatedUnionAll(Real)
599+
@test pub(pub(x, 1), 1) == Real
600+
@test pub(pub(x, 1), 2) == Int || pub(pub(x, 1), 2) == Float64
601+
end

0 commit comments

Comments
 (0)