diff --git a/docs/src/pullbacks/computation_of_pullbacks.md b/docs/src/pullbacks/computation_of_pullbacks.md index a21ab865f..984e25416 100644 --- a/docs/src/pullbacks/computation_of_pullbacks.md +++ b/docs/src/pullbacks/computation_of_pullbacks.md @@ -104,7 +104,6 @@ is equivalent to applying the [Riemannian gradient](@ref "The Riemannian Gradien ## Library Functions ```@docs -AbstractPullback GeometricMachineLearning.ZygotePullback ``` diff --git a/docs/src/reduced_order_modeling/losses.md b/docs/src/reduced_order_modeling/losses.md index 5e9c705af..8ced8d2fa 100644 --- a/docs/src/reduced_order_modeling/losses.md +++ b/docs/src/reduced_order_modeling/losses.md @@ -44,7 +44,6 @@ where ``\mathbf{x}^{(t)}`` is the solution of the FOM at point ``t`` and ``\math ## Library Functions ```@docs -GeometricMachineLearning.NetworkLoss FeedForwardLoss TransformerLoss AutoEncoderLoss diff --git a/src/GeometricMachineLearning.jl b/src/GeometricMachineLearning.jl index 8911cd4b2..17b8605b9 100644 --- a/src/GeometricMachineLearning.jl +++ b/src/GeometricMachineLearning.jl @@ -29,6 +29,7 @@ module GeometricMachineLearning import AbstractNeuralNetworks: parameterlength import AbstractNeuralNetworks: GlorotUniform import AbstractNeuralNetworks: params, architecture, model, dim + import AbstractNeuralNetworks: AbstractPullback, NetworkLoss, _compute_loss # export params, architetcure, model export dim import GeometricIntegrators.Integrators: method, GeometricIntegrator diff --git a/src/data_loader/optimize.jl b/src/data_loader/optimize.jl index a5e15e748..67366847d 100644 --- a/src/data_loader/optimize.jl +++ b/src/data_loader/optimize.jl @@ -23,7 +23,7 @@ All the arguments are mandatory (there are no defaults): 3. the neural network parameters `ps`. 4. the data (i.e. an instance of [`DataLoader`](@ref)). 5. `batch`::[`Batch`](@ref): stores `batch_size` (and optionally `seq_length` and `prediction_window`). -6. `loss::`[`NetworkLoss`](@ref). +6. `loss::NetworkLoss`. 7. the *section* `λY` of the parameters `ps`. # Implementation diff --git a/src/loss/losses.jl b/src/loss/losses.jl index c4ee747ac..257a68e1f 100644 --- a/src/loss/losses.jl +++ b/src/loss/losses.jl @@ -1,34 +1,3 @@ -@doc raw""" - NetworkLoss - -An abstract type for all the neural network losses. -If you want to implement `CustomLoss <: NetworkLoss` you need to define a functor: -```julia -(loss::CustomLoss)(model, ps, input, output) -``` -where `model` is an instance of an `AbstractExplicitLayer` or a `Chain` and `ps` the parameters. - -See [`FeedForwardLoss`](@ref), [`TransformerLoss`](@ref), [`AutoEncoderLoss`](@ref) and [`ReducedLoss`](@ref) for examples. -""" -abstract type NetworkLoss end - -function (loss::NetworkLoss)(nn::NeuralNetwork, input::QPTOAT, output::QPTOAT) - loss(nn.model, nn.params, input, output) -end - -function _compute_loss(output_prediction::QPTOAT, output::QPTOAT) - _norm(_diff(output_prediction, output)) / _norm(output) -end - -function _compute_loss(model::Union{AbstractExplicitLayer, Chain}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT) - output_prediction = model(input, ps) - _compute_loss(output_prediction, output) -end - -function (loss::NetworkLoss)(model::Union{Chain, AbstractExplicitLayer}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT) - _compute_loss(model, ps, input, output) -end - @doc raw""" TransformerLoss(seq_length, prediction_window) diff --git a/src/optimizers/optimizer.jl b/src/optimizers/optimizer.jl index d9008cfc9..ab90857b0 100644 --- a/src/optimizers/optimizer.jl +++ b/src/optimizers/optimizer.jl @@ -20,7 +20,7 @@ The arguments are: 2. `dl::`[`DataLoader`](@ref) 3. `batch::`[`Batch`](@ref) 4. `n_epochs::Integer` -5. `loss::`[`NetworkLoss`](@ref) +5. `loss::NetworkLoss` The last argument is optional for many neural network architectures. We have the following defaults: - A [`TransformerIntegrator`](@ref) uses [`TransformerLoss`](@ref). diff --git a/src/pullback.jl b/src/pullback.jl index 338fc89cf..8e31fd43e 100644 --- a/src/pullback.jl +++ b/src/pullback.jl @@ -1,28 +1,3 @@ -@doc raw""" - AbstractPullback{NNLT<:NetworkLoss} - -`AbstractPullback` is an `abstract type` that encompasses all ways of performing differentiation (especially computing the gradient with respect to neural network parameters) in `GeometricMachineLearning`. - -If a user wants to implement a custom `Pullback` the following two functions have to be extended: -```julia -(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT}) -(_pullback::AbstractPullback)(ps, model, input_nt::QPT) -``` -based on the `loss::NetworkLoss` that's stored in `_pullback`. The output of _pullback needs to be a `Tuple` that contains: -1. the `loss` evaluated at `ps` and `input_nt` (or `input_nt_output_nt`), -2. the gradient of `loss` with respect to `ps` that call be called with e.g.: -```julia -_pullback(ps, model, input_nt)[2](1) # returns the gradient wrt to `ps` -``` -``\ldots`` we use this convention as it is analogous to how `Zygote` builds pullbacks. - -Also see [`ZygotePullback`](@ref). -""" -abstract type AbstractPullback{NNLT<:NetworkLoss} end - -(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT}) = error("Pullback not implemented for input-output pair!") -(_pullback::AbstractPullback)(ps, model, input_nt::QPT) = error("Pullback not implemented for single input!") - """ ZygotePullback <: AbstractPullback