Skip to content

Commit 21a751f

Browse files
committed
Improve inference for common reflection operations
Make inference be able to infer the type constructor. This is a little tricky, since we don't really have a good way to represent this. It's not quite `Const(TypeVar(:T,lb,ub))`, because a) the bounds may only be accurate up to type equality and b) TypeVar is not `isbits`, so it's not actually egal to the value we'll have at runtime. Additionally Type{T} already has meaning as a partially constructed type (e.g. an unwrapped UnionAll), so using T::Type{T} runs the risk of confusing this with an unwrapped type. Instead, introduce a new Const-like type, `PartialTypeVar`, which carries the type var and also keeps track of whether the bounds were egal or only typequal (we don't take advantage of that yet, but we could in the future). Additionally, improve the inference of `typename` and allow constant folding of field accesses on TypeName and SimpleVectors (to be able to constant fold T.parameters[1]).
1 parent cf385fe commit 21a751f

File tree

2 files changed

+115
-14
lines changed

2 files changed

+115
-14
lines changed

base/inference.jl

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ immutable Const
6363
Const(v::ANY) = new(v)
6464
end
6565

66+
immutable PartialTypeVar
67+
tv::TypeVar
68+
lb_certain::Bool
69+
ub_certain::Bool
70+
PartialTypeVar(tv::TypeVar, lb_certain::ANY, ub_certain::ANY) = new(tv, lb_certain, ub_certain)
71+
end
72+
6673
function rewrap(t::ANY, u::ANY)
6774
isa(t, Const) && return t
6875
rewrap_unionall(t, u)
@@ -638,6 +645,7 @@ function limit_type_depth(t::ANY, d::Int, cov::Bool=true, var::Union{Void,TypeVa
638645
return (cov && !stillcov) ? UnionAll(var, R) : R
639646
end
640647

648+
const DataType_name_fieldindex = fieldindex(DataType, :name)
641649
const DataType_parameters_fieldindex = fieldindex(DataType, :parameters)
642650
const DataType_types_fieldindex = fieldindex(DataType, :types)
643651
const DataType_super_fieldindex = fieldindex(DataType, :super)
@@ -671,7 +679,8 @@ function getfield_tfunc(s00::ANY, name)
671679
if isa(sv, Module) && isa(nv, Symbol)
672680
return abstract_eval_global(sv, nv)
673681
end
674-
if (isa(sv, DataType) || isimmutable(sv)) && isdefined(sv, nv)
682+
if (isa(sv, DataType) || isa(sv, SimpleVector) || isa(sv, TypeName)
683+
|| isimmutable(sv)) && isdefined(sv, nv)
675684
return abstract_eval_constant(getfield(sv, nv))
676685
end
677686
end
@@ -729,7 +738,8 @@ function getfield_tfunc(s00::ANY, name)
729738
sp = nothing
730739
end
731740
if (sp !== nothing &&
732-
(fld == DataType_parameters_fieldindex ||
741+
(fld == DataType_name_fieldindex ||
742+
fld == DataType_parameters_fieldindex ||
733743
fld == DataType_types_fieldindex ||
734744
fld == DataType_super_fieldindex))
735745
return Const(getfield(sp, fld))
@@ -821,15 +831,19 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
821831
return Any
822832
end
823833
uncertain = false
834+
uncertain_typevar = false
824835
tparams = Any[]
825836
for i = 1:largs
826837
ai = args[i]
827838
if isType(ai)
828839
aip1 = ai.parameters[1]
829840
uncertain |= has_free_typevars(aip1)
830841
push!(tparams, aip1)
831-
elseif isa(ai, Const) && (isa(ai.val, Type) || valid_tparam(ai.val))
842+
elseif isa(ai, Const) && (isa(ai.val, Type) || isa(ai.val, TypeVar) || valid_tparam(ai.val))
832843
push!(tparams, ai.val)
844+
elseif isa(ai, PartialTypeVar)
845+
uncertain_typevar = true
846+
push!(tparams, ai.tv)
833847
else
834848
# TODO: return `Bottom` for trying to apply a non-UnionAll
835849
#if !istuple && i-1 > length(headtype.parameters)
@@ -855,7 +869,8 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
855869
appl = headtype
856870
uncertain = true
857871
end
858-
!uncertain && return Const(appl)
872+
!uncertain && !uncertain_typevar && return Const(appl)
873+
!uncertain && return Type{appl}
859874
if type_too_complex(appl,0)
860875
return Type{_} where _<:headtype
861876
end
@@ -1352,6 +1367,25 @@ function Pair_name()
13521367
return _Pair_name
13531368
end
13541369

1370+
_typename(a) = Union{}
1371+
_typename(a::Vararg) = Any
1372+
_typename(a::TypeVar) = Any
1373+
_typename(a::DataType) = Const(a.name)
1374+
function _typename(a::Union)
1375+
ta = _typename(a.a)
1376+
tb = _typename(a.b)
1377+
ta == tb ? tb : (ta === Any || tb == Any) ? Any : Union{}
1378+
end
1379+
_typename(union::UnionAll) = typename(union.body)
1380+
function typename_static(t)
1381+
# N.B.: typename maps type equivalence classes to a single value
1382+
if isa(t, Const) || isType(t)
1383+
return _typename(isa(t, Const) ? t.val : t.parameters[1])
1384+
else
1385+
return Any
1386+
end
1387+
end
1388+
13551389
function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
13561390
if f === _apply
13571391
length(fargs) > 1 || return Any
@@ -1386,19 +1420,62 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
13861420
end
13871421
end
13881422
return Any
1389-
elseif f === UnionAll
1390-
if length(fargs) == 3 && isa(argtypes[2], Const)
1391-
tv = argtypes[2].val
1392-
if isa(tv, TypeVar)
1423+
elseif f === TypeVar
1424+
lb = Union{}
1425+
ub = Any
1426+
ub_certain = lb_certain = true
1427+
if length(fargs) >= 2 && isa(argtypes[2], Const)
1428+
nv = argtypes[2].val
1429+
ubidx = 3
1430+
if length(fargs) >= 4
1431+
ubidx = 4
13931432
if isa(argtypes[3], Const)
1394-
body = argtypes[3].val
1433+
lb = argtypes[3].val
13951434
elseif isType(argtypes[3])
1396-
body = argtypes[3].parameters[1]
1435+
lb = argtypes[3].parameters[1]
1436+
lb_certain = false
13971437
else
1398-
return Any
1438+
return TypeVar
13991439
end
1400-
return abstract_eval_constant(UnionAll(tv, body))
14011440
end
1441+
if length(fargs) >= ubidx
1442+
if isa(argtypes[ubidx], Const)
1443+
ub = argtypes[ubidx].val
1444+
elseif isType(argtypes[ubidx])
1445+
ub = argtypes[ubidx].parameters[1]
1446+
ub_certain = false
1447+
else
1448+
return TypeVar
1449+
end
1450+
end
1451+
tv = TypeVar(nv, lb, ub)
1452+
return PartialTypeVar(tv, lb_certain, ub_certain)
1453+
end
1454+
return TypeVar
1455+
elseif f === UnionAll
1456+
if length(fargs) == 3
1457+
canconst = true
1458+
if isa(argtypes[3], Const)
1459+
body = argtypes[3].val
1460+
elseif isType(argtypes[3])
1461+
body = argtypes[3].parameters[1]
1462+
canconst = false
1463+
else
1464+
return Any
1465+
end
1466+
if isa(argtypes[2], Const)
1467+
tv = argtypes[2].val
1468+
elseif isa(argtypes[2], PartialTypeVar)
1469+
ptv = argtypes[2]
1470+
tv = ptv.tv
1471+
canconst = false
1472+
else
1473+
return Any
1474+
end
1475+
!isa(tv, TypeVar) && return Any
1476+
ret = canconst ? abstract_eval_constant(UnionAll(tv, body)) :
1477+
Type{UnionAll(tv, body)}
1478+
return ret
14021479
end
14031480
return Any
14041481
elseif f === return_type
@@ -1411,7 +1488,15 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
14111488
tm = _topmod(sv)
14121489
if length(argtypes)>2 && argtypes[3] Int
14131490
at2 = widenconst(argtypes[2])
1414-
if (at2 <: Tuple ||
1491+
if at2 <: SimpleVector && istopfunction(tm, f, :getindex)
1492+
if isa(argtypes[2], Const) && isa(argtypes[3], Const)
1493+
svecval = argtypes[2].val
1494+
idx = argtypes[3].val
1495+
if isa(idx, Int) && 1 <= idx <= length(svecval)
1496+
return Const(getindex(svecval, idx))
1497+
end
1498+
end
1499+
elseif (at2 <: Tuple ||
14151500
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
14161501
# allow tuple indexing functions to take advantage of constant
14171502
# index arguments.
@@ -1433,6 +1518,12 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
14331518

14341519
if istopfunction(tm, f, :promote_type) || istopfunction(tm, f, :typejoin)
14351520
return Type
1521+
elseif length(argtypes) == 2 && istopfunction(tm, f, :typename)
1522+
t = argtypes[2]
1523+
if isa(t, Const) || isType(t)
1524+
return typename_static(t)
1525+
end
1526+
return Any
14361527
end
14371528

14381529
if sv.params.inlining
@@ -1683,6 +1774,7 @@ function ⊑(a::ANY, b::ANY)
16831774
end
16841775

16851776
widenconst(c::Const) = isa(c.val, Type) ? Type{c.val} : typeof(c.val)
1777+
widenconst(c::PartialTypeVar) = TypeVar
16861778
widenconst(t::ANY) = t
16871779

16881780
issubstate(a::VarState, b::VarState) = (a.typ b.typ && a.undef <= b.undef)
@@ -3300,7 +3392,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
33003392
if method.name == :getindex || method.name == :next || method.name == :indexed_next
33013393
if length(atypes) > 2 && atypes[3] Int
33023394
at2 = widenconst(atypes[2])
3303-
if (at2 <: Tuple ||
3395+
if (at2 <: Tuple || at2 <: SimpleVector ||
33043396
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
33053397
force_infer = true
33063398
end

test/inference.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,12 @@ tpara18457{I}(::Type{AbstractMyType18457{I}}) = I
491491
tpara18457{A<:AbstractMyType18457}(::Type{A}) = tpara18457(supertype(A))
492492
@test tpara18457(MyType18457{true}) === true
493493

494+
fUnionAll{T}(::Type{T}) = Type{S} where S <: T
495+
@inferred fUnionAll(Real) == Type{T} where T <: Real
496+
@inferred fUnionAll(Rational{T} where T <: AbstractFloat) == Type{T} where T<:(Rational{S} where S <: AbstractFloat)
497+
498+
fComplicatedUnionAll{T}(::Type{T}) = Type{Tuple{S,rand() >= 0.5 ? Int : Float64}} where S <: T
499+
let pub = Base.parameter_upper_bound, x = fComplicatedUnionAll(Real)
500+
@test pub(pub(x, 1), 1) == Real
501+
@test pub(pub(x, 1), 2) == Int || pub(pub(x, 1), 2) == Float64
502+
end

0 commit comments

Comments
 (0)