diff --git a/src/passes/passes.jl b/src/passes/passes.jl index 11ee3a8..1a8d236 100644 --- a/src/passes/passes.jl +++ b/src/passes/passes.jl @@ -201,41 +201,83 @@ function prune!(ir::IR) return ir end -function slotsused(bl) - slots = [] - walk(ex) = prewalk(x -> (x isa Slot && !(x in slots) && push!(slots, x); x), ex) - for (v, st) in bl - ex = st.expr - isexpr(ex, :(=)) ? walk(ex.args[2]) : walk(ex) - end - return slots +struct CatchBranch + defs::Dict{Slot,Any} + v::Variable end function ssa!(ir::IR) current = 1 defs = Dict(b => Dict{Slot,Any}() for b in 1:length(ir.blocks)) todo = Dict(b => Dict{Int,Vector{Slot}}() for b in 1:length(ir.blocks)) - catches = Dict() - handlers = [] - function reaching(b, v) - haskey(defs[b.id], v) && return defs[b.id][v] + catch_branches = Dict{Int,Vector{CatchBranch}}() + handlers = Int[] + function reaching(b, slot) + haskey(defs[b.id], slot) && return defs[b.id][slot] b.id == 1 && return undef - x = defs[b.id][v] = argument!(b, type = v.type, insert = false) + x = defs[b.id][slot] = argument!(b, type = slot.type, insert = false) for pred in predecessors(b) if pred.id < current for br in branches(pred, b) - push!(br.args, reaching(pred, v)) + push!(br.args, reaching(pred, slot)) end else - push!(get!(todo[pred.id], b.id, Slot[]), v) + push!(get!(todo[pred.id], b.id, Slot[]), slot) end end + + if haskey(catch_branches, b.id) + # for each 'catch' branch to this catch block (catch block has `length(predecessors(b)) == 0`), + # we try to find the dominating definition for slot v. + # defs[block(ir, cbr.v).id] contains the defs at the end of + # the block, so we use the cached defs in catch_branches instead. + for cbr in catch_branches[b.id] + cbr_v = cbr.v + stmt = ir[cbr_v] + if haskey(cbr.defs, slot) + # Slot v was defined at catch branch + push!(stmt.expr.args, cbr.defs[slot]) + else + # Find slot v definition from instruction cbr_v + b = block(ir, cbr_v) + if b.id == 1 + push!(stmt.expr.args, undef) + continue + end + + # there is already a def for this slot as an argument to the block + # but which was added after the catch branch. + if haskey(defs[b.id], slot) && defs[b.id][slot] isa Variable + bdef = defs[b.id][slot] + (def_b, loc) = ir.defs[bdef.id] + if def_b == b.id && loc < 0 + push!(stmt.expr.args, bdef) + continue + end + end + + # get the slot definition from each predecessors of the block owning the catch 'branch' + new_arg = defs[b.id][slot] = argument!(b; type=slot.type, insert=false) + push!(stmt.expr.args, new_arg) + for pred in predecessors(b) + if pred.id < current + for br in branches(pred, b) + push!(br.args, reaching(pred, slot)) + end + else + push!(get!(todo[pred.id], b.id, Slot[]), slot) + end + end + end + end + end + return x end function catchbranch!(v, slot = nothing) for handler in handlers - args = reaching.((block(ir, v),), catches[handler]) - insertafter!(ir, v, Expr(:catch, handler, args...)) + cbr = CatchBranch(copy(defs[current]), insertafter!(ir, v, Expr(:catch, handler))) + push!(get!(Vector{CatchBranch}, catch_branches, handler), cbr) end end for b in blocks(ir) @@ -248,10 +290,9 @@ function ssa!(ir::IR) catchbranch!(v, ex.args[1]) delete!(ir, v) elseif isexpr(ex, :enter) - catches[ex.args[1]] = slotsused(block(ir, ex.args[1])) push!(handlers, ex.args[1]) catchbranch!(v) - elseif isexpr(ex, :leave) && !haskey(catches, current) + elseif isexpr(ex, :leave) && !haskey(catch_branches, current) pop!(handlers) else ir[v] = rename(ex) diff --git a/test/compiler.jl b/test/compiler.jl index 6d887e8..c89fa97 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -75,7 +75,7 @@ function err3(f) end @test passthrough(err3, () -> 2+2) == 4 -@test_broken passthrough(err3, () -> 0//0) == 1 +@test passthrough(err3, () -> 0//0) == 1 @dynamo function mullify(a...) ir = IR(a...) @@ -222,10 +222,140 @@ function f_try_catch(x) y end +function f_try_catch2(x, cond) + local y + if cond + y = 2x + end + + try + x = 3 * error() + catch + end + + y +end + +function f_try_catch3() + local x + try + error() + catch + x = 42 + end + x +end + +function f_try_catch4(x, cond) + local y + try + throw(x) + catch err + if cond + y = err + x + end + end + y +end + +function f_try_catch5(x, cond) + local y + cond && (x = 2x) + try + y = x + cond && error() + catch + y = x + 1 + end + y +end + +function f_try_catch6(cond, y) + x = 1 + + if cond + y = 10y + else + y = 10y + end + + try + cond && error() + catch + y = 2x + end + + y+x +end + +function f_try_catch7() + local x = 1. + + for _ in 1:10 + + try + x = sqrt(x) + x -= 1. + catch + x = -x + end + + x = x ^ 2 + end + + x +end + @testset "try/catch" begin ir = @code_ir f_try_catch(1.) - @test true fir = func(ir) - @test fir(nothing,1.) == 1. - @test_broken fir(nothing,-1.) == 1. + @test fir(nothing,1.) === 1. + @test fir(nothing,-1.) === 0. + + ir = @code_ir f_try_catch2(1., false) + fir = func(ir) + + # This should be @test_throws UndefVarError fir(nothing,42,false) + # See TODO in `IRTools.slots!` + @test_broken try + fir(nothing,42,false) + false + catch e + e isa UndefVarError + end + @test fir(nothing, 42, false) === IRTools.undef + @test fir(nothing, 42, true) === 84 + + ir = @code_ir f_try_catch3() + @test all(ir) do (_, stmt) + !IRTools.isexpr(stmt.expr, :catch) || + length(stmt.expr.args) == 1 + end + fir = func(ir) + @test fir(nothing) == 42 + + ir = @code_ir f_try_catch4(42, false) + fir = func(ir) + # This should be @test_throws UndefVarError fir(nothing,42,false) + @test_broken try + fir(nothing, 42, false) + false + catch e + e isa UndefVarError + end + @test fir(nothing, 42, false) === IRTools.undef + @test fir(nothing, 42, true) === 84 + + ir = @code_ir f_try_catch5(1, false) + fir = func(ir) + @test fir(nothing, 3, false) === 3 + @test fir(nothing, 3, true) === 7 + + ir = @code_ir f_try_catch6(true, 1) + fir = func(ir) + @test fir(nothing, true, 1) === 3 + @test fir(nothing, false, 1) === 11 + + ir = @code_ir f_try_catch7() + @test func(ir)(nothing) === 1. end