Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing subtract rule #782

Merged
merged 5 commits into from
Feb 12, 2024
Merged

Add missing subtract rule #782

merged 5 commits into from
Feb 12, 2024

Conversation

nmheim
Copy link
Contributor

@nmheim nmheim commented Feb 9, 2024

So far there was only a rule for negation, this adds a rule for subtraction (two-argument -).

Comment on lines 441 to 445
frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x-y, Δx-Δy

function rrule(::typeof(-), x::AbstractArray, y::AbstractArray)
subtract_pullback(dy) = (NoTangent(), dy, -dy)
return x-y, subtract_pullback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the autoformatter is probably complaining about spaces around binary operators

Suggested change
frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x-y, Δx-Δy
function rrule(::typeof(-), x::AbstractArray, y::AbstractArray)
subtract_pullback(dy) = (NoTangent(), dy, -dy)
return x-y, subtract_pullback
frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = (x - y), (Δx - Δy)
function rrule(::typeof(-), x::AbstractArray, y::AbstractArray)
subtract_pullback(dy) = (NoTangent(), dy, -dy)
return x - y, subtract_pullback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, fixed already!:)

@nmheim
Copy link
Contributor Author

nmheim commented Feb 9, 2024

@oxinabox the rrule for multi-arg + has some reshape related things in the pullback which I don't understand:

function rrule(::typeof(+), arrs::AbstractArray...)
y = +(arrs...)
arr_axs = map(axes, arrs)
function add_pullback(dy_raw)
dy = unthunk(dy_raw) # reshape will otherwise unthunk N times
return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...)
end
return y, add_pullback
end

should those be added here as well?

@oxinabox
Copy link
Member

oxinabox commented Feb 9, 2024

yes, in particular it should be done with ProjectTo I think (c.f. #782)

Basically ProjectTo sorts out various things that can go wrong where you escape the tangent space of this "manifold" (for our purposes manifold basically means: type + axes + + structural zeros).
Which can happen either due to the math, or do to minor errors in the types of some other operation (like adding or dropping a singlton array dimension. Or (I thinl) using a Tuple rather than a Tangent{Tuple}).

I believe a pertianant example of the kinds that fall out from the math here would be:

_, pb = rrule(-, [2.0, 1.0],  [3.0 + 2.0im, 4.0+5im)
pb([1.0 + 1.0im, 2.0+ 2.0im])  # (a valid tangent since the output is a complex vector)

Which (untested) with code as stands I believe would give a complex vector tangent to the first input.
Which is (in our opinion) nonsense and escapes the manifold that it is naturally constrained to when you add such a tangent to the primal.

@nmheim
Copy link
Contributor Author

nmheim commented Feb 9, 2024

yes, in particular it should be done with ProjectTo I think (c.f. #782)

Ahh, that makes sense! thank you!

@oxinabox
Copy link
Member

oxinabox commented Feb 9, 2024

Shall we add a test that exercises the ProjectTo path like
test_rrule(-, [2.0, 1.0], [3.0 + 2.0im, 4.0+5im)
?

@oxinabox oxinabox merged commit 971069c into JuliaDiff:main Feb 12, 2024
5 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants