Skip to content

Commit

Permalink
Optimize away try/catch blocks that are known not to trigger (#51674)
Browse files Browse the repository at this point in the history
This leverages the support from #51590 to delete any try catch block
that is known not to be triggered (either because the try-body is empty
to because we have proven `:nothrow` for all contained statements).
  • Loading branch information
Keno authored Oct 15, 2023
1 parent 4a1d74e commit 0acca3c
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 63 deletions.
41 changes: 21 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2981,6 +2981,20 @@ function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState,
end
end

function propagate_to_error_handler!(frame::InferenceState, currpc::Int, W::BitSet, 𝕃ᵢ::AbstractLattice, currstate::VarTable)
# If this statement potentially threw, propagate the currstate to the
# exception handler, BEFORE applying any state changes.
cur_hand = frame.handler_at[currpc]
if cur_hand != 0
enter = frame.src.code[cur_hand]::Expr
l = enter.args[1]::Int
exceptbb = block_for_inst(frame.cfg, l)
if update_bbstate!(𝕃ᵢ, frame, exceptbb, currstate)
push!(W, exceptbb)
end
end
end

# make as much progress on `frame` as possible (without handling cycles)
function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@assert !is_inferred(frame)
Expand Down Expand Up @@ -3037,6 +3051,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if nothrow
add_curr_ssaflag!(frame, IR_FLAG_NOTHROW)
else
propagate_to_error_handler!(frame, currpc, W, 𝕃ᵢ, currstate)
merge_effects!(interp, frame, EFFECTS_THROWS)
end

Expand Down Expand Up @@ -3107,12 +3122,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
ssavaluetypes[frame.currpc] = Any
@goto find_next_bb
elseif isexpr(stmt, :enter)
# Propagate entry info to exception handler
l = stmt.args[1]::Int
catchbb = block_for_inst(frame.cfg, l)
if update_bbstate!(𝕃ᵢ, frame, catchbb, currstate)
push!(W, catchbb)
end
ssavaluetypes[currpc] = Any
@goto fallthrough
elseif isexpr(stmt, :leave)
ssavaluetypes[currpc] = Any
@goto fallthrough
end
Expand All @@ -3121,26 +3133,15 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# Process non control-flow statements
(; changes, type) = abstract_eval_basic_statement(interp,
stmt, currstate, frame)
if (get_curr_ssaflag(frame) & IR_FLAG_NOTHROW) != IR_FLAG_NOTHROW
propagate_to_error_handler!(frame, currpc, W, 𝕃ᵢ, currstate)
end
if type === Bottom
ssavaluetypes[currpc] = Bottom
@goto find_next_bb
end
if changes !== nothing
stoverwrite1!(currstate, changes)
let cur_hand = frame.handler_at[currpc], l, enter
while cur_hand != 0
enter = frame.src.code[cur_hand]::Expr
l = enter.args[1]::Int
exceptbb = block_for_inst(frame.cfg, l)
# propagate new type info to exception handler
# the handling for Expr(:enter) propagates all changes from before the try/catch
# so this only needs to propagate any changes
if stupdate1!(𝕃ᵢ, states[exceptbb]::VarTable, changes)
push!(W, exceptbb)
end
cur_hand = frame.handler_at[cur_hand]
end
end
end
if type === nothing
ssavaluetypes[currpc] = Any
Expand Down
55 changes: 34 additions & 21 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -839,27 +839,35 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
code = copy_exprargs(ci.code)
for i = 1:length(code)
expr = code[i]
if !(i in sv.unreachable) && isa(expr, GotoIfNot)
# Replace this live GotoIfNot with:
# - no-op if :nothrow and the branch target is unreachable
# - cond if :nothrow and both targets are unreachable
# - typeassert if must-throw
block = block_for_inst(sv.cfg, i)
if ssavaluetypes[i] === Bottom
destblock = block_for_inst(sv.cfg, expr.dest)
cfg_delete_edge!(sv.cfg, block, block + 1)
((block + 1) != destblock) && cfg_delete_edge!(sv.cfg, block, destblock)
expr = Expr(:call, Core.typeassert, expr.cond, Bool)
elseif i + 1 in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
cfg_delete_edge!(sv.cfg, block, block + 1)
expr = GotoNode(expr.dest)
elseif expr.dest in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest))
expr = nothing
if !(i in sv.unreachable)
if isa(expr, GotoIfNot)
# Replace this live GotoIfNot with:
# - no-op if :nothrow and the branch target is unreachable
# - cond if :nothrow and both targets are unreachable
# - typeassert if must-throw
block = block_for_inst(sv.cfg, i)
if ssavaluetypes[i] === Bottom
destblock = block_for_inst(sv.cfg, expr.dest)
cfg_delete_edge!(sv.cfg, block, block + 1)
((block + 1) != destblock) && cfg_delete_edge!(sv.cfg, block, destblock)
expr = Expr(:call, Core.typeassert, expr.cond, Bool)
elseif i + 1 in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
cfg_delete_edge!(sv.cfg, block, block + 1)
expr = GotoNode(expr.dest)
elseif expr.dest in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest))
expr = nothing
end
code[i] = expr
elseif isexpr(expr, :enter)
catchdest = expr.args[1]::Int
if catchdest in sv.unreachable
cfg_delete_edge!(sv.cfg, block_for_inst(sv.cfg, i), block_for_inst(sv.cfg, catchdest))
code[i] = nothing
end
end
code[i] = expr
end
end

Expand Down Expand Up @@ -1239,7 +1247,12 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
end
if el.head === :enter
tgt = el.args[1]::Int
el.args[1] = tgt + labelchangemap[tgt]
was_deleted = labelchangemap[tgt] == typemin(Int)
if was_deleted
body[i] = nothing
else
el.args[1] = tgt + labelchangemap[tgt]
end
elseif !is_meta_expr_head(el.head)
args = el.args
for i = 1:length(args)
Expand Down
15 changes: 15 additions & 0 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,21 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
ssa_rename[idx] = nothing
return result_idx
end
elseif isexpr(stmt, :leave)
let i = 1
while i <= length(stmt.args)
if stmt.args[i] === nothing
deleteat!(stmt.args, i)
else
i += 1
end
end
end
if isempty(stmt.args)
# This :leave is dead
ssa_rename[idx] = nothing
return result_idx
end
end
typ = inst[:type]
if isa(typ, Const) && is_inlineable_constant(typ.val)
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,8 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
new_code[idx] = GotoIfNot(stmt.cond, new_dest)
end
elseif isexpr(stmt, :enter)
new_code[idx] = Expr(:enter, block_for_inst(cfg, stmt.args[1]::Int))
except_bb = block_for_inst(cfg, stmt.args[1]::Int)
new_code[idx] = Expr(:enter, except_bb)
ssavalmap[idx] = SSAValue(idx) # Slot to store token for pop_exception
elseif isexpr(stmt, :leave) || isexpr(stmt, :(=)) || isa(stmt, ReturnNode) ||
isexpr(stmt, :meta) || isa(stmt, NewvarNode)
Expand Down
18 changes: 0 additions & 18 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -759,24 +759,6 @@ function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable)
return changed
end

function stupdate1!(lattice::AbstractLattice, state::VarTable, change::StateUpdate)
changeid = slot_id(change.var)
for i = 1:length(state)
invalidated = invalidate_slotwrapper(state[i], changeid, change.conditional)
if invalidated !== nothing
state[i] = invalidated
end
end
# and update the type of it
newtype = change.vtype
oldtype = state[changeid]
if schanged(lattice, newtype, oldtype)
state[changeid] = smerge(lattice, oldtype, newtype)
return true
end
return false
end

function stoverwrite!(state::VarTable, newstate::VarTable)
for i = 1:length(state)
state[i] = newstate[i]
Expand Down
11 changes: 8 additions & 3 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5180,14 +5180,19 @@ static void emit_ssaval_assign(jl_codectx_t &ctx, ssize_t ssaidx_0based, jl_valu
ctx.ssavalue_assigned[ssaidx_0based] = true;
}

static void emit_varinfo_assign(jl_codectx_t &ctx, jl_varinfo_t &vi, jl_cgval_t rval_info, jl_value_t *l=NULL)
static void emit_varinfo_assign(jl_codectx_t &ctx, jl_varinfo_t &vi, jl_cgval_t rval_info, jl_value_t *l=NULL, bool allow_mismatch=false)
{
if (!vi.used || vi.value.typ == jl_bottom_type)
return;

// convert rval-type to lval-type
jl_value_t *slot_type = vi.value.typ;
rval_info = convert_julia_type(ctx, rval_info, slot_type);
// If allow_mismatch is set, type mismatches will not result in traps.
// This is used for upsilon nodes, where the destination can have a narrower
// type than the store, if inference determines that the store is never read.
Value *dummy = NULL;
Value **skip = allow_mismatch ? &dummy : NULL;
rval_info = convert_julia_type(ctx, rval_info, slot_type, skip);
if (rval_info.typ == jl_bottom_type)
return;

Expand Down Expand Up @@ -5284,7 +5289,7 @@ static void emit_upsilonnode(jl_codectx_t &ctx, ssize_t phic, jl_value_t *val)
// was unreachable and dead
val = NULL;
else
emit_varinfo_assign(ctx, vi, rval_info);
emit_varinfo_assign(ctx, vi, rval_info, NULL, true);
}
if (!val) {
if (vi.boxroot) {
Expand Down
69 changes: 69 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5284,3 +5284,72 @@ end
@test only(Base.return_types((x,f) -> getfield(x, f), (An51317, Symbol))) === Int
@test only(Base.return_types(x -> getfield(x, :b), (A51317,))) === Union{}
@test only(Base.return_types(x -> getfield(x, :b), (An51317,))) === Union{}

# Don't visit the catch block for empty try/catch
function completely_dead_try_catch()
try
catch
return 2.0
end
return 1
end
@test Base.return_types(completely_dead_try_catch) |> only === Int
@test fully_eliminated(completely_dead_try_catch)

function nothrow_try_catch()
try
1+1
catch
return 2.0
end
return 1
end
@test Base.return_types(nothrow_try_catch) |> only === Int
@test fully_eliminated(nothrow_try_catch)

may_error(b) = Base.inferencebarrier(b) && error()
function phic_type1()
a = 1
try
may_error(false)
a = 1.0
catch
return a
end
return 2
end
@test Base.return_types(phic_type1) |> only === Int
@test phic_type1() === 2

function phic_type2()
a = 1
try
may_error(false)
a = 1.0
may_error(false)
catch
return a
end
return 2
end
@test Base.return_types(phic_type2) |> only === Union{Int, Float64}
@test phic_type2() === 2

function phic_type3()
a = 1
try
may_error(false)
a = 1.0
may_error(false)
if Base.inferencebarrier(false)
a = Ref(1)
elseif Base.inferencebarrier(false)
a = nothing
end
catch
return a
end
return 2
end
@test Base.return_types(phic_type3) |> only === Union{Int, Float64}
@test phic_type3() === 2

0 comments on commit 0acca3c

Please sign in to comment.