Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: fixes and improvements for backedge computation #46741

Merged
merged 5 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
for sig_n in splitsigs
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv)
(; rt, edge, effects) = result
edge === nothing || push!(edges, edge)
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp, result,
Expand All @@ -135,12 +134,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if const_call_result !== nothing
if const_call_result.rt ᵢ rt
rt = const_call_result.rt
(; effects, const_result) = const_call_result
(; effects, const_result, edge) = const_call_result
end
end
all_effects = merge_effects(all_effects, effects)
push!(const_results, const_result)
any_const_result |= const_result !== nothing
edge === nothing || push!(edges, edge)
this_rt = tmerge(this_rt, rt)
if bail_out_call(interp, this_rt, sv)
break
Expand All @@ -153,7 +153,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
(; rt, edge, effects) = result
this_conditional = ignorelimited(rt)
this_rt = widenwrappedconditional(rt)
edge === nothing || push!(edges, edge)
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
Expand All @@ -169,12 +168,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if this_const_rt ᵢ this_rt
this_conditional = this_const_conditional
this_rt = this_const_rt
(; effects, const_result) = const_call_result
(; effects, const_result, edge) = const_call_result
end
end
all_effects = merge_effects(all_effects, effects)
push!(const_results, const_result)
any_const_result |= const_result !== nothing
edge === nothing || push!(edges, edge)
end
@assert !(this_conditional isa Conditional) "invalid lattice element returned from inter-procedural context"
seen += 1
Expand Down Expand Up @@ -483,15 +483,15 @@ function add_call_backedges!(interp::AbstractInterpreter,
end
end
for edge in edges
add_backedge!(edge, sv)
add_backedge!(sv, edge)
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
thisfullmatch || add_mt_backedge!(mt, atype, sv)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
end
Expand Down Expand Up @@ -838,17 +838,18 @@ function concrete_eval_call(interp::AbstractInterpreter,
f = invoke
end
world = get_world_counter(interp)
edge = result.edge::MethodInstance
value = try
Core._call_in_world_total(world, f, args...)
catch
# The evaluation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
return ConstCallResults(Union{}, ConcreteResult(result.edge::MethodInstance, result.effects), result.effects)
return ConstCallResults(Union{}, ConcreteResult(edge, result.effects), result.effects, edge)
end
if is_inlineable_constant(value) || call_result_unused(sv)
# If the constant is not inlineable, still do the const-prop, since the
# code that led to the creation of the Const may be inlineable in the same
# circumstance and may be optimizable.
return ConstCallResults(Const(value), ConcreteResult(result.edge::MethodInstance, EFFECTS_TOTAL, value), EFFECTS_TOTAL)
return ConstCallResults(Const(value), ConcreteResult(edge, EFFECTS_TOTAL, value), EFFECTS_TOTAL, edge)
end
return false
else # eligible for semi-concrete evaluation
Expand All @@ -875,10 +876,12 @@ struct ConstCallResults
rt::Any
const_result::ConstResult
effects::Effects
edge::MethodInstance
ConstCallResults(@nospecialize(rt),
const_result::ConstResult,
effects::Effects) =
new(rt, const_result, effects)
effects::Effects,
edge::MethodInstance) =
new(rt, const_result, effects, edge)
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter,
Expand All @@ -888,10 +891,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
return nothing
end
res = concrete_eval_call(interp, f, result, arginfo, sv, invokecall)
if isa(res, ConstCallResults)
add_backedge!(res.const_result.mi, sv, invokecall === nothing ? nothing : invokecall.lookupsig)
return res
end
isa(res, ConstCallResults) && return res
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
mi === nothing && return nothing
# try semi-concrete evaluation
Expand All @@ -903,7 +903,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
if isa(ir, IRCode)
T = ir_abstract_constant_propagation(interp, mi_cache, sv, mi, ir, arginfo.argtypes)
if !isa(T, Type) || typeintersect(T, Bool) === Union{}
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects)
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects, mi)
end
end
end
Expand Down Expand Up @@ -936,8 +936,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
result = inf_result.result
# if constant inference hits a cycle, just bail out
isa(result, InferenceState) && return nothing
add_backedge!(mi, sv)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects, mi)
end

# if there's a possibility we could get a better result with these constant arguments
Expand Down Expand Up @@ -1692,7 +1691,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge::MethodInstance, sv, lookupsig)
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand All @@ -1711,10 +1709,11 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
const_result = nothing
if const_call_result !== nothing
if (typeinf_lattice(interp), const_call_result.rt, rt)
(; rt, effects, const_result) = const_call_result
(; rt, effects, const_result, edge) = const_call_result
end
end
effects = Effects(effects; nonoverlayed=!overlayed)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
return CallMeta(from_interprocedural!(ipo_lattice(interp), rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result))
end

Expand Down Expand Up @@ -1848,7 +1847,6 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
sig = argtypes_to_type(arginfo.argtypes)
result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge, sv)
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
Expand All @@ -1858,7 +1856,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
nothing, arginfo, match, sv)
if const_call_result !== nothing
if const_call_result.rt rt
(; rt, effects, const_result) = const_call_result
(; rt, effects, const_result, edge) = const_call_result
end
end
end
Expand All @@ -1874,6 +1872,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
end
end
rt = from_interprocedural!(ipo, rt, sv, arginfo, match.spec_types)
edge !== nothing && add_backedge!(sv, edge)
return CallMeta(rt, effects, info)
end

Expand Down
41 changes: 26 additions & 15 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,38 +478,49 @@ function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceSta
return nothing
end

function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)
function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, currpc::Int)
update_valid_age!(frame, caller)
backedge = (caller, currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
add_backedge!(frame.linfo, caller)
add_backedge!(caller, frame.linfo)
return frame
end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(li::MethodInstance, caller::InferenceState, @nospecialize(invokesig=nothing))
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
function add_backedge!(caller::InferenceState, li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, li)
end
if invokesig !== nothing
push!(edges, invokesig)
return nothing
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, invokesig, li)
end
push!(edges, li)
return nothing
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::InferenceState)
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
end
return nothing
end

function get_stmt_edges!(caller::InferenceState)
if !isa(caller.linfo.def, Method)
return nothing # don't add backedges to toplevel exprs
end
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
end
push!(edges, mt)
push!(edges, typ)
return nothing
return edges
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
Expand Down
13 changes: 6 additions & 7 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ EdgeTracker() = EdgeTracker(Any[], 0:typemax(UInt))
intersect!(et::EdgeTracker, range::WorldRange) =
et.valid_worlds[] = intersect(et.valid_worlds[], range)

push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi)
function add_edge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
invokesig === nothing && return push!(et.edges, mi)
push!(et.edges, invokesig, mi)
function add_backedge!(et::EdgeTracker, mi::MethodInstance)
push!(et.edges, mi)
return nothing
end
function push!(et::EdgeTracker, ci::CodeInstance)
intersect!(et, WorldRange(min_world(li), max_world(li)))
push!(et, ci.def)
function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
push!(et.edges, invokesig, mi)
return nothing
end

struct InliningState{S <: Union{EdgeTracker, Nothing}, MICache, I<:AbstractInterpreter}
Expand Down
Loading