Skip to content

Commit ab520c7

Browse files
authored
Add hook in inference recursion resolution for external AbstractInterpreter (#36401)
This extends hookability to the same-frame comparison in inference's recursion cycle detection. The case I ran into that made this necessary is a recursive, nested AD transform. In this case, inference must detect if two frames have different orders of derivatives (e.g. the primitive for `-`, again calls `-`; the external interpreter makes sure that inference results for these end up in different caches).
1 parent 063525f commit ab520c7

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

base/compiler/inferencestate.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ mutable struct InferenceState
4646
# `max_valid`, to be used in inlining
4747
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt}}
4848

49+
# The interpreter that created this inference state. Not looked at by
50+
# NativeInterpreter. But other interpreters may use this to detect cycles
51+
interp::AbstractInterpreter
52+
4953
# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
5054
function InferenceState(result::InferenceResult, src::CodeInfo,
5155
cached::Bool, interp::AbstractInterpreter)
@@ -107,7 +111,8 @@ mutable struct InferenceState
107111
Vector{InferenceState}(), # callers_in_cycle
108112
#=parent=#nothing,
109113
cached, false, false, false,
110-
IdDict{Any, Tuple{Any, UInt, UInt}}())
114+
IdDict{Any, Tuple{Any, UInt, UInt}}(),
115+
interp)
111116
result.result = frame
112117
cached && push!(get_inference_cache(interp), result)
113118
return frame

base/compiler/typeinfer.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,21 +439,25 @@ function merge_call_chain!(parent::InferenceState, ancestor::InferenceState, chi
439439
end
440440
end
441441

442+
function is_same_frame(interp::AbstractInterpreter, linfo::MethodInstance, frame::InferenceState)
443+
return linfo === frame.linfo
444+
end
445+
442446
# Walk through `linfo`'s upstream call chain, starting at `parent`. If a parent
443447
# frame matching `linfo` is encountered, then there is a cycle in the call graph
444448
# (i.e. `linfo` is a descendant callee of itself). Upon encountering this cycle,
445449
# we "resolve" it by merging the call chain, which entails unioning each intermediary
446450
# frame's `callers_in_cycle` field and adding the appropriate backedges. Finally,
447451
# we return `linfo`'s pre-existing frame. If no cycles are found, `nothing` is
448452
# returned instead.
449-
function resolve_call_cycle!(linfo::MethodInstance, parent::InferenceState)
453+
function resolve_call_cycle!(interp::AbstractInterpreter, linfo::MethodInstance, parent::InferenceState)
450454
frame = parent
451455
uncached = false
452456
limited = false
453457
while isa(frame, InferenceState)
454458
uncached |= !frame.cached # ensure we never add an uncached frame to a cycle
455459
limited |= frame.limited
456-
if frame.linfo === linfo
460+
if is_same_frame(interp, linfo, frame)
457461
if uncached
458462
# our attempt to speculate into a constant call lead to an undesired self-cycle
459463
# that cannot be converged: poison our call-stack (up to the discovered duplicate frame)
@@ -465,7 +469,7 @@ function resolve_call_cycle!(linfo::MethodInstance, parent::InferenceState)
465469
return frame
466470
end
467471
for caller in frame.callers_in_cycle
468-
if caller.linfo === linfo
472+
if is_same_frame(interp, linfo, caller)
469473
if uncached
470474
poison_callstack(parent, frame, false)
471475
return true
@@ -496,7 +500,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
496500
# (if we asked resolve_call_cyle, it might instead detect that there is a cycle that it can't merge)
497501
frame = false
498502
else
499-
frame = resolve_call_cycle!(mi, caller)
503+
frame = resolve_call_cycle!(interp, mi, caller)
500504
end
501505
if frame === false
502506
# completely new

0 commit comments

Comments
 (0)