Skip to content

Commit

Permalink
inference: fix the correctness of inference bail out interface (#48826)
Browse files Browse the repository at this point in the history
Since we allow overloading of the `bail_out_xxx` hooks, we need to make
sure that we widen both type and effects to the top when bailing on
inference regardless of the condition presumed by a hook.

This commit particularly fixes the correctness of `bail_out_apply`
(fixes #48807). I wanted to make a simplified test case for this, but
it turns out to be a bit tricky since it relies on the details of
multiple match analysis and the bail out logic.
  • Loading branch information
aviatesk authored Mar 1, 2023
1 parent be70dab commit 533a094
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
29 changes: 18 additions & 11 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
match = applicable[i]::MethodMatch
method = match.method
sig = match.spec_types
if bail_out_toplevel_call(interp, sig, sv)
if bail_out_toplevel_call(interp, InferenceLoopState(sig, rettype, all_effects), sv)
# only infer concrete call sites in top-level expressions
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
rettype = Any
break
end
this_rt = Bottom
Expand Down Expand Up @@ -190,8 +189,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
conditionals[2][i] = tmerge(conditionals[2][i], cnd.elsetype)
end
end
if bail_out_call(interp, rettype, sv, effects)
add_remark!(interp, sv, "One of the matched returned maximally imprecise information. Bailing on call.")
if bail_out_call(interp, InferenceLoopState(sig, rettype, all_effects), sv)
add_remark!(interp, sv, "Call inference reached maximally imprecise information. Bailing on.")
break
end
end
Expand All @@ -201,7 +200,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
info = ConstCallInfo(info, const_results)
end

if seen != napplicable
if seen napplicable
# there is unanalyzed candidate, widen type and effects to the top
rettype = Any
# there may be unanalyzed effects within unseen dispatch candidate,
# but we can still ignore nonoverlayed effect here since we already accounted for it
all_effects = merge_effects(all_effects, EFFECTS_UNKNOWN)
Expand Down Expand Up @@ -1545,7 +1546,9 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
end
retinfos = ApplyCallInfo[]
retinfo = UnionSplitApplyCallInfo(retinfos)
for i = 1:length(ctypes)
napplicable = length(ctypes)
seen = 0
for i = 1:napplicable
ct = ctypes[i]
arginfo = infos[i]
lct = length(ct)
Expand All @@ -1559,17 +1562,21 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
end
end
call = abstract_call(interp, ArgInfo(nothing, ct), si, sv, max_methods)
seen += 1
push!(retinfos, ApplyCallInfo(call.info, arginfo))
res = tmerge(res, call.rt)
effects = merge_effects(effects, call.effects)
if bail_out_apply(interp, res, sv)
if i != length(ctypes)
# No point carrying forward the info, we're not gonna inline it anyway
retinfo = NoCallInfo()
end
if bail_out_apply(interp, InferenceLoopState(ct, res, effects), sv)
add_remark!(interp, sv, "_apply_iterate inference reached maximally imprecise information. Bailing on.")
break
end
end
if seen napplicable
# there is unanalyzed candidate, widen type and effects to the top
res = Any
effects = Effects()
retinfo = NoCallInfo() # NOTE this is necessary to prevent the inlining processing
end
# TODO: Add a special info type to capture all the iteration info.
# For now, only propagate info if we don't also union-split the iteration
return CallMeta(res, effects, retinfo)
Expand Down
21 changes: 15 additions & 6 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,23 @@ is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(overr

add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return

function bail_out_toplevel_call(::AbstractInterpreter, @nospecialize(callsig), sv::Union{InferenceState, IRCode})
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(callsig)
struct InferenceLoopState
sig
rt
effects::Effects
function InferenceLoopState(@nospecialize(sig), @nospecialize(rt), effects::Effects)
new(sig, rt, effects)
end
end

function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
end
function bail_out_call(::AbstractInterpreter, @nospecialize(rt), sv::Union{InferenceState, IRCode}, effects::Effects)
return rt === Any && !is_foldable(effects)
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
return state.rt === Any && !is_foldable(state.effects)
end
function bail_out_apply(::AbstractInterpreter, @nospecialize(rt), sv::Union{InferenceState, IRCode})
return rt === Any
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
return state.rt === Any
end

was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND
Expand Down

0 comments on commit 533a094

Please sign in to comment.