Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory-safety and allocations fixes for Handle{T} #446

Merged
merged 1 commit into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 35 additions & 39 deletions src/handle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
Manages automatic destruction of the referenced objects when it is
no longer in use.
"""
struct Handle{T <: AbstractSundialsObject} <: SundialsHandle
ptr_ref::Ref{Ptr{T}} # pointer to a pointer
mutable struct Handle{T <: AbstractSundialsObject} <: SundialsHandle
ptr::Ptr{T}

function Handle(ptr::Ptr{T}) where {T <: AbstractSundialsObject}
h = new{T}(Ref{Ptr{T}}(ptr))
finalizer(release_handle, h.ptr_ref)
h = new{T}(ptr)
finalizer(release_handle, h)
return h
end
end
Expand Down Expand Up @@ -81,42 +81,38 @@
end
end

Base.unsafe_convert(::Type{Ptr{T}}, h::Handle{T}) where {T} = h.ptr_ref[]
Base.unsafe_convert(::Type{Ptr{Cvoid}}, h::Handle{T}) where {T} = Ptr{Cvoid}(h.ptr_ref[])
Base.convert(::Type{Ptr{T}}, h::Handle{T}) where {T} = h.ptr_ref[]
function Base.convert(::Type{Ptr{Ptr{T}}}, h::Handle{T}) where {T}
convert(Ptr{Ptr{T}},
h.ptr_ref[])
end
"""
Base.cconvert(::Type{Ptr{T}}, h::Handle{T}) -> h
Base.unsafe_convert(::Type{Ptr{T}}, h::Handle{T}) -> Ptr{T}

Convert h::Handle{T} to Ptr{T}, for use by ccall

function release_handle(ptr_ref::Ref{Ptr{T}}) where {T}
throw(MethodError("Freeing objects of type $T not supported"))
end
function release_handle(ptr_ref::Ref{Ptr{KINMem}})
((ptr_ref[] != C_NULL) && KINFree(ptr_ref); nothing)
end
function release_handle(ptr_ref::Ref{Ptr{CVODEMem}})
((ptr_ref[] != C_NULL) && CVodeFree(ptr_ref); nothing)
end
function release_handle(ptr_ref::Ref{Ptr{ARKStepMem}})
((ptr_ref[] != C_NULL) &&
ARKStepFree(ptr_ref);
nothing)
end
function release_handle(ptr_ref::Ref{Ptr{ERKStepMem}})
((ptr_ref[] != C_NULL) &&
ERKStepFree(ptr_ref);
nothing)
end
function release_handle(ptr_ref::Ref{Ptr{MRIStepMem}})
((ptr_ref[] != C_NULL) &&
MRIStepFree(ptr_ref);
nothing)
end
function release_handle(ptr_ref::Ref{Ptr{IDAMem}})
((ptr_ref[] != C_NULL) && IDAFree(ptr_ref); nothing)
Conversion happens in two steps within ccall:
- cconvert returns h, which is preserved (by ccall) from garbage collection
- unsafe_convert to get the pointer from h
"""
Base.cconvert(::Type{Ptr{T}}, h::Handle{T}) where {T} = h
Base.unsafe_convert(::Type{Ptr{T}}, h::Handle{T}) where {T} = h.ptr

# Use the supplied Sundials sun_free_func to free h.ptr
# NB: CVodeFree and similar require a C pointer-to-pointer
function _release_handle(sun_free_func, h::Handle{T}) where {T}
if h.ptr != C_NULL
ptr_ref = Ref(h.ptr)
h.ptr = C_NULL
sun_free_func(ptr_ref)
end

return nothing
end

release_handle(h::Handle{KINMem}) = _release_handle(KINFree, h)
release_handle(h::Handle{CVODEMem}) = _release_handle(CVodeFree, h)
release_handle(h::Handle{ARKStepMem}) = _release_handle(ARKStepFree, h)
release_handle(h::Handle{ERKStepMem}) = _release_handle(ERKStepFree, h)
release_handle(h::Handle{MRIStepMem}) = _release_handle(MRIStepFree, h)

Check warning on line 113 in src/handle.jl

View check run for this annotation

Codecov / codecov/patch

src/handle.jl#L112-L113

Added lines #L112 - L113 were not covered by tests
release_handle(h::Handle{IDAMem}) = _release_handle(IDAFree, h)

function release_handle(h::MatrixHandle{DenseMatrix})
if !isempty(h)
Sundials.SUNMatDestroy_Dense(h.ptr)
Expand Down Expand Up @@ -238,8 +234,8 @@
Base.empty!(h::LinSolHandle) = release_handle(h)
Base.empty!(h::NonLinSolHandle) = release_handle(h)
Base.empty!(h::MatrixHandle) = release_handle(h)
Base.empty!(h::Handle{T}) where {T} = release_handle(h.ptr_ref)
Base.isempty(h::Handle) = h.ptr_ref[] == C_NULL
Base.empty!(h::Handle) = release_handle(h)
Base.isempty(h::Handle) = (h.ptr == C_NULL)
Base.isempty(h::MatrixHandle) = h.destroyed
Base.isempty(h::LinSolHandle) = h.destroyed
Base.isempty(h::NonLinSolHandle) = h.destroyed
Expand Down
2 changes: 1 addition & 1 deletion test/handle_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ empty!(h1)

empty!(h1) # Make sure this does not throw

h = Sundials.Handle(h1.ptr_ref[]) # Check construction with null pointers
h = Sundials.Handle(h1.ptr) # Check construction with null pointers
@test isempty(h)

neq = 3
Expand Down
Loading