Skip to content

Commit

Permalink
Simplify and test Base.copy(::Opt) (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Aug 20, 2024
1 parent 8a486d2 commit d110e6d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 40 deletions.
66 changes: 26 additions & 40 deletions src/NLopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,55 +149,41 @@ algorithm(o::Opt)::Algorithm = nlopt_get_algorithm(o)
Base.show(io::IO, o::Opt) = print(io, "Opt($(algorithm(o)), $(ndims(o)))")

############################################################################
# copying is a little tricky because we have to tell NLopt to use
# new Callback_Data.
# copying is a little tricky because we have to tell NLopt to use new
# Callback_Data.

# callback wrapper for nlopt_munge_data in NLopt 2.4
function munge_callback(p::Ptr{Cvoid}, f_::Ptr{Cvoid})
f = unsafe_pointer_to_objref(f_)::Function
return f(p)::Ptr{Cvoid}
function munge_callback(p::Ptr{Cvoid}, p_user_data::Ptr{Cvoid})
old_to_new_pointer_map =
unsafe_pointer_to_objref(p_user_data)::Dict{Ptr{Cvoid},Ptr{Cvoid}}
return old_to_new_pointer_map[p]
end

function Base.copy(o::Opt)
p = nlopt_copy(o)
function Base.copy(opt::Opt)
p = nlopt_copy(opt)
if p == C_NULL
error("Error in nlopt_copy")
end
n = Opt(p)
cb = getfield(o, :cb)
ncb = similar(cb)
setfield!(n, :cb, ncb)
for i in 1:length(cb)
try
ncb[i] = Callback_Data(cb[i].f, n)
catch e
# if objective has not been set, cb[1] will throw
# an UndefRefError, which is okay.
if i != 1 || !isa(e, UndefRefError)
rethrow(e) # some not-okay exception
end
end
end
# n.o, for each callback, stores a pointer to an element of o.cb,
# and we need to convert this into a pointer to the corresponding
# element of n.cb. nlopt_munge_data allows us to call a function
# to transform each stored pointer in n.o, and we use the cbi
# dictionary to convert pointers to indices into o.cb, whence
# we obtain the corresponding element of n.cb.
cbi = Dict{Ptr{Cvoid},Int}()
for i in 1:length(cb)
try
cbi[pointer_from_objref(cb[i])] = i
catch
new_opt = Opt(p)
opt_callbacks = getfield(opt, :cb)
new_callbacks = Vector{Callback_Data}(undef, length(opt_callbacks))
setfield!(new_opt, :cb, new_callbacks)
old_to_new_pointer_map = Dict{Ptr{Cvoid},Ptr{Cvoid}}(C_NULL => C_NULL)
for i in 1:length(opt_callbacks)
if isassigned(opt_callbacks, i)
new_callbacks[i] = Callback_Data(opt_callbacks[i].f, new_opt)
old_to_new_pointer_map[pointer_from_objref(opt_callbacks[i])] =
pointer_from_objref(new_callbacks[i])
end
end
# nlopt_munge_data is a routine that allows us to convert all pointers to
# existing Callback_Data objects into pointers for the corresponding object
# in new_callbacks.
c_fn = @cfunction(munge_callback, Ptr{Cvoid}, (Ptr{Cvoid}, Ptr{Cvoid}))
nlopt_munge_data(
n,
c_fn,
p -> p == C_NULL ? C_NULL : pointer_from_objref(ncb[cbi[p]]),
)
return n
GC.@preserve old_to_new_pointer_map begin
p_old_to_new_pointer_map = pointer_from_objref(old_to_new_pointer_map)
nlopt_munge_data(new_opt, c_fn, p_old_to_new_pointer_map)
end
return new_opt
end

############################################################################
Expand Down
36 changes: 36 additions & 0 deletions test/C_API.jl
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,42 @@ function test_initial_step()
return
end

function test_copy()
function my_objective_fn(x::Vector, grad::Vector)
if length(grad) > 0
grad[1] = 0
grad[2] = 0.5 / sqrt(x[2])
end
return sqrt(x[2])
end
function my_constraint_fn(x::Vector, grad::Vector, a, b)
if length(grad) > 0
grad[1] = 3 * a * (a * x[1] + b)^2
grad[2] = -1
end
return (a * x[1] + b)^3 - x[2]
end
opt = Opt(:LD_MMA, 2)
lower_bounds!(opt, [-Inf, 0.0])
xtol_rel!(opt, 1e-4)
min_objective!(opt, my_objective_fn)
inequality_constraint!(opt, (x, g) -> my_constraint_fn(x, g, 2, 0), 1e-8)
inequality_constraint!(opt, (x, g) -> my_constraint_fn(x, g, -1, 1), 1e-8)
opt_2 = copy(opt)
min_f, min_x, ret = optimize(opt_2, [1.234, 5.678])
@test min_f 0.5443310477213124
@test min_x [0.3333333342139688, 0.29629628951338166]
@test ret == :XTOL_REACHED
return
end

function test_copy_failure()
opt = Opt(:LD_MMA, 2)
setfield!(opt, :opt, C_NULL)
@test_throws ErrorException("Error in nlopt_copy") copy(opt)
return
end

end # module

TestCAPI.runtests()

0 comments on commit d110e6d

Please sign in to comment.