Skip to content

Commit

Permalink
Merge e21b094 into bb88c55
Browse files Browse the repository at this point in the history
  • Loading branch information
logankilpatrick authored Nov 27, 2021
2 parents bb88c55 + e21b094 commit aff2f9e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/src/training/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ Such an object contains a reference to the model's parameters, not a copy, such

Handling all the parameters on a layer by layer basis is explained in the [Layer Helpers](../models/basics.md) section. Also, for freezing model parameters, see the [Advanced Usage Guide](../models/advanced.md).

```@docs
Flux.params
```

## Datasets

The `data` argument of `train!` provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy dataset with only one data point:
Expand Down
19 changes: 19 additions & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ function params!(p::Params, x, seen = IdSet())
end
end

"""
params(model)
params(layers...)
Given a model or specific layers from a model, create a `Params` object pointing to its trainable parameters.
This can be used with the `gradient` function, see [Taking Gradients](@ref), or as input to the [`Flux.train!`](@ref Flux.train!) function.
The behaviour of `params` on custom types can be customized using [`Functor.@functor`](@ref) or [`Flux.trainable`](@ref).
# Examples
```jldoctest
julia> params(Chain(Dense(ones(2,3))), softmax)
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])
julia> params(BatchNorm(2, relu))
Params([Float32[0.0, 0.0], Float32[1.0, 1.0]])
```
"""
function params(m...)
ps = Params()
params!(ps, m)
Expand Down

0 comments on commit aff2f9e

Please sign in to comment.