From fb2f233c5aa720c896fc443f6d6e1955d255c579 Mon Sep 17 00:00:00 2001 From: fpacaud Date: Mon, 4 Nov 2019 09:52:56 +0100 Subject: [PATCH 1/5] clean getters for problem's variables in callbacks --- src/KNITRO.jl | 2 +- src/kn_attributes.jl | 10 +++-- src/kn_callbacks.jl | 99 ++++++++++++++------------------------------ 3 files changed, 38 insertions(+), 73 deletions(-) diff --git a/src/KNITRO.jl b/src/KNITRO.jl index 2d04680..d89c11e 100644 --- a/src/KNITRO.jl +++ b/src/KNITRO.jl @@ -17,7 +17,7 @@ module KNITRO f = Base.Meta.quot(Symbol("KTR_$(func)")) args = [esc(a) for a in args] quote - ccall(($f,libknitro), $(args...)) + ccall(($f, libknitro), $(args...)) end end "Load KNITRO version number via KTR API." diff --git a/src/kn_attributes.jl b/src/kn_attributes.jl index d8d7b7a..edb4e34 100644 --- a/src/kn_attributes.jl +++ b/src/kn_attributes.jl @@ -103,21 +103,23 @@ end ################################################## # Generic getters ################################################## -function KN_get_number_vars(m::Model) +function KN_get_number_vars(kc::Ptr{Cvoid}) num_vars = Cint[0] ret = @kn_ccall(get_number_vars, Cint, (Ptr{Cvoid}, Ptr{Cint}), - m.env, num_vars) + kc, num_vars) _checkraise(ret) return num_vars[1] end +KN_get_number_vars(model::Model) = KN_get_number_vars(model.env.ptr_env) -function KN_get_number_cons(m::Model) +function KN_get_number_cons(kc::Ptr{Cvoid}) num_cons = Cint[0] ret = @kn_ccall(get_number_cons, Cint, (Ptr{Cvoid}, Ptr{Cint}), - m.env, num_cons) + kc, num_cons) _checkraise(ret) return num_cons[1] end +KN_get_number_cons(model::Model) = KN_get_number_cons(model.env.ptr_env) function KN_get_obj_value(m::Model) obj = Cdouble[0] diff --git a/src/kn_callbacks.jl b/src/kn_callbacks.jl index 9b37336..d0aa4cf 100644 --- a/src/kn_callbacks.jl +++ b/src/kn_callbacks.jl @@ -1,6 +1,5 @@ # Callbacks utilities. - ################################################## # Utils ################################################## @@ -8,8 +7,6 @@ function KN_set_cb_user_params(m::Model, cb::CallbackContext, userParams=nothing if userParams != nothing cb.userparams[:data] = userParams end - # Store the model inside userParams - cb.userparams[:model] = m # Store callback context inside KNITRO user data. c_userdata = cb @@ -78,60 +75,26 @@ end #-------------------- # callback context getters #-------------------- -# TODO: dry this code with a macro. -function KN_get_cb_number_cons(m::Model, cb::Ptr{Cvoid}) - num = Cint[0] - ret = @kn_ccall(get_cb_number_cons, - Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}), - m.env, cb, num) - _checkraise(ret) - return num[1] -end - -function KN_get_cb_objgrad_nnz(m::Model, cb::Ptr{Cvoid}) - num = Cint[0] - ret = @kn_ccall(get_cb_objgrad_nnz, - Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}), - m.env, cb, num) - _checkraise(ret) - return num[1] -end - -function KN_get_cb_jacobian_nnz(m::Model, cb::Ptr{Cvoid}) - num = KNLONG[0] - ret = @kn_ccall(get_cb_jacobian_nnz, - Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}), - m.env, cb, num) - _checkraise(ret) - return num[1] -end - -function KN_get_cb_hessian_nnz(m::Model, cb::Ptr{Cvoid}) - num = KNLONG[0] - ret = @kn_ccall(get_cb_hessian_nnz, - Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}), - m.env, cb, num) - _checkraise(ret) - return num[1] -end - -function KN_get_cb_number_rsds(m::Model, cb::Ptr{Cvoid}) - num = Cint[0] - ret = @kn_ccall(get_cb_number_rsds, - Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}), - m.env, cb, num) - _checkraise(ret) - return num[1] +macro callback_getter(function_name, return_type) + fname = Symbol("KN_" * string(function_name)) + quote + function $(esc(fname))(kc::Ptr{Cvoid}, cb::Ptr{Cvoid}) + result = zeros($return_type, 1) + ret = @kn_ccall($function_name, Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}), + kc, cb, result) + _checkraise(ret) + return result[1] + end + end end -function KN_get_cb_rsd_jacobian_nnz(m::Model, cb::Ptr{Cvoid}) - num = KNLONG[0] - ret = @kn_ccall(get_cb_rsd_jacobian_nnz, - Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cint}), - m.env, cb, num) - _checkraise(ret) - return num[1] -end +@callback_getter get_cb_number_cons Cint +@callback_getter get_cb_objgrad_nnz Cint +@callback_getter get_cb_jacobian_nnz KNLONG +@callback_getter get_cb_hessian_nnz KNLONG +@callback_getter get_cb_number_rsds Cint +@callback_getter get_cb_rsd_jacobian_nnz KNLONG ################################################## @@ -175,9 +138,9 @@ mutable struct EvalRequest end # Import low level request to Julia object. -function EvalRequest(m::Model, evalRequest_::KN_eval_request) - nx = KN_get_number_vars(m) - nc = KN_get_number_cons(m) +function EvalRequest(ptr_model::Ptr{Cvoid}, evalRequest_::KN_eval_request) + nx = KN_get_number_vars(ptr_model) + nc = KN_get_number_cons(ptr_model) # Import objective's scaling. sigma = (evalRequest_.sigma != C_NULL) ? unsafe_wrap(Array, evalRequest_.sigma, 1)[1] : 1. @@ -243,16 +206,16 @@ mutable struct EvalResult rsdJac::Array{Float64} end -function EvalResult(m::Model, cb::Ptr{Cvoid}, evalResult_::KN_eval_result) +function EvalResult(kc::Ptr{Cvoid}, cb::Ptr{Cvoid}, evalResult_::KN_eval_result) return EvalResult( unsafe_wrap(Array, evalResult_.obj, 1), - unsafe_wrap(Array, evalResult_.c, KN_get_cb_number_cons(m, cb)), - unsafe_wrap(Array, evalResult_.objGrad, KN_get_cb_objgrad_nnz(m, cb)), - unsafe_wrap(Array, evalResult_.jac, KN_get_cb_jacobian_nnz(m, cb)), - unsafe_wrap(Array, evalResult_.hess, KN_get_cb_hessian_nnz(m, cb)), - unsafe_wrap(Array, evalResult_.hessVec, KN_get_number_vars(m)), - unsafe_wrap(Array, evalResult_.rsd, KN_get_cb_number_rsds(m, cb)), - unsafe_wrap(Array, evalResult_.rsdJac, KN_get_cb_rsd_jacobian_nnz(m, cb)) + unsafe_wrap(Array, evalResult_.c, KN_get_cb_number_cons(kc, cb)), + unsafe_wrap(Array, evalResult_.objGrad, KN_get_cb_objgrad_nnz(kc, cb)), + unsafe_wrap(Array, evalResult_.jac, KN_get_cb_jacobian_nnz(kc, cb)), + unsafe_wrap(Array, evalResult_.hess, KN_get_cb_hessian_nnz(kc, cb)), + unsafe_wrap(Array, evalResult_.hessVec, KN_get_number_vars(kc)), + unsafe_wrap(Array, evalResult_.rsd, KN_get_cb_number_rsds(kc, cb)), + unsafe_wrap(Array, evalResult_.rsdJac, KN_get_cb_rsd_jacobian_nnz(kc, cb)) ) end @@ -277,8 +240,8 @@ macro wrap_function(wrap_name, name) if !isa(cb, CallbackContext) return Cint(KN_RC_CALLBACK_ERR) end - request = EvalRequest(cb.userparams[:model], evalRequest) - result = EvalResult(cb.userparams[:model], ptr_cb, evalResult) + request = EvalRequest(ptr_model, evalRequest) + result = EvalResult(ptr_model, ptr_cb, evalResult) res = cb.$name(ptr_model, ptr_cb, request, result, cb.userparams) return Cint(res) catch ex From ad49471c13cef4575a31f3d45f510bfa0266167e Mon Sep 17 00:00:00 2001 From: fpacaud Date: Mon, 4 Nov 2019 10:43:44 +0100 Subject: [PATCH 2/5] clean definition of user params in callbacks --- src/kn_callbacks.jl | 10 +++++----- src/kn_model.jl | 11 ++++++++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/kn_callbacks.jl b/src/kn_callbacks.jl index d0aa4cf..802205e 100644 --- a/src/kn_callbacks.jl +++ b/src/kn_callbacks.jl @@ -5,7 +5,7 @@ ################################################## function KN_set_cb_user_params(m::Model, cb::CallbackContext, userParams=nothing) if userParams != nothing - cb.userparams[:data] = userParams + cb.userparams = userParams end # Store callback context inside KNITRO user data. c_userdata = cb @@ -227,11 +227,11 @@ macro wrap_function(wrap_name, name) userdata_::Ptr{Cvoid}) try # Load evalRequest object. - ptr0 = Ptr{KN_eval_request}(evalRequest_) - evalRequest = unsafe_load(ptr0)::KN_eval_request + ptr_request = Ptr{KN_eval_request}(evalRequest_) + evalRequest = unsafe_load(ptr_request)::KN_eval_request # Load evalResult object. - ptr = Ptr{KN_eval_result}(evalResults_) - evalResult = unsafe_load(ptr)::KN_eval_result + ptr_result = Ptr{KN_eval_result}(evalResults_) + evalResult = unsafe_load(ptr_result)::KN_eval_result # Eventually, load callback context. cb = unsafe_pointer_to_objref(userdata_) # Ensure that cb is a CallbackContext. diff --git a/src/kn_model.jl b/src/kn_model.jl index 50bf554..4ef7bc4 100644 --- a/src/kn_model.jl +++ b/src/kn_model.jl @@ -1,5 +1,10 @@ # Knitro model. +struct UserParams{T} + params::T +end +UserParams() = UserParams(nothing) + """ Structure specifying the callback context. @@ -10,7 +15,7 @@ is attached to a unique callback context. mutable struct CallbackContext context::Ptr{Cvoid} # Add a dictionnary to store user params. - userparams::Dict + userparams::UserParams # Oracle's callbacks are context dependent, so store # them inside dedicated CallbackContext. @@ -20,8 +25,8 @@ mutable struct CallbackContext eval_rsd::Function eval_jac_rsd::Function - function CallbackContext(ptr::Ptr{Cvoid}) - return new(ptr, Dict()) + function CallbackContext(ptr_cb::Ptr{Cvoid}) + return new(ptr_cb, UserParams()) end end From 896cdb558e003a034acc4979acbce0c9a9a6f948 Mon Sep 17 00:00:00 2001 From: fpacaud Date: Tue, 3 Dec 2019 16:11:06 +0100 Subject: [PATCH 3/5] add missing setters for primal/dual initial values --- src/kn_variables.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/kn_variables.jl b/src/kn_variables.jl index be15e7a..15d0c19 100644 --- a/src/kn_variables.jl +++ b/src/kn_variables.jl @@ -269,7 +269,12 @@ function KN_set_var_primal_init_values(m::Model, xinitval::Vector{Cdouble}) m.env, xinitval) _checkraise(ret) end - +function KN_set_var_primal_init_values(m::Model, xindex::Vector{Cint}, xinitval::Vector{Cdouble}) + ret = @kn_ccall(set_var_primal_init_values, Cint, + (Ptr{Cvoid}, Cint, Ptr{Cint}, Ptr{Cdouble}), + m.env, length(xindex), xindex, xinitval) + _checkraise(ret) +end function KN_set_var_primal_init_values(m::Model, indx::Integer, xinitval::Cdouble) ret = @kn_ccall(set_var_primal_init_value, Cint, (Ptr{Cvoid}, Cint, Cdouble), m.env, indx, xinitval) @@ -283,7 +288,12 @@ function KN_set_var_dual_init_values(m::Model, xinitval::Vector{Cdouble}) m.env, xinitval) _checkraise(ret) end - +function KN_set_var_dual_init_values(m::Model, cindex::Vector{Cint}, initval::Vector{Cdouble}) + ret = @kn_ccall(set_var_dual_init_values, Cint, + (Ptr{Cvoid}, Ptr{Cint}, Ptr{Cdouble}), + m.env, cindex, initval) + _checkraise(ret) +end function KN_set_var_dual_init_values(m::Model, indx::Integer, xinitval::Cdouble) ret = @kn_ccall(set_var_dual_init_value, Cint, (Ptr{Cvoid}, Cint, Cdouble), m.env, indx, xinitval) From afc708f337d591f30e556cef531be3e57b182c93 Mon Sep 17 00:00:00 2001 From: fpacaud Date: Tue, 3 Dec 2019 17:02:06 +0100 Subject: [PATCH 4/5] fix setting of userparam --- examples/multipleCB.jl | 10 ++--- src/kn_callbacks.jl | 88 +++++++++++++++++++++++++++--------------- src/kn_model.jl | 8 +--- test/knitroapi.jl | 9 +++++ 4 files changed, 72 insertions(+), 43 deletions(-) diff --git a/examples/multipleCB.jl b/examples/multipleCB.jl index a46ca31..2999962 100644 --- a/examples/multipleCB.jl +++ b/examples/multipleCB.jl @@ -30,7 +30,7 @@ using KNITRO, Test # The signature of this function matches KNITRO.KN_eval_callback in knitro.h. # Only "obj" is set in the KNITRO.KN_eval_result structure. function callbackEvalObj(kc, cb, evalRequest, evalResult, userParams) - xind = userParams[:data] + xind = userParams if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALFC println("*** callbackEvalObj incorrectly called with eval type ", evalRequest.evalRequestCode) @@ -47,7 +47,7 @@ end # The signature of this function matches KNITRO.KN_eval_callback in knitro.h. # Only "c0" is set in the KNITRO.KN_eval_result structure. function callbackEvalC0(kc, cb, evalRequest, evalResult, userParams) - xind = userParams[:data] + xind = userParams if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALFC println("*** callbackEvalC0 incorrectly called with eval type ", evalRequest.evalRequestCode) @@ -64,7 +64,7 @@ end # The signature of this function matches KNITRO.KN_eval_callback in knitro.h. # Only "c1" is set in the KNITRO.KN_eval_result structure. function callbackEvalC1(kc, cb, evalRequest, evalResult, userParams) - xind = userParams[:data] + xind = userParams if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALFC println("*** callbackEvalC1 incorrectly called with eval type %d" % evalRequest.evalRequestCode) @@ -85,7 +85,7 @@ end # The signature of this function matches KNITRO.KN_eval_callback in knitro.h. # Only "objGrad" is set in the KNITRO.KN_eval_result structure. function callbackEvalObjGrad(kc, cb, evalRequest, evalResult, userParams) - xind = userParams[:data] + xind = userParams if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALGA println("*** callbackEvalObjGrad incorrectly called with eval type %d" % evalRequest.evalRequestCode) @@ -105,7 +105,7 @@ end # The signature of this function matches KNITRO.KN_eval_callback in knitro.h. # Only gradient of c0 is set in the KNITRO.KN_eval_result structure. function callbackEvalC0Grad(kc, cb, evalRequest, evalResult, userParams) - xind = userParams[:data] + xind = userParams if evalRequest.evalRequestCode != KNITRO.KN_RC_EVALGA println("*** callbackEvalC0Grad incorrectly called with eval type ", evalRequest.evalRequestCode) diff --git a/src/kn_callbacks.jl b/src/kn_callbacks.jl index 802205e..da01add 100644 --- a/src/kn_callbacks.jl +++ b/src/kn_callbacks.jl @@ -600,17 +600,22 @@ function newpt_wrapper(ptr_model::Ptr{Cvoid}, ptr_x::Ptr{Cdouble}, ptr_lambda::Ptr{Cdouble}, userdata_::Ptr{Cvoid}) - # Load KNITRO's Julia Model. - m = unsafe_pointer_to_objref(userdata_)::Model - nx = KN_get_number_vars(m) - nc = KN_get_number_cons(m) - - x = unsafe_wrap(Array, ptr_x, nx) - lambda = unsafe_wrap(Array, ptr_lambda, nx + nc) - ret = m.user_callback(ptr_model, x, lambda, m) - - return Cint(ret) + try + m = unsafe_pointer_to_objref(userdata_)::Model + nx = KN_get_number_vars(m) + nc = KN_get_number_cons(m) + x = unsafe_wrap(Array, ptr_x, nx) + lambda = unsafe_wrap(Array, ptr_lambda, nx + nc) + ret = m.user_callback(ptr_model, x, lambda, m) + return Cint(ret) + catch ex + if isa(ex, InterruptException) + return Cint(KN_RC_USER_TERMINATION) + else + rethrow(ex) + end + end end """ @@ -660,15 +665,21 @@ function ms_process_wrapper(ptr_model::Ptr{Cvoid}, userdata_::Ptr{Cvoid}) # Load KNITRO's Julia Model. - m = unsafe_pointer_to_objref(userdata_)::Model - nx = KN_get_number_vars(m) - nc = KN_get_number_cons(m) - - x = unsafe_wrap(Array, ptr_x, nx) - lambda = unsafe_wrap(Array, ptr_lambda, nx + nc) - res = m.ms_process(ptr_model, x, lambda, m) - - return Cint(res) + try + m = unsafe_pointer_to_objref(userdata_)::Model + nx = KN_get_number_vars(m) + nc = KN_get_number_cons(m) + x = unsafe_wrap(Array, ptr_x, nx) + lambda = unsafe_wrap(Array, ptr_lambda, nx + nc) + res = m.ms_process(ptr_model, x, lambda, m) + return Cint(res) + catch ex + if isa(ex, InterruptException) + return Cint(KN_RC_USER_TERMINATION) + else + rethrow(ex) + end + end end """ @@ -712,15 +723,21 @@ function mip_node_callback_wrapper(ptr_model::Ptr{Cvoid}, ptr_lambda::Ptr{Cdouble}, userdata_::Ptr{Cvoid}) # Load KNITRO's Julia Model. - m = unsafe_pointer_to_objref(userdata_)::Model - nx = KN_get_number_vars(m) - nc = KN_get_number_cons(m) - - x = unsafe_wrap(Array, ptr_x, nx) - lambda = unsafe_wrap(Array, ptr_lambda, nx + nc) - res = m.mip_callback(ptr_model, x, lambda, m) - - return Cint(res) + try + m = unsafe_pointer_to_objref(userdata_)::Model + nx = KN_get_number_vars(m) + nc = KN_get_number_cons(m) + x = unsafe_wrap(Array, ptr_x, nx) + lambda = unsafe_wrap(Array, ptr_lambda, nx + nc) + res = m.mip_callback(ptr_model, x, lambda, m) + return Cint(res) + catch ex + if isa(ex, InterruptException) + return Cint(KN_RC_USER_TERMINATION) + else + rethrow(ex) + end + end end """ @@ -814,10 +831,17 @@ end function puts_callback_wrapper(str::Ptr{Cchar}, userdata_::Ptr{Cvoid}) # Load KNITRO's Julia Model. - m = unsafe_pointer_to_objref(userdata_)::Model - res = m.puts_callback(unsafe_string(str), m.userdata) - - return Cint(res) + try + m = unsafe_pointer_to_objref(userdata_)::Model + res = m.puts_callback(unsafe_string(str), m.userdata) + return Cint(res) + catch ex + if isa(ex, InterruptException) + return Cint(KN_RC_USER_TERMINATION) + else + rethrow(ex) + end + end end """ diff --git a/src/kn_model.jl b/src/kn_model.jl index 4ef7bc4..e67f516 100644 --- a/src/kn_model.jl +++ b/src/kn_model.jl @@ -1,9 +1,5 @@ # Knitro model. -struct UserParams{T} - params::T -end -UserParams() = UserParams(nothing) """ Structure specifying the callback context. @@ -15,7 +11,7 @@ is attached to a unique callback context. mutable struct CallbackContext context::Ptr{Cvoid} # Add a dictionnary to store user params. - userparams::UserParams + userparams # Oracle's callbacks are context dependent, so store # them inside dedicated CallbackContext. @@ -26,7 +22,7 @@ mutable struct CallbackContext eval_jac_rsd::Function function CallbackContext(ptr_cb::Ptr{Cvoid}) - return new(ptr_cb, UserParams()) + return new(ptr_cb, nothing) end end diff --git a/test/knitroapi.jl b/test/knitroapi.jl index 301f8ad..9d5af66 100644 --- a/test/knitroapi.jl +++ b/test/knitroapi.jl @@ -495,6 +495,8 @@ end @testset "Fifth problem test" begin + # Test in this environment the setting of user params + myParams = "stringUserParam" kc = KNITRO.KN_new() # START: Some specific parameter settings @@ -504,6 +506,9 @@ end function evalR(kc, cb, evalRequest, evalResult, userParams) x = evalRequest.x + # Each time we call this callback, the userParams should + # be as specified. + @test userParams == myParams evalResult.rsd[1] = x[1] * 1.309^x[2] - 2.138 evalResult.rsd[2] = x[1] * 1.471^x[2] - 3.421 evalResult.rsd[3] = x[1] * 1.49^x[2] - 3.597 @@ -515,6 +520,9 @@ end function evalJ(kc, cb, evalRequest, evalResult, userParams) x = evalRequest.x + # Each time we call this callback, the userParams should + # be as specified. + @test userParams == myParams evalResult.rsdJac[1] = 1.309^x[2] evalResult.rsdJac[2] = x[1] * log(1.309) * 1.309^x[2] evalResult.rsdJac[3] = 1.471^x[2] @@ -546,6 +554,7 @@ end KNITRO.KN_set_cb_rsd_jac(kc, cb, nnzJ, evalJ, jacIndexRsds=Int32[ 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5 ], jacIndexVars=Int32[ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1 ]) + KNITRO.KN_set_cb_user_params(kc, cb, myParams) # Solve the problem. status = KNITRO.KN_solve(kc) From d37e0d5e52f0fd6303ea626a8cc4826becb16d55 Mon Sep 17 00:00:00 2001 From: fpacaud Date: Mon, 16 Dec 2019 09:57:53 +0100 Subject: [PATCH 5/5] fix bug in branch&bound in MINLPTests --- src/kn_callbacks.jl | 40 +++++++++++++++++++++++++--------------- src/kn_model.jl | 4 +++- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/kn_callbacks.jl b/src/kn_callbacks.jl index da01add..77a0dee 100644 --- a/src/kn_callbacks.jl +++ b/src/kn_callbacks.jl @@ -1,5 +1,18 @@ # Callbacks utilities. +# Note: we store here the number of constraints and variables defined +# in the original Knitro model. We cannot retrieve these numbers during +# callbacks invokation, as sometimes the number of variables and constraints +# change internally in Knitro (e.g. when cuts are added when resolving +# Branch&Bound). We prefer to use the number of variables and constraints +# of the original model so that user's callbacks could consider that +# the arrays of primal variable x and dual variable \lambda have fixed +# sizes. +function link!(cb::CallbackContext, model::Model) + cb.n = KN_get_number_vars(model) + cb.m = KN_get_number_cons(model) +end + ################################################## # Utils ################################################## @@ -7,12 +20,12 @@ function KN_set_cb_user_params(m::Model, cb::CallbackContext, userParams=nothing if userParams != nothing cb.userparams = userParams end + # Link current callback context with Knitro model + link!(cb, m) # Store callback context inside KNITRO user data. - c_userdata = cb - ret = @kn_ccall(set_cb_user_params, Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Any), - m.env, cb, c_userdata) + m.env, cb, cb) _checkraise(ret) return nothing end @@ -138,19 +151,16 @@ mutable struct EvalRequest end # Import low level request to Julia object. -function EvalRequest(ptr_model::Ptr{Cvoid}, evalRequest_::KN_eval_request) - nx = KN_get_number_vars(ptr_model) - nc = KN_get_number_cons(ptr_model) - +function EvalRequest(ptr_model::Ptr{Cvoid}, evalRequest_::KN_eval_request, n::Int, m::Int) # Import objective's scaling. sigma = (evalRequest_.sigma != C_NULL) ? unsafe_wrap(Array, evalRequest_.sigma, 1)[1] : 1. # Wrap directly C arrays to avoid unnecessary copy. return EvalRequest(evalRequest_.evalRequestCode, evalRequest_.threadID, - unsafe_wrap(Array, evalRequest_.x, nx), - unsafe_wrap(Array, evalRequest_.lambda, nx + nc), + unsafe_wrap(Array, evalRequest_.x, n), + unsafe_wrap(Array, evalRequest_.lambda, n + m), sigma, - unsafe_wrap(Array, evalRequest_.vec, nx)) + unsafe_wrap(Array, evalRequest_.vec, n)) end @@ -206,14 +216,14 @@ mutable struct EvalResult rsdJac::Array{Float64} end -function EvalResult(kc::Ptr{Cvoid}, cb::Ptr{Cvoid}, evalResult_::KN_eval_result) +function EvalResult(kc::Ptr{Cvoid}, cb::Ptr{Cvoid}, evalResult_::KN_eval_result, n::Int, m::Int) return EvalResult( unsafe_wrap(Array, evalResult_.obj, 1), - unsafe_wrap(Array, evalResult_.c, KN_get_cb_number_cons(kc, cb)), + unsafe_wrap(Array, evalResult_.c, m), unsafe_wrap(Array, evalResult_.objGrad, KN_get_cb_objgrad_nnz(kc, cb)), unsafe_wrap(Array, evalResult_.jac, KN_get_cb_jacobian_nnz(kc, cb)), unsafe_wrap(Array, evalResult_.hess, KN_get_cb_hessian_nnz(kc, cb)), - unsafe_wrap(Array, evalResult_.hessVec, KN_get_number_vars(kc)), + unsafe_wrap(Array, evalResult_.hessVec, n), unsafe_wrap(Array, evalResult_.rsd, KN_get_cb_number_rsds(kc, cb)), unsafe_wrap(Array, evalResult_.rsdJac, KN_get_cb_rsd_jacobian_nnz(kc, cb)) ) @@ -240,8 +250,8 @@ macro wrap_function(wrap_name, name) if !isa(cb, CallbackContext) return Cint(KN_RC_CALLBACK_ERR) end - request = EvalRequest(ptr_model, evalRequest) - result = EvalResult(ptr_model, ptr_cb, evalResult) + request = EvalRequest(ptr_model, evalRequest, cb.n, cb.m) + result = EvalResult(ptr_model, ptr_cb, evalResult, cb.n, cb.m) res = cb.$name(ptr_model, ptr_cb, request, result, cb.userparams) return Cint(res) catch ex diff --git a/src/kn_model.jl b/src/kn_model.jl index e67f516..5834b0b 100644 --- a/src/kn_model.jl +++ b/src/kn_model.jl @@ -10,6 +10,8 @@ is attached to a unique callback context. """ mutable struct CallbackContext context::Ptr{Cvoid} + n::Int + m::Int # Add a dictionnary to store user params. userparams @@ -22,7 +24,7 @@ mutable struct CallbackContext eval_jac_rsd::Function function CallbackContext(ptr_cb::Ptr{Cvoid}) - return new(ptr_cb, nothing) + return new(ptr_cb, 0, 0, nothing) end end