Skip to content

Commit b4228f7

Browse files
committed
Improve inference for common reflection operations
Make inference be able to infer the type constructor. This is a little tricky, since we don't really have a good way to represent this. It's not quite `Const(TypeVar(:T,lb,ub))`, because a) the bounds may only be accurate up to type equality and b) TypeVar is not `isbits`, so it's not actually egal to the value we'll have at runtime. Additionally Type{T} already has meaning as a partially constructed type (e.g. an unwrapped UnionAll), so using T::Type{T} runs the risk of confusing this with an unwrapped type. Instead, introduce a new Const-like type, `PartialTypeVar`, which carries the type var and also keeps track of whether the bounds were egal or only typequal (we don't take advantage of that yet, but we could in the future). Additionally, improve the inference of `typename` and allow constant folding of field accesses on TypeName and SimpleVectors (to be able to constant fold T.parameters[1]).
1 parent b1b214d commit b4228f7

File tree

2 files changed

+100
-13
lines changed

2 files changed

+100
-13
lines changed

base/inference.jl

Lines changed: 91 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ immutable Const
6363
Const(v::ANY) = new(v)
6464
end
6565

66+
immutable PartialTypeVar
67+
tv::TypeVar
68+
lb_certain::Bool
69+
ub_certain::Bool
70+
PartialTypeVar(tv::TypeVar, lb_certain::ANY, ub_certain::ANY) = new(tv, lb_certain, ub_certain)
71+
end
72+
6673
function rewrap(t::ANY, u::ANY)
6774
isa(t, Const) && return t
6875
rewrap_unionall(t, u)
@@ -638,6 +645,7 @@ function limit_type_depth(t::ANY, d::Int, cov::Bool=true, var::Union{Void,TypeVa
638645
return (cov && !stillcov) ? UnionAll(var, R) : R
639646
end
640647

648+
const DataType_name_fieldindex = fieldindex(DataType, :name)
641649
const DataType_parameters_fieldindex = fieldindex(DataType, :parameters)
642650
const DataType_types_fieldindex = fieldindex(DataType, :types)
643651
const DataType_super_fieldindex = fieldindex(DataType, :super)
@@ -671,7 +679,8 @@ function getfield_tfunc(s00::ANY, name)
671679
if isa(sv, Module) && isa(nv, Symbol)
672680
return abstract_eval_global(sv, nv)
673681
end
674-
if (isa(sv, DataType) || isimmutable(sv)) && isdefined(sv, nv)
682+
if (isa(sv, DataType) || isa(sv, SimpleVector) || isa(sv, TypeName)
683+
|| isimmutable(sv)) && isdefined(sv, nv)
675684
return abstract_eval_constant(getfield(sv, nv))
676685
end
677686
end
@@ -729,7 +738,8 @@ function getfield_tfunc(s00::ANY, name)
729738
sp = nothing
730739
end
731740
if (sp !== nothing &&
732-
(fld == DataType_parameters_fieldindex ||
741+
(fld == DataType_name_fieldindex ||
742+
fld == DataType_parameters_fieldindex ||
733743
fld == DataType_types_fieldindex ||
734744
fld == DataType_super_fieldindex))
735745
return Const(getfield(sp, fld))
@@ -821,15 +831,19 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
821831
return Any
822832
end
823833
uncertain = false
834+
uncertain_typevar = false
824835
tparams = Any[]
825836
for i = 1:largs
826837
ai = args[i]
827838
if isType(ai)
828839
aip1 = ai.parameters[1]
829840
uncertain |= has_free_typevars(aip1)
830841
push!(tparams, aip1)
831-
elseif isa(ai, Const) && (isa(ai.val, Type) || valid_tparam(ai.val))
842+
elseif isa(ai, Const) && (isa(ai.val, Type) || isa(ai.val, TypeVar) || valid_tparam(ai.val))
832843
push!(tparams, ai.val)
844+
elseif isa(ai, PartialTypeVar)
845+
uncertain_typevar = true
846+
push!(tparams, ai.tv)
833847
else
834848
# TODO: return `Bottom` for trying to apply a non-UnionAll
835849
#if !istuple && i-1 > length(headtype.parameters)
@@ -855,7 +869,8 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
855869
appl = headtype
856870
uncertain = true
857871
end
858-
!uncertain && return Const(appl)
872+
!uncertain && !uncertain_typevar && return Const(appl)
873+
!uncertain && return Type{appl}
859874
if type_too_complex(appl,0)
860875
return Type{_} where _<:headtype
861876
end
@@ -1386,19 +1401,62 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
13861401
end
13871402
end
13881403
return Any
1389-
elseif f === UnionAll
1390-
if length(fargs) == 3 && isa(argtypes[2], Const)
1391-
tv = argtypes[2].val
1392-
if isa(tv, TypeVar)
1404+
elseif f === TypeVar
1405+
lb = Union{}
1406+
ub = Any
1407+
ub_certain = lb_certain = true
1408+
if length(fargs) >= 2 && isa(argtypes[2], Const)
1409+
nv = argtypes[2].val
1410+
ubidx = 3
1411+
if length(fargs) >= 4
1412+
ubidx = 4
13931413
if isa(argtypes[3], Const)
1394-
body = argtypes[3].val
1414+
lb = argtypes[3].val
13951415
elseif isType(argtypes[3])
1396-
body = argtypes[3].parameters[1]
1416+
lb = argtypes[3].parameters[1]
1417+
lb_certain = false
13971418
else
1398-
return Any
1419+
return TypeVar
1420+
end
1421+
end
1422+
if length(fargs) >= ubidx
1423+
if isa(argtypes[ubidx], Const)
1424+
ub = argtypes[ubidx].val
1425+
elseif isType(argtypes[ubidx])
1426+
ub = argtypes[ubidx].parameters[1]
1427+
ub_certain = false
1428+
else
1429+
return TypeVar
13991430
end
1400-
return abstract_eval_constant(UnionAll(tv, body))
14011431
end
1432+
tv = TypeVar(nv, lb, ub)
1433+
return PartialTypeVar(tv, lb_certain, ub_certain)
1434+
end
1435+
return TypeVar
1436+
elseif f === UnionAll
1437+
if length(fargs) == 3
1438+
canconst = true
1439+
if isa(argtypes[3], Const)
1440+
body = argtypes[3].val
1441+
elseif isType(argtypes[3])
1442+
body = argtypes[3].parameters[1]
1443+
canconst = false
1444+
else
1445+
return Any
1446+
end
1447+
if isa(argtypes[2], Const)
1448+
tv = argtypes[2].val
1449+
elseif isa(argtypes[2], PartialTypeVar)
1450+
ptv = argtypes[2]
1451+
tv = ptv.tv
1452+
canconst = false
1453+
else
1454+
return Any
1455+
end
1456+
!isa(tv, TypeVar) && return Any
1457+
ret = canconst ? abstract_eval_constant(UnionAll(tv, body)) :
1458+
Type{UnionAll(tv, body)}
1459+
return ret
14021460
end
14031461
return Any
14041462
elseif f === return_type
@@ -1411,7 +1469,15 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
14111469
tm = _topmod(sv)
14121470
if length(argtypes)>2 && argtypes[3] Int
14131471
at2 = widenconst(argtypes[2])
1414-
if (at2 <: Tuple ||
1472+
if at2 <: SimpleVector && istopfunction(tm, f, :getindex)
1473+
if isa(argtypes[2], Const) && isa(argtypes[3], Const)
1474+
svecval = argtypes[2].val
1475+
idx = argtypes[3].val
1476+
if isa(idx, Int) && 1 <= idx <= length(svecval)
1477+
return Const(getindex(svecval, idx))
1478+
end
1479+
end
1480+
elseif (at2 <: Tuple ||
14151481
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
14161482
# allow tuple indexing functions to take advantage of constant
14171483
# index arguments.
@@ -1433,6 +1499,17 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
14331499

14341500
if istopfunction(tm, f, :promote_type) || istopfunction(tm, f, :typejoin)
14351501
return Type
1502+
elseif length(argtypes) == 2 && istopfunction(tm, f, :typename)
1503+
t = argtypes[2]
1504+
if isa(t, Const) || isType(t)
1505+
val = isa(t, Const) ? t.val : t.parameters[1]
1506+
try
1507+
return Const(typename(val))
1508+
catch
1509+
return Union{}
1510+
end
1511+
end
1512+
return Any
14361513
end
14371514

14381515
if sv.params.inlining
@@ -1683,6 +1760,7 @@ function ⊑(a::ANY, b::ANY)
16831760
end
16841761

16851762
widenconst(c::Const) = isa(c.val, Type) ? Type{c.val} : typeof(c.val)
1763+
widenconst(c::PartialTypeVar) = TypeVar
16861764
widenconst(t::ANY) = t
16871765

16881766
issubstate(a::VarState, b::VarState) = (a.typ b.typ && a.undef <= b.undef)

test/inference.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,12 @@ tpara18457{I}(::Type{AbstractMyType18457{I}}) = I
491491
tpara18457{A<:AbstractMyType18457}(::Type{A}) = tpara18457(supertype(A))
492492
@test tpara18457(MyType18457{true}) === true
493493

494+
fUnionAll{T}(::Type{T}) = Type{S} where S <: T
495+
@inferred fUnionAll(Real) == Type{T} where T <: Real
496+
@inferred fUnionAll(Rational{T} where T <: AbstractFloat) == Type{T} where T<:(Rational{S} where S <: AbstractFloat)
497+
498+
fComplicatedUnionAll{T}(::Type{T}) = Type{Tuple{S,rand() >= 0.5 ? Int : Float64}} where S <: T
499+
let pub = Base.parameter_upper_bound, x = fComplicatedUnionAll(Real)
500+
@test pub(pub(x, 1), 1) == Real
501+
@test pub(pub(x, 1), 2) == Int || pub(pub(x, 1), 2) == Float64
502+
end

0 commit comments

Comments
 (0)