Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added AbstractPullback and NetworkLoss. #17

Merged
merged 8 commits into from
Nov 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/AbstractNeuralNetworks.jl
Original file line number Diff line number Diff line change
@@ -31,7 +31,6 @@ module AbstractNeuralNetworks

include("model.jl")


export Dense, Linear

include("layers/abstract.jl")
@@ -56,4 +55,11 @@ module AbstractNeuralNetworks

include("neural_network.jl")

include("losses.jl")

export NetworkLoss, FeedForwardLoss

include("pullback.jl")

export AbstractPullback
end
1 change: 0 additions & 1 deletion src/layers/linear.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

const Linear{M, N, USEBIAS} = Dense{M, N, USEBIAS, <: IdentityActivation}

Linear(m, n; kwargs...) = Dense(m, n, IdentityActivation(); kwargs...)
97 changes: 97 additions & 0 deletions src/losses.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
@doc raw"""
NetworkLoss
An abstract type for all the neural network losses.
If you want to implement `CustomLoss <: NetworkLoss` you need to define a functor:
```julia
(loss::CustomLoss)(model, ps, input, output)
```
where `model` is an instance of an `AbstractExplicitLayer` or a `Chain` and `ps` the parameters.
See [`FeedForwardLoss`](@ref), `GeometricMachineLearning.TransformerLoss`, `GeometricMachineLearning.AutoEncoderLoss` and `GeometricMachineLearning.ReducedLoss` for examples.
"""
abstract type NetworkLoss end

function apply(fun, ps::NamedTuple...)
for p in ps
@assert keys(ps[1]) == keys(p)
end
NamedTuple{keys(ps[1])}(fun(p...) for p in zip(ps...))

Check warning on line 19 in src/losses.jl

Codecov / codecov/patch

src/losses.jl#L15-L19

Added lines #L15 - L19 were not covered by tests
end

# overload norm
_norm(dx::NT) where {AT <: AbstractArray, NT <: NamedTuple{(:q, :p), Tuple{AT, AT}}} = (norm(dx.q) + norm(dx.p)) / 2 # we need this because of a Zygote problem
_norm(dx::NamedTuple) = sum(apply(norm, dx)) / length(dx)
_norm(A::AbstractArray) = norm(A)

Check warning on line 25 in src/losses.jl

Codecov / codecov/patch

src/losses.jl#L23-L25

Added lines #L23 - L25 were not covered by tests

# overloaded +/- operation
_diff(dx₁::NT, dx₂::NT) where {AT <: AbstractArray, NT <: NamedTuple{(:q, :p), Tuple{AT, AT}}} = (q = dx₁.q - dx₂.q, p = dx₁.p - dx₂.p) # we need this because of a Zygote problem
_diff(dx₁::NamedTuple, dx₂::NamedTuple) = apply(_diff, dx₁, dx₂)
_diff(A::AbstractArray, B::AbstractArray) = A - B
_add(dx₁::NamedTuple, dx₂::NamedTuple) = apply(_add, dx₁, dx₂)
_add(A::AbstractArray, B::AbstractArray) = A + B

Check warning on line 32 in src/losses.jl

Codecov / codecov/patch

src/losses.jl#L28-L32

Added lines #L28 - L32 were not covered by tests

const QPT{T} = NamedTuple{(:q, :p), Tuple{AT, AT}} where {T, AT <: AbstractArray{T}}
const QPTOAT{T} = Union{QPT{T}, AbstractArray{T}} where T

function (loss::NetworkLoss)(nn::NeuralNetwork, input::QPTOAT, output::QPTOAT)
loss(nn.model, nn.params, input, output)

Check warning on line 38 in src/losses.jl

Codecov / codecov/patch

src/losses.jl#L37-L38

Added lines #L37 - L38 were not covered by tests
end

function _compute_loss(output_prediction::QPTOAT, output::QPTOAT)
_norm(_diff(output_prediction, output)) / _norm(output)

Check warning on line 42 in src/losses.jl

Codecov / codecov/patch

src/losses.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
end

function _compute_loss(model::Union{AbstractExplicitLayer, Chain}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT)
output_prediction = model(input, ps)
_compute_loss(output_prediction, output)

Check warning on line 47 in src/losses.jl

Codecov / codecov/patch

src/losses.jl#L45-L47

Added lines #L45 - L47 were not covered by tests
end

function (loss::NetworkLoss)(model::Union{Chain, AbstractExplicitLayer}, ps::Union{NeuralNetworkParameters, NamedTuple}, input::QPTOAT, output::QPTOAT)
_compute_loss(model, ps, input, output)

Check warning on line 51 in src/losses.jl

Codecov / codecov/patch

src/losses.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
end

@doc raw"""
FeedForwardLoss()
Make an instance of a loss for feedforward neural networks.
This should be used together with a neural network of type `GeometricMachineLearning.NeuralNetworkIntegrator`.
# Example
`FeedForwardLoss` applies a neural network to an input and compares it to the `output` via an ``L_2`` norm:
```jldoctest
using AbstractNeuralNetworks
using LinearAlgebra: norm
import Random
Random.seed!(123)
const d = 2
arch = Chain(Dense(d, d), Dense(d, d))
nn = NeuralNetwork(arch)
input_vec = [1., 2.]
output_vec = [3., 4.]
loss = FeedForwardLoss()
loss(nn, input_vec, output_vec) ≈ norm(output_vec - nn(input_vec)) / norm(output_vec)
# output
true
```
So `FeedForwardLoss` simply does:
```math
\mathtt{loss}(\mathcal{NN}, \mathtt{input}, \mathtt{output}) = || \mathcal{NN}(\mathtt{input}) - \mathtt{output} || / || \mathtt{output}||,
```
where ``||\cdot||`` is the ``L_2`` norm.
# Parameters
This loss does not have any parameters.
"""
struct FeedForwardLoss <: NetworkLoss end
24 changes: 24 additions & 0 deletions src/pullback.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@doc raw"""
AbstractPullback{NNLT<:NetworkLoss}
`AbstractPullback` is an `abstract type` that encompasses all ways of performing differentiation (especially computing the gradient with respect to neural network parameters) in `GeometricMachineLearning`.
If a user wants to implement a custom `Pullback` the following two functions have to be extended:
```julia
(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT})
(_pullback::AbstractPullback)(ps, model, input_nt::QPT)
```
based on the `loss::NetworkLoss` that's stored in `_pullback`. The output of _pullback needs to be a `Tuple` that contains:
1. the `loss` evaluated at `ps` and `input_nt` (or `input_nt_output_nt`),
2. the gradient of `loss` with respect to `ps` that call be called with e.g.:
```julia
_pullback(ps, model, input_nt)[2](1) # returns the gradient wrt to `ps`
```
``\ldots`` we use this convention as it is analogous to how `Zygote` builds pullbacks.
An example is `GeometricMachineLearning.ZygotePullback`.
"""
abstract type AbstractPullback{NNLT<:NetworkLoss} end

(_pullback::AbstractPullback)(ps, model, input_nt_output_nt::Tuple{<:QPTOAT, <:QPTOAT}) = error("Pullback not implemented for input-output pair!")
(_pullback::AbstractPullback)(ps, model, input_nt::QPT) = error("Pullback not implemented for single input!")

Check warning on line 24 in src/pullback.jl

Codecov / codecov/patch

src/pullback.jl#L23-L24

Added lines #L23 - L24 were not covered by tests