From 4651208b43a7b20bfebbd24a27455dfa7d2e49e0 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 3 Dec 2023 18:42:54 +0100 Subject: [PATCH 1/7] fix ssa conversion for catch blocks the slotused function was not enough to account for all slot used by the catch block and its successors. With this change, ssa conversion keeps a live list of all catch 'branch' instructions and fetches the reaching definitions for slots at the location of these :catch instructions. --- src/passes/passes.jl | 78 +++++++++++++++++++++++++++++++++----------- test/compiler.jl | 19 ++++++++++- 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/src/passes/passes.jl b/src/passes/passes.jl index 11ee3a8..828d9e8 100644 --- a/src/passes/passes.jl +++ b/src/passes/passes.jl @@ -201,41 +201,82 @@ function prune!(ir::IR) return ir end -function slotsused(bl) - slots = [] - walk(ex) = prewalk(x -> (x isa Slot && !(x in slots) && push!(slots, x); x), ex) - for (v, st) in bl - ex = st.expr - isexpr(ex, :(=)) ? walk(ex.args[2]) : walk(ex) - end - return slots +struct CatchBranch + defs::Dict{Slot,Any} + v::Variable end function ssa!(ir::IR) current = 1 defs = Dict(b => Dict{Slot,Any}() for b in 1:length(ir.blocks)) todo = Dict(b => Dict{Int,Vector{Slot}}() for b in 1:length(ir.blocks)) - catches = Dict() - handlers = [] - function reaching(b, v) - haskey(defs[b.id], v) && return defs[b.id][v] + catch_branches = Dict{Int,Vector{CatchBranch}}() + handlers = Int[] + function reaching(b, slot) + haskey(defs[b.id], slot) && return defs[b.id][slot] b.id == 1 && return undef - x = defs[b.id][v] = argument!(b, type = v.type, insert = false) + x = defs[b.id][slot] = argument!(b, type = slot.type, insert = false) for pred in predecessors(b) if pred.id < current for br in branches(pred, b) - push!(br.args, reaching(pred, v)) + push!(br.args, reaching(pred, slot)) end else - push!(get!(todo[pred.id], b.id, Slot[]), v) + push!(get!(todo[pred.id], b.id, Slot[]), slot) end end + + if haskey(catch_branches, b.id) + # for each 'catch' branch to this catch block (catch block has `length(predecessors(b)) == 0`), + # we try to find the dominating definition for slot v. + # defs[block(ir, cbr.v).id] contains the defs at the end of + # the block, so we use the cached defs in catch_branches instead. + for cbr in catch_branches[b.id] + cbr_v = cbr.v + stmt = ir[cbr_v] + if haskey(cbr.defs, slot) + # Slot v was defined at catch branch + push!(stmt.expr.args, cbr.defs[slot]) + else + # Find slot v definition from instruction cbr_v + b = block(ir, cbr_v) + if b.id == 1 + push!(stmt.expr.args, undef) + continue + end + + # there is already a def for this slot as an argument to the block + # but which was added after the catch branch. + if haskey(defs[b.id], slot) && defs[b.id][slot] isa Variable + bdef = defs[b.id][slot] + (def_b, loc) = ir.defs[bdef.id] + if def_b == b.id && loc < 0 + push!(stmt.expr.args, bdef) + continue + end + end + + # get the slot definition from each predecessors of the block owning the catch 'branch' + defs[b.id][slot] = argument!(b; type=slot.type, insert=false) + for pred in predecessors(b) + if pred.id < current + for br in branches(pred, b) + push!(br.args, reaching(pred, slot)) + end + else + push!(get!(todo[pred.id], b.id, Slot[]), slot) + end + end + end + end + end + return x end function catchbranch!(v, slot = nothing) for handler in handlers - args = reaching.((block(ir, v),), catches[handler]) - insertafter!(ir, v, Expr(:catch, handler, args...)) + cbr = CatchBranch(copy(defs[current]), insertafter!(ir, v, Expr(:catch, handler))) + push!(get!(Vector{Variable}, catch_branches, handler), cbr) end end for b in blocks(ir) @@ -248,10 +289,9 @@ function ssa!(ir::IR) catchbranch!(v, ex.args[1]) delete!(ir, v) elseif isexpr(ex, :enter) - catches[ex.args[1]] = slotsused(block(ir, ex.args[1])) push!(handlers, ex.args[1]) catchbranch!(v) - elseif isexpr(ex, :leave) && !haskey(catches, current) + elseif isexpr(ex, :leave) && !haskey(catch_branches, current) pop!(handlers) else ir[v] = rename(ex) diff --git a/test/compiler.jl b/test/compiler.jl index 634ccec..4b0beef 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -75,7 +75,7 @@ function err3(f) end @test passthrough(err3, () -> 2+2) == 4 -@test_broken passthrough(err3, () -> 0//0) == 1 +@test passthrough(err3, () -> 0//0) == 1 @dynamo function mullify(a...) ir = IR(a...) @@ -211,3 +211,20 @@ end @test (code_typed(func_ir, Tuple{typeof(func_ir)}) |> only isa Pair{Core.CodeInfo,DataType}) end + +function f_try_catch(x) + y = 0. + try + y = sqrt(x) + catch + + end + y +end + +@testset "try/catch" begin + ir = @code_ir f_try_catch(1.) + fir = func(ir) + @test fir(nothing,1.) == 1. + @test fir(nothing,-1.) == 0. +end From 284684c06c3a028c0bd21b41494e5d8d0cbebd35 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 5 Dec 2023 12:21:22 +0100 Subject: [PATCH 2/7] add more tests --- test/compiler.jl | 80 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 4b0beef..82fe0d2 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -222,9 +222,85 @@ function f_try_catch(x) y end +function f_try_catch2(x, cond) + local y + if cond + y = 2x + end + + try + x = 3 * error() + catch + end + + y +end + +function f_try_catch3() + local x + try + error() + catch + x = 42 + end + x +end + +function f_try_catch4(x, cond) + local y + try + throw(x) + catch err + if cond + y = err + x + end + end + y +end + +function f_try_catch5(x, cond) + local y + cond && (x = 2x) + try + y = x + cond && error() + catch + y = x + 1 + end + y +end + @testset "try/catch" begin ir = @code_ir f_try_catch(1.) fir = func(ir) - @test fir(nothing,1.) == 1. - @test fir(nothing,-1.) == 0. + @test fir(nothing,1.) === 1. + @test fir(nothing,-1.) === 0. + + ir = @code_ir f_try_catch2(1., false) + fir = func(ir) + + # This should be @test_throws UndefVarError fir(nothing,42,false) + # See TODO in `IRTools.slots!` + @test fir(nothing, 42, false) === IRTools.undef + @test fir(nothing, 42, true) === 84 + + ir = @code_ir f_try_catch3() + @test any(ir) do (_, stmt) + IRTools.isexpr(stmt.expr, :catch) && + length(stmt.expr.args) == 1 + end + fir = func(ir) + @test fir(nothing) == 42 + + ir = @code_ir f_try_catch4(42, false) + fir = func(ir) + # This should be @test_throws UndefVarError fir(nothing,42,false) + # See TODO in `IRTools.slots!` + @test fir(nothing, 42, false) === IRTools.undef + @test fir(nothing, 42, true) === 84 + + ir = @code_ir f_try_catch5(1, false) + fir = func(ir) + @test fir(nothing, 3, false) === 3 + @test fir(nothing, 3, true) === 7 end From 0c9c2bd3656035ca0ed8edd532699c6cf01d9d7c Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 5 Dec 2023 12:25:22 +0100 Subject: [PATCH 3/7] Use right type in default --- src/passes/passes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/passes/passes.jl b/src/passes/passes.jl index 828d9e8..5f017d9 100644 --- a/src/passes/passes.jl +++ b/src/passes/passes.jl @@ -276,7 +276,7 @@ function ssa!(ir::IR) function catchbranch!(v, slot = nothing) for handler in handlers cbr = CatchBranch(copy(defs[current]), insertafter!(ir, v, Expr(:catch, handler))) - push!(get!(Vector{Variable}, catch_branches, handler), cbr) + push!(get!(Vector{CatchBranch}, catch_branches, handler), cbr) end end for b in blocks(ir) From 43fc81221315b594a37608c1257a9b0f2ec052e0 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 5 Dec 2023 12:55:40 +0100 Subject: [PATCH 4/7] fix indent && more tests --- src/passes/passes.jl | 79 ++++++++++++++++++++++---------------------- test/compiler.jl | 46 +++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 40 deletions(-) diff --git a/src/passes/passes.jl b/src/passes/passes.jl index 5f017d9..2022fb9 100644 --- a/src/passes/passes.jl +++ b/src/passes/passes.jl @@ -227,48 +227,49 @@ function ssa!(ir::IR) end if haskey(catch_branches, b.id) - # for each 'catch' branch to this catch block (catch block has `length(predecessors(b)) == 0`), - # we try to find the dominating definition for slot v. - # defs[block(ir, cbr.v).id] contains the defs at the end of - # the block, so we use the cached defs in catch_branches instead. - for cbr in catch_branches[b.id] - cbr_v = cbr.v - stmt = ir[cbr_v] - if haskey(cbr.defs, slot) - # Slot v was defined at catch branch - push!(stmt.expr.args, cbr.defs[slot]) + # for each 'catch' branch to this catch block (catch block has `length(predecessors(b)) == 0`), + # we try to find the dominating definition for slot v. + # defs[block(ir, cbr.v).id] contains the defs at the end of + # the block, so we use the cached defs in catch_branches instead. + for cbr in catch_branches[b.id] + cbr_v = cbr.v + stmt = ir[cbr_v] + if haskey(cbr.defs, slot) + # Slot v was defined at catch branch + push!(stmt.expr.args, cbr.defs[slot]) + else + # Find slot v definition from instruction cbr_v + b = block(ir, cbr_v) + if b.id == 1 + push!(stmt.expr.args, undef) + continue + end + + # there is already a def for this slot as an argument to the block + # but which was added after the catch branch. + if haskey(defs[b.id], slot) && defs[b.id][slot] isa Variable + bdef = defs[b.id][slot] + (def_b, loc) = ir.defs[bdef.id] + if def_b == b.id && loc < 0 + push!(stmt.expr.args, bdef) + continue + end + end + + # get the slot definition from each predecessors of the block owning the catch 'branch' + x = defs[b.id][slot] = argument!(b; type=slot.type, insert=false) + push!(stmt.expr.args, x) + for pred in predecessors(b) + if pred.id < current + for br in branches(pred, b) + push!(br.args, reaching(pred, slot)) + end else - # Find slot v definition from instruction cbr_v - b = block(ir, cbr_v) - if b.id == 1 - push!(stmt.expr.args, undef) - continue - end - - # there is already a def for this slot as an argument to the block - # but which was added after the catch branch. - if haskey(defs[b.id], slot) && defs[b.id][slot] isa Variable - bdef = defs[b.id][slot] - (def_b, loc) = ir.defs[bdef.id] - if def_b == b.id && loc < 0 - push!(stmt.expr.args, bdef) - continue - end - end - - # get the slot definition from each predecessors of the block owning the catch 'branch' - defs[b.id][slot] = argument!(b; type=slot.type, insert=false) - for pred in predecessors(b) - if pred.id < current - for br in branches(pred, b) - push!(br.args, reaching(pred, slot)) - end - else - push!(get!(todo[pred.id], b.id, Slot[]), slot) - end - end + push!(get!(todo[pred.id], b.id, Slot[]), slot) end + end end + end end return x diff --git a/test/compiler.jl b/test/compiler.jl index 82fe0d2..a0ecc0c 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -270,6 +270,42 @@ function f_try_catch5(x, cond) y end +function f_try_catch6(cond, y) + x = 1 + + if cond + y = 10y + else + y = 10y + end + + try + cond && error() + catch + y = 2x + end + + y+x +end + +function f_try_catch7() + local x = 1. + + for _ in 1:10 + + try + x = sqrt(x) + x -= 1. + catch + x = -x + end + + x = x ^ 2 + end + + x +end + @testset "try/catch" begin ir = @code_ir f_try_catch(1.) fir = func(ir) @@ -287,7 +323,7 @@ end ir = @code_ir f_try_catch3() @test any(ir) do (_, stmt) IRTools.isexpr(stmt.expr, :catch) && - length(stmt.expr.args) == 1 + length(stmt.expr.args) == 1 end fir = func(ir) @test fir(nothing) == 42 @@ -303,4 +339,12 @@ end fir = func(ir) @test fir(nothing, 3, false) === 3 @test fir(nothing, 3, true) === 7 + + ir = @code_ir f_try_catch6(true, 1) + fir = func(ir) + @test fir(nothing, true, 1) === 3 + @test fir(nothing, false, 1) === 11 + + ir = @code_ir f_try_catch7() + @test func(ir)(nothing) === 1. end From 44f59c30bc90ac95dc02f142be480e6fc7a5aa27 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Tue, 5 Dec 2023 12:57:36 +0100 Subject: [PATCH 5/7] Update passes.jl --- src/passes/passes.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/passes/passes.jl b/src/passes/passes.jl index 2022fb9..1a8d236 100644 --- a/src/passes/passes.jl +++ b/src/passes/passes.jl @@ -257,8 +257,8 @@ function ssa!(ir::IR) end # get the slot definition from each predecessors of the block owning the catch 'branch' - x = defs[b.id][slot] = argument!(b; type=slot.type, insert=false) - push!(stmt.expr.args, x) + new_arg = defs[b.id][slot] = argument!(b; type=slot.type, insert=false) + push!(stmt.expr.args, new_arg) for pred in predecessors(b) if pred.id < current for br in branches(pred, b) From 1d45fcbe0ca48a7095dcea9c3c54ca2dec104b4f Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 8 Dec 2023 20:39:15 +0100 Subject: [PATCH 6/7] add test_broken for UndefVarError --- test/compiler.jl | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index a0ecc0c..5f9f5dc 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -317,12 +317,18 @@ end # This should be @test_throws UndefVarError fir(nothing,42,false) # See TODO in `IRTools.slots!` + @test try + fir(nothing,42,false) + false + catch e + e isa UndefVarError + end broken=true @test fir(nothing, 42, false) === IRTools.undef @test fir(nothing, 42, true) === 84 ir = @code_ir f_try_catch3() - @test any(ir) do (_, stmt) - IRTools.isexpr(stmt.expr, :catch) && + @test all(ir) do (_, stmt) + !IRTools.isexpr(stmt.expr, :catch) || length(stmt.expr.args) == 1 end fir = func(ir) @@ -331,7 +337,12 @@ end ir = @code_ir f_try_catch4(42, false) fir = func(ir) # This should be @test_throws UndefVarError fir(nothing,42,false) - # See TODO in `IRTools.slots!` + @test try + fir(nothing, 42, false) + false + catch e + e isa UndefVarError + end broken=true @test fir(nothing, 42, false) === IRTools.undef @test fir(nothing, 42, true) === 84 From 2d6242deb229a5548c6507600f6d92a22e3ccd3a Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 8 Dec 2023 21:13:09 +0100 Subject: [PATCH 7/7] use test_broken for 1.6 --- test/compiler.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 5f9f5dc..c89fa97 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -317,12 +317,12 @@ end # This should be @test_throws UndefVarError fir(nothing,42,false) # See TODO in `IRTools.slots!` - @test try + @test_broken try fir(nothing,42,false) false catch e e isa UndefVarError - end broken=true + end @test fir(nothing, 42, false) === IRTools.undef @test fir(nothing, 42, true) === 84 @@ -337,12 +337,12 @@ end ir = @code_ir f_try_catch4(42, false) fir = func(ir) # This should be @test_throws UndefVarError fir(nothing,42,false) - @test try + @test_broken try fir(nothing, 42, false) false catch e e isa UndefVarError - end broken=true + end @test fir(nothing, 42, false) === IRTools.undef @test fir(nothing, 42, true) === 84