Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a functor for projection #385

Merged
merged 44 commits into from
Jul 6, 2021
Merged

use a functor for projection #385

merged 44 commits into from
Jul 6, 2021

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Jun 29, 2021

Alternative to #380, #382 and #383. See usage in JuliaDiff/ChainRules.jl#459

use a ProjectTo functor, instead of passing to a project function a target type + info keyword arg.
In the place of #383 it uses the constructor for the ProjectTo rather than a preproject function returning an info namedtuple.

It's kinda a curried functor. One would do ProjectTo(2.5)(1.0 + 2.0im), or more realistically project = ProjectTo(2.5); ...; project(1.0 + 2.0im).

An additional change over #383 is it only allows you to pass the type-alone into constructing ProjectTo for subtypes of Number.
Since other things might have structure that we need to be capturing; so recursion breaks.
(See e.g. additional tests over #383 for Arrays of Arrays, and Arrays of Any)
One reason not to allow this at all is that it may be easy to misread: ProjectTo(Float64)(1.0 + 2.0im), is not ProjectTo{Float64}(1.0 + 2.0im)
but it is probably worth keeping for arrays of numbers, since it makes it easy to write the optimized case for that which doesn't store 1 ProjectTo per element.

Another minor additional change is it makes use of some of the tooling for working with structs like construct and backing which should be faster, those were optimized fairly carefully.

Some observations (comments welcome):

Something I am not certain on is if out type arg is the type of the primal or not.
I think it is not, right?
It the type of the destination differential?
Which for most types we want this for will be a natural tangent type for the primal, and probably even equal to the primal.
We might want to document this somewhere there the destination type-arg must be for a valid tangent type.
(And so a NamedTuple is not valid)

I am not 100% sold on the name ProjectTo

IMO this is the best way forward of the four PRs. Thoughts?

I am putting this up here, but probably @mzgubic, will take it back over tomorrow.
It is is very heavily based on #383, and all the learnings that went into that.

src/projection.jl Outdated Show resolved Hide resolved
Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks, that's such an elegant solution, will take it over from here if that's ok?

src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member Author

Please do

@mzgubic
Copy link
Member

mzgubic commented Jun 30, 2021

Something I am not certain on is if out type arg is the type of the primal or not.
I think it is not, right?
It the type of the destination differential?
Which for most types we want this for will be a natural tangent type for the primal, and probably even equal to the primal.
We might want to document this somewhere there the destination type-arg must be for a valid tangent type.
(And so a NamedTuple is not valid)

I am not sure I understand what you are saying here. Is it that in

struct ProjectTo{P, D<:NamedTuple}
    info::D
end

P is not necessarily the primal type? The only two cases I can think of are Period and DateTimes, and the projection to a Tangent type. Are there any other cases you can think of, where P is not equal to the primal type?

@oxinabox
Copy link
Member Author

oxinabox commented Jun 30, 2021

It is hard to construct good examples because many of the LinearAlgebra structured array types, only accept vector-space elements.
Something like if the primal type was a Diagonal of Tuples.
then when doing sum(prod, Diagonal([(1.0, 2.0), (3.0, 4.0)]))
you would want to project not onto a Diagonal of Tuples but on to a Diagonal of Tangent{<:Tuple}s.
Except that sum errors in the primal as calling sum on a sparse array that has elements that don't define zero is not allowed.

I think this is very rare, and I think the thing we need to do is be clear in the docs.
We do not promise to project onto arbitrary types, only on to valid tangent types.
And the cases where the is most useful is where the primal is also a valid natural tangent type for itself.

test/projection.jl Outdated Show resolved Hide resolved
Miha Zgubic and others added 2 commits July 2, 2021 19:49
src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
Copy link
Member Author

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some last comments,
other than that LGTM.
I do approve this, I can't actually click approve as I am the original author.
But you can approve yourself and merge.
We have both looked closely enough at this

docs/src/writing_good_rules.md Outdated Show resolved Hide resolved
docs/src/writing_good_rules.md Outdated Show resolved Hide resolved
docs/src/writing_good_rules.md Show resolved Hide resolved
docs/src/writing_good_rules.md Outdated Show resolved Hide resolved
docs/src/writing_good_rules.md Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
docs/src/writing_good_rules.md Show resolved Hide resolved
docs/Manifest.toml Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member

mcabbott commented Jul 6, 2021

Something I am not certain on is if out type arg is the type of the primal or not.
I think it is not, right?
It the type of the destination differential?

Is this clarified and written up somewhere? There has been a blizzard of implementation but I am not sure I've seen clarity on the basic goal here, what types, how they are chosen, and where this gets applied. (FluxML/Zygote.jl#965 was one take on these questions, but I'm not sure anyone read it.)

Comment on lines +98 to +99
function (project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix)
return $SymHerm(project.parent(dx), project.uplo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a lot like it just applies a Symmetric wrapper, rather than projecting onto the space of symmetric matrices. I think that's wrong, and raised this point on one of the other implementations. Maybe you disagree and can point me to where this was discussed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier comment was #382 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, I meant to comment on this but forgot in the middle of the many small things that came up in the PR. I don't think we are projecting onto the space of symmetric matrices, but rather on the space of Symmetric matrices. Which can indeed hold an asymmetric data field.

Is there an example that gives an unexpected result? I stared at the finite differencing a little bit and that does seem odd, but not long enough to figure out whether it is a project or finite differencing issue

Copy link
Member

@mcabbott mcabbott Jul 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my Zygote PR has examples. The finite-differencing code was giving bizarre answers and should for now be ignored. The mathematical question seems pretty clear.

"Many small things" seems accurate, I worry that in rushing to sort them all out, we've lost sight of big-picture questions.

@mzgubic
Copy link
Member

mzgubic commented Jul 6, 2021

Something I am not certain on is if out type arg is the type of the primal or not.
I think it is not, right?
It the type of the destination differential?

Is this clarified and written up somewhere? There has been a blizzard of implementation but I am not sure I've seen clarity on the basic goal here, what types, how they are chosen, and where this gets applied. (FluxML/Zygote.jl#965 was one take on these questions, but I'm not sure anyone read it.)

My understanding is that we have decided to only project onto valid tangent spaces (and not on arbitrary primals, so we don't accept Tuple's or NamedTuples). @oxinabox will expand the write up in the docs on this. See JuliaDiff/ChainRules.jl#467 (I will add your other concerns there as well)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants