Skip to content

Commit

Permalink
inference: setup separate functions for each backedge kind
Browse files Browse the repository at this point in the history
Also changes the argument list so that they are ordered as
`(caller, [backedge information])`.
  • Loading branch information
aviatesk committed Sep 13, 2022
1 parent 70bfa3f commit 386e2c7
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 28 deletions.
18 changes: 11 additions & 7 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,15 +483,15 @@ function add_call_backedges!(interp::AbstractInterpreter,
end
end
for edge in edges
add_backedge!(edge, sv)
add_backedge!(sv, edge)
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
thisfullmatch || add_mt_backedge!(mt, atype, sv)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
end
Expand Down Expand Up @@ -882,7 +882,11 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
end
res = concrete_eval_call(interp, f, result, arginfo, sv)
if isa(res, ConstCallResults)
add_backedge!(res.const_result.mi, sv, invoketypes)
if invoketypes === nothing
add_backedge!(sv, res.const_result.mi)
else
add_invoke_backedge!(sv, invoketypes, res.const_result.mi)
end
return res
end
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
Expand Down Expand Up @@ -929,7 +933,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
result = inf_result.result
# if constant inference hits a cycle, just bail out
isa(result, InferenceState) && return nothing
add_backedge!(mi, sv)
add_backedge!(sv, mi)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects)
end

Expand Down Expand Up @@ -1685,7 +1689,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge::MethodInstance, sv, types)
edge !== nothing && add_invoke_backedge!(sv, types, edge::MethodInstance)
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand Down Expand Up @@ -1839,7 +1843,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
sig = argtypes_to_type(arginfo.argtypes)
result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge, sv)
edge !== nothing && add_backedge!(sv, edge)
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
Expand Down
41 changes: 26 additions & 15 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,38 +478,49 @@ function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceSta
return nothing
end

function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)
function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, currpc::Int)
update_valid_age!(frame, caller)
backedge = (caller, currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
add_backedge!(frame.linfo, caller)
add_backedge!(caller, frame.linfo)
return frame
end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(li::MethodInstance, caller::InferenceState, @nospecialize(invokesig=nothing))
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
function add_backedge!(caller::InferenceState, li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, li)
end
if invokesig !== nothing
push!(edges, invokesig)
return nothing
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, invokesig, li)
end
push!(edges, li)
return nothing
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::InferenceState)
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
end
return nothing
end

function get_stmt_edges!(caller::InferenceState)
if !isa(caller.linfo.def, Method)
return nothing # don't add backedges to toplevel exprs
end
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
end
push!(edges, mt)
push!(edges, typ)
return nothing
return edges
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
Expand Down
8 changes: 6 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ intersect!(et::EdgeTracker, range::WorldRange) =
et.valid_worlds[] = intersect(et.valid_worlds[], range)

push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi)
function add_edge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
invokesig === nothing && return push!(et.edges, mi)
function add_backedge!(et::EdgeTracker, mi::MethodInstance)
push!(et.edges, mi)
return nothing
end
function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
push!(et.edges, invokesig, mi)
return nothing
end
function push!(et::EdgeTracker, ci::CodeInstance)
intersect!(et, WorldRange(min_world(li), max_world(li)))
Expand Down
24 changes: 21 additions & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
inferred_src = match.src
if isa(inferred_src, ConstAPI)
# use constant calling convention
et !== nothing && add_edge!(et, invokesig, mi)
if et !== nothing
if invokesig === nothing
add_backedge!(et, mi)
else
add_invoke_backedge!(et, invokesig, mi)
end
end
return ConstantCase(quoted(inferred_src.val))
else
src = inferred_src # ::Union{Nothing,CodeInfo} for NativeInterpreter
Expand All @@ -842,7 +848,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
if code isa CodeInstance
if use_const_api(code)
# in this case function can be inlined to a constant
et !== nothing && add_edge!(et, invokesig, mi)
if et !== nothing
if invokesig === nothing
add_backedge!(et, mi)
else
add_invoke_backedge!(et, invokesig, mi)
end
end
return ConstantCase(quoted(code.rettype_const))
else
src = @atomic :monotonic code.inferred
Expand All @@ -866,7 +878,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
src === nothing && return compileable_specialization(et, match, effects;
compilesig_invokes=state.params.compilesig_invokes)

et !== nothing && add_edge!(et, invokesig, mi)
if et !== nothing
if invokesig === nothing
add_backedge!(et, mi)
else
add_invoke_backedge!(et, invokesig, mi)
end
end
return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects)
end

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ function merge_call_chain!(interp::AbstractInterpreter, parent::InferenceState,
# of recursion.
merge_effects!(interp, parent, Effects(EFFECTS_TOTAL; terminates=false))
while true
add_cycle_backedge!(child, parent, parent.currpc)
add_cycle_backedge!(parent, child, parent.currpc)
union_caller_cycle!(ancestor, child)
merge_effects!(interp, child, Effects(EFFECTS_TOTAL; terminates=false))
child = parent
Expand Down

0 comments on commit 386e2c7

Please sign in to comment.