From 9a28998fc81ff294909286ef8d7c2ef0cb6a6efb Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 26 Mar 2024 12:07:42 +0100 Subject: [PATCH] more fixes --- docs/src/training/reference.md | 4 ++-- src/losses/functions.jl | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index 1bf0cfd1bf..595310e944 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -60,8 +60,8 @@ See the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for det ```@docs Flux.params -Flux.Optimise.update!(opt::Flux.Optimise.AbstractOptimiser, xs::AbstractArray, gs) -Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb) +Flux.update!(opt::Flux.Optimise.AbstractOptimiser, xs::AbstractArray, gs) +Flux.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb) ``` ## Callbacks diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 9e6264ced8..f84ca22186 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -66,9 +66,9 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3) 0.011100831f0 ``` """ -function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) +function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) - agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 ) + agg((log.((ŷ .+ eps) ./ (y .+ eps))) .^2 ) end function _huber_metric(abs_error, δ) @@ -228,7 +228,7 @@ julia> Flux.crossentropy(y_model, y_smooth) 1.5776052f0 ``` """ -function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) +function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) agg(.-sum(xlogy.(y, ŷ .+ eps); dims = dims)) end @@ -607,8 +607,8 @@ true See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ -function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ=nothing, γ=nothing) - γ = gamma_temp isa Integer ? gamma : ofeltype(ŷ, gamma) +function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ)) + γ = gamma isa Integer ? gamma : ofeltype(ŷ, gamma) _check_sizes(ŷ, y) ŷϵ = ŷ .+ eps agg(sum(@. -y * (1 - ŷϵ)^γ * log(ŷϵ); dims))