Skip to content

Commit

Permalink
inference: fixes and improvements for backedge computation (#46741)
Browse files Browse the repository at this point in the history
This commit consists of the following changes:

* inference: setup separate functions for each backedge kind

  Also changes the argument list so that they are ordered as
  `(caller, [backedge information])`.

* inference: fix backedge computation for const-prop'ed callsite

  With this commit `abstract_call_method_with_const_args` doesn't add
  backedge but rather returns the backedge to the caller, letting the
  callers like `abstract_call_gf_by_type` and `abstract_invoke` take the
  responsibility to add backedge to current context appropriately.
  
  As a result, this fixes the backedge calculation for const-prop'ed
  `invoke` callsite.
  
  For example, for the following call graph,
  ```julia
  foo(a::Int) = a > 0 ? :int : println(a)
  foo(a::Integer) = a > 0 ? "integer" : println(a)
  
  bar(a::Int) = @invoke foo(a::Integer)
  ```
  
  Previously we added the wrong backedge `nothing, bar(Int64) from bar(Int64)`:
  ```julia
  julia> last(only(code_typed(()->bar(42))))
  String
  
  julia> let m = only(methods(foo, (UInt,)))
  @eval Core.Compiler for (sig, caller) in BackedgeIterator($m.specializations[1].backedges)
  println(sig, ", ", caller)
  end
  end
  Tuple{typeof(Main.foo), Integer}, bar(Int64) from bar(Int64)
  nothing, bar(Int64) from bar(Int64)
  ```
  but now we only add `invoke`-backedge:
  ```julia
  julia> last(only(code_typed(()->bar(42))))
  String
  
  julia> let m = only(methods(foo, (UInt,)))
  @eval Core.Compiler for (sig, caller) in BackedgeIterator($m.specializations[1].backedges)
  println(sig, ", ", caller)
  end
  end
  Tuple{typeof(Main.foo), Integer}, bar(Int64) from bar(Int64)
  ```

* inference: make `BackedgePair` struct

* add invalidation test for `invoke` call

* optimizer: fixup inlining backedge calculation

  Should fix the following backedge calculation:
  ```julia
  julia> m = which(unique, Tuple{Any})
  unique(itr)
    @ Base set.jl:170
  
  julia> specs = collect(Iterators.filter(m.specializations) do mi
             mi === nothing && return false
             return mi.specTypes.parameters[end] === Vector{Int}   # find specialization of `unique(::Any)` for `::Vector{Int}`
         end)
  Any[]
  
  julia> Base._unique_dims([1,2,3],:)   # no existing callers with specialization `Vector{Int}`, let's make one
  3-element Vector{Int64}:
   1
   2
   3
  
  julia> mi = only(Iterators.filter(m.specializations) do mi
             mi === nothing && return false
             return mi.specTypes.parameters[end] === Vector{Int}   # find specialization of `unique(::Any)` for `::Vector{Int}`
         end)
  MethodInstance for unique(::Vector{Int64})
  
  julia> mi.def
  unique(itr)
    @ Base set.jl:170
  
  julia> mi.backedges
  3-element Vector{Any}:
   Tuple{typeof(unique), Any}
   MethodInstance for Base._unique_dims(::Vector{Int64}, ::Colon)
   MethodInstance for Base._unique_dims(::Vector{Int64}, ::Colon) # <= now we don't register this backedge
 ```
  • Loading branch information
aviatesk authored Sep 15, 2022
1 parent 94c3a15 commit 997e336
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 106 deletions.
43 changes: 21 additions & 22 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
for sig_n in splitsigs
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv)
(; rt, edge, effects) = result
edge === nothing || push!(edges, edge)
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp, result,
Expand All @@ -135,12 +134,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if const_call_result !== nothing
if const_call_result.rt ᵢ rt
rt = const_call_result.rt
(; effects, const_result) = const_call_result
(; effects, const_result, edge) = const_call_result
end
end
all_effects = merge_effects(all_effects, effects)
push!(const_results, const_result)
any_const_result |= const_result !== nothing
edge === nothing || push!(edges, edge)
this_rt = tmerge(this_rt, rt)
if bail_out_call(interp, this_rt, sv)
break
Expand All @@ -153,7 +153,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
(; rt, edge, effects) = result
this_conditional = ignorelimited(rt)
this_rt = widenwrappedconditional(rt)
edge === nothing || push!(edges, edge)
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
Expand All @@ -169,12 +168,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if this_const_rt ᵢ this_rt
this_conditional = this_const_conditional
this_rt = this_const_rt
(; effects, const_result) = const_call_result
(; effects, const_result, edge) = const_call_result
end
end
all_effects = merge_effects(all_effects, effects)
push!(const_results, const_result)
any_const_result |= const_result !== nothing
edge === nothing || push!(edges, edge)
end
@assert !(this_conditional isa Conditional) "invalid lattice element returned from inter-procedural context"
seen += 1
Expand Down Expand Up @@ -483,15 +483,15 @@ function add_call_backedges!(interp::AbstractInterpreter,
end
end
for edge in edges
add_backedge!(edge, sv)
add_backedge!(sv, edge)
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
thisfullmatch || add_mt_backedge!(mt, atype, sv)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
end
Expand Down Expand Up @@ -838,17 +838,18 @@ function concrete_eval_call(interp::AbstractInterpreter,
f = invoke
end
world = get_world_counter(interp)
edge = result.edge::MethodInstance
value = try
Core._call_in_world_total(world, f, args...)
catch
# The evaluation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
return ConstCallResults(Union{}, ConcreteResult(result.edge::MethodInstance, result.effects), result.effects)
return ConstCallResults(Union{}, ConcreteResult(edge, result.effects), result.effects, edge)
end
if is_inlineable_constant(value) || call_result_unused(sv)
# If the constant is not inlineable, still do the const-prop, since the
# code that led to the creation of the Const may be inlineable in the same
# circumstance and may be optimizable.
return ConstCallResults(Const(value), ConcreteResult(result.edge::MethodInstance, EFFECTS_TOTAL, value), EFFECTS_TOTAL)
return ConstCallResults(Const(value), ConcreteResult(edge, EFFECTS_TOTAL, value), EFFECTS_TOTAL, edge)
end
return false
else # eligible for semi-concrete evaluation
Expand All @@ -875,10 +876,12 @@ struct ConstCallResults
rt::Any
const_result::ConstResult
effects::Effects
edge::MethodInstance
ConstCallResults(@nospecialize(rt),
const_result::ConstResult,
effects::Effects) =
new(rt, const_result, effects)
effects::Effects,
edge::MethodInstance) =
new(rt, const_result, effects, edge)
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter,
Expand All @@ -888,10 +891,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
return nothing
end
res = concrete_eval_call(interp, f, result, arginfo, sv, invokecall)
if isa(res, ConstCallResults)
add_backedge!(res.const_result.mi, sv, invokecall === nothing ? nothing : invokecall.lookupsig)
return res
end
isa(res, ConstCallResults) && return res
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
mi === nothing && return nothing
# try semi-concrete evaluation
Expand All @@ -903,7 +903,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
if isa(ir, IRCode)
T = ir_abstract_constant_propagation(interp, mi_cache, sv, mi, ir, arginfo.argtypes)
if !isa(T, Type) || typeintersect(T, Bool) === Union{}
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects)
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects, mi)
end
end
end
Expand Down Expand Up @@ -936,8 +936,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
result = inf_result.result
# if constant inference hits a cycle, just bail out
isa(result, InferenceState) && return nothing
add_backedge!(mi, sv)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects, mi)
end

# if there's a possibility we could get a better result with these constant arguments
Expand Down Expand Up @@ -1692,7 +1691,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge::MethodInstance, sv, lookupsig)
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand All @@ -1711,10 +1709,11 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
const_result = nothing
if const_call_result !== nothing
if (typeinf_lattice(interp), const_call_result.rt, rt)
(; rt, effects, const_result) = const_call_result
(; rt, effects, const_result, edge) = const_call_result
end
end
effects = Effects(effects; nonoverlayed=!overlayed)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
return CallMeta(from_interprocedural!(ipo_lattice(interp), rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result))
end

Expand Down Expand Up @@ -1848,7 +1847,6 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
sig = argtypes_to_type(arginfo.argtypes)
result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge, sv)
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
Expand All @@ -1858,7 +1856,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
nothing, arginfo, match, sv)
if const_call_result !== nothing
if const_call_result.rt rt
(; rt, effects, const_result) = const_call_result
(; rt, effects, const_result, edge) = const_call_result
end
end
end
Expand All @@ -1874,6 +1872,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
end
end
rt = from_interprocedural!(ipo, rt, sv, arginfo, match.spec_types)
edge !== nothing && add_backedge!(sv, edge)
return CallMeta(rt, effects, info)
end

Expand Down
41 changes: 26 additions & 15 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,38 +478,49 @@ function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceSta
return nothing
end

function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)
function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, currpc::Int)
update_valid_age!(frame, caller)
backedge = (caller, currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
add_backedge!(frame.linfo, caller)
add_backedge!(caller, frame.linfo)
return frame
end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(li::MethodInstance, caller::InferenceState, @nospecialize(invokesig=nothing))
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
function add_backedge!(caller::InferenceState, li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, li)
end
if invokesig !== nothing
push!(edges, invokesig)
return nothing
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, invokesig, li)
end
push!(edges, li)
return nothing
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::InferenceState)
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
end
return nothing
end

function get_stmt_edges!(caller::InferenceState)
if !isa(caller.linfo.def, Method)
return nothing # don't add backedges to toplevel exprs
end
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
end
push!(edges, mt)
push!(edges, typ)
return nothing
return edges
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
Expand Down
13 changes: 6 additions & 7 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ EdgeTracker() = EdgeTracker(Any[], 0:typemax(UInt))
intersect!(et::EdgeTracker, range::WorldRange) =
et.valid_worlds[] = intersect(et.valid_worlds[], range)

push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi)
function add_edge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
invokesig === nothing && return push!(et.edges, mi)
push!(et.edges, invokesig, mi)
function add_backedge!(et::EdgeTracker, mi::MethodInstance)
push!(et.edges, mi)
return nothing
end
function push!(et::EdgeTracker, ci::CodeInstance)
intersect!(et, WorldRange(min_world(li), max_world(li)))
push!(et, ci.def)
function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
push!(et.edges, invokesig, mi)
return nothing
end

struct InliningState{S <: Union{EdgeTracker, Nothing}, MICache, I<:AbstractInterpreter}
Expand Down
Loading

0 comments on commit 997e336

Please sign in to comment.