forked from FluxML/Optimisers.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
destructure.jl
166 lines (141 loc) · 5.5 KB
/
destructure.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
const NoT = NoTangent()
"""
destructure(model) -> vector, reconstructor
Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
to a vector, and returns also a function which reverses this transformation.
Differentiable.
# Example
```jldoctest
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))
julia> re([3, 5-im, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))
```
"""
function destructure(x)
flat, off, len = _flatten(x)
flat, Restructure(x, off, len)
end
"""
Restructure(Model, ..., length)
This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with
new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)`.
# Example
```julia
julia> using Flux, Optimisers
julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid))
([1, 3, 2, 4, 0, 0], Restructure(Dense, ..., 6))
julia> m = re(-4:1)
Dense(2, 2, σ) # 6 parameters
julia> m([0.2, 0.3]) ≈ re([0.2, 0.3], -4:1)
true
```
"""
struct Restructure{T,S}
model::T
offsets::S
length::Int
end
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length)
(re::Restructure)(x, flat::AbstractVector) = re(flat)(x)
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
Base.length(re::Restructure) = re.length
struct Offset
i::Int
end
# This flattens a model, and returns a web of offsets for later use:
function _flatten(x)
isnumeric(x) && return vcat(_vec(x)), Offset(0), length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
push!(arrays, _vec(y))
o = len[]
len[] = o + length(y)
Offset(o)
end
reduce(vcat, arrays), off, len[]
end
_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)
function ChainRulesCore.rrule(::typeof(_flatten), x)
flat, off, len = _flatten(x)
_maybewarn()
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT))
(flat, off, len), _flatten_back
end
# This reconstructs either a model like x, or a gradient for it:
function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trainable_biwalk, kw...)
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
_getat(y, o, flat)
end
end
_getat(y::Number, off::Offset, flat::AbstractVector) = ProjectTo(y)(flat[off.i + 1])
_getat(y::AbstractArray, off::Offset, flat::AbstractVector) =
ProjectTo(y)(reshape(flat[off.i .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
function _trainable_biwalk(f, x, aux)
ch, re = functor(typeof(x), x)
au = _aux_children(aux)
_trainmap(f, ch, _trainable(x), au) |> re
end
_aux_children(off) = functor(off)[1]
function _trainmap(f, ch, tr, aux)
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
isnothing(t) ? c : f(t, a)
end
end
function _Tangent_biwalk(f, x, aux) # use with prune = NoT
ch, re = functor(typeof(x), x)
au = _aux_children(aux)
y = _trainmap(f, ch, _trainable(x), au)
y isa Tuple{} && return NoT
p = ProjectTo(x)
if p isa ProjectTo # e.g. Array, NamedTuple
p(y)
else # p === identity for unknown structs
y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields
Tangent{typeof(x), typeof(y)}(y)
end
end
function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT)
_rebuild(x, off, flat, len; kw...), _rebuild_back
end
_zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad!
ChainRulesCore.@non_differentiable _zero(x)
# This is the gradient of model reconstruction, accumulating duplicates:
function _grad!(x, dx, off, flat::AbstractVector)
x′, _ = functor(typeof(x), x)
dx′, _ = functor(typeof(x), base(dx))
off′ = _aux_children(off)
for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′)
flat = _grad!(xᵢ, dxᵢ, oᵢ, flat)
end
flat
end
function _grad!(x, dx, off::Offset, flat::AbstractVector{T}) where T
dx_un = unthunk(dx)
T2 = promote_type(T, eltype(dx_un))
if T != T2 # then we must widen the type
flat = copyto!(similar(flat, T2), flat)
end
@views flat[off.i .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes
flat
end
_grad!(x, dx::Zero, off, flat::AbstractVector) = flat
_grad!(x, dx::Zero, off::Offset, flat::AbstractVector) = flat # ambiguity
# These are only needed for 2nd derivatives:
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
@warn "second derivatives of Restructure may not work yet, sorry!" maxlog=3
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
_grad!(x, dx, off, flat), _grad_back
end
base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure))
base(dx::Tangent{Any, <:NamedTuple{(:backing,)}}) = base(backing(dx).backing) # Zygote version
_maybewarn() = nothing
function ChainRulesCore.rrule(::typeof(_maybewarn))
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
nothing, _ -> (NoT,)
end