Skip to content

Commit

Permalink
Add GeometricMachineLearning for docstrings.
Browse files Browse the repository at this point in the history
Qualified package the routines are coming from.

Added docstests for GML.

Removed GML specification again.

Removed GML again (couldn't make it work.
  • Loading branch information
benedict-96 committed Nov 26, 2024
1 parent 2a43d7d commit cafd97f
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@ If you want to implement `CustomLoss <: NetworkLoss` you need to define a functo
```
where `model` is an instance of an `AbstractExplicitLayer` or a `Chain` and `ps` the parameters.
See [`FeedForwardLoss`](@ref), [`TransformerLoss`](@ref), [`AutoEncoderLoss`](@ref) and [`ReducedLoss`](@ref) for examples.
See [`FeedForwardLoss`](@ref), `GeometricMachineLearning.TransformerLoss`, `GeometricMachineLearning.AutoEncoderLoss` and `GeometricMachineLearning.ReducedLoss` for examples.
"""
abstract type NetworkLoss end

function apply_toNT(fun, ps::NamedTuple...)
for p in ps
@assert keys(ps[1]) == keys(p)
end
NamedTuple{keys(ps[1])}(fun(p...) for p in zip(ps...))

Check warning on line 19 in src/losses.jl

View check run for this annotation

Codecov / codecov/patch

src/losses.jl#L15-L19

Added lines #L15 - L19 were not covered by tests
end

# overload norm
_norm(dx::NT) where {AT <: AbstractArray, NT <: NamedTuple{(:q, :p), Tuple{AT, AT}}} = (norm(dx.q) + norm(dx.p)) / 2 # we need this because of a Zygote problem
_norm(dx::NamedTuple) = sum(apply_toNT(norm, dx)) / length(dx)
Expand Down

0 comments on commit cafd97f

Please sign in to comment.