Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #62 #70

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open

Fix #62 #70

wants to merge 8 commits into from

Conversation

jondeuce
Copy link
Contributor

@jondeuce jondeuce commented Apr 22, 2022

This adds a couple small changes on top of this draft PR in order to fix #62:

  1. Wrap offset indices in a dummy struct Offset to fix the issue mentioned in Attempt to fix #62 #63 for array of arrays. For example, the offset structure for x = [[1.0, 2.0]] is now something like o = [Offset(4)] which is not leaflike, compared to o = [4] previously. This also opens the door to storing more information in this wrapper struct (original array size? eltype?), but that doesn't seem necessary at this time
  2. y = backing(re(y)) allows for functor(x) to return children which aren't its own fields: y is first restructured to match the structure of x, and then the NamedTuple backing for re(y) is extracted and passed to Tangent. It has the added benefit of adding some symmetry with _trainable_biwalk which naturally restructures the output of _trainmap, whereas _Tangent_biwalk previously did not

Closes #63 (replaces).

src/destructure.jl Outdated Show resolved Hide resolved
y = _trainmap(f, ch, _trainable(x), au)
y isa Tuple{} && return NoT
p = ProjectTo(x)
if p isa ProjectTo # e.g. Array, NamedTuple
p(y)
else # p === identity for unknown structs
y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self, this I need to think about. Some of this complication was working around things that are now fixed in CRC.jl, if I remember right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, admittedly this line took some trial and error and is a little bit above my pay-grade. I managed to convince myself, but perhaps there's something cleaner.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I finally understand what's going on. Sorry it took a while.

re constructs another Skip containing the gradient, and backing turns that into a NamedTuple with the same field names, which is what Tangent wants.

The only way I can see this failing is this: If the primal type's constructor is fussy about what types it can accept, then it may not be happy to accept something which is valid as its gradient. E.g. if there is only Skip(::AbstractLayer), and re tries to make one with a Tangent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries! Yes, I struggled with that edge case too. Unfortunately I think it's quite tricky to work around. For example, suppose you have a user-defined functor(m::MyModel) = (m.w,), w -> .... Then:

  1. In general there's no way to reconstruct MyModel (or even a NamedTuple of fields/values) without re, as you do not know the corresponding field name given only (m.w,), but
  2. As you say, if the primal constructor isn't sufficiently generic then it won't be able to store Tangent/Nothing/etc. values in it's fields and will error before backing can unpack it again

Avoiding re would be ideal, but I think that would require functor to always return NamedTuples on custom structs. I noticed that this is the default in @functor, though, so maybe it's not such a painful requirement? In the mean time I can at least add a branch that would avoid re for structs that are functored to NamedTuples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact there's another problem I didn't spot before, what a mess:

julia> ac = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0]);  # from tests: a,c are functor-ed, and only a is trainable

julia> v2, re2 = destructure(ac)
([1.0, 2.0], Restructure(TwoThirds, ..., 2))

julia> gradient(ac) do x  # with Tangent{typeof(x), typeof(y)}(y)
             w2, _ = destructure(x)
             w2[2]^2
           end
((a = [0.0, 4.0], b = nothing, c = [4.0, 5.0]),) 

# Same, with z = backing(re(y)) :
julia> gradient(ac) do x
             w2, _ = destructure(x)
             w2[2]^2
           end
┌ Info: last case
│   x = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])
│   y = (a = [0.0, 4.0], c = [4.0, 5.0])
└   z = NamedTuple{(:a, :b, :c), Tuple{Any, Any, Any}}(([0.0, 4.0], [3.0], [4.0, 5.0]))
((a = [0.0, 4.0], b = [3.0], c = [4.0, 5.0]),)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yikes. That's a good example, hits all the pain points at once. If I'm understanding correctly, the gradient should be ((a = [0.0, 4.0], b = nothing, c = nothing),), right?

I think the problem is the _trainmap above; it populates the nothing values from _trainable (non-trainable fields) with the primal values, when they should be NoT. That's how the b and/or c values get back in there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think _trainmap needs to do something isnothing(t) ? NoT : f(t, a) here. That's where c = [4.0, 5.0] is coming from.

But b = [3.0] is coming from this PR's trick of calling the reconstructor made by @functor:

julia> ch, re = Functors.functor(ac)
((a = [1.0, 2.0], c = [4.0, 5.0]), var"#1#2"{TwoThirds}(TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])))

julia> re((a = [10, 20], c = nothing))
TwoThirds([10, 20], [3.0], nothing)

Copy link
Contributor Author

@jondeuce jondeuce May 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. So on top of the modified _trainmap to fix c, one would still have to filter backing(re(y)) to replace repopulated primal values which aren't functor-ed with NoT in order to fix b.

EDIT: But, based on the output of Tangent{typeof(x), typeof(y)}(y), maybe the modified _trainmap alone would be enough and backing(re(y)) isn't needed after all, as Tangent will assign NoT to omitted fields in y automatically.

EDIT 2: Never mind, that would still fail for children which aren't fields, like Skip.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright pushed something that works for both Skip and your TwoThirds example (modified _trainmap + filtering backing(re(y))). But since it uses re it would still fail for fussy constructors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

destructure doesn't work correctly with certain functors
2 participants