diff --git a/src/GeometricMachineLearning.jl b/src/GeometricMachineLearning.jl index 8911cd4b2..d13493161 100644 --- a/src/GeometricMachineLearning.jl +++ b/src/GeometricMachineLearning.jl @@ -29,6 +29,7 @@ module GeometricMachineLearning import AbstractNeuralNetworks: parameterlength import AbstractNeuralNetworks: GlorotUniform import AbstractNeuralNetworks: params, architecture, model, dim + import AbstractNeuralNetworks: AbstractPullback, NetworkLoss # export params, architetcure, model export dim import GeometricIntegrators.Integrators: method, GeometricIntegrator diff --git a/src/loss/losses.jl b/src/loss/losses.jl index c4ee747ac..257a68e1f 100644 --- a/src/loss/losses.jl +++ b/src/loss/losses.jl @@ -1,34 +1,3 @@ -@doc raw""" - NetworkLoss - -An abstract type for all the neural network losses. -If you want to implement `CustomLoss <: NetworkLoss` you need to define a functor: -```julia -(loss::CustomLoss)(model, ps, input, output) -``` -where `model` is an instance of an `AbstractExplicitLayer` or a `Chain` and `ps` the parameters. - -See [`FeedForwardLoss`](@ref), [`TransformerLoss`](@ref), [`AutoEncoderLoss`](@ref) and [`ReducedLoss`](@ref) for examples. -""" -abstract type NetworkLoss end - -function (loss::NetworkLoss)(nn::NeuralNetwork, input::QPTOAT, output::QPTOAT) - loss(nn.model, nn.params, input, output) -end - -function _compute_loss(output_prediction::QPTOAT, output::QPTOAT) - _norm(_diff(output_prediction, output)) / _norm(output) -end - -function _compute_loss(model::Union{AbstractExplicitLayer, Chain}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT) - output_prediction = model(input, ps) - _compute_loss(output_prediction, output) -end - -function (loss::NetworkLoss)(model::Union{Chain, AbstractExplicitLayer}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT) - _compute_loss(model, ps, input, output) -end - @doc raw""" TransformerLoss(seq_length, prediction_window) diff --git a/src/pullback.jl b/src/pullback.jl index 338fc89cf..8e31fd43e 100644 --- a/src/pullback.jl +++ b/src/pullback.jl @@ -1,28 +1,3 @@ -@doc raw""" - AbstractPullback{NNLT<:NetworkLoss} - -`AbstractPullback` is an `abstract type` that encompasses all ways of performing differentiation (especially computing the gradient with respect to neural network parameters) in `GeometricMachineLearning`. - -If a user wants to implement a custom `Pullback` the following two functions have to be extended: -```julia -(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT}) -(_pullback::AbstractPullback)(ps, model, input_nt::QPT) -``` -based on the `loss::NetworkLoss` that's stored in `_pullback`. The output of _pullback needs to be a `Tuple` that contains: -1. the `loss` evaluated at `ps` and `input_nt` (or `input_nt_output_nt`), -2. the gradient of `loss` with respect to `ps` that call be called with e.g.: -```julia -_pullback(ps, model, input_nt)[2](1) # returns the gradient wrt to `ps` -``` -``\ldots`` we use this convention as it is analogous to how `Zygote` builds pullbacks. - -Also see [`ZygotePullback`](@ref). -""" -abstract type AbstractPullback{NNLT<:NetworkLoss} end - -(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT}) = error("Pullback not implemented for input-output pair!") -(_pullback::AbstractPullback)(ps, model, input_nt::QPT) = error("Pullback not implemented for single input!") - """ ZygotePullback <: AbstractPullback