Skip to content

Commit

Permalink
Merge 0f9eb79 into f7dea04
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk authored Sep 14, 2022
2 parents f7dea04 + 0f9eb79 commit 81df1a1
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 57 deletions.
47 changes: 23 additions & 24 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
for sig_n in splitsigs
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv)
(; rt, edge, effects) = result
edge === nothing || push!(edges, edge)
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp, result,
Expand All @@ -135,12 +134,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if const_call_result !== nothing
if const_call_result.rt ᵢ rt
rt = const_call_result.rt
(; effects, const_result) = const_call_result
(; effects, const_result, edge) = const_call_result
end
end
all_effects = merge_effects(all_effects, effects)
push!(const_results, const_result)
any_const_result |= const_result !== nothing
edge === nothing || push!(edges, edge)
this_rt = tmerge(this_rt, rt)
if bail_out_call(interp, this_rt, sv)
break
Expand All @@ -153,7 +153,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
(; rt, edge, effects) = result
this_conditional = ignorelimited(rt)
this_rt = widenwrappedconditional(rt)
edge === nothing || push!(edges, edge)
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
Expand All @@ -169,12 +168,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if this_const_rt ᵢ this_rt
this_conditional = this_const_conditional
this_rt = this_const_rt
(; effects, const_result) = const_call_result
(; effects, const_result, edge) = const_call_result
end
end
all_effects = merge_effects(all_effects, effects)
push!(const_results, const_result)
any_const_result |= const_result !== nothing
edge === nothing || push!(edges, edge)
end
@assert !(this_conditional isa Conditional) "invalid lattice element returned from inter-procedural context"
seen += 1
Expand Down Expand Up @@ -483,15 +483,15 @@ function add_call_backedges!(interp::AbstractInterpreter,
end
end
for edge in edges
add_backedge!(edge, sv)
add_backedge!(sv, edge)
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
thisfullmatch || add_mt_backedge!(mt, atype, sv)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
end
Expand Down Expand Up @@ -831,17 +831,18 @@ function concrete_eval_call(interp::AbstractInterpreter,
if eligible
args = collect_const_args(arginfo, #=start=#2)
world = get_world_counter(interp)
edge = result.edge::MethodInstance
value = try
Core._call_in_world_total(world, f, args...)
catch
# The evaluation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
return ConstCallResults(Union{}, ConcreteResult(result.edge::MethodInstance, result.effects), result.effects)
return ConstCallResults(Union{}, ConcreteResult(edge, result.effects), result.effects, edge)
end
if is_inlineable_constant(value) || call_result_unused(sv)
# If the constant is not inlineable, still do the const-prop, since the
# code that led to the creation of the Const may be inlineable in the same
# circumstance and may be optimizable.
return ConstCallResults(Const(value), ConcreteResult(result.edge::MethodInstance, EFFECTS_TOTAL, value), EFFECTS_TOTAL)
return ConstCallResults(Const(value), ConcreteResult(edge, EFFECTS_TOTAL, value), EFFECTS_TOTAL, edge)
end
return false
else # eligible for semi-concrete evaluation
Expand All @@ -868,23 +869,22 @@ struct ConstCallResults
rt::Any
const_result::ConstResult
effects::Effects
edge::MethodInstance
ConstCallResults(@nospecialize(rt),
const_result::ConstResult,
effects::Effects) =
new(rt, const_result, effects)
effects::Effects,
edge::MethodInstance) =
new(rt, const_result, effects, edge)
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult,
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState, @nospecialize(invoketypes=nothing))
sv::InferenceState)
if !const_prop_enabled(interp, sv, match)
return nothing
end
res = concrete_eval_call(interp, f, result, arginfo, sv)
if isa(res, ConstCallResults)
add_backedge!(res.const_result.mi, sv, invoketypes)
return res
end
isa(res, ConstCallResults) && return res
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
mi === nothing && return nothing
# try semi-concrete evaluation
Expand All @@ -896,7 +896,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
if isa(ir, IRCode)
T = ir_abstract_constant_propagation(interp, mi_cache, sv, mi, ir, arginfo.argtypes)
if !isa(T, Type) || typeintersect(T, Bool) === Union{}
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects)
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects, mi)
end
end
end
Expand Down Expand Up @@ -929,8 +929,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
result = inf_result.result
# if constant inference hits a cycle, just bail out
isa(result, InferenceState) && return nothing
add_backedge!(mi, sv)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects)
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects, mi)
end

# if there's a possibility we could get a better result with these constant arguments
Expand Down Expand Up @@ -1685,7 +1684,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge::MethodInstance, sv, types)
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand All @@ -1698,14 +1696,15 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
# argtypes′[i] = t ⊑ a ? t : a
# end
const_call_result = abstract_call_method_with_const_args(interp, result,
overlayed ? nothing : singleton_type(ft′), arginfo, match, sv, types)
overlayed ? nothing : singleton_type(ft′), arginfo, match, sv)
const_result = nothing
if const_call_result !== nothing
if (typeinf_lattice(interp), const_call_result.rt, rt)
(; rt, effects, const_result) = const_call_result
(; rt, effects, const_result, edge) = const_call_result
end
end
effects = Effects(effects; nonoverlayed=!overlayed)
edge !== nothing && add_invoke_backedge!(sv, types, edge)
return CallMeta(from_interprocedural!(ipo_lattice(interp), rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result))
end

Expand Down Expand Up @@ -1839,7 +1838,6 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
sig = argtypes_to_type(arginfo.argtypes)
result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge, sv)
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
Expand All @@ -1849,7 +1847,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
nothing, arginfo, match, sv)
if const_call_result !== nothing
if const_call_result.rt rt
(; rt, effects, const_result) = const_call_result
(; rt, effects, const_result, edge) = const_call_result
end
end
end
Expand All @@ -1865,6 +1863,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
end
end
rt = from_interprocedural!(ipo, rt, sv, arginfo, match.spec_types)
edge !== nothing && add_backedge!(sv, edge)
return CallMeta(rt, effects, info)
end

Expand Down
41 changes: 26 additions & 15 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,38 +478,49 @@ function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceSta
return nothing
end

function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)
function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, currpc::Int)
update_valid_age!(frame, caller)
backedge = (caller, currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
add_backedge!(frame.linfo, caller)
add_backedge!(caller, frame.linfo)
return frame
end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(li::MethodInstance, caller::InferenceState, @nospecialize(invokesig=nothing))
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
function add_backedge!(caller::InferenceState, li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, li)
end
if invokesig !== nothing
push!(edges, invokesig)
return nothing
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, invokesig, li)
end
push!(edges, li)
return nothing
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::InferenceState)
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
end
return nothing
end

function get_stmt_edges!(caller::InferenceState)
if !isa(caller.linfo.def, Method)
return nothing # don't add backedges to toplevel exprs
end
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
end
push!(edges, mt)
push!(edges, typ)
return nothing
return edges
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
Expand Down
8 changes: 6 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ intersect!(et::EdgeTracker, range::WorldRange) =
et.valid_worlds[] = intersect(et.valid_worlds[], range)

push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi)
function add_edge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
invokesig === nothing && return push!(et.edges, mi)
function add_backedge!(et::EdgeTracker, mi::MethodInstance)
push!(et.edges, mi)
return nothing
end
function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
push!(et.edges, invokesig, mi)
return nothing
end
function push!(et::EdgeTracker, ci::CodeInstance)
intersect!(et, WorldRange(min_world(li), max_world(li)))
Expand Down
24 changes: 21 additions & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
inferred_src = match.src
if isa(inferred_src, ConstAPI)
# use constant calling convention
et !== nothing && add_edge!(et, invokesig, mi)
if et !== nothing
if invokesig === nothing
add_backedge!(et, mi)
else
add_invoke_backedge!(et, invokesig, mi)
end
end
return ConstantCase(quoted(inferred_src.val))
else
src = inferred_src # ::Union{Nothing,CodeInfo} for NativeInterpreter
Expand All @@ -843,7 +849,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
if code isa CodeInstance
if use_const_api(code)
# in this case function can be inlined to a constant
et !== nothing && add_edge!(et, invokesig, mi)
if et !== nothing
if invokesig === nothing
add_backedge!(et, mi)
else
add_invoke_backedge!(et, invokesig, mi)
end
end
return ConstantCase(quoted(code.rettype_const))
else
src = @atomic :monotonic code.inferred
Expand All @@ -867,7 +879,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
src === nothing && return compileable_specialization(et, match, effects;
compilesig_invokes=state.params.compilesig_invokes)

et !== nothing && add_edge!(et, invokesig, mi)
if et !== nothing
if invokesig === nothing
add_backedge!(et, mi)
else
add_invoke_backedge!(et, invokesig, mi)
end
end
return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects)
end

Expand Down
14 changes: 7 additions & 7 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,13 @@ function store_backedges(frame::InferenceResult, edges::Vector{Any})
nothing
end

function store_backedges(caller::MethodInstance, edges::Vector{Any})
for (typ, to) in BackedgeIterator(edges)
if isa(to, MethodInstance)
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), to, typ, caller)
function store_backedges(frame::MethodInstance, edges::Vector{Any})
for (; sig, caller) in BackedgeIterator(edges)
if isa(caller, MethodInstance)
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), caller, sig, frame)
else
typeassert(to, Core.MethodTable)
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), to, typ, caller)
typeassert(caller, Core.MethodTable)
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), caller, sig, frame)
end
end
end
Expand Down Expand Up @@ -806,7 +806,7 @@ function merge_call_chain!(interp::AbstractInterpreter, parent::InferenceState,
# of recursion.
merge_effects!(interp, parent, Effects(EFFECTS_TOTAL; terminates=false))
while true
add_cycle_backedge!(child, parent, parent.currpc)
add_cycle_backedge!(parent, child, parent.currpc)
union_caller_cycle!(ancestor, child)
merge_effects!(interp, child, Effects(EFFECTS_TOTAL; terminates=false))
child = parent
Expand Down
15 changes: 9 additions & 6 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ is_no_constprop(method::Union{Method,CodeInfo}) = method.constprop == 0x02
Return an iterator over a list of backedges. Iteration returns `(sig, caller)` elements,
which will be one of the following:
- `(nothing, caller::MethodInstance)`: a call made by ordinary inferrable dispatch
- `(invokesig, caller::MethodInstance)`: a call made by `invoke(f, invokesig, args...)`
- `(specsig, mt::MethodTable)`: an abstract call
- `BackedgePair(nothing, caller::MethodInstance)`: a call made by ordinary inferrable dispatch
- `BackedgePair(invokesig, caller::MethodInstance)`: a call made by `invoke(f, invokesig, args...)`
- `BackedgePair(specsig, mt::MethodTable)`: an abstract call
# Examples
Expand All @@ -254,7 +254,7 @@ julia> callyou(2.0)
julia> mi = first(which(callme, (Any,)).specializations)
MethodInstance for callme(::Float64)
julia> @eval Core.Compiler for (sig, caller) in BackedgeIterator(Main.mi.backedges)
julia> @eval Core.Compiler for (; sig, caller) in BackedgeIterator(Main.mi.backedges)
println(sig)
println(caller)
end
Expand All @@ -268,8 +268,11 @@ end

const empty_backedge_iter = BackedgeIterator(Any[])

const MethodInstanceOrTable = Union{MethodInstance, Core.MethodTable}
const BackedgePair = Pair{Union{Type, Nothing, MethodInstanceOrTable}, MethodInstanceOrTable}
struct BackedgePair
sig # ::Union{Nothing,Type}
caller::Union{MethodInstance,Core.MethodTable}
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,Core.MethodTable}) = new(sig, caller)
end

function iterate(iter::BackedgeIterator, i::Int=1)
backedges = iter.backedges
Expand Down

0 comments on commit 81df1a1

Please sign in to comment.