Skip to content

Commit

Permalink
Now taking AbstractPullback and NetworkLoss from ANN.
Browse files Browse the repository at this point in the history
  • Loading branch information
benedict-96 committed Nov 28, 2024
1 parent 5aa558c commit d0abe33
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 56 deletions.
1 change: 1 addition & 0 deletions src/GeometricMachineLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ module GeometricMachineLearning
import AbstractNeuralNetworks: parameterlength
import AbstractNeuralNetworks: GlorotUniform
import AbstractNeuralNetworks: params, architecture, model, dim
import AbstractNeuralNetworks: AbstractPullback, NetworkLoss
# export params, architetcure, model
export dim
import GeometricIntegrators.Integrators: method, GeometricIntegrator
Expand Down
31 changes: 0 additions & 31 deletions src/loss/losses.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,3 @@
@doc raw"""
NetworkLoss
An abstract type for all the neural network losses.
If you want to implement `CustomLoss <: NetworkLoss` you need to define a functor:
```julia
(loss::CustomLoss)(model, ps, input, output)
```
where `model` is an instance of an `AbstractExplicitLayer` or a `Chain` and `ps` the parameters.
See [`FeedForwardLoss`](@ref), [`TransformerLoss`](@ref), [`AutoEncoderLoss`](@ref) and [`ReducedLoss`](@ref) for examples.
"""
abstract type NetworkLoss end

function (loss::NetworkLoss)(nn::NeuralNetwork, input::QPTOAT, output::QPTOAT)
loss(nn.model, nn.params, input, output)
end

function _compute_loss(output_prediction::QPTOAT, output::QPTOAT)
_norm(_diff(output_prediction, output)) / _norm(output)
end

function _compute_loss(model::Union{AbstractExplicitLayer, Chain}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT)
output_prediction = model(input, ps)
_compute_loss(output_prediction, output)
end

function (loss::NetworkLoss)(model::Union{Chain, AbstractExplicitLayer}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT)
_compute_loss(model, ps, input, output)
end

@doc raw"""
TransformerLoss(seq_length, prediction_window)
Expand Down
25 changes: 0 additions & 25 deletions src/pullback.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,3 @@
@doc raw"""
AbstractPullback{NNLT<:NetworkLoss}
`AbstractPullback` is an `abstract type` that encompasses all ways of performing differentiation (especially computing the gradient with respect to neural network parameters) in `GeometricMachineLearning`.
If a user wants to implement a custom `Pullback` the following two functions have to be extended:
```julia
(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT})
(_pullback::AbstractPullback)(ps, model, input_nt::QPT)
```
based on the `loss::NetworkLoss` that's stored in `_pullback`. The output of _pullback needs to be a `Tuple` that contains:
1. the `loss` evaluated at `ps` and `input_nt` (or `input_nt_output_nt`),
2. the gradient of `loss` with respect to `ps` that call be called with e.g.:
```julia
_pullback(ps, model, input_nt)[2](1) # returns the gradient wrt to `ps`
```
``\ldots`` we use this convention as it is analogous to how `Zygote` builds pullbacks.
Also see [`ZygotePullback`](@ref).
"""
abstract type AbstractPullback{NNLT<:NetworkLoss} end

(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT}) = error("Pullback not implemented for input-output pair!")
(_pullback::AbstractPullback)(ps, model, input_nt::QPT) = error("Pullback not implemented for single input!")

"""
ZygotePullback <: AbstractPullback
Expand Down

0 comments on commit d0abe33

Please sign in to comment.