-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use a functor for projection #385
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks, that's such an elegant solution, will take it over from here if that's ok?
Please do |
I am not sure I understand what you are saying here. Is it that in struct ProjectTo{P, D<:NamedTuple}
info::D
end
|
It is hard to construct good examples because many of the LinearAlgebra structured array types, only accept vector-space elements. I think this is very rare, and I think the thing we need to do is be clear in the docs. |
Co-authored-by: Lyndon White <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some last comments,
other than that LGTM.
I do approve this, I can't actually click approve as I am the original author.
But you can approve yourself and merge.
We have both looked closely enough at this
Co-authored-by: Lyndon White <[email protected]>
Is this clarified and written up somewhere? There has been a blizzard of implementation but I am not sure I've seen clarity on the basic goal here, what types, how they are chosen, and where this gets applied. (FluxML/Zygote.jl#965 was one take on these questions, but I'm not sure anyone read it.) |
function (project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix) | ||
return $SymHerm(project.parent(dx), project.uplo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks a lot like it just applies a Symmetric wrapper, rather than projecting onto the space of symmetric matrices. I think that's wrong, and raised this point on one of the other implementations. Maybe you disagree and can point me to where this was discussed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Earlier comment was #382 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry, I meant to comment on this but forgot in the middle of the many small things that came up in the PR. I don't think we are projecting onto the space of symmetric matrices, but rather on the space of Symmetric
matrices. Which can indeed hold an asymmetric data
field.
Is there an example that gives an unexpected result? I stared at the finite differencing a little bit and that does seem odd, but not long enough to figure out whether it is a project or finite differencing issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, my Zygote PR has examples. The finite-differencing code was giving bizarre answers and should for now be ignored. The mathematical question seems pretty clear.
"Many small things" seems accurate, I worry that in rushing to sort them all out, we've lost sight of big-picture questions.
My understanding is that we have decided to only project onto valid tangent spaces (and not on arbitrary primals, so we don't accept Tuple's or NamedTuples). @oxinabox will expand the write up in the docs on this. See JuliaDiff/ChainRules.jl#467 (I will add your other concerns there as well) |
Alternative to #380, #382 and #383. See usage in JuliaDiff/ChainRules.jl#459
use a
ProjectTo
functor, instead of passing to aproject
function a target type + info keyword arg.In the place of #383 it uses the constructor for the
ProjectTo
rather than apreproject
function returning aninfo
namedtuple.It's kinda a curried functor. One would do
ProjectTo(2.5)(1.0 + 2.0im)
, or more realisticallyproject = ProjectTo(2.5); ...; project(1.0 + 2.0im)
.An additional change over #383 is it only allows you to pass the type-alone into constructing
ProjectTo
for subtypes ofNumber
.Since other things might have structure that we need to be capturing; so recursion breaks.
(See e.g. additional tests over #383 for Arrays of Arrays, and Arrays of Any)
One reason not to allow this at all is that it may be easy to misread:
ProjectTo(Float64)(1.0 + 2.0im)
, is notProjectTo{Float64}(1.0 + 2.0im)
but it is probably worth keeping for arrays of numbers, since it makes it easy to write the optimized case for that which doesn't store 1 ProjectTo per element.
Another minor additional change is it makes use of some of the tooling for working with structs like
construct
andbacking
which should be faster, those were optimized fairly carefully.Some observations (comments welcome):
projector
implementation (returning a closure) #382, and WIP:preproject
andproject
implementation #383 it allows us to not close over primal value (x
)preproject
andproject
implementation #383) unlike WIP:projector
implementation (returning a closure) #382, where the returned closure can not be extended, theproject
function can be extended by someone developing aQuaternions
package. They would simply define(::ProjectTo{Real})(dx::Quaternion) = ...
project
implementation #380 where there are 3-arg functions, and unlike WIP:preproject
andproject
implementation #383 callingproject
doesn't needs aninfo
kwarg, as everything needed in stored in the functor (information about the primalx
, such as size or element type.)ProjectTo
contains a NamedTuple, one element of which can beProjectTo
for an inner projectionpreproject
implementation ChainRules.jl#457 that closing over the primal type values doesn't infer, because it puts the primal type into the type-parameters ofProjectTo
(I think this is a big one)preproject
andproject
implementation #383 it doesn't have unadorned NamedTuples floating around, so debugging can be easier.Something I am not certain on is if out type arg is the type of the primal or not.
I think it is not, right?
It the type of the destination differential?
Which for most types we want this for will be a natural tangent type for the primal, and probably even equal to the primal.
We might want to document this somewhere there the destination type-arg must be for a valid tangent type.
(And so a
NamedTuple
is not valid)I am not 100% sold on the name
ProjectTo
IMO this is the best way forward of the four PRs. Thoughts?
I am putting this up here, but probably @mzgubic, will take it back over tomorrow.
It is is very heavily based on #383, and all the learnings that went into that.