From 89074bccf99d0834f64f743c60c84d08d057d01b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 25 Nov 2022 08:23:08 -0500 Subject: [PATCH] tweaks --- Project.toml | 2 +- docs/src/training/train_api.md | 18 ++++++++++++++++-- docs/src/training/training.md | 22 +++++++++++++++++----- src/Flux.jl | 1 + 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 2832ddb8e9..6a0da4670c 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ MacroTools = "0.5" NNlib = "0.8.9" NNlibCUDA = "0.2.4" OneHotArrays = "0.1, 0.2" -Optimisers = "0.2.10" +Optimisers = "0.2.11" ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" diff --git a/docs/src/training/train_api.md b/docs/src/training/train_api.md index bd2f0c83c3..b8c05b240e 100644 --- a/docs/src/training/train_api.md +++ b/docs/src/training/train_api.md @@ -9,14 +9,28 @@ Flux.Optimise.train!(loss, model, data, opt; cb) To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl) and follow its setup instructions. -The new version of Flux's training code was written as an independent package, called Optimisers.jl. -However, at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) +The new version of Flux's training code was written as an independent package, [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). +This is designed to allow for immutable objects. +But at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) which can be updated in-place. Thus objects returned by `update!` can be ignored. ```@docs Optimisers.update! ``` +### Modifiers + +The state returned by `setup` can be modified to temporarily prevent training of +some parts of the model, or to change the learning rate uses. +The functions for doing so may be accessed as `Flux.freeze!`, `Flux.thaw!`, and `Flux.adjust`: + +```@docs +Optimisers.adjust +Optimisers.freeze! +Optimisers.thaw! +``` + + ## Implicit style Flux used to handle gradients, training, and optimisation rules quite differently. diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 7c34e7b9a8..791345ce7e 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -326,15 +326,27 @@ The first, [`WeightDecay`](@ref) adds `0.42` times original parameter to the gra matching the gradient of the penalty above (with the same, unrealistically large, constant). After that, in either case, [`Adam`](@ref) computes the final update. -The same mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ). +The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ). Besides L2 / weight decay, another common and quite different kind of regularisation is -provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some ... ?? - -?? do we discuss test/train mode here too? +provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the +previous layer during training. +It should switch automatically, but see [trainmode!](@ref Flux.trainmode!) / [testmode!](@ref Flux.testmode!) to manually enable or disable this layer. ## Freezing, Schedules -?? maybe these also fit in here. +Finer control of training + +```julia +model = Chain(enc = encoder, dec = decoder) + +opt = Flux.setup(Adam(), model) + +Flux.freeze!(opt.layers.enc) # corresponds to model.layers.end +``` +!!! note + This `freeze!` goes with the "explicit" style. + The earlier "implicit" equivalent was to pass to `gradient` an object referencing only + part of the model, such as `Flux.params(model.layers.enc)`. diff --git a/src/Flux.jl b/src/Flux.jl index 53749b85fa..4933b95aa2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -8,6 +8,7 @@ using MacroTools: @forward @reexport using NNlib using MLUtils import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions +using Optimisers: freeze!, thaw!, adjust using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback, @nograd