From 97a5703ba341c06af7f3ea6bbb77b209cd21e7b7 Mon Sep 17 00:00:00 2001 From: Anirudh Date: Mon, 17 Apr 2023 11:40:54 -0700 Subject: [PATCH] fix: force exact numeric sum - simplex projection Credit @hopeyen Signed-off-by: Anirudh --- src/project.jl | 16 +++++++++++----- test/project.jl | 5 +++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/project.jl b/src/project.jl index 192badc..64af5eb 100644 --- a/src/project.jl +++ b/src/project.jl @@ -9,13 +9,19 @@ export σsimplex, gssp Project `x` onto the `σ`-simplex. In other words, project `x`s to be non-negative and sum to `σ`. + +This operation is precision-senstive, so we conver the data to bigfloat within the function. +We then convert back to `T` before returning. """ function σsimplex(x::AbstractVector{T}, σ::Real) where {T<:Real} - n = length(x) - μ = sort(x; rev=true) - ρ = maximum((1:n)[μ - (cumsum(μ) .- σ) ./ (1:n) .> zero(T)]) - θ = (sum(μ[1:ρ]) - σ) / ρ - w = max.(x .- θ, zero(T)) + _x = convert(Vector{BigFloat}, x) + _σ = convert(BigFloat, σ) + n = length(_x) + μ = sort(_x; rev=true) + ρ = maximum((1:n)[μ-(cumsum(μ).-_σ)./(1:n).>zero(BigFloat)]) + θ = (sum(μ[1:ρ]) - _σ) / ρ + _w = max.(_x .- θ, zero(BigFloat)) + w = convert(Vector{T}, _w) return w end diff --git a/test/project.jl b/test/project.jl index f2735a5..7677ac8 100644 --- a/test/project.jl +++ b/test/project.jl @@ -14,6 +14,11 @@ x = Float64[-1, 0, 0] σ = 5 @test σsimplex(x, σ) ≈ [1, 2, 2] # Scale up + + # Credit @hopeyen + x = Float64[23133337391432116] + σ = 652174.7265297174 + @test sum(σsimplex(x, σ)) ≈ σ # exact end @testset "gssp" begin