diff --git a/src/passes/passes.jl b/src/passes/passes.jl index 87b3bcc..805f5ef 100644 --- a/src/passes/passes.jl +++ b/src/passes/passes.jl @@ -51,7 +51,11 @@ function usecounts(ir::IR) end function dominators(cfg; entry = 1) + blocks = reachable_blocks(cfg, entry) preds = cfg' + for i in 1:length(cfg.graph) + preds.graph[i] = filter(a -> a in blocks, preds[i]) + end blocks = [1:length(cfg.graph);] doms = Dict(b => Set(blocks) for b in blocks) while !isempty(blocks) @@ -72,8 +76,9 @@ function dominators(cfg; entry = 1) end function domtree(cfg::CFG; entry = 1) + blocks = sort(reachable_blocks(cfg, entry)) doms = dominators(cfg, entry = entry) - doms = Dict(b => filter(c -> b != c && b in doms[c], 1:length(cfg)) for b in 1:length(cfg)) + doms = Dict(b => filter(c -> b != c && b in doms[c], blocks) for b in blocks) children(b) = filter(c -> !(c in union(map(c -> doms[c], doms[b])...)), doms[b]) tree(b) = Pair{Int,Any}(b,tree.(children(b))) tree(entry) @@ -308,10 +313,10 @@ function ssa!(ir::IR) return ir end -function reachable_blocks(cfg::CFG) +function reachable_blocks(cfg::CFG, entry = 1) bs = Int[] reaches(b) = b in bs || (push!(bs, b); reaches.(cfg[b])) - reaches(1) + reaches(entry) return bs end diff --git a/test/compiler.jl b/test/compiler.jl index 1fce761..d77a840 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -377,3 +377,26 @@ end @test fir(nothing, 1.) == log(1. - log(2.)) @test fir(nothing, -1.) == 1. end + +@testset "functional" begin + relu(x) = (y = x > 0 ? x : 0) + ir = @code_ir relu(1) + + @test_nowarn functional(ir) +end + +@testset "while break" begin + function while_loop() + while true break end + end + tmp_ir = @code_ir(while_loop()) + @test_nowarn functional(tmp_ir) + + function while_loop2() + while true l == 0 && break end + end + tmp_ir = @code_ir(while_loop2()) + @test_nowarn functional(tmp_ir) +end + +