-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make ProjectTo
convert Tangent
back to Diagonal
, etc, when safe
#446
Conversation
Codecov Report
@@ Coverage Diff @@
## main #446 +/- ##
==========================================
- Coverage 93.03% 90.16% -2.88%
==========================================
Files 15 15
Lines 862 925 +63
==========================================
+ Hits 802 834 +32
- Misses 60 91 +31
Continue to review full report at Codecov.
|
452a1b5
to
36e561c
Compare
ProjectTo
convert Tangent
back to Diagonal
, etc, when safe
36e561c
to
a318d7d
Compare
[Edited!] The simplest effect is like so, although right now this works only if it accepts julia> Zygote.gradient(x -> parent(x)[1], Diagonal([1,2,3]))[1]
3×3 Diagonal{Float64, Vector{Float64}}:
1.0 ⋅ ⋅
⋅ 0.0 ⋅
⋅ ⋅ 0.0 Zygote is applying the projection at the last backward step above. But the point of this is really to apply it at the first step. With that, the gradient can propagate through for instance further broadcasting steps: julia> gradient(x -> sum(sqrt.((cbrt.(x)).diag)), Diagonal([1,2,3]))[1]
3×3 Diagonal{Float64, Vector{Float64}}:
0.166667 ⋅ ⋅
⋅ 0.0935385 ⋅
⋅ ⋅ 0.0667187 Without this PR, Zygote would like to return a NamedTuple. Which cannot be handled by the gradient of broadcasting. In fact the generic julia> pullback(x -> parent(x)[1], Diagonal([1,2,3]))[2](1.0) # no projection
((diag = [1.0, 0.0, 0.0],),)
julia> gradient(x -> parent(x)[1], Diagonal([1,2,3]))[1] # with projection
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}})(::ChainRulesCore.Tangent{Any, NamedTuple{(:diag,), Tuple{Vector{Float64}}}}) That error is a problem with or without this PR. |
src/projection.jl
Outdated
|
||
# Diagonal | ||
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) | ||
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) | ||
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) | ||
function (project::ProjectTo{Diagonal})(dx::Tangent) # structural => natural | ||
return dx.diag isa Tangent ? dx.diag : Diagonal(project.diag(dx.diag)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain to me why the right thing to do is to return the diag
field of dx
if it's a Tangent
? I would have thought the correct thing to do would be to throw an error or something, since the conversion can't be performed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that if you manage to produce something like this:
dx = Tangent{Diagonal}(; diag=Tanget{SVector}(; data = ??))
then we can't wrap dx.diag
. I'm not really sure that can happen, though.
Oh I see. In that case I meant to return dx
untouched. But I have a typo.
I meant also to think about dx.diag isa Thunk
, can that happen?
dx.diag isa AbstractZero
is handled by the Diagonal constructor now, this PR. Although perhaps a cleaner pattern would be to handle that here. What's nice about that is that project.diag(dx.diag)
could probably produce a Zero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fixed. I reversed the comparison to
return dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx
I am going to leave this to @willtebbutt to review. |
Spotted today, this is an example in the wild of what this PR wants to do: That's for Tridiagonal. Notice that it always has to make two zero Vectors on each call, which seems a bit unfortunate. And is probably an argument for allowing these to make Diagonal, Bidiagonal, subspaces. |
for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg | ||
for UL in (:UpperTriangular, :LowerTriangular, :UpperHessenberg) | ||
@eval begin | ||
ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x))) | ||
(project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx)) | ||
function (project::ProjectTo{$UL})(dx::Diagonal) | ||
sub = project.parent | ||
sub_one = ProjectTo{project_type(sub)}(; | ||
element=sub.element, axes=(sub.axes[1],) | ||
) | ||
return Diagonal(sub_one(dx.diag)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To explain what's going on here:
- First, this used to include
UnitUpperTriangular
for which it was wrong. The gradient of that has to be zero on the diagonal, not one. So that has moved to its own case, where it doesUnitUpperTriangular(dx) .- I
instead, which makes in fact anUpperTriangular
, not a subtype. - Second, the handling of
(project::ProjectTo{$UL})(dx::Diagonal)
is much simplified. Instead of inventing the projector needed and handling the diagonal by hand, it exploits the fact thatmap(ProjectTo{Float32}, ::Diagonal)
already knows what to do.
The second idea greatly simplifies many more exotic examples below, such as (project::ProjectTo{Tridiagonal})(dx::Bidiagonal) = project.full(dx)
.
src/tangent_types/abstract_zero.jl
Outdated
LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractVector, du::AbstractZero) = Bidiagonal(d, dl, :L) | ||
function LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractZero, du::AbstractVector) | ||
d = fill!(similar(dl, length(dl) + 1), 0) | ||
Tridiagonal(convert(typeof(d), dl), d, convert(typeof(d), du)) | ||
end | ||
# two Zeros: | ||
LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractZero) = Diagonal(d) | ||
LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractZero, du::AbstractVector) = Bidiagonal(d, du, :U) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, accepting Diagonal and Bidiagonal as the gradients of Tridiagonal avoids many cases of making an empty array to pad out the type. These types only accept same-type vectors, so you cannot pad it with a Fill
.
OK I think this is done. Plays well with Zygote 6.30: julia> gradient(x -> sqrt(sum((x .^ 2).ev)), Bidiagonal([1,2,3], [4,5], :U))[1]
3×3 Bidiagonal{Float64, Vector{Float64}}:
0.0 0.624695 ⋅
⋅ 0.0 0.780869
⋅ ⋅ 0.0 This was wrong before: julia> gradient(x -> sum(abs, x), UnitUpperTriangular(rand(3,3)))[1]
3×3 UpperTriangular{Float64, Matrix{Float64}}:
0.0 1.0 1.0
⋅ 0.0 1.0
⋅ ⋅ 0.0 Two bugs I give up on for now: julia> UpperTriangular(Fill(3,3,3)) - I # not my fault. Fixed in Julia 1.8
ERROR: ArgumentError: Cannot setindex! to 2 for an AbstractFill with value 3.
julia> pullback(x -> sqrt(sum((x .^ 2).dv)), Bidiagonal([1,2,3], [4,5], :U))[2](1)[1] # My code makes a Diagonal, I have no idea where the Tri comes from
3×3 Tridiagonal{Float64, Vector{Float64}}:
0.267261 0.0 ⋅
0.0 0.534522 0.0
⋅ 0.0 0.801784 |
c97aafd
to
24c05c3
Compare
24c05c3
to
508cb75
Compare
1f61ed2
to
154f387
Compare
Sorry this is taking so long to review, It is just slightly too big (conceptually) for me to review during the time i normally have aside to review things, and so I keep starting to review it and not completing it. |
Bump? This is one of the last pieces of the ProjectTo story worked out last summer, but somehow hasn't made it. We'd have heard if its absence was holding anyone back, so it can't be that important, but it does clarify what the design is, a bit. |
sorry this has been on my to do for a long time. |
The example from #441 (comment) is
x -> sqrt(Diagonal(x))
, whose implementation issqrt.(x.diag)
. At present, Zygote returns a "structural" tangent for this, i.e. a NamedTuple. When the.diag
is the very first operation being performed, this is returned, but if it occurs after other operations, then their gradients will tend not to understand this.This PR proposes that there should be a method
ProjectTo{Diagonal}(::Tangent)
which converts this back to the "natural" form, i.e. to another Diagonal. To try it out:There
arewere two immediate hurdles here.One is that, to work with Base'sDone in FluxML/Zygote.jl#1104sqrt(::Diagonal)
method, you would have to insert a projection step into Zygote'slitereal_getproperty
adjoint definition. It's not immediately obvious to me how to do that.The second is that Zygote makes aSee below, I guess this wants FluxML/Zygote.jl#1057Tangent{Any}
here. I thought thatdx::Tangent{T}
was supposed to always havetypeof(x) == T
exactly. If that's not true, then must we worry about getting aTangent
which doesn't come from a Diagonal at all?I added a similar line for UpperTriangular. At present
x::UpperTriangular
acceptsdx::Diagonal
as a "natural" gradient. If it must acceptdx::Tangent{Diagonal}
too, well we could write a method for that. But how generally that can work I don't know. I'm not sure it's worth trying.Beyond the immediate, this is still an easy case of the problems discussed in #441 (not 411, sorry!), in that the Tangent contains a Vector and can trivially be re-wrapped to form a Diagonal. (Before finding this
Any
, I was trying to make dispatch restrict it to this case.) It doesn't address what should happen if the content of theTangent{Diagonal}
is some other weird structural thing, which cannot itself be ProjectTo'd to an AbstractVector -- that's what we don't have concrete examples of yet.Edit -- this particular example is now handled explicitly by JuliaDiff/ChainRules.jl#509.
Edit' -- the
Tangent{Any}
isn't an inference failure, it's explicitly constructed that way because the input NamedTuple doesn't have the type, here:https://github.com/FluxML/Zygote.jl/blob/05d0c2ae04f334a2ec61e42decfe1172d0f2e6e8/src/compiler/chainrules.jl#L126-L129