Skip to content

Commit

Permalink
propagate ambiguities from rrule lookup instead of failing inexplicably
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir committed Aug 2, 2022
1 parent 5c80f55 commit 328eb4d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ matching_cr_sig(t, s) = matching_cr_sig(t.method.sig, s.method.sig)
matching_cr_sig(::DataType, ::UnionAll) = false
matching_cr_sig(::UnionAll, ::DataType) = false
matching_cr_sig(t::Type, s::Type) = type_tuple_tail(t) == type_tuple_tail(s)

matching_cr_sig(::Any, ::Nothing) = false # https://github.com/FluxML/Zygote.jl/issues/1234

type_tuple_tail(d::DataType) = Tuple{d.parameters[2:end]...}
function type_tuple_tail(d::UnionAll)
body = Base.unwrap_unionall(d)
Expand Down
9 changes: 9 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,15 @@ using Zygote: ZygoteRuleConfig
@test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
end
end

# https://github.com/FluxML/Zygote.jl/issues/1234
@testset "rrule lookup ambiguities" begin
f_ambig(x, y) = x + y
ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0)
ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0)

@test_throws MethodError pullback(f_ambig, 1, 2)
end
end

@testset "ChainRulesCore.rrule_via_ad" begin
Expand Down

0 comments on commit 328eb4d

Please sign in to comment.