diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 7f2106b8f00701..8eb85bc3a509c0 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -483,15 +483,15 @@ function add_call_backedges!(interp::AbstractInterpreter, end end for edge in edges - add_backedge!(edge, sv) + add_backedge!(sv, edge) end # also need an edge to the method table in case something gets # added that did not intersect with any existing method if isa(matches, MethodMatches) - matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv) + matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype) else for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts) - thisfullmatch || add_mt_backedge!(mt, atype, sv) + thisfullmatch || add_mt_backedge!(sv, mt, atype) end end end @@ -882,7 +882,11 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul end res = concrete_eval_call(interp, f, result, arginfo, sv) if isa(res, ConstCallResults) - add_backedge!(res.const_result.mi, sv, invoketypes) + if invoketypes === nothing + add_backedge!(sv, res.const_result.mi) + else + add_invoke_backedge!(sv, invoketypes, res.const_result.mi) + end return res end mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv) @@ -929,7 +933,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul result = inf_result.result # if constant inference hits a cycle, just bail out isa(result, InferenceState) && return nothing - add_backedge!(mi, sv) + add_backedge!(sv, mi) return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects) end @@ -1685,7 +1689,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn ti = tienv[1]; env = tienv[2]::SimpleVector result = abstract_call_method(interp, method, ti, env, false, sv) (; rt, edge, effects) = result - edge !== nothing && add_backedge!(edge::MethodInstance, sv, types) + edge !== nothing && add_invoke_backedge!(sv, types, edge::MethodInstance) match = MethodMatch(ti, env, method, argtype <: method.sig) res = nothing sig = match.spec_types @@ -1839,7 +1843,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, sig = argtypes_to_type(arginfo.argtypes) result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv) (; rt, edge, effects) = result - edge !== nothing && add_backedge!(edge, sv) + edge !== nothing && add_backedge!(sv, edge) tt = closure.typ sigT = (unwrap_unionall(tt)::DataType).parameters[1] match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt)) diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index e9bd7474d265de..e1d20f01042c47 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -478,38 +478,49 @@ function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceSta return nothing end -function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int) +function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, currpc::Int) update_valid_age!(frame, caller) backedge = (caller, currpc) contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge) - add_backedge!(frame.linfo, caller) + add_backedge!(caller, frame.linfo) return frame end # temporarily accumulate our edges to later add as backedges in the callee -function add_backedge!(li::MethodInstance, caller::InferenceState, @nospecialize(invokesig=nothing)) - isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs - edges = caller.stmt_edges[caller.currpc] - if edges === nothing - edges = caller.stmt_edges[caller.currpc] = [] +function add_backedge!(caller::InferenceState, li::MethodInstance) + edges = get_stmt_edges!(caller) + if edges !== nothing + push!(edges, li) end - if invokesig !== nothing - push!(edges, invokesig) + return nothing +end + +function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance) + edges = get_stmt_edges!(caller) + if edges !== nothing + push!(edges, invokesig, li) end - push!(edges, li) return nothing end # used to temporarily accumulate our no method errors to later add as backedges in the callee method table -function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::InferenceState) - isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs +function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ)) + edges = get_stmt_edges!(caller) + if edges !== nothing + push!(edges, mt, typ) + end + return nothing +end + +function get_stmt_edges!(caller::InferenceState) + if !isa(caller.linfo.def, Method) + return nothing # don't add backedges to toplevel exprs + end edges = caller.stmt_edges[caller.currpc] if edges === nothing edges = caller.stmt_edges[caller.currpc] = [] end - push!(edges, mt) - push!(edges, typ) - return nothing + return edges end function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 1b6d19bc152b6a..51cc95a728ff07 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -65,9 +65,13 @@ intersect!(et::EdgeTracker, range::WorldRange) = et.valid_worlds[] = intersect(et.valid_worlds[], range) push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi) -function add_edge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance) - invokesig === nothing && return push!(et.edges, mi) +function add_backedge!(et::EdgeTracker, mi::MethodInstance) + push!(et.edges, mi) + return nothing +end +function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance) push!(et.edges, invokesig, mi) + return nothing end function push!(et::EdgeTracker, ci::CodeInstance) intersect!(et, WorldRange(min_world(li), max_world(li))) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index d6920a0d482405..f8fef7c3dd2085 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -831,7 +831,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8) inferred_src = match.src if isa(inferred_src, ConstAPI) # use constant calling convention - et !== nothing && add_edge!(et, invokesig, mi) + if et !== nothing + if invokesig === nothing + add_backedge!(et, mi) + else + add_invoke_backedge!(et, invokesig, mi) + end + end return ConstantCase(quoted(inferred_src.val)) else src = inferred_src # ::Union{Nothing,CodeInfo} for NativeInterpreter @@ -842,7 +848,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8) if code isa CodeInstance if use_const_api(code) # in this case function can be inlined to a constant - et !== nothing && add_edge!(et, invokesig, mi) + if et !== nothing + if invokesig === nothing + add_backedge!(et, mi) + else + add_invoke_backedge!(et, invokesig, mi) + end + end return ConstantCase(quoted(code.rettype_const)) else src = @atomic :monotonic code.inferred @@ -866,7 +878,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8) src === nothing && return compileable_specialization(et, match, effects; compilesig_invokes=state.params.compilesig_invokes) - et !== nothing && add_edge!(et, invokesig, mi) + if et !== nothing + if invokesig === nothing + add_backedge!(et, mi) + else + add_invoke_backedge!(et, invokesig, mi) + end + end return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects) end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index dad2ba0a7e3b14..c5d19166ad7d0f 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -806,7 +806,7 @@ function merge_call_chain!(interp::AbstractInterpreter, parent::InferenceState, # of recursion. merge_effects!(interp, parent, Effects(EFFECTS_TOTAL; terminates=false)) while true - add_cycle_backedge!(child, parent, parent.currpc) + add_cycle_backedge!(parent, child, parent.currpc) union_caller_cycle!(ancestor, child) merge_effects!(interp, child, Effects(EFFECTS_TOTAL; terminates=false)) child = parent