Skip to content

Commit 3b5c7e1

Browse files
committed
Improve inference for common reflection operations
Make inference be able to infer the type constructor. This is a little tricky, since we don't really have a good way to represent this. It's not quite `Const(TypeVar(:T,lb,ub))`, because a) the bounds may only be accurate up to type equality and b) TypeVar is not `isbits`, so it's not actually egal to the value we'll have at runtime. Additionally Type{T} already has meaning as a partially constructed type (e.g. an unwrapped UnionAll), so using T::Type{T} runs the risk of confusing this with an unwrapped type. Instead, introduce a new Const-like type, `PartialTypeVar`, which carries the type var and also keeps track of whether the bounds were egal or only typequal (we don't take advantage of that yet, but we could in the future). Additionally, improve the inference of `typename` and allow constant folding of field accesses on TypeName and SimpleVectors (to be able to constant fold T.parameters[1]).
1 parent 91fbcbd commit 3b5c7e1

File tree

3 files changed

+133
-21
lines changed

3 files changed

+133
-21
lines changed

base/inference.jl

Lines changed: 113 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ type Conditional
8282
end
8383
end
8484

85+
immutable PartialTypeVar
86+
tv::TypeVar
87+
# N.B.: Currently unused, but would allow turning something back
88+
# into Const, if the bounds are pulled out of this TypeVar
89+
lb_certain::Bool
90+
ub_certain::Bool
91+
PartialTypeVar(tv::TypeVar, lb_certain::Bool, ub_certain::Bool) = new(tv, lb_certain, ub_certain)
92+
end
93+
8594
function rewrap(t::ANY, u::ANY)
8695
isa(t, Const) && return t
8796
isa(t, Conditional) && return t
@@ -678,6 +687,7 @@ function limit_type_depth(t::ANY, d::Int, cov::Bool=true, var::Union{Void,TypeVa
678687
return (cov && !stillcov) ? UnionAll(var, R) : R
679688
end
680689

690+
const DataType_name_fieldindex = fieldindex(DataType, :name)
681691
const DataType_parameters_fieldindex = fieldindex(DataType, :parameters)
682692
const DataType_types_fieldindex = fieldindex(DataType, :types)
683693
const DataType_super_fieldindex = fieldindex(DataType, :super)
@@ -713,7 +723,8 @@ function getfield_tfunc(s00::ANY, name)
713723
if isa(sv, Module) && isa(nv, Symbol)
714724
return abstract_eval_global(sv, nv)
715725
end
716-
if (isa(sv, DataType) || isimmutable(sv)) && isdefined(sv, nv)
726+
if (isa(sv, DataType) || isa(sv, SimpleVector) || isa(sv, TypeName)
727+
|| isimmutable(sv)) && isdefined(sv, nv)
717728
return abstract_eval_constant(getfield(sv, nv))
718729
end
719730
end
@@ -774,7 +785,8 @@ function getfield_tfunc(s00::ANY, name)
774785
sp = nothing
775786
end
776787
if (sp !== nothing &&
777-
(fld == DataType_parameters_fieldindex ||
788+
(fld == DataType_name_fieldindex ||
789+
fld == DataType_parameters_fieldindex ||
778790
fld == DataType_types_fieldindex ||
779791
fld == DataType_super_fieldindex))
780792
return Const(getfield(sp, fld))
@@ -905,15 +917,20 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
905917
return Any
906918
end
907919
uncertain = false
920+
canconst = true
908921
tparams = Any[]
909922
outervars = Any[]
910923
for i = 1:largs
911924
ai = args[i]
912925
if isType(ai)
913926
aip1 = ai.parameters[1]
927+
canconst &= !isleaftype(aip1)
914928
push!(tparams, aip1)
915-
elseif isa(ai, Const) && (isa(ai.val, Type) || valid_tparam(ai.val))
929+
elseif isa(ai, Const) && (isa(ai.val, Type) || isa(ai.val, TypeVar) || valid_tparam(ai.val))
916930
push!(tparams, ai.val)
931+
elseif isa(ai, PartialTypeVar)
932+
canconst = false
933+
push!(tparams, ai.tv)
917934
else
918935
# TODO: return `Bottom` for trying to apply a non-UnionAll
919936
uncertain = true
@@ -956,11 +973,11 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
956973
# doesn't match, which could happen if a type estimate is too coarse
957974
return Type{_} where _<:headtype
958975
end
959-
!uncertain && return Const(appl)
976+
!uncertain && canconst && return Const(appl)
960977
if isvarargtype(headtype)
961978
return Type
962979
end
963-
if type_too_complex(appl,0)
980+
if uncertain && type_too_complex(appl,0)
964981
return Type{_} where _<:headtype
965982
end
966983
if istuple
@@ -1476,6 +1493,25 @@ function Pair_name()
14761493
return _Pair_name
14771494
end
14781495

1496+
_typename(a) = Union{}
1497+
_typename(a::Vararg) = Any
1498+
_typename(a::TypeVar) = Any
1499+
_typename(a::DataType) = Const(a.name)
1500+
function _typename(a::Union)
1501+
ta = _typename(a.a)
1502+
tb = _typename(a.b)
1503+
ta === tb ? tb : (ta === Any || tb === Any) ? Any : Union{}
1504+
end
1505+
_typename(union::UnionAll) = _typename(union.body)
1506+
function typename_static(t)
1507+
# N.B.: typename maps type equivalence classes to a single value
1508+
if isa(t, Const) || isType(t)
1509+
return _typename(isa(t, Const) ? t.val : t.parameters[1])
1510+
else
1511+
return Any
1512+
end
1513+
end
1514+
14791515
function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
14801516
if f === _apply
14811517
length(fargs) > 1 || return Any
@@ -1557,19 +1593,65 @@ function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vect
15571593
end
15581594
end
15591595
return Any
1560-
elseif f === UnionAll
1561-
if length(fargs) == 3 && isa(argtypes[2], Const)
1562-
tv = argtypes[2].val
1563-
if isa(tv, TypeVar)
1596+
elseif f === TypeVar
1597+
lb = Union{}
1598+
ub = Any
1599+
ub_certain = lb_certain = true
1600+
if length(fargs) >= 2 && isa(argtypes[2], Const)
1601+
nv = argtypes[2].val
1602+
ubidx = 3
1603+
if length(fargs) >= 4
1604+
ubidx = 4
15641605
if isa(argtypes[3], Const)
1565-
body = argtypes[3].val
1606+
lb = argtypes[3].val
15661607
elseif isType(argtypes[3])
1567-
body = argtypes[3].parameters[1]
1608+
lb = argtypes[3].parameters[1]
1609+
lb_certain = false
15681610
else
1569-
return Any
1611+
return TypeVar
15701612
end
1571-
return abstract_eval_constant(UnionAll(tv, body))
15721613
end
1614+
if length(fargs) >= ubidx
1615+
if isa(argtypes[ubidx], Const)
1616+
ub = argtypes[ubidx].val
1617+
elseif isType(argtypes[ubidx])
1618+
ub = argtypes[ubidx].parameters[1]
1619+
ub_certain = false
1620+
else
1621+
return TypeVar
1622+
end
1623+
end
1624+
tv = TypeVar(nv, lb, ub)
1625+
return PartialTypeVar(tv, lb_certain, ub_certain)
1626+
end
1627+
return TypeVar
1628+
elseif f === UnionAll
1629+
if length(fargs) == 3
1630+
canconst = true
1631+
if isa(argtypes[3], Const)
1632+
body = argtypes[3].val
1633+
elseif isType(argtypes[3])
1634+
body = argtypes[3].parameters[1]
1635+
canconst = false
1636+
else
1637+
return Any
1638+
end
1639+
if isa(argtypes[2], Const)
1640+
tv = argtypes[2].val
1641+
elseif isa(argtypes[2], PartialTypeVar)
1642+
ptv = argtypes[2]
1643+
tv = ptv.tv
1644+
canconst = false
1645+
else
1646+
return Any
1647+
end
1648+
!isa(tv, TypeVar) && return Any
1649+
if !isa(body, Type) && !isa(body, TypeVar)
1650+
return Any
1651+
end
1652+
theunion = UnionAll(tv, body)
1653+
ret = canconst ? abstract_eval_constant(theunion) : Type{theunion}
1654+
return ret
15731655
end
15741656
return Any
15751657
elseif f === return_type
@@ -1595,7 +1677,16 @@ function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vect
15951677

15961678
if length(argtypes)>2 && argtypes[3] Int
15971679
at2 = widenconst(argtypes[2])
1598-
if (at2 <: Tuple ||
1680+
if at2 <: SimpleVector && istopfunction(tm, f, :getindex)
1681+
if isa(argtypes[2], Const) && isa(argtypes[3], Const)
1682+
svecval = argtypes[2].val
1683+
idx = argtypes[3].val
1684+
if isa(idx, Int) && 1 <= idx <= length(svecval) &
1685+
isassigned(svecval, idx)
1686+
return Const(getindex(svecval, idx))
1687+
end
1688+
end
1689+
elseif (at2 <: Tuple ||
15991690
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
16001691
# allow tuple indexing functions to take advantage of constant
16011692
# index arguments.
@@ -1617,6 +1708,12 @@ function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vect
16171708

16181709
if istopfunction(tm, f, :promote_type) || istopfunction(tm, f, :typejoin)
16191710
return Type
1711+
elseif length(argtypes) == 2 && istopfunction(tm, f, :typename)
1712+
t = argtypes[2]
1713+
if isa(t, Const) || isType(t)
1714+
return typename_static(t)
1715+
end
1716+
return Any
16201717
end
16211718

16221719
if sv.params.inlining
@@ -1905,6 +2002,7 @@ function widenconst(c::Const)
19052002
return typeof(c.val)
19062003
end
19072004
end
2005+
widenconst(c::PartialTypeVar) = TypeVar
19082006
widenconst(t::ANY) = t
19092007

19102008
issubstate(a::VarState, b::VarState) = (a.typ b.typ && a.undef <= b.undef)
@@ -3552,7 +3650,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
35523650
if method.name == :getindex || method.name == :next || method.name == :indexed_next
35533651
if length(atypes) > 2 && atypes[3] Int
35543652
at2 = widenconst(atypes[2])
3555-
if (at2 <: Tuple ||
3653+
if (at2 <: Tuple || at2 <: SimpleVector ||
35563654
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
35573655
force_infer = true
35583656
end

test/inference.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,14 @@ f11015(a::AT11015) = g11015(Base.fieldtype(typeof(a), :f), true)
588588
g11015(::Type{Bool}, ::Bool) = 2.0
589589
@test Int <: Base.return_types(f11015, (AT11015,))[1]
590590
@test f11015(AT11015(true)) === 1
591+
592+
# Inference for some type-level computation
593+
fUnionAll{T}(::Type{T}) = Type{S} where S <: T
594+
@inferred fUnionAll(Real) == Type{T} where T <: Real
595+
@inferred fUnionAll(Rational{T} where T <: AbstractFloat) == Type{T} where T<:(Rational{S} where S <: AbstractFloat)
596+
597+
fComplicatedUnionAll{T}(::Type{T}) = Type{Tuple{S,rand() >= 0.5 ? Int : Float64}} where S <: T
598+
let pub = Base.parameter_upper_bound, x = fComplicatedUnionAll(Real)
599+
@test pub(pub(x, 1), 1) == Real
600+
@test pub(pub(x, 1), 2) == Int || pub(pub(x, 1), 2) == Float64
601+
end

test/reflection.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -629,9 +629,12 @@ end
629629
@test Base.parameter_upper_bound(ReflectionExample, 2) === Any
630630
@test Base.parameter_upper_bound(ReflectionExample{T, N} where T where N <: Real, 2) === Real
631631

632-
@test Base.typename(ReflectionExample{Float64, Int64}).wrapper === ReflectionExample
633-
@test Base.typename(ReflectionExample{Float64, N} where N).wrapper === ReflectionExample
634-
@test Base.typename(ReflectionExample{T, Int64} where T).wrapper === ReflectionExample
635-
@test Base.typename(ReflectionExample).wrapper === ReflectionExample
636-
@test Base.typename(Union{ReflectionExample{Union{},1},ReflectionExample{Float64,1}}).wrapper === ReflectionExample
637-
@test_throws ErrorException Base.typename(Union{Int, Float64})
632+
let
633+
wrapperT(T) = Base.typename(T).wrapper
634+
@test @inferred wrapperT(ReflectionExample{Float64, Int64}) == ReflectionExample
635+
@test @inferred wrapperT(ReflectionExample{Float64, N} where N) == ReflectionExample
636+
@test @inferred wrapperT(ReflectionExample{T, Int64} where T) == ReflectionExample
637+
@test @inferred wrapperT(ReflectionExample) == ReflectionExample
638+
@test @inferred wrapperT(Union{ReflectionExample{Union{},1},ReflectionExample{Float64,1}}) == ReflectionExample
639+
@test_throws ErrorException Base.typename(Union{Int, Float64})
640+
end

0 commit comments

Comments
 (0)