Skip to content

Commit

Permalink
cl/zero
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 6, 2024
1 parent e5d187c commit bf480b3
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
44 changes: 43 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,50 @@ julia> trainables(model)
Float32[-0.8764882 0.40812716 0.1919528; -0.9123545 -0.4462516 0.6751252]
Float32[0.0, 0.0]

julia> l2reg(model) = sum([sum(abs2,p) for p in trainables(model)]);
julia> l2reg(model) = sum([sum(abs2, p) for p in trainables(model)]);

julia> g = gradient(l2reg, model)[1];
```
Notice that the `BatchNorm` layer has two trainable parameters, `γ` and `β`, which are included in the list, while the `μ ` and `σ²` buffers are not.

Sometimes one wants to iterate over all trainable parameters in a model and the corresponding parameters of a matched structure such a gradient or the moving average of the model.
This can be done using `trainables(model, path=true)`. For instance, here is how to update the parameters
of a moving average model with the parameters of the model:

```julia
for (kp, p_avg) in trainables(model_avg, path=true)
p = getkeypath(model, kp)
p_avg .= 0.99 .* p_avg .+ 0.01 .* pnew
end
```

## Incomplete or nothing gradients

If the gradient is not available for some parameters, or branches of the model,
`update` will not take an optimisation step for those parameters.
This is the case when the gradient is `nothing` or a subtype of `ChainRules.AbstractZero`.

For stateful optimisers, skipping an update it is generaly not the same as updating with a zero gradient.
For example, in the case of Adam, the momentum and variance are updated even if the gradient is zero:

```julia-repl
julia> x = (a = ones(2), b = ones(2));
(a = [1.0, 1.0], b = [1.0, 1.0])
julia> opt_state = Optimisers.setup(Adam(0.1), x)
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.0, 0.0], [0.0, 0.0], (0.9, 0.999))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.0, 0.0], [0.0, 0.0], (0.9, 0.999))))
julia> g = (; a = ones(2), b = ones(2)); # First an update with a non-zero gradient to increase the momentum and variance
julia> Optimisers.update!(opt_state, x, g);
julia> opt_state # the state in `a` and `b` are the same
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))))
julia> g = (; a = zeros(2), b = nothing); # Now an update with a zero gradient for a and no gradient for b
julia> Optimisers.update!(opt_state, x, g);
julia> opt_state # the state in `a` and `b` differ
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.09, 0.09], [0.000999, 0.000999], (0.729, 0.997003))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))))
```
8 changes: 8 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,21 @@ end
subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
subtract!(x, x̄::Zero) = x

# If we get Zero from AD on a leaf we skip the optimizer step. See
# https://github.com/FluxML/Optimisers.jl/issues/140
_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing

function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...)
x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s))
dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible.
nothing
end

# If we get Zero from AD in correspondence of a non-leaf node
# we end the recursion. The optimizer step won't be taken.
# https://github.com/FluxML/Optimisers.jl/issues/140
_grads!(dict::IdDict, t, x, ::Zero...) = nothing

function _grads!(dict::IdDict, tree, x, x̄s...)
# The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from
# functor(typeof(tree), base(x̄)), for things like Transpose
Expand Down

0 comments on commit bf480b3

Please sign in to comment.