Skip to content

Commit

Permalink
Re-write training docs (#2114)
Browse files Browse the repository at this point in the history
* re-write training.md

* add train_api page for docstrings

* update basic.md to introduce explicit not implicit

* more links, comments on notes

* updates, rm some Optimisers detail

* mention TerminalLoggers

* tweaks

* perhaps we should build regularisation into the same page

* tweaks

* update quickstart + readme too

* finish freezing etc, update everything

* fix a test, etc

* add note to "advanced" page

* tweaks

* comments

* tweaks, bugs, missing files, etc

* move a sentence

* change opt to state

* new page lost in rebase

* don't say "explicit" so often

* opt to state in a few more places

* add three compat boxes about common errors / problems re old versions

* change to opt_state

* fixes

* fixup

* fixup

* fixup

* spelling & indentation
  • Loading branch information
mcabbott authored Dec 15, 2022
1 parent 4f015e9 commit 40f0a63
Show file tree
Hide file tree
Showing 21 changed files with 639 additions and 355 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MacroTools = "0.5"
NNlib = "0.8.9"
NNlibCUDA = "0.2.4"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2.10"
Optimisers = "0.2.12"
ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
Expand Down
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2]

model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)

mloss(x,y) = (model(x) - y)^2
optim = Flux.Adam()
optim = Flux.setup(Adam(), model)
for epoch in 1:1000
Flux.train!(mloss, Flux.params(model), data, optim)
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
end

plot(x -> 2x-x^3, -2, 2, legend=false)
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand Down
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ makedocs(
"Fitting a Line" => "models/overview.md",
"Gradients and Layers" => "models/basics.md",
"Training" => "training/training.md",
"Regularisation" => "models/regularisation.md", # consolidated in #2114
"Recurrence" => "models/recurrence.md",
"GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md",
Expand All @@ -31,7 +30,8 @@ makedocs(
"Activation Functions" => "models/activation.md",
"Weight Initialisation" => "utilities.md",
"Loss Functions" => "models/losses.md",
"Optimisation Rules" => "training/optimisers.md", # TODO move optimiser intro up to Training
"Training API" => "training/reference.md",
"Optimisation Rules" => "training/optimisers.md",
"Shape Inference" => "outputsize.md",
"Flat vs. Nested" => "destructure.md",
"Callback Helpers" => "training/callbacks.md",
Expand Down
6 changes: 6 additions & 0 deletions docs/src/destructure.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ julia> Flux.destructure(grad) # acts on non-models, too
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5))
```

!!! compat "Flux ≤ 0.12"
Old versions of Flux had an entirely different implementation of `destructure`, which
had many bugs (and almost no tests). Many comments online still refer to that now-deleted
function, or to memories of it.


### All Parameters

The function `destructure` now lives in [`Optimisers.jl`](https://github.com/FluxML/Optimisers.jl).
Expand Down
12 changes: 12 additions & 0 deletions docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ However, doing this requires the `struct` to have a corresponding constructor th

When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.

!!! compat "Flux ≤ 0.13"
The mechanism described here is for Flux's old "implicit" training style.
When upgrading for Flux 0.14, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.

Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
this using the slicing features `Chain` provides:

Expand Down Expand Up @@ -155,6 +159,10 @@ model(xs)
# returns a single float vector with one value
```

!!! note
This `Join` layer is available from the [Fluxperimental.jl](https://github.com/FluxML/Fluxperimental.jl) package.


#### Using `Parallel`

Flux already provides [`Parallel`](@ref) that can offer the same functionality. In this case, `Join` is going to just be syntactic sugar for `Parallel`.
Expand Down Expand Up @@ -223,3 +231,7 @@ function loss(x, ys, model)
return sqrt(mean(Flux.mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs)))
end
```

!!! note
This `Split` layer is available from the [Fluxperimental.jl](https://github.com/FluxML/Fluxperimental.jl) package.

80 changes: 61 additions & 19 deletions docs/src/models/basics.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# [How Flux Works: Gradients and Layers](@id man-basics)

## Taking Gradients
## [Taking Gradients](@id man-taking-gradients)

Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)

Expand Down Expand Up @@ -29,35 +29,77 @@ julia> gradient(f, [2, 1], [2, 0])
([0.0, 2.0], [-0.0, -2.0])
```

These gradients are based on `x` and `y`. Flux works by instead taking gradients based on the weights and biases that make up the parameters of a model.
These gradients are based on `x` and `y`. Flux works by instead taking gradients based on the weights and biases that make up the parameters of a model.


Machine learning often can have *hundreds* of parameters, so Flux lets you work with collections of parameters, via the `params` functions. You can get the gradient of all parameters used in a program without explicitly passing them in.
Machine learning often can have *hundreds* of parameter arrays.
Instead of passing them to `gradient` individually, we can store them together in a structure.
The simplest example is a named tuple, created by the following syntax:

```jldoctest basics
julia> x = [2, 1];
julia> nt = (a = [2, 1], b = [2, 0], c = tanh);
julia> g(x::NamedTuple) = sum(abs2, x.a .- x.b);
julia> g(nt)
1
julia> dg_nt = gradient(g, nt)[1]
(a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing)
```

Notice that `gradient` has returned a matching structure. The field `dg_nt.a` is the gradient
for `nt.a`, and so on. Some fields have no gradient, indicated by `nothing`.

julia> y = [2, 0];
Rather than define a function like `g` every time (and think up a name for it),
it is often useful to use anonymous functions: this one is `x -> sum(abs2, x.a .- x.b)`.
Anonymous functions can be defined either with `->` or with `do`,
and such `do` blocks are often useful if you have a few steps to perform:

```jldoctest basics
julia> gradient((x, y) -> sum(abs2, x.a ./ y .- x.b), nt, [1, 2])
((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25])
julia> gs = gradient(Flux.params(x, y)) do
f(x, y)
julia> gradient(nt, [1, 2]) do x, y
z = x.a ./ y
sum(abs2, z .- x.b)
end
Grads(...)
((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25])
```

julia> gs[x]
2-element Vector{Float64}:
0.0
2.0
Sometimes you may want to know the value of the function, as well as its gradient.
Rather than calling the function a second time, you can call [`withgradient`](@ref Zygote.withgradient) instead:

julia> gs[y]
2-element Vector{Float64}:
-0.0
-2.0
```
julia> Flux.withgradient(g, nt)
(val = 1, grad = ((a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing),))
```

!!! note "Implicit gradients"
Flux used to handle many parameters in a different way, using the [`params`](@ref Flux.params) function.
This uses a method of `gradient` which takes a zero-argument function, and returns a dictionary
through which the resulting gradients can be looked up:

```jldoctest basics
julia> x = [2, 1];

julia> y = [2, 0];

julia> gs = gradient(Flux.params(x, y)) do
f(x, y)
end
Grads(...)

julia> gs[x]
2-element Vector{Float64}:
0.0
2.0

Here, `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
julia> gs[y]
2-element Vector{Float64}:
-0.0
-2.0
```

This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.

## Building Simple Models

Expand Down
7 changes: 6 additions & 1 deletion docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The `Dense` exemplifies several features:
* It is annotated with [`@functor`](@ref Functors.@functor), which means that [`params`](@ref Flux.params) will see the contents, and [`gpu`](@ref Flux.gpu) will move their arrays to the GPU.

By contrast, `Chain` itself contains no parameters, but connects other layers together.
The section on [dataflow layers](@ref man-dataflow-layers) introduces others like this,
The section on [dataflow layers](@ref man-dataflow-layers) introduces others like this.

## Fully Connected

Expand All @@ -27,6 +27,11 @@ Flux.Scale

Perhaps `Scale` isn't quite fully connected, but it may be thought of as `Dense(Diagonal(s.weights), s.bias)`, and LinearAlgebra's `Diagonal` is a matrix which just happens to contain many zeros.

!!! compat "Flux ≤ 0.12"
Old versions of Flux accepted only `Dense(in, out, act)` and not `Dense(in => out, act)`.
This notation makes a `Pair` object. If you get an error like `MethodError: no method matching Dense(::Pair{Int64,Int64})`, this means that you should upgrade to Flux 0.13.


## Convolution Models

These layers are used to build convolutional neural networks (CNNs).
Expand Down
32 changes: 22 additions & 10 deletions docs/src/models/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,23 @@ target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneH
loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true);
# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix)

pars = Flux.params(model) # contains references to arrays in model
opt = Flux.Adam(0.01) # will store optimiser momentum, etc.
optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc.

# Training loop, using the whole data set 1000 times:
losses = []
@showprogress for epoch in 1:1_000
for (x, y) in loader
loss, grad = Flux.withgradient(pars) do
loss, grads = Flux.withgradient(model) do m
# Evaluate model and loss inside gradient context:
y_hat = model(x)
y_hat = m(x)
Flux.crossentropy(y_hat, y)
end
Flux.update!(opt, pars, grad)
Flux.update!(optim, model, grads[1])
push!(losses, loss) # logging, outside gradient context
end
end

pars # parameters, momenta and output have all changed
opt
optim # parameters, momenta and output have all changed
out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false)

mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far!
Expand Down Expand Up @@ -89,17 +87,31 @@ Some things to notice in this example are:

* The `model` can be called like a function, `y = model(x)`. Each layer like [`Dense`](@ref Flux.Dense) is an ordinary `struct`, which encapsulates some arrays of parameters (and possibly other state, as for [`BatchNorm`](@ref Flux.BatchNorm)).

* But the model does not contain the loss function, nor the optimisation rule. The [`Adam`](@ref Flux.Adam) object stores between iterations the momenta it needs. And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function.
* But the model does not contain the loss function, nor the optimisation rule. The momenta needed by [`Adam`](@ref Flux.Adam) are stored in the object returned by [setup](@ref Flux.Train.setup). And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function.

* The `do` block creates an anonymous function, as the first argument of `gradient`. Anything executed within this is differentiated.

Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux.update!) separately, there is a convenience function [`train!`](@ref Flux.train!). If we didn't want anything extra (like logging the loss), we could replace the training loop with the following:

```julia
for epoch in 1:1_000
Flux.train!(pars, loader, opt) do x, y
y_hat = model(x)
Flux.train!(model, loader, optim) do m, x, y
y_hat = m(x)
Flux.crossentropy(y_hat, y)
end
end
```

!!! compat "Implicit-style training, Flux ≤ 0.13"
Until recently Flux's training worked a bit differently.
Any code which looks like
```
gradient(() -> loss(model, x, y), Flux.params(model))
```
(gradient of a zero-argument function) or
```
train!((x,y) -> loss(model, x, y), Flux.params(model), loader, opt)
```
(with `Flux.params`) is in the old "implicit" style.
This still works on Flux 0.13, but will be removed from Flux 0.14.
See the [training section](@ref man-training) for more details.
80 changes: 0 additions & 80 deletions docs/src/models/regularisation.md

This file was deleted.

2 changes: 1 addition & 1 deletion docs/src/training/callbacks.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Callback Helpers
# [Callback Helpers](@id man-callback-helpers)

```@docs
Flux.throttle
Expand Down
Loading

0 comments on commit 40f0a63

Please sign in to comment.