diff --git a/docs/src/training/training.md b/docs/src/training/training.md index fa8573d6a6..845a22d8a6 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -70,6 +70,10 @@ Such an object contains a reference to the model's parameters, not a copy, such Handling all the parameters on a layer by layer basis is explained in the [Layer Helpers](../models/basics.md) section. Also, for freezing model parameters, see the [Advanced Usage Guide](../models/advanced.md). +```@docs +Flux.params +``` + ## Datasets The `data` argument of `train!` provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy dataset with only one data point: diff --git a/src/functor.jl b/src/functor.jl index ee9eb1d543..179259a3c6 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -48,6 +48,25 @@ function params!(p::Params, x, seen = IdSet()) end end +""" + params(model) + params(layers...) + +Given a model or specific layers from a model, create a `Params` object pointing to its trainable parameters. + +This can be used with the `gradient` function, see [Taking Gradients](@ref), or as input to the [`Flux.train!`](@ref Flux.train!) function. + +The behaviour of `params` on custom types can be customized using [`Functor.@functor`](@ref) or [`Flux.trainable`](@ref). + +# Examples +```jldoctest +julia> params(Chain(Dense(ones(2,3))), softmax) +Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]]) + +julia> params(BatchNorm(2, relu)) +Params([Float32[0.0, 0.0], Float32[1.0, 1.0]]) +``` +""" function params(m...) ps = Params() params!(ps, m)