diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 99d8f4652..7c7de8655 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -73,7 +73,8 @@ matching_cr_sig(t, s) = matching_cr_sig(t.method.sig, s.method.sig) matching_cr_sig(::DataType, ::UnionAll) = false matching_cr_sig(::UnionAll, ::DataType) = false matching_cr_sig(t::Type, s::Type) = type_tuple_tail(t) == type_tuple_tail(s) - +matching_cr_sig(::Any, ::Nothing) = false # https://github.com/FluxML/Zygote.jl/issues/1234 + type_tuple_tail(d::DataType) = Tuple{d.parameters[2:end]...} function type_tuple_tail(d::UnionAll) body = Base.unwrap_unionall(d) diff --git a/test/chainrules.jl b/test/chainrules.jl index 94ab9584a..e9cb4afbc 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -275,6 +275,15 @@ using Zygote: ZygoteRuleConfig @test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,) end end + + # https://github.com/FluxML/Zygote.jl/issues/1234 + @testset "rrule lookup ambiguities" begin + f_ambig(x, y) = x + y + ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0) + ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0) + + @test_throws MethodError pullback(f_ambig, 1, 2) + end end @testset "ChainRulesCore.rrule_via_ad" begin