Skip to content

Commit

Permalink
Make EnterNode save/restore dynamic scope
Browse files Browse the repository at this point in the history
As discussed in #51352, this gives `EnterNode` the ability to set
(and restore on leave or catch edge) jl_current_task->scope. Manual
modifications of the task field after the task has started are
considered undefined behavior. In addition, we gain a new intrinsic
to access current_task->scope and both inference and the optimizer
will forward scopes from EnterNodes to this intrinsic (non-interprocedurally).
Together with #51993 this is sufficient to fully optimize ScopedValues
(non-interprocedurally at least).
  • Loading branch information
Keno committed Dec 11, 2023
1 parent 46ad1c1 commit a583b92
Show file tree
Hide file tree
Showing 21 changed files with 208 additions and 61 deletions.
4 changes: 3 additions & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ eval(Core, quote
ReturnNode() = $(Expr(:new, :ReturnNode)) # unassigned val indicates unreachable
GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest))
EnterNode(dest::Int) = $(Expr(:new, :EnterNode, :dest))
EnterNode(dest::Int, @nospecialize(scope)) = $(Expr(:new, :EnterNode, :dest, :scope))
LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing))
function LineNumberNode(l::Int, @nospecialize(f))
isa(f, String) && (f = Symbol(f))
Expand Down Expand Up @@ -966,7 +967,8 @@ arraysize(a::Array, i::Int) = sle_int(i, nfields(a.size)) ? getfield(a.size, i)
export arrayref, arrayset, arraysize, const_arrayref

# For convenience
EnterNode(old::EnterNode, new_dest::Int) = EnterNode(new_dest)
EnterNode(old::EnterNode, new_dest::Int) = isdefined(old, :scope) ?
EnterNode(new_dest, old.scope) : EnterNode(new_dest)

include(Core, "optimized_generics.jl")

Expand Down
13 changes: 13 additions & 0 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3265,6 +3265,19 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
elseif isa(stmt, EnterNode)
ssavaluetypes[currpc] = Any
add_curr_ssaflag!(frame, IR_FLAG_NOTHROW)
if isdefined(stmt, :scope)
scopet = abstract_eval_value(interp, stmt.scope, currstate, frame)
handler = frame.handlers[frame.handler_at[frame.currpc+1][1]]
@assert handler.scopet !== nothing
if !(𝕃ᵢ, scopet, handler.scopet)
handler.scopet = tmerge(𝕃ᵢ, scopet, handler.scopet)
if isdefined(handler, :scope_uses)
for bb in handler.scope_uses
push!(W, bb)
end
end
end
end
@goto fallthrough
elseif isexpr(stmt, :leave)
ssavaluetypes[currpc] = Any
Expand Down
6 changes: 4 additions & 2 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed

mutable struct TryCatchFrame
exct
scopet
const enter_idx::Int
TryCatchFrame(@nospecialize(exct), enter_idx::Int) = new(exct, enter_idx)
scope_uses::Vector{Int}
TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx)
end

mutable struct InferenceState
Expand Down Expand Up @@ -364,7 +366,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
stmt = code[pc]
if isa(stmt, EnterNode)
l = stmt.catch_dest
push!(handlers, TryCatchFrame(Bottom, pc))
push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc))
handler_id = length(handlers)
handler_at[pc + 1] = (handler_id, 0)
push!(ip, pc + 1)
Expand Down
1 change: 1 addition & 0 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
result_idx += 1
end
elseif cfg_transforms_enabled && isa(stmt, EnterNode)
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa, mark_refined!)::EnterNode
label = bb_rename_succ[stmt.catch_dest]
@assert label > 0
ssa_rename[idx] = SSAValue(result_idx)
Expand Down
25 changes: 25 additions & 0 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,29 @@ function fold_ifelse!(compact::IncrementalCompact, idx::Int, stmt::Expr)
return false
end

function fold_current_scope!(compact::IncrementalCompact, idx::Int, stmt::Expr, lazydomtree)
domtree = get!(lazydomtree)

# The frontend enforces the invariant that any :enter dominates its active
# region, so all we have to do here is walk the domtree to find it.
dombb = block_for_inst(compact, SSAValue(idx))

local bbterminator
while true
dombb = domtree.idoms_bb[dombb]

# Did not find any dominating :enter - scope is inherited from the outside
dombb == 0 && return nothing

bbterminator = compact[SSAValue(last(compact.cfg_transform.result_bbs[dombb].stmts))][:stmt]
isa(bbterminator, EnterNode) || continue
isdefined(bbterminator, :scope) || continue
compact[idx] = bbterminator.scope
return nothing
end
end


# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}
Expand Down Expand Up @@ -1208,6 +1231,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
elseif is_known_invoke_or_call(stmt, Core.OptimizedGenerics.KeyValue.get, compact)
2 == (length(stmt.args) - (isexpr(stmt, :invoke) ? 2 : 1)) || continue
lift_keyvalue_get!(compact, idx, stmt, 𝕃ₒ)
elseif is_known_call(stmt, Core.current_scope, compact)
fold_current_scope!(compact, idx, stmt, lazydomtree)
elseif isexpr(stmt, :new)
refine_new_effects!(𝕃ₒ, compact, idx, stmt)
end
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), used::BitSet, maxleng
# given control flow information, we prefer to print these with the basic block #, instead of the ssa %
elseif isa(stmt, EnterNode)
print(io, "enter #", stmt.catch_dest, "")
if isdefined(stmt, :scope)
print(io, " with scope ")
show_unquoted(io, stmt.scope, indent)
end
elseif stmt isa GotoNode
print(io, "goto #", stmt.label)
elseif stmt isa PhiNode
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/verify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

function maybe_show_ir(ir::IRCode)
if isdefined(Core, :Main)
Core.Main.Base.display(ir)
invokelatest(Core.Main.Base.display, ir)
end
end

Expand Down
41 changes: 38 additions & 3 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2488,6 +2488,12 @@ function builtin_effects(𝕃::AbstractLattice, @nospecialize(f::Builtin), argty
return Effects(EFFECTS_TOTAL;
consistent = (isa(setting, Const) && setting.val === :conditional) ? ALWAYS_TRUE : ALWAYS_FALSE,
nothrow = compilerbarrier_nothrow(setting, nothing))
elseif f === Core.current_scope
length(argtypes) == 0 || return Effects(EFFECTS_THROWS; consistent=ALWAYS_FALSE)
return Effects(EFFECTS_TOTAL;
consistent = ALWAYS_FALSE,
notaskstate = false,
)
else
if contains_is(_CONSISTENT_BUILTINS, f)
consistent = ALWAYS_TRUE
Expand Down Expand Up @@ -2554,6 +2560,32 @@ function memoryop_noub(@nospecialize(f), argtypes::Vector{Any})
return false
end

function current_scope_tfunc(interp::AbstractInterpreter, sv::InferenceState)
pc = sv.currpc
while true
handleridx = sv.handler_at[pc][2]
if handleridx == 0
# No local scope available - inherited from the outside
return Any
end
pchandler = sv.handlers[handleridx]
# Remember that we looked at this handler, so we get re-scheduled
# if the scope information changes
isdefined(pchandler, :scope_uses) || (pchandler.scope_uses = Int[])
pcbb = block_for_inst(sv.cfg, pc)
if findfirst(pchandler.scope_uses, pcbb) === nothing
push!(pchandler.scope_uses, pcbb)
end
scope = pchandler.scopet
if scope !== nothing
# Found the scope - forward it
return scope
end
pc = pchandler.enter_idx
end
end
current_scope_tfunc(interp::AbstractInterpreter, sv) = Any

"""
builtin_nothrow(𝕃::AbstractLattice, f::Builtin, argtypes::Vector{Any}, rt) -> Bool
Expand All @@ -2568,9 +2600,6 @@ end
function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any},
sv::Union{AbsIntState, Nothing})
𝕃ᵢ = typeinf_lattice(interp)
if f === tuple
return tuple_tfunc(𝕃ᵢ, argtypes)
end
if isa(f, IntrinsicFunction)
if is_pure_intrinsic_infer(f) && all(@nospecialize(a) -> isa(a, Const), argtypes)
argvals = anymap(@nospecialize(a) -> (a::Const).val, argtypes)
Expand All @@ -2596,6 +2625,12 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
end
tf = T_IFUNC[iidx]
else
if f === tuple
return tuple_tfunc(𝕃ᵢ, argtypes)
elseif f === Core.current_scope
length(argtypes) == 0 || return Bottom
return current_scope_tfunc(interp, sv)
end
fidx = find_tfunc(f)
if fidx === nothing
# unknown/unhandled builtin function
Expand Down
9 changes: 8 additions & 1 deletion base/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}(
:new => 1:typemax(Int),
:splatnew => 2:2,
:the_exception => 0:0,
:enter => 1:1,
:enter => 1:2,
:leave => 1:typemax(Int),
:pop_exception => 1:1,
:inbounds => 1:1,
Expand Down Expand Up @@ -160,6 +160,13 @@ function validate_code!(errors::Vector{InvalidCodeError}, c::CodeInfo, is_top_le
push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.cond))
end
validate_val!(x.cond)
elseif isa(x, EnterNode)
if isdefined(x, :scope)
if !is_valid_argument(x.scope)
push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.scope))
end
validate_val!(x.scope)
end
elseif isa(x, ReturnNode)
if isdefined(x, :val)
if !is_valid_return(x.val)
Expand Down
51 changes: 12 additions & 39 deletions base/scopedvalues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,6 @@ function Scope(scope, pairs::Pair{<:ScopedValue}...)
end
Scope(::Nothing) = nothing

"""
current_scope()::Union{Nothing, Scope}
Return the current dynamic scope.
"""
current_scope() = current_task().scope::Union{Nothing, Scope}

function Base.show(io::IO, scope::Scope)
print(io, Scope, "(")
first = true
Expand Down Expand Up @@ -113,8 +106,7 @@ return `nothing`. Otherwise returns `Some{T}` with the current
value.
"""
function get(val::ScopedValue{T}) where {T}
# Inline current_scope to avoid doing the type assertion twice.
scope = current_task().scope
scope = Core.current_scope()::Union{Scope, Nothing}
if scope === nothing
isassigned(val) && return Some{T}(val.default)
return nothing
Expand Down Expand Up @@ -148,25 +140,6 @@ function Base.show(io::IO, val::ScopedValue)
print(io, ')')
end

"""
with(f, (var::ScopedValue{T} => val::T)...)
Execute `f` in a new scope with `var` set to `val`.
"""
function with(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...)
@nospecialize
ct = Base.current_task()
current_scope = ct.scope::Union{Nothing, Scope}
ct.scope = Scope(current_scope, pair, rest...)
try
return f()
finally
ct.scope = current_scope
end
end

with(@nospecialize(f)) = f()

"""
@with vars... expr
Expand All @@ -184,18 +157,18 @@ macro with(exprs...)
else
error("@with expects at least one argument")
end
for expr in exprs
if expr.head !== :call || first(expr.args) !== :(=>)
error("@with expects arguments of the form `A => 2` got $expr")
end
end
exprs = map(esc, exprs)
quote
ct = $(Base.current_task)()
current_scope = ct.scope::$(Union{Nothing, Scope})
ct.scope = $(Scope)(current_scope, $(exprs...))
$(Expr(:tryfinally, esc(ex), :(ct.scope = current_scope)))
end
Expr(:tryfinally, esc(ex), :(), :($(Scope)($(Core.current_scope)()::Union{Nothing, Scope}, $(exprs...))))
end

"""
with(f, (var::ScopedValue{T} => val::T)...)
Execute `f` in a new scope with `var` set to `val`.
"""
function with(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...)
@with(pair, rest..., f())
end
with(@nospecialize(f)) = f()

end # module ScopedValues
4 changes: 4 additions & 0 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,11 @@ static jl_value_t *scm_to_julia_(fl_context_t *fl_ctx, value_t e, jl_module_t *m
else if (sym == jl_enter_sym) {
ex = scm_to_julia_(fl_ctx, car_(e), mod);
temp = jl_new_struct_uninit(jl_enternode_type);
jl_enternode_scope(temp) = NULL;
jl_enternode_catch_dest(temp) = jl_unbox_long(ex);
if (n == 2) {
jl_enternode_scope(temp) = scm_to_julia(fl_ctx, car_(cdr_(e)), mod);
}
}
else if (sym == jl_newvar_sym) {
ex = scm_to_julia_(fl_ctx, car_(e), mod);
Expand Down
1 change: 1 addition & 0 deletions src/builtin_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ DECLARE_BUILTIN(setglobal);
DECLARE_BUILTIN(finalizer);
DECLARE_BUILTIN(_compute_sparams);
DECLARE_BUILTIN(_svec_ref);
DECLARE_BUILTIN(current_scope);

JL_CALLABLE(jl_f__structtype);
JL_CALLABLE(jl_f__abstracttype);
Expand Down
7 changes: 7 additions & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,12 @@ JL_CALLABLE(jl_f_ifelse)
return (args[0] == jl_false ? args[2] : args[1]);
}

JL_CALLABLE(jl_f_current_scope)
{
JL_NARGS(current_scope, 0, 0);
return jl_current_task->scope;
}

// apply ----------------------------------------------------------------------

static NOINLINE jl_svec_t *_copy_to(size_t newalloc, jl_value_t **oldargs, size_t oldalloc)
Expand Down Expand Up @@ -2158,6 +2164,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin_func("finalizer", jl_f_finalizer);
add_builtin_func("_compute_sparams", jl_f__compute_sparams);
add_builtin_func("_svec_ref", jl_f__svec_ref);
add_builtin_func("current_scope", jl_f_current_scope);

// builtin types
add_builtin("Any", (jl_value_t*)jl_any_type);
Expand Down
Loading

0 comments on commit a583b92

Please sign in to comment.