Skip to content

Commit

Permalink
change opt to state
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Dec 7, 2022
1 parent 5e62649 commit a572da2
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 30 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2]

model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)

optim = Flux.setup(Adam(), model)
state = Flux.setup(Adam(), model)
for epoch in 1:1000
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, state)
end

plot(x -> 2x-x^3, -2, 2, legend=false)
Expand Down
10 changes: 5 additions & 5 deletions docs/src/models/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneH
loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true);
# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix)

optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc.
state = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc.

# Training loop, using the whole data set 1000 times:
losses = []
Expand All @@ -38,12 +38,12 @@ losses = []
y_hat = m(x)
Flux.crossentropy(y_hat, y)
end
Flux.update!(optim, model, grads[1])
Flux.update!(state, model, grads[1])
push!(losses, loss) # logging, outside gradient context
end
end

optim # parameters, momenta and output have all changed
state # parameters, momenta and output have all changed
out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false)

mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far!
Expand Down Expand Up @@ -95,7 +95,7 @@ Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux.

```julia
for epoch in 1:1_000
train!(model, loader, optim) do m, x, y
train!(model, loader, state) do m, x, y
y_hat = m(x)
Flux.crossentropy(y_hat, y)
end
Expand All @@ -110,7 +110,7 @@ end
```
(gradient of a zero-argument function) or
```
train!((x,y) -> loss(model, x, y), Flux.params(model), loader, optim)
train!((x,y) -> loss(model, x, y), Flux.params(model), loader, opt)
```
(with `Flux.params`) is in the old "implicit" style.
This still works on Flux 0.13, but will be removed from Flux 0.14.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/training/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ see the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for det

```@docs
Flux.Train.setup
Flux.Train.train!(loss, model, data, opt; cb)
Flux.Train.train!(loss, model, data, state; cb)
Optimisers.update!
```

Expand Down
47 changes: 25 additions & 22 deletions docs/src/training/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ are handled one-by-one. One *epoch* of training means that each example is used
something like this:

```julia
# Initialise the optimiser for this model:
state = Flux.setup(rule, model)

for data in train_set
# Unpack this element (for supervised training):
input, label = data
Expand All @@ -24,16 +27,16 @@ for data in train_set
end

# Update the parameters so as to reduce the objective,
# according to a particular optimiser:
Flux.update!(opt, model, grads[1])
# according the chosen optimisation rule:
Flux.update!(state, model, grads[1])
end
```

This loop can also be written using the function [`train!`](@ref Flux.Train.train!),
but it's helpful to undersand the pieces first:

```julia
train!(model, train_set, opt) do m, x, y
train!(model, train_set, state) do m, x, y
loss(m(x), y)
end
```
Expand Down Expand Up @@ -113,7 +116,7 @@ fmap(model, grads[1]) do p, g
end
```

A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt, model, grads[1])`.
A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(state, model, grads[1])`.
And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct.

However, there are many other optimisation rules, which adjust the step size and
Expand All @@ -126,13 +129,13 @@ first argument of `update!`. Like this:

```julia
# Initialise momentum
opt = Flux.setup(Momentum(0.01, 0.9), model)
state = Flux.setup(Momentum(0.01, 0.9), model)

for data in train_set
grads = [...]

# Update both model parameters and optimiser state:
Flux.update!(opt, model, grads[1])
Flux.update!(state, model, grads[1])
end
```

Expand Down Expand Up @@ -192,17 +195,17 @@ Simple training loops like the one above can be written compactly using
the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads:

```julia
opt = Flux.setup(Adam(), model)
state = Flux.setup(Adam(), model)

for epoch in 1:100
Flux.train!(model, train_set, opt) do m, x, y
Flux.train!(model, train_set, state) do m, x, y
loss(m(x), y)
end
end
```

Or explicitly writing the anonymous function which this `do` block creates,
`train!((m,x,y) -> loss(m(x),y), model, train_set, opt)` is exactly equivalent.
`train!((m,x,y) -> loss(m(x),y), model, train_set, state)` is exactly equivalent.

!!! compat "Implicit-style `train!`"
This is the new "explicit" method of `train!`, which takes the result of `setup` as its 4th argument.
Expand All @@ -224,7 +227,7 @@ callback API. Here is an example, in which it may be helpful to note:
* Julia's `break` and `continue` keywords let you exit from parts of the loop.

```julia
opt = Flux.setup(Adam(), model)
state = Flux.setup(Adam(), model)

my_log = []
for epoch in 1:100
Expand All @@ -248,7 +251,7 @@ for epoch in 1:100
continue
end

Flux.update!(opt, model, grads[1])
Flux.update!(state, model, grads[1])
end

# Compute some accuracy, and save details as a NamedTuple
Expand Down Expand Up @@ -300,13 +303,13 @@ So there is a simpler way to implement exactly the same thing, by modifying the
instead of the loss function. This is done by replacing this:

```julia
opt = Flux.setup(Adam(0.1), model)
state = Flux.setup(Adam(0.1), model)
```

with this:

```julia
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model)
decay_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model)
```

Flux's optimisers are really modifications applied to the gradient before using it to update
Expand All @@ -328,12 +331,12 @@ Finer control of training, you may wish to alter the learning rate mid-way throu
This can be done with [`adjust!`](@ref Flux.adjust!), like this:

```julia
opt = Flux.setup(Adam(0.1), model) # initialise once
state = Flux.setup(Adam(0.1), model) # initialise once

for epoch in 1:1000
train!([...], opt) # Train with η = 0.1 for first 100,
train!([...], state) # Train with η = 0.1 for first 100,
if epoch == 100 # then change to use η = 0.01 for the rest.
Flux.adjust!(opt, 0.01)
Flux.adjust!(state, 0.01)
end
end
```
Expand All @@ -342,7 +345,7 @@ end
With the old "implicit" optimiser, `opt = Adam(0.1)`, the equivalent was to
directly mutate the `Adam` struct, `opt.eta = 0.001`.

Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt, beta = (0.8, 0.99))`.
Other hyper-parameters can also be adjusted, such as `Flux.adjust!(state, beta = (0.8, 0.99))`.
And such modifications can be applied to just one part of the model.
For instance, this sets a different learning rate for the encoder and the decoder:

Expand All @@ -351,23 +354,23 @@ For instance, this sets a different learning rate for the encoder and the decode
bimodel = Chain(enc = [...], dec = [...])

# This returns a tree whose structure matches the model:
opt = Flux.setup(Adam(0.02), bimodel)
state = Flux.setup(Adam(0.02), bimodel)

# Adjust the learning rate to be used for bimodel.layers.enc
Flux.adjust!(opt.layers.enc, 0.03)
Flux.adjust!(state.layers.enc, 0.03)
```

To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!).
This is a temporary modification, reversed by `thaw!`:

```julia
Flux.freeze!(opt.layers.enc)
Flux.freeze!(state.layers.enc)

# Now training won't update parameters in bimodel.layers.enc
train!(loss, bimodel, data, opt)
train!(loss, bimodel, data, state)

# Un-freeze the entire model:
Flux.thaw!(opt)
Flux.thaw!(state)
```

!!! compat "Flux ≤ 0.13"
Expand Down

0 comments on commit a572da2

Please sign in to comment.