diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 078bb602a..89ed9fe8a 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -443,10 +443,10 @@ frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...) function rrule(::typeof(+), arrs::AbstractArray...) y = +(arrs...) - arr_axs = map(axes, arrs) + projs = map(ProjectTo, arrs) function add_pullback(dy_raw) - dy = unthunk(dy_raw) # reshape will otherwise unthunk N times - return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...) + dy = unthunk(dy_raw) # projs will otherwise unthunk N times + return (NoTangent(), map(proj -> proj(dy), projs)...) end return y, add_pullback end diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 2682f3b8a..0a1416444 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -215,6 +215,8 @@ @gpu test_frule(+, randn(2), randn(2), randn(2)) # rev @gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) - @gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) + @gpu test_rrule(+, randn(3), randn(3, 1), randn(3, 1, 1)) + test_rrule(+, randn(3, 3), Diagonal(randn(3)), randn(3, 3, 1)) + test_rrule(+, randn(3, 3), Diagonal(randn(3)), Symmetric(randn(3, 3))) end end