diff --git a/README.md b/README.md index 5d2faf255d..9f7efe7ee7 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2] model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only) -optim = Flux.setup(Adam(), model) +state = Flux.setup(Adam(), model) for epoch in 1:1000 - Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim) + Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, state) end plot(x -> 2x-x^3, -2, 2, legend=false) diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index f500c7e70e..de54949053 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -27,7 +27,7 @@ target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneH loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true); # 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) -optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. +state = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. # Training loop, using the whole data set 1000 times: losses = [] @@ -38,12 +38,12 @@ losses = [] y_hat = m(x) Flux.crossentropy(y_hat, y) end - Flux.update!(optim, model, grads[1]) + Flux.update!(state, model, grads[1]) push!(losses, loss) # logging, outside gradient context end end -optim # parameters, momenta and output have all changed +state # parameters, momenta and output have all changed out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false) mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far! @@ -95,7 +95,7 @@ Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux. ```julia for epoch in 1:1_000 - train!(model, loader, optim) do m, x, y + train!(model, loader, state) do m, x, y y_hat = m(x) Flux.crossentropy(y_hat, y) end @@ -110,7 +110,7 @@ end ``` (gradient of a zero-argument function) or ``` - train!((x,y) -> loss(model, x, y), Flux.params(model), loader, optim) + train!((x,y) -> loss(model, x, y), Flux.params(model), loader, opt) ``` (with `Flux.params`) is in the old "implicit" style. This still works on Flux 0.13, but will be removed from Flux 0.14. diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index a6054515ea..a440a9a641 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -16,7 +16,7 @@ see the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for det ```@docs Flux.Train.setup -Flux.Train.train!(loss, model, data, opt; cb) +Flux.Train.train!(loss, model, data, state; cb) Optimisers.update! ``` diff --git a/docs/src/training/training.md b/docs/src/training/training.md index ba3bba071d..b4ce944249 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -12,6 +12,9 @@ are handled one-by-one. One *epoch* of training means that each example is used something like this: ```julia +# Initialise the optimiser for this model: +state = Flux.setup(rule, model) + for data in train_set # Unpack this element (for supervised training): input, label = data @@ -24,8 +27,8 @@ for data in train_set end # Update the parameters so as to reduce the objective, - # according to a particular optimiser: - Flux.update!(opt, model, grads[1]) + # according the chosen optimisation rule: + Flux.update!(state, model, grads[1]) end ``` @@ -33,7 +36,7 @@ This loop can also be written using the function [`train!`](@ref Flux.Train.trai but it's helpful to undersand the pieces first: ```julia -train!(model, train_set, opt) do m, x, y +train!(model, train_set, state) do m, x, y loss(m(x), y) end ``` @@ -113,7 +116,7 @@ fmap(model, grads[1]) do p, g end ``` -A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt, model, grads[1])`. +A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(state, model, grads[1])`. And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct. However, there are many other optimisation rules, which adjust the step size and @@ -126,13 +129,13 @@ first argument of `update!`. Like this: ```julia # Initialise momentum -opt = Flux.setup(Momentum(0.01, 0.9), model) +state = Flux.setup(Momentum(0.01, 0.9), model) for data in train_set grads = [...] # Update both model parameters and optimiser state: - Flux.update!(opt, model, grads[1]) + Flux.update!(state, model, grads[1]) end ``` @@ -192,17 +195,17 @@ Simple training loops like the one above can be written compactly using the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads: ```julia -opt = Flux.setup(Adam(), model) +state = Flux.setup(Adam(), model) for epoch in 1:100 - Flux.train!(model, train_set, opt) do m, x, y + Flux.train!(model, train_set, state) do m, x, y loss(m(x), y) end end ``` Or explicitly writing the anonymous function which this `do` block creates, -`train!((m,x,y) -> loss(m(x),y), model, train_set, opt)` is exactly equivalent. +`train!((m,x,y) -> loss(m(x),y), model, train_set, state)` is exactly equivalent. !!! compat "Implicit-style `train!`" This is the new "explicit" method of `train!`, which takes the result of `setup` as its 4th argument. @@ -224,7 +227,7 @@ callback API. Here is an example, in which it may be helpful to note: * Julia's `break` and `continue` keywords let you exit from parts of the loop. ```julia -opt = Flux.setup(Adam(), model) +state = Flux.setup(Adam(), model) my_log = [] for epoch in 1:100 @@ -248,7 +251,7 @@ for epoch in 1:100 continue end - Flux.update!(opt, model, grads[1]) + Flux.update!(state, model, grads[1]) end # Compute some accuracy, and save details as a NamedTuple @@ -300,13 +303,13 @@ So there is a simpler way to implement exactly the same thing, by modifying the instead of the loss function. This is done by replacing this: ```julia -opt = Flux.setup(Adam(0.1), model) +state = Flux.setup(Adam(0.1), model) ``` with this: ```julia -decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) +decay_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) ``` Flux's optimisers are really modifications applied to the gradient before using it to update @@ -328,12 +331,12 @@ Finer control of training, you may wish to alter the learning rate mid-way throu This can be done with [`adjust!`](@ref Flux.adjust!), like this: ```julia -opt = Flux.setup(Adam(0.1), model) # initialise once +state = Flux.setup(Adam(0.1), model) # initialise once for epoch in 1:1000 - train!([...], opt) # Train with η = 0.1 for first 100, + train!([...], state) # Train with η = 0.1 for first 100, if epoch == 100 # then change to use η = 0.01 for the rest. - Flux.adjust!(opt, 0.01) + Flux.adjust!(state, 0.01) end end ``` @@ -342,7 +345,7 @@ end With the old "implicit" optimiser, `opt = Adam(0.1)`, the equivalent was to directly mutate the `Adam` struct, `opt.eta = 0.001`. -Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt, beta = (0.8, 0.99))`. +Other hyper-parameters can also be adjusted, such as `Flux.adjust!(state, beta = (0.8, 0.99))`. And such modifications can be applied to just one part of the model. For instance, this sets a different learning rate for the encoder and the decoder: @@ -351,23 +354,23 @@ For instance, this sets a different learning rate for the encoder and the decode bimodel = Chain(enc = [...], dec = [...]) # This returns a tree whose structure matches the model: -opt = Flux.setup(Adam(0.02), bimodel) +state = Flux.setup(Adam(0.02), bimodel) # Adjust the learning rate to be used for bimodel.layers.enc -Flux.adjust!(opt.layers.enc, 0.03) +Flux.adjust!(state.layers.enc, 0.03) ``` To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!). This is a temporary modification, reversed by `thaw!`: ```julia -Flux.freeze!(opt.layers.enc) +Flux.freeze!(state.layers.enc) # Now training won't update parameters in bimodel.layers.enc -train!(loss, bimodel, data, opt) +train!(loss, bimodel, data, state) # Un-freeze the entire model: -Flux.thaw!(opt) +Flux.thaw!(state) ``` !!! compat "Flux ≤ 0.13"