Skip to content

Commit

Permalink
Merge pull request #42 from vvpisarev/catch_domain_errors_linesearch
Browse files Browse the repository at this point in the history
Catch domain errors linesearch
  • Loading branch information
stepanzh authored Nov 7, 2023
2 parents 77fd1a9 + 5d13fe9 commit c2582d9
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
24 changes: 21 additions & 3 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,36 @@ function strong_backtracking!(
nbracket_max = 200
@logmsg LSLogLevel "==BRACKETING THE MINIMUM=="
for nbracket in 1:nbracket_max
y, grad = fdf(x0, α, d)
local y, grad
while true
try
y, grad = fdf(x0, α, d)
break
catch err
if err isa DomainError
α_new =+ α_prev) / 2
if α_new == α_prev || α_new == α
@warn "==BRACKETING FAIL, LAST STEP VALUE RETURNED==" α
@logmsg LSLogLevel "==LINEAR SEARCH INTERRUPTED=="
return zero(α)
end
α = α_new
else
rethrow(err)
end
end
end
g = dot(grad, d)
Δyp = (g + g0) * α / 2 # parabolic approximation
if abs(Δyp) < ϵ
@logmsg LSLogLevel "" """
@logmsg LSLogLevel """
Δyp = $(Δyp) (*)
Δy = $(y-y0)
"""
Δy = Δyp
else
Δy = y - y0
@logmsg LSLogLevel "" """
@logmsg LSLogLevel """
Δyp = $(Δyp)
Δy = $(Δy) (*)
"""
Expand Down
1 change: 1 addition & 0 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ function solver(
maxcalls=nothing,
constrain_step=nothing,
)
x0 = argumentvec(M)
grad_tol = (isnothing(gtol) || gtol < 0) ? zero(eltype(x0)) : gtol
convcond = isnothing(convcond) ? stopbygradient(grad_tol) : convcond
iter_limit = (isnothing(maxiter) || maxiter < 0) ? typemax(Int) : convert(Int, maxiter)
Expand Down
7 changes: 7 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,12 @@ end
)
@test optimize!(rosenbrock!, opt, init_vec) isa NamedTuple
end

@testset "$(typeof(descent).name) (default parameters)" for descent in descent_methods
opt = Downhill.solver(
descent,
)
@test optimize!(rosenbrock!, opt, init_vec) isa NamedTuple
end
end
end
10 changes: 10 additions & 0 deletions test/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,13 @@ end
@test isapprox(α, α₀, rtol=0.05)
end
end

@testset "Limited domain" begin
function fdf(x, α, d)
z = x[] + α * d[]
return z - sqrt(z), [1 - 0.5 / sqrt(z)]
end

α = Downhill.strong_backtracking!(fdf, [0.5], [-1.0])
@test isapprox(α, 0.25, rtol=0.05)
end

0 comments on commit c2582d9

Please sign in to comment.