Skip to content

Commit

Permalink
Merge pull request #1379 from FluxML/bc/fwd-undef-ref
Browse files Browse the repository at this point in the history
Forward mode: don't use `zerolike` fallback for `GlobalRef`s
  • Loading branch information
ToucheSir authored Feb 26, 2023
2 parents 6937f06 + 14120e9 commit 215828d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/forward/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ end
# TODO figure out why this made a test fail
zerolike(x::Union{Module,Type}) = nothing

# Required to not get an UndefRefError on 1.10
zerolike(x::GlobalRef) = nothing

# TODO: `@non_differentiable` and `@linear`

@tangent zerolike(x) = zerolike(x), _ -> zerolike(x)
Expand Down
22 changes: 14 additions & 8 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ trace_contains(st, func, file, line) = any(st) do fr
end

bad(x) = x
const bad_def_line = (@__LINE__) + 1
@adjoint bad(x) = x, Δ -> error("bad")

const bad_call_line = (@__LINE__) + 3
function badly(x)
x = x + 1
x = bad(x)
Expand All @@ -30,11 +32,11 @@ y, back = pullback(badly, 2)
@test_throws Exception back(1)
bt = try back(1) catch e stacktrace(catch_backtrace()) end

@test trace_contains(bt, nothing, "compiler.jl", 20)
if VERSION >= v"1.6-"
@test_broken trace_contains(bt, :badly, "compiler.jl", 24)
@test trace_contains(bt, nothing, "compiler.jl", bad_def_line)
if VERSION <= v"1.6-" || VERSION >= v"1.10-"
@test trace_contains(bt, :badly, "compiler.jl", bad_call_line)
else
@test trace_contains(bt, :badly, "compiler.jl", 24)
@test_broken trace_contains(bt, :badly, "compiler.jl", bad_call_line)
end

# Type inference checks
Expand Down Expand Up @@ -81,10 +83,14 @@ y, back = @test_inferred pullback(((a,b),) -> a, (5, 10))

# testcase for issue #808
# testing that methods(Base.show) does not throw. Having something more specific would be too fragile
buf = IOBuffer()
Base.show(buf, methods(Base.show))
str_repr = String(take!(buf))
@test !isempty(str_repr)
show_err = try
buf = IOBuffer()
Base.show(buf, methods(Base.show))
nothing
catch ex
ex
end
@test show_err === nothing

struct Funky
x
Expand Down

0 comments on commit 215828d

Please sign in to comment.