Skip to content

Commit f396fee

Browse files
committed
Add a demo detecting all poorly-inferred calls
This generates JuliaLang/julia#36463
1 parent adaaafc commit f396fee

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed

demos/abstract_inference.jl

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
using MethodAnalysis
2+
using Base: _methods_by_ftype, get_world_counter, to_tuple_type, func_for_method_checked, remove_linenums!
3+
using Core: CodeInfo, SSAValue, SlotNumber, SimpleVector
4+
5+
if isdefined(Core.Compiler, :NativeInterpreter)
6+
getcode(meth, x, world, optimize; interp=Core.Compiler.NativeInterpreter(world)) =
7+
Core.Compiler.typeinf_code(interp, meth, x[1], x[2], optimize)
8+
else
9+
getcode(meth, x, world, optimize; params=Core.Compiler.Params(world)) =
10+
Core.Compiler.typeinf_code(meth, x[1], x[2], optimize, params)
11+
end
12+
13+
function infer_with_sig(m::Method; optimize=true, debuginfo=:none, world=get_world_counter(), kwargs...)
14+
tt = to_tuple_type(m.sig)
15+
meths = _methods_by_ftype(tt, -1, world)
16+
for x in meths
17+
x[3] == m || continue
18+
meth = func_for_method_checked(x[3], tt, x[2])
19+
(code, ty) = getcode(meth, x, world, optimize; kwargs...)
20+
debuginfo === :none && code !== nothing && remove_linenums!(code)
21+
return (code, x[2])=>ty
22+
end
23+
error("no match for ", m)
24+
end
25+
26+
struct BadCall
27+
callee::GlobalRef
28+
argtyps
29+
rettype
30+
end
31+
32+
function peeltype(@nospecialize(T))
33+
isa(T, Core.Compiler.Const) && return Core.Typeof(T.val)
34+
isa(T, Core.Compiler.PartialStruct) && return T.typ
35+
isa(T, Core.Compiler.MaybeUndef) && return T.typ
36+
return T
37+
end
38+
39+
resolve(g::GlobalRef) = isdefined(g.mod, g.name) ? getfield(g.mod, g.name) : nothing
40+
resolve(T::Type) = T
41+
42+
"""
43+
`tfunc(argtyps, rettype)` returns `true` if `rettype` is the expected type
44+
"""
45+
function bad_calls(src::CodeInfo, sparams::SimpleVector, @nospecialize(ty), tfuncs::AbstractDict)
46+
function lookup(a; typof::Bool=true)
47+
if isa(a, SSAValue)
48+
return peeltype(src.ssavaluetypes[a.id])
49+
elseif isa(a, SlotNumber)
50+
return peeltype(src.slottypes[a.id])
51+
elseif isdefined(Core.Compiler, :Argument) && isa(a, Core.Compiler.Argument)
52+
return peeltype(src.slottypes[a.n])
53+
elseif isa(a, GlobalRef) && isdefined(a.mod, a.name)
54+
return Core.Typeof(getfield(a.mod, a.name))
55+
elseif isa(a, Expr)
56+
if a.head === :static_parameter
57+
n = a.args[1]
58+
t = Any
59+
if 1 <= n <= length(sparams)
60+
t = sparams[n]
61+
end
62+
return t
63+
else
64+
error("unrecognized Expr head ", a.head)
65+
end
66+
end
67+
return typof ? Core.Typeof(peeltype(a)) : peeltype(a)
68+
end
69+
70+
badstmts = Pair{Int,BadCall}[]
71+
for (i, stmt) in enumerate(src.code)
72+
if isa(stmt, Expr)
73+
stmt.head === :call || continue
74+
g = stmt.args[1]
75+
isa(g, GlobalRef) || isa(g, Type) || continue
76+
tfunc = get(tfuncs, resolve(g), nothing)
77+
if tfunc !== nothing
78+
atyps = []
79+
for j = 2:length(stmt.args)
80+
a = stmt.args[j]
81+
push!(atyps, lookup(a))
82+
end
83+
sttyp = peeltype(src.ssavaluetypes[i])
84+
# Check to see if the next line has a typeassert
85+
if i < length(src.code)
86+
nextstmt = src.code[i+1]
87+
if isa(nextstmt, Expr) && nextstmt.head === :call
88+
c = nextstmt.args[1]
89+
if isa(c, GlobalRef) && c.mod === Core && c.name === :typeassert && nextstmt.args[2] == SSAValue(i)
90+
tatyp = lookup(nextstmt.args[3]; typof=false)
91+
sttyp = typeintersect(sttyp, tatyp)
92+
end
93+
end
94+
end
95+
if !tfunc(atyps, sttyp)
96+
push!(badstmts, i => BadCall(g, atyps, sttyp))
97+
end
98+
end
99+
end
100+
end
101+
return badstmts
102+
end
103+
104+
function tfunc_promote(atyps, @nospecialize(rettyp))
105+
# peeltyp(T) = T<:Type ? T.parameters[1] : T
106+
T = atyps[1]
107+
isa(T, TypeVar) && return true
108+
for i = 2:length(atyps)
109+
T = promote_type(T, atyps[i])
110+
end
111+
return rettyp === T
112+
end
113+
114+
function tfunc_promote_or_subtype(atyps, @nospecialize(rettyp))
115+
tfunc_promote(atyps, rettyp) && return true
116+
for a in atyps
117+
rettyp <: a && return true
118+
end
119+
return false
120+
end
121+
122+
function tfunc_sub1(atyps, @nospecialize(rettyp), @nospecialize(U))
123+
T = only(atyps)
124+
return T<:U && rettyp<:U
125+
end
126+
127+
tfunc_returns(atyps, @nospecialize(rettyp), @nospecialize(U)) = rettyp <: U
128+
129+
function gettyp(T)
130+
if isa(T, TypeVar)
131+
return gettyp(T.ub)
132+
elseif isa(T, UnionAll)
133+
return gettyp(Base.unwrap_unionall(T))
134+
elseif isa(T, DataType) && T<:Type
135+
return length(T.parameters) == 1 ? gettyp(T.parameters[1]) : Any
136+
else
137+
return T
138+
end
139+
end
140+
141+
function tfunc_convert(atyps, @nospecialize(rettyp))
142+
T = gettyp(atyps[1])
143+
return gettyp(rettyp) <: T
144+
end
145+
146+
function tfunc_iterate(atyps, @nospecialize(rettyp))
147+
atyps[1] <: AbstractString && return rettyp <: Union{Nothing,Tuple{AbstractChar,Int}}
148+
if atyps[1] <: AbstractArray
149+
T = eltype(atyps[1])
150+
return rettyp <: Union{Nothing,Tuple{T,Union{Int,CartesianIndex}}}
151+
end
152+
return rettyp <: Union{Nothing,Tuple{Any,Any}}
153+
end
154+
155+
function tfunc_getindex(atyps, @nospecialize(rettyp))
156+
Tel = gettyp(eltype(atyps[1]))
157+
if all(T->(Tt = gettyp(T); isa(Tt,Type) ? Tt<:Integer : false), atyps[2:end])
158+
return gettyp(rettyp) <: Tel
159+
end
160+
return true # don't try to infer non-scalar indexing
161+
end
162+
163+
tfuncs = IdDict{Any,Function}(
164+
Base.:(&) => tfunc_promote,
165+
Base.:(|) => tfunc_promote,
166+
Base.:(!) => (a,t)->tfunc_returns(a,t,Union{Bool,Missing}),
167+
Base.:(+) => tfunc_promote_or_subtype,
168+
Base.:(-) => tfunc_promote_or_subtype,
169+
Base.:((==)) => (a,t)->tfunc_returns(a,t,Union{Bool,Missing}),
170+
Base.:((<)) => (a,t)->tfunc_returns(a,t,Union{Bool,Missing}),
171+
Base.:((<=)) => (a,t)->tfunc_returns(a,t,Union{Bool,Missing}),
172+
Base.:((>)) => (a,t)->tfunc_returns(a,t,Union{Bool,Missing}),
173+
Base.:((>=)) => (a,t)->tfunc_returns(a,t,Union{Bool,Missing}),
174+
Base.:(cmp) => (a,t)->tfunc_returns(a,t,Int),
175+
Base.:(convert) => tfunc_convert,
176+
Base.:(cconvert) => tfunc_convert,
177+
Base.:(unsafe_convert) => tfunc_convert,
178+
Base.:(iterate) => tfunc_iterate,
179+
Base.:(getindex) => tfunc_getindex,
180+
Base.:(leading_zeros) => (a,t)->tfunc_sub1(a, t, Integer),
181+
Base.:(thisind) => (a,t)->tfunc_returns(a,t,Int),
182+
Base.:(prevind) => (a,t)->tfunc_returns(a,t,Int),
183+
Base.:(nextind) => (a,t)->tfunc_returns(a,t,Int),
184+
Base.:(ncodeunits) => (a,t)->tfunc_returns(a,t,Int),
185+
Base.:(codeunit) => (a,t)->tfunc_returns(a,t,Type{Union{UInt8,UInt16,UInt32}}),
186+
Base.:(eof) => (a,t)->tfunc_returns(a,t,Bool),
187+
Base.:(readline) => (a,t)->tfunc_returns(a,t,AbstractString),
188+
Base.:(displaysize) => (a,t)->tfunc_returns(a,t,Tuple{Int,Int}),
189+
Base.:(sizeof) => (a,t)->tfunc_returns(a,t,Int),
190+
Base.:(length) => (a,t)->tfunc_returns(a,t,Union{Int,UInt}),
191+
Base.:(size) => (a,t)->tfunc_returns(a,t,length(a) == 1 ? Tuple{Vararg{Int}} : Int),
192+
Base.:(axes) => (a,t)->tfunc_returns(a,t,length(a) == 1 ? Tuple{Vararg{<:AbstractUnitRange}} : AbstractUnitRange),
193+
Base.:(resize!) => (a,t)->tfunc_returns(a,t,a[1]),
194+
Base.:(copyto!) => (a,t)->tfunc_returns(a,t,a[1]),
195+
)
196+
for sym in (
197+
:isabspath,
198+
:isapprox,
199+
:isascii,
200+
:isblockdev,
201+
:ischardev,
202+
:iscntrl,
203+
:isdigit,
204+
:isdir,
205+
:isdirpath,
206+
:isdisjoint,
207+
:isempty,
208+
:isequal,
209+
:iseven,
210+
:isfifo,
211+
:isfile,
212+
:isfinite,
213+
:isinf,
214+
:isinteger,
215+
:isinteractive,
216+
:isless,
217+
:isletter,
218+
:islink,
219+
:islocked,
220+
:islowercase,
221+
:ismarked,
222+
:ismissing,
223+
:ismount,
224+
:isnan,
225+
:isnothing,
226+
:isnumeric,
227+
:isodd,
228+
:isone,
229+
:isopen,
230+
:ispath,
231+
:isperm,
232+
:ispow2,
233+
:isprint,
234+
:ispunct,
235+
:isreadable,
236+
:isreadonly,
237+
:isready,
238+
:isreal,
239+
:issetequal,
240+
:issetgid,
241+
:issetuid,
242+
:issocket,
243+
:issorted,
244+
:isspace,
245+
:issticky,
246+
:issubnormal,
247+
:issubset,
248+
:istaskdone,
249+
:istaskfailed,
250+
:istaskstarted,
251+
:istextmime,
252+
:isuppercase,
253+
:isvalid,
254+
:iswritable,
255+
:isxdigit,
256+
:iszero,
257+
)
258+
f = resolve(GlobalRef(Base, sym))
259+
f === nothing && continue
260+
tfuncs[f] = (a,t)->tfunc_returns(a,t,Union{Bool,Missing})
261+
end
262+
263+
function parcel_by_callee(badcalls::Dict{Method,Any})
264+
callers = IdDict{Any,Set{Method}}()
265+
for (m, prs) in badcalls
266+
for (idx, bc) in prs
267+
g = resolve(bc.callee)
268+
list = get!(callers, g, Set{Method}())
269+
push!(list, m)
270+
end
271+
end
272+
return callers
273+
end
274+
function print_sorted(callees)
275+
strs = String[]
276+
for (callee, list) in callees
277+
push!(strs, string(callee, ": ", length(list)))
278+
end
279+
sort!(strs)
280+
for str in strs
281+
println(str)
282+
end
283+
nothing
284+
end
285+
286+
bfs = Dict{Method,Any}()
287+
visit(Base) do item
288+
if isa(item, Method)
289+
isdefined(item, :generator) && return false
290+
try
291+
(src, sparams), ty = infer_with_sig(item)
292+
bs = bad_calls(src, sparams, ty, tfuncs)
293+
isempty(bs) || (bfs[item] = bs)
294+
catch err
295+
@show item
296+
throw(err)
297+
end
298+
return false
299+
end
300+
return true
301+
end
302+
303+
callees = parcel_by_callee(bfs)
304+
print_sorted(callees)

0 commit comments

Comments
 (0)