Skip to content

Commit

Permalink
perhaps we should build regularisation into the same page
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 21, 2022
1 parent 3d7eb3f commit c6bac9a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ makedocs(
"Training Models" => [
"Training" => "training/training.md",
"Training API 📚" => "training/train_api.md",
"Regularisation" => "models/regularisation.md",
# "Regularisation" => "models/regularisation.md",
"Loss Functions 📚" => "models/losses.md",
"Optimisation Rules 📚" => "training/optimisers.md", # TODO move optimiser intro up to Training
"Callback Helpers 📚" => "training/callbacks.md",
Expand Down
63 changes: 63 additions & 0 deletions docs/src/training/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,66 @@ For more details on training in the implicit style, see [Flux 0.13.6 documentati

For details about the two gradient modes, see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1).

## Regularisation

The term *regularisation* covers a wide variety of techniques aiming to improve the
result of training. This is often done to avoid overfitting.

Some of these are can be implemented by simply modifying the loss function.
L2 or ... umm ... adds to the loss a penalty proportional to `θ^2` for every scalar parameter,
and for a simple model could be implemented as follows:

```julia
Flux.gradient(model) do m
result = m(input)
penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2
my_loss(result, label) + 0.42 * penalty
end
```

Accessing each individual parameter array by hand won't work well for large models.
Instead, we can use [`Flux.params`](@ref) to collect all of them,
and then apply a function to each one, and sum the result:

```julia
pen_l2(x::AbstractArray) = sum(abs2, x)/2

Flux.gradient(model) do m
result = m(input)
penalty = sum(pen_l2, Flux.params(m))
my_loss(result, label) + 0.42 * penalty
end
```

However, the gradient of this penalty term is very simple: It is proportional to the original weights.
So there is a simpler way to implement exactly the same thing, by modifying the optimiser
instead of the loss function. This is done by replacing this:

```julia
opt = Flux.setup(Adam(0.1), model)
```

with this:

```julia
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model)
```

Flux's optimisers are really modifications applied to the gradient before using it to update
the parameters, and `OptimiserChain` applies two such modifications.
The first, [`WeightDecay`](@ref) adds `0.42` times original parameter to the gradient,
matching the gradient of the penalty above (with the same, unrealistically large, constant).
After that, in either case, [`Adam`](@ref) computes the final update.

The same mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ).

Besides L2 / weight decay, another common and quite different kind of regularisation is
provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some ... ??

?? do we discuss test/train mode here too?

## Freezing, Schedules

?? maybe these also fit in here.


41 changes: 41 additions & 0 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,44 @@ end
@test y5 < y4
end

@testset "L2 regularisation" begin
# New docs claim an exact equivalent. It's a bit long to put the example in there,
# but perhaps the tests should contain it.

model = Dense(3 => 2, tanh);
init_weight = copy(model.weight);
data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10];

# Take 1: explicitly add a penalty in the loss function
opt = Flux.setup(Adam(0.1), model)
Flux.train!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2
err + 0.33 * l2
end
diff1 = model.weight .- init_weight

# Take 2: the same, but with Flux.params. Was broken for a bit, no tests!
model.weight .= init_weight
model.bias .= 0
pen2(x::AbstractArray) = sum(abs2, x)/2
opt = Flux.setup(Adam(0.1), model)
Flux.train!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(pen2, Flux.params(m))
err + 0.33 * l2
end
diff2 = model.weight .- init_weight
@test_broken diff1 diff2

# Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
model.weight .= init_weight
model.bias .= 0
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model);
Flux.train!(model, data, decay_opt) do m, x, y
Flux.mse(m(x), y)
end
diff3 = model.weight .- init_weight
@test diff1 diff3
end

0 comments on commit c6bac9a

Please sign in to comment.