Skip to content

Commit

Permalink
Merge pull request #138 from JuliaOpt/fp/clean_callbacks
Browse files Browse the repository at this point in the history
Clean wrapping of Knitro callbacks in Julia
  • Loading branch information
frapac authored Dec 17, 2019
2 parents ecffc05 + d37e0d5 commit 1e299a4
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 126 deletions.
10 changes: 5 additions & 5 deletions examples/multipleCB.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ using KNITRO, Test
# The signature of this function matches KNITRO.KN_eval_callback in knitro.h.
# Only "obj" is set in the KNITRO.KN_eval_result structure.
function callbackEvalObj(kc, cb, evalRequest, evalResult, userParams)
xind = userParams[:data]
xind = userParams

if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALFC
println("*** callbackEvalObj incorrectly called with eval type ", evalRequest.evalRequestCode)
Expand All @@ -47,7 +47,7 @@ end
# The signature of this function matches KNITRO.KN_eval_callback in knitro.h.
# Only "c0" is set in the KNITRO.KN_eval_result structure.
function callbackEvalC0(kc, cb, evalRequest, evalResult, userParams)
xind = userParams[:data]
xind = userParams

if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALFC
println("*** callbackEvalC0 incorrectly called with eval type ", evalRequest.evalRequestCode)
Expand All @@ -64,7 +64,7 @@ end
# The signature of this function matches KNITRO.KN_eval_callback in knitro.h.
# Only "c1" is set in the KNITRO.KN_eval_result structure.
function callbackEvalC1(kc, cb, evalRequest, evalResult, userParams)
xind = userParams[:data]
xind = userParams

if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALFC
println("*** callbackEvalC1 incorrectly called with eval type %d" % evalRequest.evalRequestCode)
Expand All @@ -85,7 +85,7 @@ end
# The signature of this function matches KNITRO.KN_eval_callback in knitro.h.
# Only "objGrad" is set in the KNITRO.KN_eval_result structure.
function callbackEvalObjGrad(kc, cb, evalRequest, evalResult, userParams)
xind = userParams[:data]
xind = userParams

if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALGA
println("*** callbackEvalObjGrad incorrectly called with eval type %d" % evalRequest.evalRequestCode)
Expand All @@ -105,7 +105,7 @@ end
# The signature of this function matches KNITRO.KN_eval_callback in knitro.h.
# Only gradient of c0 is set in the KNITRO.KN_eval_result structure.
function callbackEvalC0Grad(kc, cb, evalRequest, evalResult, userParams)
xind = userParams[:data]
xind = userParams

if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALGA
println("*** callbackEvalC0Grad incorrectly called with eval type ", evalRequest.evalRequestCode)
Expand Down
2 changes: 1 addition & 1 deletion src/KNITRO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ module KNITRO
f = Base.Meta.quot(Symbol("KTR_$(func)"))
args = [esc(a) for a in args]
quote
ccall(($f,libknitro), $(args...))
ccall(($f, libknitro), $(args...))
end
end
"Load KNITRO version number via KTR API."
Expand Down
10 changes: 6 additions & 4 deletions src/kn_attributes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,23 @@ end
##################################################
# Generic getters
##################################################
function KN_get_number_vars(m::Model)
function KN_get_number_vars(kc::Ptr{Cvoid})
num_vars = Cint[0]
ret = @kn_ccall(get_number_vars, Cint, (Ptr{Cvoid}, Ptr{Cint}),
m.env, num_vars)
kc, num_vars)
_checkraise(ret)
return num_vars[1]
end
KN_get_number_vars(model::Model) = KN_get_number_vars(model.env.ptr_env)

function KN_get_number_cons(m::Model)
function KN_get_number_cons(kc::Ptr{Cvoid})
num_cons = Cint[0]
ret = @kn_ccall(get_number_cons, Cint, (Ptr{Cvoid}, Ptr{Cint}),
m.env, num_cons)
kc, num_cons)
_checkraise(ret)
return num_cons[1]
end
KN_get_number_cons(model::Model) = KN_get_number_cons(model.env.ptr_env)

function KN_get_obj_value(m::Model)
obj = Cdouble[0]
Expand Down
219 changes: 108 additions & 111 deletions src/kn_callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
# Callbacks utilities.

# Note: we store here the number of constraints and variables defined
# in the original Knitro model. We cannot retrieve these numbers during
# callbacks invokation, as sometimes the number of variables and constraints
# change internally in Knitro (e.g. when cuts are added when resolving
# Branch&Bound). We prefer to use the number of variables and constraints
# of the original model so that user's callbacks could consider that
# the arrays of primal variable x and dual variable \lambda have fixed
# sizes.
function link!(cb::CallbackContext, model::Model)
cb.n = KN_get_number_vars(model)
cb.m = KN_get_number_cons(model)
end

##################################################
# Utils
##################################################
function KN_set_cb_user_params(m::Model, cb::CallbackContext, userParams=nothing)
if userParams != nothing
cb.userparams[:data] = userParams
cb.userparams = userParams
end
# Store the model inside userParams
cb.userparams[:model] = m
# Link current callback context with Knitro model
link!(cb, m)
# Store callback context inside KNITRO user data.
c_userdata = cb

ret = @kn_ccall(set_cb_user_params, Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Any),
m.env, cb, c_userdata)
m.env, cb, cb)
_checkraise(ret)
return nothing
end
Expand Down Expand Up @@ -78,60 +88,26 @@ end
#--------------------
# callback context getters
#--------------------
# TODO: dry this code with a macro.
function KN_get_cb_number_cons(m::Model, cb::Ptr{Cvoid})
num = Cint[0]
ret = @kn_ccall(get_cb_number_cons,
Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}),
m.env, cb, num)
_checkraise(ret)
return num[1]
end

function KN_get_cb_objgrad_nnz(m::Model, cb::Ptr{Cvoid})
num = Cint[0]
ret = @kn_ccall(get_cb_objgrad_nnz,
Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}),
m.env, cb, num)
_checkraise(ret)
return num[1]
end

function KN_get_cb_jacobian_nnz(m::Model, cb::Ptr{Cvoid})
num = KNLONG[0]
ret = @kn_ccall(get_cb_jacobian_nnz,
Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}),
m.env, cb, num)
_checkraise(ret)
return num[1]
end

function KN_get_cb_hessian_nnz(m::Model, cb::Ptr{Cvoid})
num = KNLONG[0]
ret = @kn_ccall(get_cb_hessian_nnz,
Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}),
m.env, cb, num)
_checkraise(ret)
return num[1]
end

function KN_get_cb_number_rsds(m::Model, cb::Ptr{Cvoid})
num = Cint[0]
ret = @kn_ccall(get_cb_number_rsds,
Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}),
m.env, cb, num)
_checkraise(ret)
return num[1]
macro callback_getter(function_name, return_type)
fname = Symbol("KN_" * string(function_name))
quote
function $(esc(fname))(kc::Ptr{Cvoid}, cb::Ptr{Cvoid})
result = zeros($return_type, 1)
ret = @kn_ccall($function_name, Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}),
kc, cb, result)
_checkraise(ret)
return result[1]
end
end
end

function KN_get_cb_rsd_jacobian_nnz(m::Model, cb::Ptr{Cvoid})
num = KNLONG[0]
ret = @kn_ccall(get_cb_rsd_jacobian_nnz,
Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}),
m.env, cb, num)
_checkraise(ret)
return num[1]
end
@callback_getter get_cb_number_cons Cint
@callback_getter get_cb_objgrad_nnz Cint
@callback_getter get_cb_jacobian_nnz KNLONG
@callback_getter get_cb_hessian_nnz KNLONG
@callback_getter get_cb_number_rsds Cint
@callback_getter get_cb_rsd_jacobian_nnz KNLONG


##################################################
Expand Down Expand Up @@ -175,19 +151,16 @@ mutable struct EvalRequest
end

# Import low level request to Julia object.
function EvalRequest(m::Model, evalRequest_::KN_eval_request)
nx = KN_get_number_vars(m)
nc = KN_get_number_cons(m)

function EvalRequest(ptr_model::Ptr{Cvoid}, evalRequest_::KN_eval_request, n::Int, m::Int)
# Import objective's scaling.
sigma = (evalRequest_.sigma != C_NULL) ? unsafe_wrap(Array, evalRequest_.sigma, 1)[1] : 1.
# Wrap directly C arrays to avoid unnecessary copy.
return EvalRequest(evalRequest_.evalRequestCode,
evalRequest_.threadID,
unsafe_wrap(Array, evalRequest_.x, nx),
unsafe_wrap(Array, evalRequest_.lambda, nx + nc),
unsafe_wrap(Array, evalRequest_.x, n),
unsafe_wrap(Array, evalRequest_.lambda, n + m),
sigma,
unsafe_wrap(Array, evalRequest_.vec, nx))
unsafe_wrap(Array, evalRequest_.vec, n))
end


Expand Down Expand Up @@ -243,16 +216,16 @@ mutable struct EvalResult
rsdJac::Array{Float64}
end

function EvalResult(m::Model, cb::Ptr{Cvoid}, evalResult_::KN_eval_result)
function EvalResult(kc::Ptr{Cvoid}, cb::Ptr{Cvoid}, evalResult_::KN_eval_result, n::Int, m::Int)
return EvalResult(
unsafe_wrap(Array, evalResult_.obj, 1),
unsafe_wrap(Array, evalResult_.c, KN_get_cb_number_cons(m, cb)),
unsafe_wrap(Array, evalResult_.objGrad, KN_get_cb_objgrad_nnz(m, cb)),
unsafe_wrap(Array, evalResult_.jac, KN_get_cb_jacobian_nnz(m, cb)),
unsafe_wrap(Array, evalResult_.hess, KN_get_cb_hessian_nnz(m, cb)),
unsafe_wrap(Array, evalResult_.hessVec, KN_get_number_vars(m)),
unsafe_wrap(Array, evalResult_.rsd, KN_get_cb_number_rsds(m, cb)),
unsafe_wrap(Array, evalResult_.rsdJac, KN_get_cb_rsd_jacobian_nnz(m, cb))
unsafe_wrap(Array, evalResult_.c, m),
unsafe_wrap(Array, evalResult_.objGrad, KN_get_cb_objgrad_nnz(kc, cb)),
unsafe_wrap(Array, evalResult_.jac, KN_get_cb_jacobian_nnz(kc, cb)),
unsafe_wrap(Array, evalResult_.hess, KN_get_cb_hessian_nnz(kc, cb)),
unsafe_wrap(Array, evalResult_.hessVec, n),
unsafe_wrap(Array, evalResult_.rsd, KN_get_cb_number_rsds(kc, cb)),
unsafe_wrap(Array, evalResult_.rsdJac, KN_get_cb_rsd_jacobian_nnz(kc, cb))
)
end

Expand All @@ -264,11 +237,11 @@ macro wrap_function(wrap_name, name)
userdata_::Ptr{Cvoid})
try
# Load evalRequest object.
ptr0 = Ptr{KN_eval_request}(evalRequest_)
evalRequest = unsafe_load(ptr0)::KN_eval_request
ptr_request = Ptr{KN_eval_request}(evalRequest_)
evalRequest = unsafe_load(ptr_request)::KN_eval_request
# Load evalResult object.
ptr = Ptr{KN_eval_result}(evalResults_)
evalResult = unsafe_load(ptr)::KN_eval_result
ptr_result = Ptr{KN_eval_result}(evalResults_)
evalResult = unsafe_load(ptr_result)::KN_eval_result
# Eventually, load callback context.
cb = unsafe_pointer_to_objref(userdata_)
# Ensure that cb is a CallbackContext.
Expand All @@ -277,8 +250,8 @@ macro wrap_function(wrap_name, name)
if !isa(cb, CallbackContext)
return Cint(KN_RC_CALLBACK_ERR)
end
request = EvalRequest(cb.userparams[:model], evalRequest)
result = EvalResult(cb.userparams[:model], ptr_cb, evalResult)
request = EvalRequest(ptr_model, evalRequest, cb.n, cb.m)
result = EvalResult(ptr_model, ptr_cb, evalResult, cb.n, cb.m)
res = cb.$name(ptr_model, ptr_cb, request, result, cb.userparams)
return Cint(res)
catch ex
Expand Down Expand Up @@ -637,17 +610,22 @@ function newpt_wrapper(ptr_model::Ptr{Cvoid},
ptr_x::Ptr{Cdouble},
ptr_lambda::Ptr{Cdouble},
userdata_::Ptr{Cvoid})

# Load KNITRO's Julia Model.
m = unsafe_pointer_to_objref(userdata_)::Model
nx = KN_get_number_vars(m)
nc = KN_get_number_cons(m)

x = unsafe_wrap(Array, ptr_x, nx)
lambda = unsafe_wrap(Array, ptr_lambda, nx + nc)
ret = m.user_callback(ptr_model, x, lambda, m)

return Cint(ret)
try
m = unsafe_pointer_to_objref(userdata_)::Model
nx = KN_get_number_vars(m)
nc = KN_get_number_cons(m)
x = unsafe_wrap(Array, ptr_x, nx)
lambda = unsafe_wrap(Array, ptr_lambda, nx + nc)
ret = m.user_callback(ptr_model, x, lambda, m)
return Cint(ret)
catch ex
if isa(ex, InterruptException)
return Cint(KN_RC_USER_TERMINATION)
else
rethrow(ex)
end
end
end

"""
Expand Down Expand Up @@ -697,15 +675,21 @@ function ms_process_wrapper(ptr_model::Ptr{Cvoid},
userdata_::Ptr{Cvoid})

# Load KNITRO's Julia Model.
m = unsafe_pointer_to_objref(userdata_)::Model
nx = KN_get_number_vars(m)
nc = KN_get_number_cons(m)

x = unsafe_wrap(Array, ptr_x, nx)
lambda = unsafe_wrap(Array, ptr_lambda, nx + nc)
res = m.ms_process(ptr_model, x, lambda, m)

return Cint(res)
try
m = unsafe_pointer_to_objref(userdata_)::Model
nx = KN_get_number_vars(m)
nc = KN_get_number_cons(m)
x = unsafe_wrap(Array, ptr_x, nx)
lambda = unsafe_wrap(Array, ptr_lambda, nx + nc)
res = m.ms_process(ptr_model, x, lambda, m)
return Cint(res)
catch ex
if isa(ex, InterruptException)
return Cint(KN_RC_USER_TERMINATION)
else
rethrow(ex)
end
end
end

"""
Expand Down Expand Up @@ -749,15 +733,21 @@ function mip_node_callback_wrapper(ptr_model::Ptr{Cvoid},
ptr_lambda::Ptr{Cdouble},
userdata_::Ptr{Cvoid})
# Load KNITRO's Julia Model.
m = unsafe_pointer_to_objref(userdata_)::Model
nx = KN_get_number_vars(m)
nc = KN_get_number_cons(m)

x = unsafe_wrap(Array, ptr_x, nx)
lambda = unsafe_wrap(Array, ptr_lambda, nx + nc)
res = m.mip_callback(ptr_model, x, lambda, m)

return Cint(res)
try
m = unsafe_pointer_to_objref(userdata_)::Model
nx = KN_get_number_vars(m)
nc = KN_get_number_cons(m)
x = unsafe_wrap(Array, ptr_x, nx)
lambda = unsafe_wrap(Array, ptr_lambda, nx + nc)
res = m.mip_callback(ptr_model, x, lambda, m)
return Cint(res)
catch ex
if isa(ex, InterruptException)
return Cint(KN_RC_USER_TERMINATION)
else
rethrow(ex)
end
end
end

"""
Expand Down Expand Up @@ -851,10 +841,17 @@ end
function puts_callback_wrapper(str::Ptr{Cchar}, userdata_::Ptr{Cvoid})

# Load KNITRO's Julia Model.
m = unsafe_pointer_to_objref(userdata_)::Model
res = m.puts_callback(unsafe_string(str), m.userdata)

return Cint(res)
try
m = unsafe_pointer_to_objref(userdata_)::Model
res = m.puts_callback(unsafe_string(str), m.userdata)
return Cint(res)
catch ex
if isa(ex, InterruptException)
return Cint(KN_RC_USER_TERMINATION)
else
rethrow(ex)
end
end
end

"""
Expand Down
Loading

0 comments on commit 1e299a4

Please sign in to comment.