Skip to content

Commit

Permalink
carefully cache freshly-inferred edge for call-site inlining
Browse files Browse the repository at this point in the history
Currently call-site inlining fails on freshly-inferred edge if its
source isn't inlineable. This happens because such sources are
exclusively cached globally. For successful call-site inlining, it's
necessary to cache these locally also, ensuring the inliner can access
them later.

To this end, the type of the `cache_mode` field of `InferenceState` has
been switched to `UInt8` from `Symbol`. This change allows it to
represent multiple caching strategies. A new caching mode can be
introduced, e.g. `VOLATILE_CACHE_MODE`, in the future.
  • Loading branch information
aviatesk committed Nov 2, 2023
1 parent 013311c commit c033de1
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 17 deletions.
45 changes: 37 additions & 8 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ to enable flow-sensitive analysis.
"""
const VarTable = Vector{VarState}

const CACHE_MODE_NULL = 0x00
const CACHE_MODE_GLOBAL = 0x01 << 0
const CACHE_MODE_LOCAL = 0x01 << 1

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand Down Expand Up @@ -240,15 +244,15 @@ mutable struct InferenceState
# Whether to restrict inference of abstract call sites to avoid excessive work
# Set by default for toplevel frame.
restrict_abstract_call_sites::Bool
cache_mode::Symbol # TODO move this to InferenceResult?
cache_mode::UInt8 # TODO move this to InferenceResult?
insert_coverage::Bool

# The interpreter that created this inference state. Not looked at by
# NativeInterpreter. But other interpreters may use this to detect cycles
interp::AbstractInterpreter

# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
function InferenceState(result::InferenceResult, src::CodeInfo, cache_mode::Symbol,
function InferenceState(result::InferenceResult, src::CodeInfo, cache_mode::Union{UInt8,Symbol},
interp::AbstractInterpreter)
linfo = result.linfo
world = get_world_counter(interp)
Expand Down Expand Up @@ -303,11 +307,21 @@ mutable struct InferenceState
end

restrict_abstract_call_sites = isa(def, Module)
@assert cache_mode === :no || cache_mode === :local || cache_mode === :global
if cache_mode isa Symbol
if cache_mode === :global
cache_mode = CACHE_MODE_GLOBAL
elseif cache_mode === :local
cache_mode = CACHE_MODE_LOCAL
elseif cache_mode === :no
cache_mode = CACHE_MODE_NULL
else
error("unexpected `cache_mode` is given")
end
end

# some more setups
InferenceParams(interp).unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at)
cache_mode === :local && push!(get_inference_cache(interp), result)
!iszero(cache_mode & CACHE_MODE_LOCAL) && push!(get_inference_cache(interp), result)

return new(
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
Expand Down Expand Up @@ -446,13 +460,28 @@ function should_insert_coverage(mod::Module, src::CodeInfo)
return false
end

function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
function InferenceState(result::InferenceResult, cache_mode::UInt8, interp::AbstractInterpreter)
# prepare an InferenceState object for inferring lambda
world = get_world_counter(interp)
src = retrieve_code_info(result.linfo, world)
src === nothing && return nothing
validate_code_in_debug_mode(result.linfo, src, "lowered")
return InferenceState(result, src, cache, interp)
return InferenceState(result, src, cache_mode, interp)
end
InferenceState(result::InferenceResult, cache_mode::Symbol, interp::AbstractInterpreter) =
InferenceState(result, convert_cache_mode(cache_mode), interp)
InferenceState(result::InferenceResult, src::CodeInfo, cache_mode::Symbol, interp::AbstractInterpreter) =
InferenceState(result, src, convert_cache_mode(cache_mode), interp)

function convert_cache_mode(cache_mode::Symbol)
if cache_mode === :global
return CACHE_MODE_GLOBAL
elseif cache_mode === :local
return CACHE_MODE_LOCAL
elseif cache_mode === :no
return CACHE_MODE_NULL
end
error("unexpected `cache_mode` is given")
end

"""
Expand Down Expand Up @@ -666,7 +695,7 @@ end
function print_callstack(sv::InferenceState)
while sv !== nothing
print(sv.linfo)
sv.cache_mode === :global || print(" [uncached]")
is_cached(sv) || print(" [uncached]")
println()
for cycle in sv.callers_in_cycle
print(' ', cycle.linfo)
Expand Down Expand Up @@ -764,7 +793,7 @@ frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState}
is_constproped(sv::InferenceState) = any(sv.result.overridden_by_const)
is_constproped(::IRInterpretationState) = true

is_cached(sv::InferenceState) = sv.cache_mode === :global
is_cached(sv::InferenceState) = !iszero(sv.cache_mode & CACHE_MODE_GLOBAL)
is_cached(::IRInterpretationState) = false

method_info(sv::InferenceState) = sv.method_info
Expand Down
18 changes: 11 additions & 7 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,12 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
end
for caller in frames
finish!(caller.interp, caller)
if caller.cache_mode === :global
if is_cached(caller)
cache_result!(caller.interp, caller.result)
end
# Drop result.src here since otherwise it can waste memory.
# N.B. If the `cache_mode === :local`, the inliner may request to use it later.
if caller.cache_mode !== :local
# N.B. If cached locally, the inliner may request to use it later.
if iszero(caller.cache_mode & CACHE_MODE_LOCAL)
caller.result.src = nothing
end
end
Expand Down Expand Up @@ -546,7 +546,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# a parent may be cached still, but not this intermediate work:
# we can throw everything else away now
me.result.src = nothing
me.cache_mode = :no
me.cache_mode = CACHE_MODE_NULL
set_inlineable!(me.src, false)
unlock_mi_inference(interp, me.linfo)
elseif limited_src
Expand All @@ -558,7 +558,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# annotate fulltree with type information,
# either because we are the outermost code, or we might use this later
type_annotate!(interp, me)
doopt = (me.cache_mode !== :no || me.parent !== nothing)
doopt = (me.cache_mode != CACHE_MODE_NULL || me.parent !== nothing)
# Disable the optimizer if we've already determined that there's nothing for
# it to do.
if may_discard_trees(interp) && is_result_constabi_eligible(me.result)
Expand Down Expand Up @@ -817,15 +817,19 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
# we already inferred this edge before and decided to discard the inferred code,
# nevertheless we re-infer it here again and keep it around in the local cache
# since the inliner will request to use it later
cache_mode = :local
cache_mode = CACHE_MODE_LOCAL
else
rt = cached_return_type(code)
effects = ipo_effects(code)
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
return EdgeCallResult(rt, mi, effects)
end
elseif is_stmt_inline(get_curr_ssaflag(caller))
# if this fresh is going to be inlined, we cache it locally too so that the inliner
# can see it later even in a case when the inferred source is discarded from the global cache
cache_mode = CACHE_MODE_GLOBAL | CACHE_MODE_LOCAL
else
cache_mode = :global # cache edge targets by default
cache_mode = CACHE_MODE_GLOBAL # cache edge targets globally by default
end
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0 && !generating_output(#=incremental=#false)
add_remark!(interp, caller, "Inference is disabled for the target module")
Expand Down
4 changes: 2 additions & 2 deletions stdlib/REPL/src/REPLCompletions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,10 @@ CC.bail_out_toplevel_call(::REPLInterpreter, ::CC.InferenceLoopState, ::CC.Infer
# `REPLInterpreter` is specifically used by `repl_eval_ex`, where all top-level frames are
# `repl_frame` always. However, this assumption wouldn't stand if `REPLInterpreter` were to
# be employed, for instance, by `typeinf_ext_toplevel`.
is_repl_frame(sv::CC.InferenceState) = sv.linfo.def isa Module && sv.cache_mode === :no
is_repl_frame(sv::CC.InferenceState) = sv.linfo.def isa Module && sv.cache_mode === CC.CACHE_MODE_NULL

function is_call_graph_uncached(sv::CC.InferenceState)
sv.cache_mode === :global && return false
sv.cache_mode === CC.CACHE_MODE_GLOBAL && return false
parent = sv.parent
parent === nothing && return true
return is_call_graph_uncached(parent::CC.InferenceState)
Expand Down
12 changes: 12 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,18 @@ end
end
end

@noinline fresh_edge_noinlined(a::Integer) = unresolvable(a)
let src = code_typed1((Integer,)) do x
@inline fresh_edge_noinlined(x)
end
@test count(iscall((src, fresh_edge_noinlined)), src.code) == 0
end
let src = code_typed1((Integer,)) do x
@inline fresh_edge_noinlined(x)
end
@test count(iscall((src, fresh_edge_noinlined)), src.code) == 0 # should be idempotent
end

# force constant-prop' for `setproperty!`
# https://github.com/JuliaLang/julia/pull/41882
let code = @eval Module() begin
Expand Down

0 comments on commit c033de1

Please sign in to comment.