Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Apr 4, 2024
1 parent 6d2476d commit 9a28998
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions docs/src/training/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ See the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for det

```@docs
Flux.params
Flux.Optimise.update!(opt::Flux.Optimise.AbstractOptimiser, xs::AbstractArray, gs)
Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb)
Flux.update!(opt::Flux.Optimise.AbstractOptimiser, xs::AbstractArray, gs)
Flux.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb)
```

## Callbacks
Expand Down
10 changes: 5 additions & 5 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3)
0.011100831f0
```
"""
function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ))
_check_sizes(ŷ, y)
agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 )
agg((log.((ŷ .+ eps) ./ (y .+ eps))) .^2 )
end

function _huber_metric(abs_error, δ)
Expand Down Expand Up @@ -228,7 +228,7 @@ julia> Flux.crossentropy(y_model, y_smooth)
1.5776052f0
```
"""
function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ))
_check_sizes(ŷ, y)
agg(.-sum(xlogy.(y, ŷ .+ eps); dims = dims))
end
Expand Down Expand Up @@ -607,8 +607,8 @@ true
See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels
"""
function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ=nothing, γ=nothing)
γ = gamma_temp isa Integer ? gamma : ofeltype(ŷ, gamma)
function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ))
γ = gamma isa Integer ? gamma : ofeltype(ŷ, gamma)
_check_sizes(ŷ, y)
ŷϵ =.+ eps
agg(sum(@. -y * (1 - ŷϵ)^γ * log(ŷϵ); dims))
Expand Down

0 comments on commit 9a28998

Please sign in to comment.