Skip to content

Commit 27fbfa9

Browse files
committed
IRShow: label builtin / intrinsic / dynamic calls in code_typed
This makes it much easier to spot dynamic dispatches: ```julia 3 ── %9 = (isa)(%4, @NamedTuple{x::Int64, y})::Bool └─── goto JuliaLang#5 if not %9 4 ── %11 = π (%4, @NamedTuple{x::Int64, y}) └─── goto JuliaLang#6 5 ── %13 = (Tuple{Int64, Any})(%4)::Tuple{Int64, Any} │ %14 = (getfield)(%13, 1)::Int64 │ %15 = (getfield)(%13, 2)::Any │ %16 = %new(@NamedTuple{x::Int64, y}, %14, %15)::@NamedTuple{x::Int64, y} ``` is now: ```julia 3 ── %9 = builtin (isa)(%4, @NamedTuple{x::Int64, y})::Bool └─── goto JuliaLang#5 if not %9 4 ── %11 = π (%4, @NamedTuple{x::Int64, y}) └─── goto JuliaLang#6 5 ── %13 = dynamic (Tuple{Int64, Any})(%4)::Tuple{Int64, Any} │ %14 = builtin (getfield)(%13, 1)::Int64 │ %15 = builtin (getfield)(%13, 2)::Any │ %16 = %new(@NamedTuple{x::Int64, y}, %14, %15)::@NamedTuple{x::Int64, y} ``` This is on by default when displaying a CodeInfo, and off by default for `code_warntype`, unless optimize=true. Can be enabled / disabled via IRShowConfig.label_dynamic_calls
1 parent 9844d85 commit 27fbfa9

File tree

11 files changed

+179
-49
lines changed

11 files changed

+179
-49
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,6 +3114,7 @@ end
31143114
abstract_eval_ssavalue(s::SSAValue, sv::InferenceState) = abstract_eval_ssavalue(s, sv.ssavaluetypes)
31153115

31163116
function abstract_eval_ssavalue(s::SSAValue, ssavaluetypes::Vector{Any})
3117+
(1 s.id length(ssavaluetypes)) || throw(InvalidIRError())
31173118
typ = ssavaluetypes[s.id]
31183119
if typ === NOT_FOUND
31193120
return Bottom

base/compiler/optimize.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,26 +411,35 @@ function argextype(@nospecialize(x), compact::IncrementalCompact, sptypes::Vecto
411411
isa(x, AnySSAValue) && return types(compact)[x]
412412
return argextype(x, compact, sptypes, compact.ir.argtypes)
413413
end
414-
argextype(@nospecialize(x), src::CodeInfo, sptypes::Vector{VarState}) = argextype(x, src, sptypes, src.slottypes::Vector{Any})
414+
function argextype(@nospecialize(x), src::CodeInfo, sptypes::Vector{VarState})
415+
return argextype(x, src, sptypes, src.slottypes::Union{Vector{Any},Nothing})
416+
end
415417
function argextype(
416418
@nospecialize(x), src::Union{IRCode,IncrementalCompact,CodeInfo},
417-
sptypes::Vector{VarState}, slottypes::Vector{Any})
419+
sptypes::Vector{VarState}, slottypes::Union{Vector{Any},Nothing})
418420
if isa(x, Expr)
419421
if x.head === :static_parameter
420-
return sptypes[x.args[1]::Int].typ
422+
idx = x.args[1]::Int
423+
(1 idx length(sptypes)) || throw(InvalidIRError())
424+
return sptypes[idx].typ
421425
elseif x.head === :boundscheck
422426
return Bool
423427
elseif x.head === :copyast
428+
length(x.args) == 0 && throw(InvalidIRError())
424429
return argextype(x.args[1], src, sptypes, slottypes)
425430
end
426431
Core.println("argextype called on Expr with head ", x.head,
427432
" which is not valid for IR in argument-position.")
428433
@assert false
429434
elseif isa(x, SlotNumber)
435+
slottypes === nothing && return Any
436+
(1 x.id length(slottypes)) || throw(InvalidIRError())
430437
return slottypes[x.id]
431438
elseif isa(x, SSAValue)
432439
return abstract_eval_ssavalue(x, src)
433440
elseif isa(x, Argument)
441+
slottypes === nothing && return Any
442+
(1 x.n length(slottypes)) || throw(InvalidIRError())
434443
return slottypes[x.n]
435444
elseif isa(x, QuoteNode)
436445
return Const(x.value)
@@ -444,7 +453,15 @@ function argextype(
444453
return Const(x)
445454
end
446455
end
447-
abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any})
456+
function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo)
457+
ssavaluetypes = src.ssavaluetypes
458+
if ssavaluetypes isa Int
459+
(1 s.id ssavaluetypes) || throw(InvalidIRError())
460+
return Any
461+
else
462+
return abstract_eval_ssavalue(s, ssavaluetypes::Vector{Any})
463+
end
464+
end
448465
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]
449466

450467
"""

base/compiler/ssair/ir.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ Instruction(is::InstructionStream) = Instruction(is, add_new_idx!(is))
313313
fldarray = getfield(getfield(node, :data), fld)
314314
fldidx = getfield(node, :idx)
315315
(fld === :line) && return (fldarray[3fldidx-2], fldarray[3fldidx-1], fldarray[3fldidx-0])
316+
(1 fldidx length(fldarray)) || throw(InvalidIRError())
316317
return fldarray[fldidx]
317318
end
318319
@inline function setindex!(node::Instruction, @nospecialize(val), fld::Symbol)
@@ -481,11 +482,16 @@ function block_for_inst(ir::IRCode, inst::Int)
481482
end
482483

483484
function getindex(ir::IRCode, s::SSAValue)
485+
id = s.id
486+
(id 1) || throw(InvalidIRError())
484487
nstmts = length(ir.stmts)
485-
if s.id <= nstmts
486-
return ir.stmts[s.id]
488+
if id <= nstmts
489+
return ir.stmts[id]
487490
else
488-
return ir.new_nodes.stmts[s.id - nstmts]
491+
id -= nstmts
492+
stmts = ir.new_nodes.stmts
493+
(id length(stmts)) || throw(InvalidIRError())
494+
return stmts[id]
489495
end
490496
end
491497

@@ -801,12 +807,13 @@ end
801807
types(ir::Union{IRCode, IncrementalCompact}) = TypesView(ir)
802808

803809
function getindex(compact::IncrementalCompact, ssa::SSAValue)
804-
@assert ssa.id < compact.result_idx
810+
(1 ssa.id compact.result_idx) || throw(InvalidIRError())
805811
return compact.result[ssa.id]
806812
end
807813

808814
function getindex(compact::IncrementalCompact, ssa::OldSSAValue)
809815
id = ssa.id
816+
(id 1) || throw(InvalidIRError())
810817
if id < compact.idx
811818
new_idx = compact.ssa_rename[id]::Int
812819
return compact.result[new_idx]
@@ -818,12 +825,15 @@ function getindex(compact::IncrementalCompact, ssa::OldSSAValue)
818825
return compact.ir.new_nodes.stmts[id]
819826
end
820827
id -= length(compact.ir.new_nodes)
828+
(id length(compact.pending_nodes.stmts)) || throw(InvalidIRError())
821829
return compact.pending_nodes.stmts[id]
822830
end
823831

824832
function getindex(compact::IncrementalCompact, ssa::NewSSAValue)
825833
if ssa.id < 0
826-
return compact.new_new_nodes.stmts[-ssa.id]
834+
stmts = compact.new_new_nodes.stmts
835+
(-ssa.id length(stmts)) || throw(InvalidIRError())
836+
return stmts[-ssa.id]
827837
else
828838
return compact[SSAValue(ssa.id)]
829839
end
@@ -1069,6 +1079,7 @@ function getindex(view::TypesView, v::OldSSAValue)
10691079
id = v.id
10701080
ir = view.ir.ir
10711081
stmts = ir.stmts
1082+
(id 1) || throw(InvalidIRError())
10721083
if id <= length(stmts)
10731084
return stmts[id][:type]
10741085
end
@@ -1077,7 +1088,9 @@ function getindex(view::TypesView, v::OldSSAValue)
10771088
return ir.new_nodes.stmts[id][:type]
10781089
end
10791090
id -= length(ir.new_nodes)
1080-
return view.ir.pending_nodes.stmts[id][:type]
1091+
stmts = view.ir.pending_nodes.stmts
1092+
(id length(stmts)) || throw(InvalidIRError())
1093+
return stmts[id][:type]
10811094
end
10821095

10831096
function kill_current_use!(compact::IncrementalCompact, @nospecialize(val))
@@ -1204,20 +1217,27 @@ end
12041217

12051218
getindex(view::TypesView, idx::SSAValue) = getindex(view, idx.id)
12061219
function getindex(view::TypesView, idx::Int)
1220+
(idx 1) || throw(InvalidIRError())
12071221
if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx
12081222
return view.ir.result[idx][:type]
12091223
elseif isa(view.ir, IncrementalCompact) && view.ir.renamed_new_nodes
12101224
if idx <= length(view.ir.result)
12111225
return view.ir.result[idx][:type]
12121226
else
1213-
return view.ir.new_new_nodes.stmts[idx - length(view.ir.result)][:type]
1227+
idx -= length(view.ir.result)
1228+
stmts = view.ir.new_new_nodes.stmts
1229+
(idx length(stmts)) || throw(InvalidIRError())
1230+
return stmts[idx][:type]
12141231
end
12151232
else
12161233
ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir
12171234
if idx <= length(ir.stmts)
12181235
return ir.stmts[idx][:type]
12191236
else
1220-
return ir.new_nodes.stmts[idx - length(ir.stmts)][:type]
1237+
idx -= length(ir.stmts)
1238+
stmts = ir.new_nodes.stmts
1239+
(idx length(stmts)) || throw(InvalidIRError())
1240+
return stmts[idx][:type]
12211241
end
12221242
end
12231243
end

0 commit comments

Comments
 (0)