diff --git a/src/destructure.jl b/src/destructure.jl index d000ff75..49f73ed1 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -53,16 +53,20 @@ end Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") Base.length(re::Restructure) = re.length +struct Offset + i::Int +end + # This flattens a model, and returns a web of offsets for later use: function _flatten(x) - isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case + isnumeric(x) && return vcat(_vec(x)), Offset(0), length(x) # trivial case arrays = AbstractVector[] len = Ref(0) off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y push!(arrays, _vec(y)) o = len[] len[] = o + length(y) - o + Offset(o) end reduce(vcat, arrays), off, len[] end @@ -85,13 +89,13 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trai end end -_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) -_getat(y::AbstractArray, o::Int, flat::AbstractVector) = - ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes +_getat(y::Number, off::Offset, flat::AbstractVector) = ProjectTo(y)(flat[off.i + 1]) +_getat(y::AbstractArray, off::Offset, flat::AbstractVector) = + ProjectTo(y)(reshape(flat[off.i .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes function _trainable_biwalk(f, x, aux) ch, re = functor(typeof(x), x) - au, _ = functor(typeof(x), aux) + au, _ = functor(aux) _trainmap(f, ch, _trainable(x), au) |> re end @@ -103,13 +107,20 @@ end function _Tangent_biwalk(f, x, aux) # use with prune = NoT ch, re = functor(typeof(x), x) - au, _ = functor(typeof(x), aux) - y = _trainmap(f, ch, _trainable(x), au) + au, _ = functor(aux) + y = map(ch, _trainable(x), au) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c) + isnothing(t) ? NoT : f(t, a) + end y isa Tuple{} && return NoT p = ProjectTo(x) if p isa ProjectTo # e.g. Array, NamedTuple p(y) else # p === identity for unknown structs + y = map(backing(x), backing(re(y))) do c, t + # backing(re(y)) extracts NamedTuple backing from re(y); required if x has children which aren't its own fields + # however, re(y) will repopulate primal field values from x which weren't functor-ed; these gradients should be NoT + c === t ? NoT : t + end Tangent{typeof(x), typeof(y)}(y) end end @@ -126,23 +137,23 @@ ChainRulesCore.@non_differentiable _zero(x) function _grad!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) dx′, _ = functor(typeof(x), base(dx)) - off′, _ = functor(typeof(x), off) + off′, _ = functor(off) for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′) flat = _grad!(xᵢ, dxᵢ, oᵢ, flat) end flat end -function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T +function _grad!(x, dx, off::Offset, flat::AbstractVector{T}) where T dx_un = unthunk(dx) T2 = promote_type(T, eltype(dx_un)) if T != T2 # then we must widen the type flat = copyto!(similar(flat, T2), flat) end - @views flat[off .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes + @views flat[off.i .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes flat end _grad!(x, dx::Zero, off, flat::AbstractVector) = flat -_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity +_grad!(x, dx::Zero, off::Offset, flat::AbstractVector) = flat # ambiguity # These are only needed for 2nd derivatives: function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) diff --git a/test/destructure.jl b/test/destructure.jl index d20f4f30..fce4ceeb 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -181,6 +181,40 @@ end end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],) end +@testset "issue 62" begin + # Flux.Chain used to have children which aren't its own fields, which Skip immitates. + + sk = Skip([1.0, 2.0], (x=3, y=[4.0, 5.0])) + @test fmap(identity, sk) == sk + + gk = gradient(x -> sum(x[2].y), sk)[1] + @test fmap(Zygote.accum, sk, gk) isa Skip # this relies on functor(typeof(x), dx) + + st = fmapstructure(identity, sk) + @test st isa Tuple{Vector, NamedTuple} + @test_throws Exception fmap(+, sk, st) # this fails because of functor(typeof(x), dx) + + v, re = destructure(sk) + @test v == [1,2,4,5] + @test re(10v) isa Skip + @test re(10v)[1] == [10, 20] + + @test gradient(zero(v)) do w + re(w)[2].y[1] + end == ([0,0,1,0],) + + @test gradient(sk) do x + w, _ = destructure(x) + w[1]^2 + w[4]^2 + end == ((layers = ([2.0, 0.0], (x = nothing, y = [0.0, 10.0])),),) + + ac = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0]) # a,c are functor-ed, and only a is trainable + @test gradient(ac) do x + w2, _ = destructure(x) + w2[2]^2 + end == ((a = [0.0, 4.0], b = nothing, c = nothing),) +end + @testset "DiffEqFlux issue 699" begin # The gradient of `re` is a vector into which we accumulate contributions, and the issue # is that one contribution may have a wider type than `v`, especially for `Dual` numbers. diff --git a/test/runtests.jl b/test/runtests.jl index d47bce08..1a54c5e4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,13 @@ struct TwoThirds a; b; c; end Functors.@functor TwoThirds (a, c) Optimisers.trainable(x::TwoThirds) = (a = x.a,) +struct Skip{T} # like Flux 0.12's Chain + layers::T + Skip(ls...) = new{typeof(ls)}(ls) +end +Base.getindex(x::Skip, i::Integer) = x.layers[i] +Functors.functor(::Type{<:Skip}, x) = x.layers, ls -> Skip(ls...) + @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin @@ -165,6 +172,16 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) @test_throws ArgumentError Optimisers.setup(ADAMW(), m2) end + @testset "issue 62" begin + m62 = (s = Skip([1.0, 2.0], Foo([3.0], false)), t = [4.0, 5.0]) + s62 = Optimisers.setup(Descent(), m62) + g62 = gradient(m -> m.s[2].x[1] + 3 * m.t[2], m62) + s, m = Optimisers.update(s62, m62, g62...) + @test m.s isa Skip + @test m.s[2].x ≈ [2.9] + @test m.t ≈ [4, 4.7] + end + end @testset verbose=true "Destructure" begin include("destructure.jl")