From c6bac9a53daddad1301d2bd60a695868d731ad05 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 21 Nov 2022 17:50:45 -0500 Subject: [PATCH] perhaps we should build regularisation into the same page --- docs/make.jl | 2 +- docs/src/training/training.md | 63 +++++++++++++++++++++++++++++++++++ test/train.jl | 41 +++++++++++++++++++++++ 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index fc96fafff9..9e5e16afe6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -28,7 +28,7 @@ makedocs( "Training Models" => [ "Training" => "training/training.md", "Training API 📚" => "training/train_api.md", - "Regularisation" => "models/regularisation.md", + # "Regularisation" => "models/regularisation.md", "Loss Functions 📚" => "models/losses.md", "Optimisation Rules 📚" => "training/optimisers.md", # TODO move optimiser intro up to Training "Callback Helpers 📚" => "training/callbacks.md", diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 57c9638332..7c34e7b9a8 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -275,3 +275,66 @@ For more details on training in the implicit style, see [Flux 0.13.6 documentati For details about the two gradient modes, see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1). +## Regularisation + +The term *regularisation* covers a wide variety of techniques aiming to improve the +result of training. This is often done to avoid overfitting. + +Some of these are can be implemented by simply modifying the loss function. +L2 or ... umm ... adds to the loss a penalty proportional to `θ^2` for every scalar parameter, +and for a simple model could be implemented as follows: + +```julia +Flux.gradient(model) do m + result = m(input) + penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 + my_loss(result, label) + 0.42 * penalty +end +``` + +Accessing each individual parameter array by hand won't work well for large models. +Instead, we can use [`Flux.params`](@ref) to collect all of them, +and then apply a function to each one, and sum the result: + +```julia +pen_l2(x::AbstractArray) = sum(abs2, x)/2 + +Flux.gradient(model) do m + result = m(input) + penalty = sum(pen_l2, Flux.params(m)) + my_loss(result, label) + 0.42 * penalty +end +``` + +However, the gradient of this penalty term is very simple: It is proportional to the original weights. +So there is a simpler way to implement exactly the same thing, by modifying the optimiser +instead of the loss function. This is done by replacing this: + +```julia +opt = Flux.setup(Adam(0.1), model) +``` + +with this: + +```julia +decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) +``` + +Flux's optimisers are really modifications applied to the gradient before using it to update +the parameters, and `OptimiserChain` applies two such modifications. +The first, [`WeightDecay`](@ref) adds `0.42` times original parameter to the gradient, +matching the gradient of the penalty above (with the same, unrealistically large, constant). +After that, in either case, [`Adam`](@ref) computes the final update. + +The same mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ). + +Besides L2 / weight decay, another common and quite different kind of regularisation is +provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some ... ?? + +?? do we discuss test/train mode here too? + +## Freezing, Schedules + +?? maybe these also fit in here. + + diff --git a/test/train.jl b/test/train.jl index 49ecf9c751..c93a8dbf83 100644 --- a/test/train.jl +++ b/test/train.jl @@ -91,3 +91,44 @@ end @test y5 < y4 end +@testset "L2 regularisation" begin + # New docs claim an exact equivalent. It's a bit long to put the example in there, + # but perhaps the tests should contain it. + + model = Dense(3 => 2, tanh); + init_weight = copy(model.weight); + data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10]; + + # Take 1: explicitly add a penalty in the loss function + opt = Flux.setup(Adam(0.1), model) + Flux.train!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 + err + 0.33 * l2 + end + diff1 = model.weight .- init_weight + + # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! + model.weight .= init_weight + model.bias .= 0 + pen2(x::AbstractArray) = sum(abs2, x)/2 + opt = Flux.setup(Adam(0.1), model) + Flux.train!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(pen2, Flux.params(m)) + err + 0.33 * l2 + end + diff2 = model.weight .- init_weight + @test_broken diff1 ≈ diff2 + + # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. + model.weight .= init_weight + model.bias .= 0 + decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model); + Flux.train!(model, data, decay_opt) do m, x, y + Flux.mse(m(x), y) + end + diff3 = model.weight .- init_weight + @test diff1 ≈ diff3 +end +