Skip to content

Commit

Permalink
inference: Model type propagation through exceptions
Browse files Browse the repository at this point in the history
Currently the type of a caught exception is always modeled as `Any`.
This isn't a huge problem, because control flow in Julia is generally
assumed to be somewhat slow, so the extra type imprecision of not
knowing the return type does not matter all that much. However,
there are a few situations where it matters. For example:

```
maybe_getindex(A, i) =
    try; A[i]; catch e; isa(e, BoundsError) && return nothing; rethrow(); end
```

At present, we cannot infer :nothrow for this method, even if that
is the only error type that `A[i]` can throw. This is particularly
noticable, since we can now optimize away `:nothrow` exception frames
entirely (#51674). Note that this PR still does not make the above
example particularly efficient (at least interprocedurally), though
specialized codegen could be added on top of this to make that happen.
It does however improve the inference result.

A second major motivation of this change is that reasoning about
exception types is likely to be a major aspect of any future work
on interface checking (since interfaces imply the absence of
MethodErrors), so this PR lays the groundwork for appropriate modeling
of these error paths.

Note that this PR adds all the required plumbing, but does not yet have
a particularly precise model of error types for our builtins, bailing
to `Any` for any builtin not known to be `:nothrow`. This can be improved
in follow up PRs as required.
  • Loading branch information
Keno committed Nov 13, 2023
1 parent 16e61e2 commit 4128c99
Show file tree
Hide file tree
Showing 22 changed files with 386 additions and 211 deletions.
6 changes: 3 additions & 3 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,13 @@ eval(Core, quote
end)

function CodeInstance(
mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
mi::MethodInstance, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
ipo_effects::UInt32, effects::UInt32, @nospecialize(argescapes#=::Union{Nothing,Vector{ArgEscapeInfo}}=#),
relocatability::UInt8)
return ccall(:jl_new_codeinst, Ref{CodeInstance},
(Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world,
(Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
mi, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
ipo_effects, effects, argescapes,
relocatability)
end
Expand Down
288 changes: 181 additions & 107 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions base/compiler/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,60 @@ function Effects(effects::Effects = _EFFECTS_UNKNOWN;
nonoverlayed)
end

function better_effects(new::Effects, old::Effects)
any_improved = false
if new.consistent == ALWAYS_TRUE
any_improved |= old.consistent != ALWAYS_TRUE
elseif new.consistent != old.consistent
return false
end
if new.effect_free == ALWAYS_TRUE
any_improved |= old.consistent != ALWAYS_TRUE
elseif new.effect_free == EFFECT_FREE_IF_INACCESSIBLEMEMONLY
old.effect_free == ALWAYS_TRUE && return false
any_improved |= old.effect_free != EFFECT_FREE_IF_INACCESSIBLEMEMONLY
elseif new.effect_free != old.effect_free
return false
end
if new.nothrow
any_improved |= !old.nothrow
elseif new.nothrow != old.nothrow
return false
end
if new.terminates
any_improved |= !old.terminates
elseif new.terminates != old.terminates
return false
end
if new.notaskstate
any_improved |= !old.notaskstate
elseif new.notaskstate != old.notaskstate
return false
end
if new.inaccessiblememonly == ALWAYS_TRUE
any_improved |= old.inaccessiblememonly != ALWAYS_TRUE
elseif new.inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
old.inaccessiblememonly == ALWAYS_TRUE && return false
any_improved |= old.inaccessiblememonly != INACCESSIBLEMEM_OR_ARGMEMONLY
elseif new.inaccessiblememonly != old.inaccessiblememonly
return false
end
if new.noub == ALWAYS_TRUE
any_improved |= old.noub != ALWAYS_TRUE
elseif new.noub == NOUB_IF_NOINBOUNDS
old.noub == ALWAYS_TRUE && return false
any_improved |= old.noub != NOUB_IF_NOINBOUNDS
elseif new.noub != old.noub
return false
end
if new.nonoverlayed
any_improved |= !old.nonoverlayed
elseif new.nonoverlayed != old.nonoverlayed
return false
end
return any_improved
end

function merge_effects(old::Effects, new::Effects)
return Effects(
merge_effectbits(old.consistent, new.consistent),
Expand Down
58 changes: 36 additions & 22 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization allowed
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed

mutable struct TryCatchFrame
exct
const enter_idx
end

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand All @@ -218,7 +223,8 @@ mutable struct InferenceState
currbb::Int
currpc::Int
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handler_at::Vector{Int} # current exception handler info
handlers::Vector{TryCatchFrame}
handler_at::Vector{Tuple{Int, Int}} # tuple of current (handler, exception stack) value at the pc
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
Expand All @@ -239,6 +245,7 @@ mutable struct InferenceState
unreachable::BitSet # statements that were found to be statically unreachable
valid_worlds::WorldRange
bestguess #::Type
exc_bestguess
ipo_effects::Effects

#= flags =#
Expand Down Expand Up @@ -266,7 +273,7 @@ mutable struct InferenceState

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_at = compute_trycatch(code, BitSet())
handler_at, handlers = compute_trycatch(code, BitSet())
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
Expand Down Expand Up @@ -296,6 +303,7 @@ mutable struct InferenceState

valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
bestguess = Bottom
exc_bestguess = Bottom
ipo_effects = EFFECTS_TOTAL

insert_coverage = should_insert_coverage(mod, src)
Expand All @@ -315,9 +323,9 @@ mutable struct InferenceState

return new(
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handlers, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, ipo_effects,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
interp)
end
Expand Down Expand Up @@ -347,16 +355,19 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
empty!(ip)
ip.offset = 0 # for _bits_findnext
push!(ip, n + 1)
handler_at = fill(0, n)
handler_at = fill((0, 0), n)
handlers = TryCatchFrame[]

# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isexpr(stmt, :enter)
l = stmt.args[1]::Int
handler_at[pc + 1] = pc
push!(handlers, TryCatchFrame(Bottom, pc))
handler_id = length(handlers)
handler_at[pc + 1] = (handler_id, 0)
push!(ip, pc + 1)
handler_at[l] = pc
handler_at[l] = (handler_id, handler_id)
push!(ip, l)
end
end
Expand All @@ -369,25 +380,26 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
pc´ = pc + 1 # next program-counter (after executing instruction)
delete!(ip, pc)
cur_hand = handler_at[pc]
@assert cur_hand != 0 "unbalanced try/catch"
cur_stacks = handler_at[pc]
@assert cur_stacks != (0, 0) "unbalanced try/catch"
stmt = code[pc]
if isa(stmt, GotoNode)
pc´ = stmt.label
elseif isa(stmt, GotoIfNot)
l = stmt.dest::Int
if handler_at[l] != cur_hand
@assert handler_at[l] == 0 "unbalanced try/catch"
handler_at[l] = cur_hand
if handler_at[l] != cur_stacks
@assert handler_at[l][1] == 0 || handler_at[l][1] == cur_stacks[1] "unbalanced try/catch"
handler_at[l] = cur_stacks
push!(ip, l)
end
elseif isa(stmt, ReturnNode)
@assert !isdefined(stmt, :val) "unbalanced try/catch"
@assert !isdefined(stmt, :val) || cur_stacks[1] == 0 "unbalanced try/catch"
break
elseif isa(stmt, Expr)
head = stmt.head
if head === :enter
cur_hand = pc
# Already set above
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
elseif head === :leave
l = 0
for j = 1:length(stmt.args)
Expand All @@ -403,19 +415,21 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
end
l += 1
end
cur_hand = cur_stacks[1]
for i = 1:l
cur_hand = handler_at[cur_hand]
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
end
cur_hand == 0 && break
cur_stacks = (cur_hand, cur_stacks[2])
cur_stacks == (0, 0) && break
elseif head === :pop_exception
cur_stacks = (cur_stacks[1], handler_at[(stmt.args[1]::SSAValue).id][2])
cur_stacks == (0, 0) && break
end
end

pc´ > n && break # can't proceed with the fast-path fall-through
if handler_at[pc´] != cur_hand
if handler_at[pc´] != 0
@assert false "unbalanced try/catch"
end
handler_at[pc´] = cur_hand
if handler_at[pc´] != cur_stacks
handler_at[pc´] = cur_stacks
elseif !in(pc´, ip)
break # already visited
end
Expand All @@ -424,7 +438,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
end

@assert first(ip) == n + 1
return handler_at
return handler_at, handlers
end

# check if coverage mode is enabled
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ end

function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
si = StmtInfo(true) # TODO better job here?
(; rt, effects, info) = abstract_call(interp, arginfo, si, irsv)
(; rt, exct, effects, info) = abstract_call(interp, arginfo, si, irsv)
irsv.ir.stmts[irsv.curridx][:info] = info
return RTEffects(rt, effects)
return RTEffects(rt, exct, effects)
end

function update_phi!(irsv::IRInterpretationState, from::Int, to::Int)
Expand Down
15 changes: 10 additions & 5 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
end

# Record the correct exception handler for all critical sections
handler_at = compute_trycatch(code, BitSet())
handler_at, handlers = compute_trycatch(code, BitSet())

phi_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
live_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
Expand Down Expand Up @@ -627,10 +627,12 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
# The slot is live-in into this block. We need to
# Create a PhiC node in the catch entry block and
# an upsilon node in the corresponding enter block
varstate = sv.bb_vartables[li]
if varstate === nothing
continue
end
node = PhiCNode(Any[])
insertpoint = first_insert_for_bb(code, cfg, li)
varstate = sv.bb_vartables[li]
@assert varstate !== nothing
vt = varstate[idx]
phic_ssa = NewSSAValue(
insert_node!(ir, insertpoint,
Expand Down Expand Up @@ -690,6 +692,9 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
new_nodes = ir.new_nodes
@timeit "SSA Rename" while !isempty(worklist)
(item::Int, pred, incoming_vals) = pop!(worklist)
if sv.bb_vartables[item] === nothing
continue
end
# Rename existing phi nodes first, because their uses occur on the edge
# TODO: This isn't necessary if inlining stops replacing arguments by slots.
for idx in cfg.blocks[item].stmts
Expand Down Expand Up @@ -810,8 +815,8 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
incoming_vals[id] = Pair{Any, Any}(thisval, thisdef)
has_pinode[id] = false
enter_idx = idx
while handler_at[enter_idx] != 0
enter_idx = handler_at[enter_idx]
while handler_at[enter_idx][1] != 0
(; enter_idx) = handlers[handler_at[enter_idx][1]]
leave_block = block_for_inst(cfg, code[enter_idx].args[1]::Int)
cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, new_phic_nodes[leave_block])
if cidx !== nothing
Expand Down
1 change: 1 addition & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and any additional information (`call.info`) for a given generic call.
"""
struct CallMeta
rt::Any
exct::Any
effects::Effects
info::CallInfo
end
Expand Down
Loading

0 comments on commit 4128c99

Please sign in to comment.