Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic ChainRules compatibility #579

Closed
gdalle opened this issue Mar 6, 2022 · 20 comments
Closed

Automatic ChainRules compatibility #579

gdalle opened this issue Mar 6, 2022 · 20 comments

Comments

@gdalle
Copy link
Member

gdalle commented Mar 6, 2022

In a recent Slack discussion, @mohamed82008 posted this useful code snippet that shouldn't go to waste.
With a little bit of work, this could be turned into a macro that automatically translates a ChainRulesCore.frule into its ForwardDiff.Dual-compatible counterpart.
Since ChainRulesCore is a very light dependency, would it make sense to include such a thing to ForwardDiff? Judging by the reactions on the Slack #autodiff channel, lots of people would find it useful.

using ChainRulesCore, ForwardDiff

macro ForwardDiff_frule(f)
	quote
		function $(esc(f))(x::Vector{<:ForwardDiff.Dual{T}}) where {T}
		  xv, Δx = ForwardDiff.value.(x), reduce(vcat, transpose.(ForwardDiff.partials.(x)))
		  out, Δf = ChainRulesCore.frule((NoTangent(), Δx), $(esc(f)), xv)
		  if out isa Real
			return ForwardDiff.Dual{T}(out, ForwardDiff.Partials(Tuple(Δf)))
		  elseif out isa Vector
			return ForwardDiff.Dual{T}.(out, ForwardDiff.Partials.(Tuple.(eachrow(Δf))))
		  else
		  	throw("Unsupported output.")
		  end
		end
	end
end

f1(x) = sum(x)
function ChainRulesCore.frule((_, Δx), ::typeof(f1), x::AbstractVector{<:Number})
  println("frule was used")
  return f1(x), sum(Δx, dims = 1)
end

f2(x) = x
function ChainRulesCore.frule((_, Δx), ::typeof(f2), x::AbstractVector{<:Number})
  println("frule was used")
  return f2(x), Δx
end

@ForwardDiff_frule f1
ForwardDiff.gradient(f1, rand(3))
# frule was used
# 3-element Vector{Float64}:
#  1.0
#  1.0
#  1.0

@ForwardDiff_frule f2
ForwardDiff.jacobian(f2, rand(3))
# frule was used
# 3×3 Matrix{Float64}:
#  1.0  0.0  0.0
#  0.0  1.0  0.0
#  0.0  0.0  1.0
@doddgray
Copy link

doddgray commented Mar 7, 2022

+1
I think this kind of ForwardDiff inter-operability with ChainRulesCore frules and even opt-in access to the many frules defined in ChainRules.jl would be a widely-appreciated update if at all possible.

Is avoiding dependency on ChainRulesCore/ChainRules the issue preventing this?

@gdalle
Copy link
Member Author

gdalle commented Mar 9, 2022

For future reference, @mohamed82008 is on the move again: JuliaNonconvex/NonconvexUtils.jl#6

@mcabbott
Copy link
Member

mcabbott commented Jun 28, 2022

One way to make this fully automatic would be to move the most basic definitions of Dual, Partials etc. from here to ChainRulesCore. Then @scalar_rule f ... could automatically add methods for f(::Dual), in addition to its existing methods rrule(f, x) and frule, etc.

This may not be very difficult to do, although I haven't tried.

I don't know how other people feel about entangling these two packages. This one has slowly acquired quite a few dependencies (including CRC, indirectly); adding CRC directly may in fact make it lighter-weight, as those packages could (and probably already all do) define rules themselves using CRC.

@oxinabox
Copy link
Member

oxinabox commented Aug 9, 2022

I am not keen on that particular solution. In general i want @scalar_rule to be less extra powerful.

@ThummeTo
Copy link

ThummeTo commented Oct 7, 2022

The solution from @mohamed82008 is a working macro now, and it is as easy as typing:

import NonconvexUtils

NonconvexUtils.@ForwardDiff_frule f1(x1::ForwardDiff.Dual, x2::ForwardDiff.Dual)
NonconvexUtils.@ForwardDiff_frule f1(x1::AbstractVector{<:ForwardDiff.Dual}, x2::AbstractVector{<:ForwardDiff.Dual})
NonconvexUtils.@ForwardDiff_frule f1(x1::AbstractMatrix{<:ForwardDiff.Dual}, x2::AbstractMatrix{<:ForwardDiff.Dual})

So you have ForwardDiff-dispatches for scalars, vectors and matrices for your function f1, based on an existing frule, see source code.

In my opinion this should definitely be added to ForwardDiff, would be a really nice thing.

@mohamed82008
Copy link
Member

mohamed82008 commented Oct 7, 2022

My implementation even works for structs if you pass in the constructor, see the tests. Struct support needs more infrastructure though compared to the simple vec/reshape needed for vector/matrix support. Even if the macro does not live in ForwardDiff, I would be happy if someone took it and put it in a separate light package and added a section in the ForwardDiff documentation pointing to the new package.

@mohamed82008
Copy link
Member

I mentioned to @gdalle before, every feature of NonconvexUtils should probably be its own package :)

@ThummeTo
Copy link

I can seperate the function and put it in a dedicated package if you wish.

Are there any suggestions for names?
It's basically something like the interface between ForwardDiff and the ChainRulesCore. That is what I was looking for on Google: How can I make ForwardDiff using frules. Or simply ForwardDiffFRule.jl ?

@mohamed82008
Copy link
Member

ForwardDiffChainRules?

@ThummeTo
Copy link

Ok, mini-package is coming. I will post it here!

@ThummeTo
Copy link

Find the repo here: ForwardDiffChainRules.jl

I tried to keep it light-weight, e.g. the SparseArrays- and JuMP-dispatches are only added if the corresponding libraries are (with Requires.jl). The rest is basically copy-paste (with author attribution). CI-Test currently fails (I will check this soon or maybe @mohamed82008 has a clue). As soon this is fixed and @mohamed82008 gives his OK, I can register the version.

Regards!

@mohamed82008
Copy link
Member

I can take a look in a few days. Sorry, a bit busy the coming couple of days.

@mohamed82008
Copy link
Member

@ThummeTo can I get an invite to the repo? I fixed it locally and used https://github.com/JuliaNonconvex/DifferentiableFlatten.jl. I can open a PR or push directly to master.

@ThummeTo
Copy link

Of course, you should have an invitation now. Much appreciated!

@ThummeTo
Copy link

ThummeTo commented Oct 18, 2022

So it's ready for a first release I guess @mohamed82008 ?

@mohamed82008
Copy link
Member

Yes

@mohamed82008
Copy link
Member

Since we have a sufficiently lightweight package that implements this feature now, should we close this issue @gdalle?

@oxinabox
Copy link
Member

Can we document this in the docs here (and also ChainRulesCore.jl)

@gdalle
Copy link
Member Author

gdalle commented Oct 18, 2022

Since we have a sufficiently lightweight package that implements this feature now, should we close this issue @gdalle?

Let's close it once this is referenced in the FD and CRC docs?

@gdalle
Copy link
Member Author

gdalle commented Nov 17, 2022

Both links were added, closing this

@gdalle gdalle closed this as completed Nov 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants