-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: projector
implementation (returning a closure)
#382
Conversation
Codecov Report
@@ Coverage Diff @@
## master #382 +/- ##
==========================================
+ Coverage 89.12% 89.78% +0.65%
==========================================
Files 14 15 +1
Lines 561 607 +46
==========================================
+ Hits 500 545 +45
- Misses 61 62 +1
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems simpler than I thought, maybe something like this is the way to go. I left some comments...
# Tangent | ||
function projector(::Type{<:Tangent}, x::T) where {T} | ||
project(dx) = Tangent{T}(; ((k, getproperty(dx, k)) for k in fieldnames(T))...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still not clear to me what's going to call this. Clearly we will not have x::Tangent
in the forward pass. So this thing is perhaps trying to serve several functions, and perhaps they can be clarified.
project(dx::AbstractZero) = zero(x) | ||
project(dx::AbstractThunk) = project(unthunk(dx)) | ||
return project | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there should be some struct Project
which is returned, in part to avoid writing these out every time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify how this would work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One attempt is here: https://gist.github.com/mcabbott/8a84086cc604d34b5e8dff2eb3839f3a
function projector(::Type{T}, x::T) where {T<:Real} | ||
project(dx::Real) = T(dx) | ||
project(dx::Number) = T(real(dx)) # to avoid InexactError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is too tight, as projector(2)(3.5)
is going to be an InexactError right? As is projector(false)(1.5)
.
And more generally, what if (say) I want to put dual numbers into the pullback? My impression is that that should be allowed. Which is what led me to think that only known problems should be projected out, like dx::Complex
when x::Real
, or anything when x::Bool
. But it would be nice if the door were open for packages to add to the list of "things which get projected like Complex -> Real".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that sounds like a relatively serious downside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to have made it into the tagged version:
julia> ProjectTo(1)(2.5)
ERROR: InexactError: Int64(2.5)
(jl_5kFIPa) pkg> st ChainRulesCore
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_5kFIPa/Project.toml`
[d360d2e6] ChainRulesCore v0.10.11
function projector(::Type{<:Symmetric{<:Any, M}}, x::Symmetric) where {M} | ||
projM = projector(M, parent(x)) | ||
uplo = Symbol(x.uplo) | ||
project(dx::AbstractMatrix) = Symmetric(projM(dx), uplo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is right, you need to symmetrise, not merely to apply the wrapper.
There's a fairly efficient one here:
https://github.com/FluxML/Zygote.jl/pull/965/files#diff-9bc4a61f220da7bc58a4009fe88887b5b584b3d6139c68b0e13cbdbcd21f7289R48
function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} | ||
sizex = size(x) | ||
projT = projector(zero(T)) | ||
project(dx::Array{T, N}) = dx # identity | ||
project(dx::AbstractArray) = project(collect(dx)) # from Diagonal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I also wonder if this is the right behaviour. Maybe the ability to reproduce a similar dense array is desirable sometimes, but making the default projector materialise when it doesn't have to seems odd --- shouldn't we preserve Diagonal or Fill backwards as many steps as possible, by default?
But again maybe this is trying to serve multiple purposes which perhaps can be clarified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe there ought to be abstract types involved, something like:
projector(x::Real) = projector(Real, x)
projector(x::Bool) = projector(Nothing, x)
projector(x::AbstractArray{<:Real}) = projector(AbstractArray{Real}, x)
projector(x::AbstractArray) = projector(AbstractArray, x)
where projector(AbstractArray, x)(dx)
may reshape but won't do more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the method which specifically wants the output to be a dense array, i.e. where x
is a Matrix
in projector(x)
call. When x
is a Diagonal
, a different projector
method would be hit.
I couldn't quite see how to generalise the method for an arbitrary AbstractArray
(see how Diagonal
and Symmetric
) cases are different. My plan was to just add the dispatch for any type that we need to make ChainRules rules work.
closed in favour of #385 |
Alternative to #380.
Some observations (comments welcome):
project
implementation #380. Hopefully it is also easier to read.projector(::Diagonal)
, where we create aprojV
to project the vector representing the diagonal. cc @mcabbottI imagine we also want to add some extra requirements to ChainRulesTestUtils, to make sure that the rules always return a correct differential. Something like
_is_appropriate(primal, tangent)
where theunthunk(tangent)
must either be the same type as the primal, aTangent
, or anAbstractZero
.