Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ssa conversion for catch blocks #117

Merged
merged 8 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 60 additions & 19 deletions src/passes/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,41 +201,83 @@ function prune!(ir::IR)
return ir
end

function slotsused(bl)
slots = []
walk(ex) = prewalk(x -> (x isa Slot && !(x in slots) && push!(slots, x); x), ex)
for (v, st) in bl
ex = st.expr
isexpr(ex, :(=)) ? walk(ex.args[2]) : walk(ex)
end
return slots
struct CatchBranch
defs::Dict{Slot,Any}
v::Variable
end

function ssa!(ir::IR)
current = 1
defs = Dict(b => Dict{Slot,Any}() for b in 1:length(ir.blocks))
todo = Dict(b => Dict{Int,Vector{Slot}}() for b in 1:length(ir.blocks))
catches = Dict()
handlers = []
function reaching(b, v)
haskey(defs[b.id], v) && return defs[b.id][v]
catch_branches = Dict{Int,Vector{CatchBranch}}()
handlers = Int[]
function reaching(b, slot)
haskey(defs[b.id], slot) && return defs[b.id][slot]
b.id == 1 && return undef
x = defs[b.id][v] = argument!(b, type = v.type, insert = false)
x = defs[b.id][slot] = argument!(b, type = slot.type, insert = false)
for pred in predecessors(b)
if pred.id < current
for br in branches(pred, b)
push!(br.args, reaching(pred, v))
push!(br.args, reaching(pred, slot))
end
else
push!(get!(todo[pred.id], b.id, Slot[]), v)
push!(get!(todo[pred.id], b.id, Slot[]), slot)
end
Comment on lines +215 to 226
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is purely just a clarification of the existing code, right? to rename v to slot
(A good clarification)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! In other parts of the package, v often refers to a Variable (SSA value) so I renamed to clarify.

end

if haskey(catch_branches, b.id)
# for each 'catch' branch to this catch block (catch block has `length(predecessors(b)) == 0`),
# we try to find the dominating definition for slot v.
# defs[block(ir, cbr.v).id] contains the defs at the end of
# the block, so we use the cached defs in catch_branches instead.
for cbr in catch_branches[b.id]
cbr_v = cbr.v
stmt = ir[cbr_v]
if haskey(cbr.defs, slot)
# Slot v was defined at catch branch
push!(stmt.expr.args, cbr.defs[slot])
else
# Find slot v definition from instruction cbr_v
b = block(ir, cbr_v)
if b.id == 1
push!(stmt.expr.args, undef)
continue
end

# there is already a def for this slot as an argument to the block
# but which was added after the catch branch.
if haskey(defs[b.id], slot) && defs[b.id][slot] isa Variable
bdef = defs[b.id][slot]
(def_b, loc) = ir.defs[bdef.id]
if def_b == b.id && loc < 0
push!(stmt.expr.args, bdef)
continue
end
end

# get the slot definition from each predecessors of the block owning the catch 'branch'
new_arg = defs[b.id][slot] = argument!(b; type=slot.type, insert=false)
push!(stmt.expr.args, new_arg)
for pred in predecessors(b)
if pred.id < current
for br in branches(pred, b)
push!(br.args, reaching(pred, slot))
end
else
push!(get!(todo[pred.id], b.id, Slot[]), slot)
end
end
end
end
end

return x
end
function catchbranch!(v, slot = nothing)
for handler in handlers
args = reaching.((block(ir, v),), catches[handler])
insertafter!(ir, v, Expr(:catch, handler, args...))
cbr = CatchBranch(copy(defs[current]), insertafter!(ir, v, Expr(:catch, handler)))
push!(get!(Vector{CatchBranch}, catch_branches, handler), cbr)
end
end
for b in blocks(ir)
Expand All @@ -248,10 +290,9 @@ function ssa!(ir::IR)
catchbranch!(v, ex.args[1])
delete!(ir, v)
elseif isexpr(ex, :enter)
catches[ex.args[1]] = slotsused(block(ir, ex.args[1]))
push!(handlers, ex.args[1])
catchbranch!(v)
elseif isexpr(ex, :leave) && !haskey(catches, current)
elseif isexpr(ex, :leave) && !haskey(catch_branches, current)
pop!(handlers)
else
ir[v] = rename(ex)
Expand Down
139 changes: 138 additions & 1 deletion test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function err3(f)
end

@test passthrough(err3, () -> 2+2) == 4
@test_broken passthrough(err3, () -> 0//0) == 1
@test passthrough(err3, () -> 0//0) == 1

@dynamo function mullify(a...)
ir = IR(a...)
Expand Down Expand Up @@ -211,3 +211,140 @@ end
@test (code_typed(func_ir, Tuple{typeof(func_ir)}) |> only
isa Pair{Core.CodeInfo,DataType})
end

function f_try_catch(x)
y = 0.
try
y = sqrt(x)
catch

end
y
end

function f_try_catch2(x, cond)
local y
if cond
y = 2x
end

try
x = 3 * error()
catch
end

y
end

function f_try_catch3()
local x
try
error()
catch
x = 42
end
x
end

function f_try_catch4(x, cond)
local y
try
throw(x)
catch err
if cond
y = err + x
end
end
y
end

function f_try_catch5(x, cond)
local y
cond && (x = 2x)
try
y = x
cond && error()
catch
y = x + 1
end
y
end

function f_try_catch6(cond, y)
x = 1

if cond
y = 10y
else
y = 10y
end

try
cond && error()
catch
y = 2x
end

y+x
end

function f_try_catch7()
local x = 1.

for _ in 1:10

try
x = sqrt(x)
x -= 1.
catch
x = -x
end

x = x ^ 2
end

x
end

@testset "try/catch" begin
ir = @code_ir f_try_catch(1.)
fir = func(ir)
@test fir(nothing,1.) === 1.
@test fir(nothing,-1.) === 0.

ir = @code_ir f_try_catch2(1., false)
fir = func(ir)

# This should be @test_throws UndefVarError fir(nothing,42,false)
# See TODO in `IRTools.slots!`
@test fir(nothing, 42, false) === IRTools.undef
@test fir(nothing, 42, true) === 84

ir = @code_ir f_try_catch3()
@test any(ir) do (_, stmt)
IRTools.isexpr(stmt.expr, :catch) &&
length(stmt.expr.args) == 1
end
fir = func(ir)
@test fir(nothing) == 42

ir = @code_ir f_try_catch4(42, false)
fir = func(ir)
# This should be @test_throws UndefVarError fir(nothing,42,false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test_throws UndefVarError ... broken=true for this? (and the one above)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. It looks like @test_throws does not support broken=true yet so I had added @test try ... end broken=true instead.

# See TODO in `IRTools.slots!`
@test fir(nothing, 42, false) === IRTools.undef
@test fir(nothing, 42, true) === 84

ir = @code_ir f_try_catch5(1, false)
fir = func(ir)
@test fir(nothing, 3, false) === 3
@test fir(nothing, 3, true) === 7

ir = @code_ir f_try_catch6(true, 1)
fir = func(ir)
@test fir(nothing, true, 1) === 3
@test fir(nothing, false, 1) === 11

ir = @code_ir f_try_catch7()
@test func(ir)(nothing) === 1.
end
Loading