Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Now taking AbstractPullback and NetworkLoss from ANN. #175

Merged
merged 4 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/src/pullbacks/computation_of_pullbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ is equivalent to applying the [Riemannian gradient](@ref "The Riemannian Gradien
## Library Functions

```@docs
AbstractPullback
GeometricMachineLearning.ZygotePullback
```

Expand Down
1 change: 0 additions & 1 deletion docs/src/reduced_order_modeling/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ where ``\mathbf{x}^{(t)}`` is the solution of the FOM at point ``t`` and ``\math
## Library Functions

```@docs
GeometricMachineLearning.NetworkLoss
FeedForwardLoss
TransformerLoss
AutoEncoderLoss
Expand Down
1 change: 1 addition & 0 deletions src/GeometricMachineLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ module GeometricMachineLearning
import AbstractNeuralNetworks: parameterlength
import AbstractNeuralNetworks: GlorotUniform
import AbstractNeuralNetworks: params, architecture, model, dim
import AbstractNeuralNetworks: AbstractPullback, NetworkLoss, _compute_loss
# export params, architetcure, model
export dim
import GeometricIntegrators.Integrators: method, GeometricIntegrator
Expand Down
2 changes: 1 addition & 1 deletion src/data_loader/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ All the arguments are mandatory (there are no defaults):
3. the neural network parameters `ps`.
4. the data (i.e. an instance of [`DataLoader`](@ref)).
5. `batch`::[`Batch`](@ref): stores `batch_size` (and optionally `seq_length` and `prediction_window`).
6. `loss::`[`NetworkLoss`](@ref).
6. `loss::NetworkLoss`.
7. the *section* `λY` of the parameters `ps`.

# Implementation
Expand Down
31 changes: 0 additions & 31 deletions src/loss/losses.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,3 @@
@doc raw"""
NetworkLoss

An abstract type for all the neural network losses.
If you want to implement `CustomLoss <: NetworkLoss` you need to define a functor:
```julia
(loss::CustomLoss)(model, ps, input, output)
```
where `model` is an instance of an `AbstractExplicitLayer` or a `Chain` and `ps` the parameters.

See [`FeedForwardLoss`](@ref), [`TransformerLoss`](@ref), [`AutoEncoderLoss`](@ref) and [`ReducedLoss`](@ref) for examples.
"""
abstract type NetworkLoss end

function (loss::NetworkLoss)(nn::NeuralNetwork, input::QPTOAT, output::QPTOAT)
loss(nn.model, nn.params, input, output)
end

function _compute_loss(output_prediction::QPTOAT, output::QPTOAT)
_norm(_diff(output_prediction, output)) / _norm(output)
end

function _compute_loss(model::Union{AbstractExplicitLayer, Chain}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT)
output_prediction = model(input, ps)
_compute_loss(output_prediction, output)
end

function (loss::NetworkLoss)(model::Union{Chain, AbstractExplicitLayer}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT)
_compute_loss(model, ps, input, output)
end

@doc raw"""
TransformerLoss(seq_length, prediction_window)

Expand Down
2 changes: 1 addition & 1 deletion src/optimizers/optimizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The arguments are:
2. `dl::`[`DataLoader`](@ref)
3. `batch::`[`Batch`](@ref)
4. `n_epochs::Integer`
5. `loss::`[`NetworkLoss`](@ref)
5. `loss::NetworkLoss`

The last argument is optional for many neural network architectures. We have the following defaults:
- A [`TransformerIntegrator`](@ref) uses [`TransformerLoss`](@ref).
Expand Down
25 changes: 0 additions & 25 deletions src/pullback.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,3 @@
@doc raw"""
AbstractPullback{NNLT<:NetworkLoss}

`AbstractPullback` is an `abstract type` that encompasses all ways of performing differentiation (especially computing the gradient with respect to neural network parameters) in `GeometricMachineLearning`.

If a user wants to implement a custom `Pullback` the following two functions have to be extended:
```julia
(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT})
(_pullback::AbstractPullback)(ps, model, input_nt::QPT)
```
based on the `loss::NetworkLoss` that's stored in `_pullback`. The output of _pullback needs to be a `Tuple` that contains:
1. the `loss` evaluated at `ps` and `input_nt` (or `input_nt_output_nt`),
2. the gradient of `loss` with respect to `ps` that call be called with e.g.:
```julia
_pullback(ps, model, input_nt)[2](1) # returns the gradient wrt to `ps`
```
``\ldots`` we use this convention as it is analogous to how `Zygote` builds pullbacks.

Also see [`ZygotePullback`](@ref).
"""
abstract type AbstractPullback{NNLT<:NetworkLoss} end

(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT}) = error("Pullback not implemented for input-output pair!")
(_pullback::AbstractPullback)(ps, model, input_nt::QPT) = error("Pullback not implemented for single input!")

"""
ZygotePullback <: AbstractPullback

Expand Down
Loading