Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #62 #70

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
35 changes: 23 additions & 12 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,20 @@ end
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
Base.length(re::Restructure) = re.length

struct Offset
i::Int
end

# This flattens a model, and returns a web of offsets for later use:
function _flatten(x)
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
isnumeric(x) && return vcat(_vec(x)), Offset(0), length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
push!(arrays, _vec(y))
o = len[]
len[] = o + length(y)
o
Offset(o)
end
reduce(vcat, arrays), off, len[]
end
Expand All @@ -85,13 +89,13 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trai
end
end

_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
_getat(y::AbstractArray, o::Int, flat::AbstractVector) =
ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
_getat(y::Number, off::Offset, flat::AbstractVector) = ProjectTo(y)(flat[off.i + 1])
_getat(y::AbstractArray, off::Offset, flat::AbstractVector) =
ProjectTo(y)(reshape(flat[off.i .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes

function _trainable_biwalk(f, x, aux)
ch, re = functor(typeof(x), x)
au, _ = functor(typeof(x), aux)
au, _ = functor(aux)
_trainmap(f, ch, _trainable(x), au) |> re
end

Expand All @@ -103,13 +107,20 @@ end

function _Tangent_biwalk(f, x, aux) # use with prune = NoT
ch, re = functor(typeof(x), x)
au, _ = functor(typeof(x), aux)
y = _trainmap(f, ch, _trainable(x), au)
au, _ = functor(aux)
y = map(ch, _trainable(x), au) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
isnothing(t) ? NoT : f(t, a)
end
y isa Tuple{} && return NoT
p = ProjectTo(x)
if p isa ProjectTo # e.g. Array, NamedTuple
p(y)
else # p === identity for unknown structs
y = map(backing(x), backing(re(y))) do c, t
# backing(re(y)) extracts NamedTuple backing from re(y); required if x has children which aren't its own fields
# however, re(y) will repopulate primal field values from x which weren't functor-ed; these gradients should be NoT
c === t ? NoT : t
end
Tangent{typeof(x), typeof(y)}(y)
end
end
Expand All @@ -126,23 +137,23 @@ ChainRulesCore.@non_differentiable _zero(x)
function _grad!(x, dx, off, flat::AbstractVector)
x′, _ = functor(typeof(x), x)
dx′, _ = functor(typeof(x), base(dx))
off′, _ = functor(typeof(x), off)
off′, _ = functor(off)
for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′)
flat = _grad!(xᵢ, dxᵢ, oᵢ, flat)
end
flat
end
function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T
function _grad!(x, dx, off::Offset, flat::AbstractVector{T}) where T
dx_un = unthunk(dx)
T2 = promote_type(T, eltype(dx_un))
if T != T2 # then we must widen the type
flat = copyto!(similar(flat, T2), flat)
end
@views flat[off .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes
@views flat[off.i .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes
flat
end
_grad!(x, dx::Zero, off, flat::AbstractVector) = flat
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity
_grad!(x, dx::Zero, off::Offset, flat::AbstractVector) = flat # ambiguity

# These are only needed for 2nd derivatives:
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
Expand Down
34 changes: 34 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,40 @@ end
end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],)
end

@testset "issue 62" begin
# Flux.Chain used to have children which aren't its own fields, which Skip immitates.

sk = Skip([1.0, 2.0], (x=3, y=[4.0, 5.0]))
@test fmap(identity, sk) == sk

gk = gradient(x -> sum(x[2].y), sk)[1]
@test fmap(Zygote.accum, sk, gk) isa Skip # this relies on functor(typeof(x), dx)

st = fmapstructure(identity, sk)
@test st isa Tuple{Vector, NamedTuple}
@test_throws Exception fmap(+, sk, st) # this fails because of functor(typeof(x), dx)

v, re = destructure(sk)
@test v == [1,2,4,5]
@test re(10v) isa Skip
@test re(10v)[1] == [10, 20]

@test gradient(zero(v)) do w
re(w)[2].y[1]
end == ([0,0,1,0],)

@test gradient(sk) do x
w, _ = destructure(x)
w[1]^2 + w[4]^2
end == ((layers = ([2.0, 0.0], (x = nothing, y = [0.0, 10.0])),),)

ac = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0]) # a,c are functor-ed, and only a is trainable
@test gradient(ac) do x
w2, _ = destructure(x)
w2[2]^2
end == ((a = [0.0, 4.0], b = nothing, c = nothing),)
end

@testset "DiffEqFlux issue 699" begin
# The gradient of `re` is a vector into which we accumulate contributions, and the issue
# is that one contribution may have a wider type than `v`, especially for `Dual` numbers.
Expand Down
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ struct TwoThirds a; b; c; end
Functors.@functor TwoThirds (a, c)
Optimisers.trainable(x::TwoThirds) = (a = x.a,)

struct Skip{T} # like Flux 0.12's Chain
layers::T
Skip(ls...) = new{typeof(ls)}(ls)
end
Base.getindex(x::Skip, i::Integer) = x.layers[i]
Functors.functor(::Type{<:Skip}, x) = x.layers, ls -> Skip(ls...)

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin

Expand Down Expand Up @@ -165,6 +172,16 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@test_throws ArgumentError Optimisers.setup(ADAMW(), m2)
end

@testset "issue 62" begin
m62 = (s = Skip([1.0, 2.0], Foo([3.0], false)), t = [4.0, 5.0])
s62 = Optimisers.setup(Descent(), m62)
g62 = gradient(m -> m.s[2].x[1] + 3 * m.t[2], m62)
s, m = Optimisers.update(s62, m62, g62...)
@test m.s isa Skip
@test m.s[2].x ≈ [2.9]
@test m.t ≈ [4, 4.7]
end

end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down