Skip to content

Commit

Permalink
inference: fix backedge computation for const-prop'ed callsite
Browse files Browse the repository at this point in the history
With this commit `abstract_call_method_with_const_args` doesn't add
backedge but rather returns the backedge to the caller, letting the
callers like `abstract_call_gf_by_type` and `abstract_invoke` take the
responsibility to add backedge to current context appropriately.

As a result, this fixes the backedge calculation for const-prop'ed
`invoke` callsite.

For example, for the following call graph,
```julia
foo(a::Int) = a > 0 ? :int : println(a)
foo(a::Integer) = a > 0 ? "integer" : println(a)

bar(a::Int) = @invoke foo(a::Integer)
```

Previously we added the wrong backedge `nothing, bar(Int64) from bar(Int64)`:
```julia
julia> last(only(code_typed(()->bar(42))))
String

julia> let m = only(methods(foo, (UInt,)))
           @eval Core.Compiler for (sig, caller) in BackedgeIterator($m.specializations[1].backedges)
               println(sig, ", ", caller)
           end
       end
Tuple{typeof(Main.foo), Integer}, bar(Int64) from bar(Int64)
nothing, bar(Int64) from bar(Int64)
```
but now we only add `invoke`-backedge:
```julia
julia> last(only(code_typed(()->bar(42))))
String

julia> let m = only(methods(foo, (UInt,)))
           @eval Core.Compiler for (sig, caller) in BackedgeIterator($m.specializations[1].backedges)
               println(sig, ", ", caller)
           end
       end
Tuple{typeof(Main.foo), Integer}, bar(Int64) from bar(Int64)
```
  • Loading branch information
aviatesk committed Sep 13, 2022
1 parent 386e2c7 commit 9dd281d
Showing 1 changed file with 20 additions and 25 deletions.
45 changes: 20 additions & 25 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
for sig_n in splitsigs
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv)
(; rt, edge, effects) = result
edge === nothing || push!(edges, edge)
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp, result,
Expand All @@ -135,12 +134,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if const_call_result !== nothing
if const_call_result.rt ᵢ rt
rt = const_call_result.rt
(; effects, const_result) = const_call_result
(; effects, const_result, edge) = const_call_result
end
end
all_effects = merge_effects(all_effects, effects)
push!(const_results, const_result)
any_const_result |= const_result !== nothing
edge === nothing || push!(edges, edge)
this_rt = tmerge(this_rt, rt)
if bail_out_call(interp, this_rt, sv)
break
Expand All @@ -153,7 +153,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
(; rt, edge, effects) = result
this_conditional = ignorelimited(rt)
this_rt = widenwrappedconditional(rt)
edge === nothing || push!(edges, edge)
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
Expand All @@ -169,12 +168,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if this_const_rt ᵢ this_rt
this_conditional = this_const_conditional
this_rt = this_const_rt
(; effects, const_result) = const_call_result
(; effects, const_result, edge) = const_call_result
end
end
all_effects = merge_effects(all_effects, effects)
push!(const_results, const_result)
any_const_result |= const_result !== nothing
edge === nothing || push!(edges, edge)
end
@assert !(this_conditional isa Conditional) "invalid lattice element returned from inter-procedural context"
seen += 1
Expand Down Expand Up @@ -831,17 +831,18 @@ function concrete_eval_call(interp::AbstractInterpreter,
if eligible
args = collect_const_args(arginfo, #=start=#2)
world = get_world_counter(interp)
edge = result.edge::MethodInstance
value = try
Core._call_in_world_total(world, f, args...)
catch
# The evaluation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
return ConstCallResults(Union{}, ConcreteResult(result.edge::MethodInstance, result.effects), result.effects)
return ConstCallResults(Union{}, ConcreteResult(edge, result.effects), result.effects, edge)
end
if is_inlineable_constant(value) || call_result_unused(sv)
# If the constant is not inlineable, still do the const-prop, since the
# code that led to the creation of the Const may be inlineable in the same
# circumstance and may be optimizable.
return ConstCallResults(Const(value), ConcreteResult(result.edge::MethodInstance, EFFECTS_TOTAL, value), EFFECTS_TOTAL)
return ConstCallResults(Const(value), ConcreteResult(edge, EFFECTS_TOTAL, value), EFFECTS_TOTAL, edge)
end
return false
else # eligible for semi-concrete evaluation
Expand All @@ -868,27 +869,22 @@ struct ConstCallResults
rt::Any
const_result::ConstResult
effects::Effects
edge::MethodInstance
ConstCallResults(@nospecialize(rt),
const_result::ConstResult,
effects::Effects) =
new(rt, const_result, effects)
effects::Effects,
edge::MethodInstance) =
new(rt, const_result, effects, edge)
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult,
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState, @nospecialize(invoketypes=nothing))
sv::InferenceState)
if !const_prop_enabled(interp, sv, match)
return nothing
end
res = concrete_eval_call(interp, f, result, arginfo, sv)
if isa(res, ConstCallResults)
if invoketypes === nothing
add_backedge!(sv, res.const_result.mi)
else
add_invoke_backedge!(sv, invoketypes, res.const_result.mi)
end
return res
end
isa(res, ConstCallResults) && return res
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
mi === nothing && return nothing
# try semi-concrete evaluation
Expand All @@ -900,7 +896,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
if isa(ir, IRCode)
T = ir_abstract_constant_propagation(interp, mi_cache, sv, mi, ir, arginfo.argtypes)
if !isa(T, Type) || typeintersect(T, Bool) === Union{}
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects)
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects, mi)
end
end
end
Expand Down Expand Up @@ -933,8 +929,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
result = inf_result.result
# if constant inference hits a cycle, just bail out
isa(result, InferenceState) && return nothing
add_backedge!(sv, mi)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects, mi)
end

# if there's a possibility we could get a better result with these constant arguments
Expand Down Expand Up @@ -1689,7 +1684,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_invoke_backedge!(sv, types, edge::MethodInstance)
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand All @@ -1702,14 +1696,15 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
# argtypes′[i] = t ⊑ a ? t : a
# end
const_call_result = abstract_call_method_with_const_args(interp, result,
overlayed ? nothing : singleton_type(ft′), arginfo, match, sv, types)
overlayed ? nothing : singleton_type(ft′), arginfo, match, sv)
const_result = nothing
if const_call_result !== nothing
if (typeinf_lattice(interp), const_call_result.rt, rt)
(; rt, effects, const_result) = const_call_result
(; rt, effects, const_result, edge) = const_call_result
end
end
effects = Effects(effects; nonoverlayed=!overlayed)
edge !== nothing && add_invoke_backedge!(sv, types, edge)
return CallMeta(from_interprocedural!(ipo_lattice(interp), rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result))
end

Expand Down Expand Up @@ -1843,7 +1838,6 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
sig = argtypes_to_type(arginfo.argtypes)
result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(sv, edge)
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
Expand All @@ -1853,7 +1847,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
nothing, arginfo, match, sv)
if const_call_result !== nothing
if const_call_result.rt rt
(; rt, effects, const_result) = const_call_result
(; rt, effects, const_result, edge) = const_call_result
end
end
end
Expand All @@ -1869,6 +1863,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
end
end
rt = from_interprocedural!(ipo, rt, sv, arginfo, match.spec_types)
edge !== nothing && add_backedge!(sv, edge)
return CallMeta(rt, effects, info)
end

Expand Down

0 comments on commit 9dd281d

Please sign in to comment.