Skip to content

Commit

Permalink
Tidy C API to use @CCall (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Apr 29, 2024
1 parent 2486535 commit d91e9b1
Showing 1 changed file with 51 additions and 91 deletions.
142 changes: 51 additions & 91 deletions src/C_API.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@ const INFINITY = 1e20
###

function c_api_License_SetString(license::String)
return ccall(
(:License_SetString, PATH_SOLVER),
Cint,
(Ptr{Cchar},),
license,
)
return @ccall PATH_SOLVER.License_SetString(license::Ptr{Cchar})::Cint
end

###
Expand Down Expand Up @@ -69,11 +64,8 @@ function OutputInterface(output_data)
end

function c_api_Output_SetInterface(o::OutputInterface)
return ccall(
(:Output_SetInterface, PATH_SOLVER),
Cvoid,
(Ref{OutputInterface},),
o,
return @ccall(
PATH_SOLVER.Output_SetInterface(o::Ref{OutputInterface})::Cvoid,
)
end

Expand All @@ -94,34 +86,29 @@ Base.cconvert(::Type{Ptr{Cvoid}}, x::Options) = x
Base.unsafe_convert(::Type{Ptr{Cvoid}}, x::Options) = x.ptr

function c_api_Options_Create()
ptr = ccall((:Options_Create, PATH_SOLVER), Ptr{Cvoid}, ())
return Options(ptr)
return Options(@ccall PATH_SOLVER.Options_Create()::Ptr{Cvoid})
end

function c_api_Options_Destroy(o::Options)
return ccall((:Options_Destroy, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), o)
return @ccall PATH_SOLVER.Options_Destroy(o::Ptr{Cvoid})::Cvoid
end

function c_api_Options_Default(o::Options)
return ccall((:Options_Default, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), o)
return @ccall PATH_SOLVER.Options_Default(o::Ptr{Cvoid})::Cvoid
end

function c_api_Options_Display(o::Options)
return ccall((:Options_Display, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), o)
return @ccall PATH_SOLVER.Options_Display(o::Ptr{Cvoid})::Cvoid
end

function c_api_Options_Read(o::Options, filename::String)
return ccall(
(:Options_Read, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Ptr{Cchar}),
o,
filename,
return @ccall(
PATH_SOLVER.Options_Read(o::Ptr{Cvoid}, filename::Ptr{Cchar})::Cvoid,
)
end

function c_api_Path_AddOptions(o::Options)
return ccall((:Path_AddOptions, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), o)
return @ccall PATH_SOLVER.Path_AddOptions(o::Ptr{Cvoid})::Cvoid
end

###
Expand All @@ -137,8 +124,7 @@ end

function _c_jac_typ(data_ptr::Ptr{Cvoid}, nnz::Cint, typ_ptr::Ptr{Cint})
data = unsafe_pointer_to_objref(data_ptr)::PresolveData
typ = unsafe_wrap(Array{Cint}, typ_ptr, nnz)
data.jac_typ(nnz, typ)
data.jac_typ(nnz, unsafe_wrap(Array{Cint}, typ_ptr, nnz))
return
end

Expand Down Expand Up @@ -189,10 +175,8 @@ function _c_problem_size(
nnz_ptr::Ptr{Cint},
)
id_data = unsafe_pointer_to_objref(id_ptr)::InterfaceData
n = unsafe_wrap(Array{Cint}, n_ptr, 1)
n[1] = id_data.n
nnz = unsafe_wrap(Array{Cint}, nnz_ptr, 1)
nnz[1] = id_data.nnz
unsafe_store!(n_ptr, id_data.n)
unsafe_store!(nnz_ptr, id_data.nnz)
return
end

Expand All @@ -204,14 +188,9 @@ function _c_bounds(
ub_ptr::Ptr{Cdouble},
)
id_data = unsafe_pointer_to_objref(id_ptr)::InterfaceData
z = unsafe_wrap(Array{Cdouble}, z_ptr, n)
lb = unsafe_wrap(Array{Cdouble}, lb_ptr, n)
ub = unsafe_wrap(Array{Cdouble}, ub_ptr, n)
for i in 1:n
z[i] = id_data.z[i]
lb[i] = id_data.lb[i]
ub[i] = id_data.ub[i]
end
copy!(unsafe_wrap(Array{Cdouble}, z_ptr, n), id_data.z)
copy!(unsafe_wrap(Array{Cdouble}, lb_ptr, n), id_data.lb)
copy!(unsafe_wrap(Array{Cdouble}, ub_ptr, n), id_data.ub)
return
end

Expand All @@ -224,8 +203,7 @@ function _c_function_evaluation(
id_data = unsafe_pointer_to_objref(id_ptr)::InterfaceData
x = unsafe_wrap(Array{Cdouble}, x_ptr, n)
f = unsafe_wrap(Array{Cdouble}, f_ptr, n)
err = id_data.F(n, x, f)
return err
return id_data.F(n, x, f)
end

function _c_jacobian_evaluation(
Expand All @@ -247,13 +225,13 @@ function _c_jacobian_evaluation(
f = unsafe_wrap(Array{Cdouble}, f_ptr, n)
err += id_data.F(n, x, f)
end
nnz = unsafe_wrap(Array{Cint}, nnz_ptr, 1)
nnz = unsafe_load(nnz_ptr)::Cint
col = unsafe_wrap(Array{Cint}, col_ptr, n)
len = unsafe_wrap(Array{Cint}, len_ptr, n)
row = unsafe_wrap(Array{Cint}, row_ptr, nnz[1])
data = unsafe_wrap(Array{Cdouble}, data_ptr, nnz[1])
err += id_data.J(n, nnz[1], x, col, len, row, data)
nnz[1] = sum(len)
row = unsafe_wrap(Array{Cint}, row_ptr, nnz)
data = unsafe_wrap(Array{Cdouble}, data_ptr, nnz)
err += id_data.J(n, nnz, x, col, len, row, data)
unsafe_store!(nnz_ptr, Cint(sum(len)))
return err
end

Expand Down Expand Up @@ -401,64 +379,52 @@ Base.cconvert(::Type{Ptr{Cvoid}}, x::MCP) = x
Base.unsafe_convert(::Type{Ptr{Cvoid}}, x::MCP) = x.ptr

function c_api_MCP_Create(n::Int, nnz::Int)
ptr = ccall((:MCP_Create, PATH_SOLVER), Ptr{Cvoid}, (Cint, Cint), n, nnz)
ptr = @ccall PATH_SOLVER.MCP_Create(n::Cint, nnz::Cint)::Ptr{Cvoid}
return MCP(n, ptr)
end

function c_api_MCP_Jacobian_Structure_Constant(m::MCP, flag::Bool)
ccall(
(:MCP_Jacobian_Structure_Constant, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Cint),
m,
flag,
)
@ccall PATH_SOLVER.MCP_Jacobian_Structure_Constant(
m::Ptr{Cvoid},
flag::Cint,
)::Cvoid
return
end

function c_api_MCP_Jacobian_Data_Contiguous(m::MCP, flag::Bool)
ccall(
(:MCP_Jacobian_Data_Contiguous, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Cint),
m,
flag,
)
@ccall PATH_SOLVER.MCP_Jacobian_Data_Contiguous(
m::Ptr{Cvoid},
flag::Cint,
)::Cvoid
return
end

function c_api_MCP_Destroy(m::MCP)
if m.ptr === C_NULL
return
end
ccall((:MCP_Destroy, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), m)
@ccall PATH_SOLVER.MCP_Destroy(m::Ptr{Cvoid})::Cvoid
return
end

function c_api_MCP_SetInterface(m::MCP, interface::MCP_Interface)
ccall(
(:MCP_SetInterface, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Ref{MCP_Interface}),
m,
interface,
)
@ccall PATH_SOLVER.MCP_SetInterface(
m::Ptr{Cvoid},
interface::Ref{MCP_Interface},
)::Cvoid
return
end

function c_api_MCP_SetPresolveInterface(m::MCP, interface::Presolve_Interface)
ccall(
(:MCP_SetPresolveInterface, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Ref{Presolve_Interface}),
m,
interface,
)
@ccall PATH_SOLVER.MCP_SetPresolveInterface(
m::Ptr{Cvoid},
interface::Ref{Presolve_Interface},
)::Cvoid
return
end

function c_api_MCP_GetX(m::MCP)
ptr = ccall((:MCP_GetX, PATH_SOLVER), Ptr{Cdouble}, (Ptr{Cvoid},), m)
ptr = @ccall PATH_SOLVER.MCP_GetX(m::Ptr{Cvoid})::Ptr{Cdouble}
return copy(unsafe_wrap(Array{Cdouble}, ptr, m.n))
end

Expand Down Expand Up @@ -580,7 +546,7 @@ Check that the current license (stored in the environment variable
Returns a nonzero value on successful completion, and a zero value on failure.
"""
function c_api_Path_CheckLicense(n::Int, nnz::Int)
return ccall((:Path_CheckLicense, PATH_SOLVER), Cint, (Cint, Cint), n, nnz)
return @ccall PATH_SOLVER.Path_CheckLicense(n::Cint, nnz::Cint)::Cint
end

"""
Expand All @@ -589,8 +555,7 @@ end
Return a string of the PATH version.
"""
function c_api_Path_Version()
ptr = ccall((:Path_Version, PATH_SOLVER), Ptr{Cchar}, ())
return unsafe_string(ptr)
return unsafe_string(@ccall PATH_SOLVER.Path_Version()::Ptr{Cchar})
end

"""
Expand All @@ -599,12 +564,8 @@ end
Returns a MCP_Termination status.
"""
function c_api_Path_Solve(m::MCP, info::Information)
return ccall(
(:Path_Solve, PATH_SOLVER),
Cint,
(Ptr{Cvoid}, Ref{Information}),
m,
info,
return @ccall(
PATH_SOLVER.Path_Solve(m::Ptr{Cvoid}, info::Ref{Information})::Cint,
)
end

Expand Down Expand Up @@ -812,12 +773,11 @@ function solve_mcp(
gc_root[m_interface] = true
c_api_MCP_SetInterface(m, m_interface)
if jacobian_structure_constant && !isempty(jacobian_linear_elements)
presolve_data = PresolveData() do nnz, types
for i in jacobian_linear_elements
types[i] = PRESOLVE_LINEAR
end
function presolve_fn(::Cint, types::Vector{Cint})
types[jacobian_linear_elements] .= PRESOLVE_LINEAR
return
end
presolve_data = PresolveData(presolve_fn)
# We shouldn't GC presolve_data until we exit the GC.@preserve block.
gc_root[presolve_data] = true
presolve_interface = Presolve_Interface(presolve_data)
Expand Down Expand Up @@ -863,7 +823,7 @@ function _linear_function(M::AbstractMatrix, q::Vector)
elseif size(M, 1) != length(q)
error("q is wrong shape. Expected $(size(M, 1)), got $(length(q)).")
end
return (n::Cint, x::Vector{Cdouble}, f::Vector{Cdouble}) -> begin
return function F(n::Cint, x::Vector{Cdouble}, f::Vector{Cdouble})
f .= M * x .+ q
return Cint(0)
end
Expand All @@ -872,15 +832,15 @@ end
function _linear_jacobian(M::SparseArrays.SparseMatrixCSC{Cdouble,Cint})
# Size is checked with error message in _linear_function.
@assert size(M, 1) == size(M, 2)
return (
return function J(
n::Cint,
nnz::Cint,
x::Vector{Cdouble},
col::Vector{Cint},
len::Vector{Cint},
row::Vector{Cint},
data::Vector{Cdouble},
) -> begin
)
@assert n == length(x) == length(col) == length(len) == size(M, 1)
@assert nnz == length(row) == length(data)
@assert nnz >= SparseArrays.nnz(M)
Expand Down

0 comments on commit d91e9b1

Please sign in to comment.