diff --git a/crates/pindakaas-derive/src/lib.rs b/crates/pindakaas-derive/src/lib.rs index 91cf66cb8e..f1744f6814 100644 --- a/crates/pindakaas-derive/src/lib.rs +++ b/crates/pindakaas-derive/src/lib.rs @@ -91,79 +91,109 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream { quote!() }; - let term_callback = if opts.term_callback { + let (term_callback, term_drop) = if opts.term_callback { let term_cb = match opts.term_callback_ident { Some(x) => quote! { self. #x }, None => quote! { self.term_cb }, }; - quote! { - impl crate::solver::TermCallback for #ident { - fn set_terminate_callback crate::solver::SlvTermSignal + 'static>( - &mut self, - cb: Option, - ) { - if let Some(mut cb) = cb { - #term_cb = crate::solver::libloading::TermCB::new(move || -> std::ffi::c_int { - match cb() { - crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0), - crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1), + ( + quote! { + impl crate::solver::TermCallback for #ident { + fn set_terminate_callback crate::solver::SlvTermSignal + 'static>( + &mut self, + cb: Option, + ) { + if let Some(mut cb) = cb { + let mut wrapped_cb = move || -> std::ffi::c_int { + match cb() { + crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0), + crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1), + } + }; + let trampoline = crate::solver::libloading::get_trampoline0(&wrapped_cb); + let layout = std::alloc::Layout::for_value(&wrapped_cb); + let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void; + if layout.size() != 0 { + // Otherwise nothing was leaked. + #term_cb = Some((data, layout)); } - }); - - unsafe { - #krate::ipasir_set_terminate( - #ptr, - #term_cb .as_ptr(), - Some(crate::solver::libloading::TermCB::exec_callback), - ) + unsafe { + #krate::ipasir_set_terminate( + #ptr, + data, + Some(trampoline), + ) + } + } else { + if let Some((ptr, layout)) = #term_cb .take() { + unsafe { std::alloc::dealloc(ptr as *mut _, layout) }; + } + unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) } } - } else { - #term_cb = Default::default(); - unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) } } } - } - } + }, + quote! { + if let Some((ptr, layout)) = #term_cb .take() { + unsafe { std::alloc::dealloc(ptr as *mut _, layout) }; + } + }, + ) } else { - quote!() + (quote!(), quote!()) }; - let learn_callback = if opts.learn_callback { + let (learn_callback, learn_drop) = if opts.learn_callback { let learn_cb = match opts.learn_callback_ident { Some(x) => quote! { self. #x }, None => quote! { self.learn_cb }, }; - quote! { - impl crate::solver::LearnCallback for #ident { - fn set_learn_callback) + 'static>( - &mut self, - cb: Option, - ) { - const MAX_LEN: std::ffi::c_int = 512; - if let Some(mut cb) = cb { - #learn_cb = crate::solver::libloading::LearnCB::new(move |clause: *const i32| { - let mut iter = crate::solver::libloading::ExplIter(clause) - .map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap())); - cb(&mut iter) - }); - - unsafe { - #krate::ipasir_set_learn( - #ptr, - #learn_cb .as_ptr(), - MAX_LEN, - Some(crate::solver::libloading::LearnCB::exec_callback), - ) + ( + quote! { + impl crate::solver::LearnCallback for #ident { + fn set_learn_callback) + 'static>( + &mut self, + cb: Option, + ) { + const MAX_LEN: std::ffi::c_int = 512; + if let Some(mut cb) = cb { + let mut wrapped_cb = move |clause: *const i32| { + let mut iter = crate::solver::libloading::ExplIter(clause) + .map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap())); + cb(&mut iter) + }; + let trampoline = crate::solver::libloading::get_trampoline1(&wrapped_cb); + let layout = std::alloc::Layout::for_value(&wrapped_cb); + let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void; + if layout.size() != 0 { + // Otherwise nothing was leaked. + #learn_cb = Some((data, layout)); + } + unsafe { + #krate::ipasir_set_learn( + #ptr, + data, + MAX_LEN, + Some(trampoline), + ) + } + } else { + if let Some((ptr, layout)) = #learn_cb .take() { + unsafe { std::alloc::dealloc(ptr as *mut _, layout) }; + } + unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) } } - } else { - #learn_cb = Default::default(); - unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) } } } - } - } + }, + quote! { + if let Some((ptr, layout)) = #learn_cb .take() { + unsafe { std::alloc::dealloc(ptr as *mut _, layout) }; + } + }, + ) } else { - quote!() + (quote!(), quote!()) }; let sol_ident = format_ident!("{}Sol", ident); @@ -356,6 +386,8 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream { quote! { impl Drop for #ident { fn drop(&mut self) { + #learn_drop + #term_drop unsafe { #krate::ipasir_release( #ptr ) } } } diff --git a/crates/pindakaas/src/solver/cadical.rs b/crates/pindakaas/src/solver/cadical.rs index bdacc3a53f..0fe369e4fa 100644 --- a/crates/pindakaas/src/solver/cadical.rs +++ b/crates/pindakaas/src/solver/cadical.rs @@ -1,25 +1,26 @@ -use std::{ffi::CString, fmt}; +use std::{ + alloc::Layout, + ffi::{c_void, CString}, + fmt, +}; use pindakaas_cadical::{ccadical_copy, ccadical_phase, ccadical_unphase}; use pindakaas_derive::IpasirSolver; use super::VarFactory; -use crate::{ - solver::libloading::{LearnCB, TermCB}, - Lit, -}; +use crate::Lit; #[derive(IpasirSolver)] #[ipasir(krate = pindakaas_cadical, assumptions, learn_callback, term_callback, ipasir_up)] pub struct Cadical { /// The raw pointer to the Cadical solver. - ptr: *mut std::ffi::c_void, + ptr: *mut c_void, /// The variable factory for this solver. vars: VarFactory, /// The callback used when a clause is learned. - learn_cb: LearnCB, + learn_cb: Option<(*mut c_void, Layout)>, /// The callback used to check whether the solver should terminate. - term_cb: TermCB, + term_cb: Option<(*mut c_void, Layout)>, #[cfg(feature = "ipasir-up")] /// The external propagator called by the solver @@ -31,8 +32,8 @@ impl Default for Cadical { Self { ptr: unsafe { pindakaas_cadical::ipasir_init() }, vars: VarFactory::default(), - learn_cb: LearnCB::default(), - term_cb: TermCB::default(), + learn_cb: None, + term_cb: None, #[cfg(feature = "ipasir-up")] prop: None, } @@ -45,8 +46,8 @@ impl Clone for Cadical { Self { ptr, vars: self.vars, - learn_cb: LearnCB::default(), - term_cb: TermCB::default(), + learn_cb: None, + term_cb: None, #[cfg(feature = "ipasir-up")] prop: None, } diff --git a/crates/pindakaas/src/solver/intel_sat.rs b/crates/pindakaas/src/solver/intel_sat.rs index 47b68b95a8..ce694dafc9 100644 --- a/crates/pindakaas/src/solver/intel_sat.rs +++ b/crates/pindakaas/src/solver/intel_sat.rs @@ -1,19 +1,20 @@ +use std::{alloc::Layout, ffi::c_void}; + use pindakaas_derive::IpasirSolver; use super::VarFactory; -use crate::solver::libloading::{LearnCB, TermCB}; #[derive(Debug, IpasirSolver)] #[ipasir(krate = pindakaas_intel_sat, assumptions, learn_callback, term_callback)] pub struct IntelSat { /// The raw pointer to the Intel SAT solver. - ptr: *mut std::ffi::c_void, + ptr: *mut c_void, /// The variable factory for this solver. vars: VarFactory, /// The callback used when a clause is learned. - learn_cb: LearnCB, + learn_cb: Option<(*mut c_void, Layout)>, /// The callback used to check whether the solver should terminate. - term_cb: TermCB, + term_cb: Option<(*mut c_void, Layout)>, } impl Default for IntelSat { @@ -21,8 +22,8 @@ impl Default for IntelSat { Self { ptr: unsafe { pindakaas_intel_sat::ipasir_init() }, vars: VarFactory::default(), - term_cb: TermCB::default(), - learn_cb: LearnCB::default(), + term_cb: None, + learn_cb: None, } } } diff --git a/crates/pindakaas/src/solver/libloading.rs b/crates/pindakaas/src/solver/libloading.rs index f8dcd9d53e..934fd6a4ae 100644 --- a/crates/pindakaas/src/solver/libloading.rs +++ b/crates/pindakaas/src/solver/libloading.rs @@ -1,11 +1,11 @@ -#[cfg(feature = "ipasir-up")] -use std::{any::Any, collections::VecDeque}; use std::{ + alloc::{self, Layout}, ffi::{c_char, c_int, c_void, CStr}, - fmt, num::NonZeroI32, ptr, }; +#[cfg(feature = "ipasir-up")] +use std::{any::Any, collections::VecDeque}; use libloading::{Library, Symbol}; @@ -87,8 +87,8 @@ impl IpasirLibrary { IpasirSolver { slv: (self.ipasir_init_sym().unwrap())(), vars: VarFactory::default(), - learn_cb: LearnCB::default(), - term_cb: TermCB::default(), + learn_cb: None, + term_cb: None, signature_fn: self.ipasir_signature_sym().unwrap(), release_fn: self.ipasir_release_sym().unwrap(), add_fn: self.ipasir_add_sym().unwrap(), @@ -129,9 +129,9 @@ pub struct IpasirSolver<'lib> { vars: VarFactory, /// The callback used when a clause is learned. - learn_cb: LearnCB, + learn_cb: Option<(*mut c_void, Layout)>, /// The callback used to check whether the solver should terminate. - term_cb: TermCB, + term_cb: Option<(*mut c_void, Layout)>, signature_fn: Symbol<'lib, extern "C" fn() -> *const c_char>, release_fn: Symbol<'lib, extern "C" fn(*mut c_void)>, @@ -157,7 +157,16 @@ pub struct IpasirSolver<'lib> { impl<'lib> Drop for IpasirSolver<'lib> { fn drop(&mut self) { - (self.release_fn)(self.slv) + // Drop the termination callback. + if let Some((ptr, layout)) = self.term_cb.take() { + unsafe { alloc::dealloc(ptr as *mut _, layout) }; + } + // Drop the learning callback. + if let Some((ptr, layout)) = self.learn_cb.take() { + unsafe { alloc::dealloc(ptr as *mut _, layout) }; + } + // Release the solver. + (self.release_fn)(self.slv); } } @@ -290,100 +299,72 @@ impl FailedAssumtions for IpasirFailed<'_> { } impl<'lib> TermCallback for IpasirSolver<'lib> { - fn set_terminate_callback SlvTermSignal + 'static>(&mut self, cb: Option) { + fn set_terminate_callback SlvTermSignal>(&mut self, cb: Option) { if let Some(mut cb) = cb { - self.term_cb = TermCB::new(move || -> c_int { + let wrapped_cb = move || -> c_int { match cb() { SlvTermSignal::Continue => c_int::from(0), SlvTermSignal::Terminate => c_int::from(1), } - }); - - (self.set_terminate_fn)(self.slv, self.term_cb.as_ptr(), Some(TermCB::exec_callback)); + }; + + let trampoline = get_trampoline0(&wrapped_cb); + let layout = Layout::for_value(&wrapped_cb); + let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void; + if layout.size() != 0 { + // Otherwise nothing was leaked. + self.term_cb = Some((data, layout)); + } + (self.set_terminate_fn)(self.slv, data, Some(trampoline)); } else { - self.term_cb = TermCB::default(); + if let Some((ptr, layout)) = self.term_cb.take() { + unsafe { alloc::dealloc(ptr as *mut _, layout) }; + } (self.set_terminate_fn)(self.slv, ptr::null_mut(), None); } } } impl<'lib> LearnCallback for IpasirSolver<'lib> { - fn set_learn_callback) + 'static>( - &mut self, - cb: Option, - ) { + fn set_learn_callback)>(&mut self, cb: Option) { const MAX_LEN: std::ffi::c_int = 512; if let Some(mut cb) = cb { - self.learn_cb = LearnCB::new(move |clause: *const i32| { + let wrapped_cb = |clause: *const i32| { let mut iter = ExplIter(clause).map(|i: i32| Lit(NonZeroI32::new(i).unwrap())); cb(&mut iter) - }); - (self.set_learn_fn)( - self.slv, - self.learn_cb.as_ptr(), - MAX_LEN, - Some(LearnCB::exec_callback), - ); + }; + + let trampoline = get_trampoline1(&wrapped_cb); + let layout = Layout::for_value(&wrapped_cb); + let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void; + if layout.size() != 0 { + // Otherwise nothing was leaked. + self.learn_cb = Some((data, layout)); + } + (self.set_learn_fn)(self.slv, data, MAX_LEN, Some(trampoline)); } else { - self.learn_cb = LearnCB::default(); + if let Some((ptr, layout)) = self.learn_cb.take() { + unsafe { alloc::dealloc(ptr as *mut _, layout) }; + } (self.set_learn_fn)(self.slv, ptr::null_mut(), MAX_LEN, None); } } } - -/// Storage for user provided callbacks when a new clause is learned. -pub(crate) struct LearnCB(pub(crate) Box>); -impl LearnCB { - pub(crate) unsafe extern "C" fn exec_callback(data: *mut c_void, clause: *const c_int) { - let cb: &mut Box = - &mut *(data as *mut Box); - cb(clause) - } - pub(crate) fn new(f: impl FnMut(*const c_int) + 'static) -> Self { - Self(Box::new(Box::new(f))) - } - pub(crate) fn as_ptr(&self) -> *mut c_void { - #[allow(clippy::borrowed_box)] - let x: &Box = &self.0; - x as *const _ as *mut c_void - } -} -impl fmt::Debug for LearnCB { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("LearnCB").field(&self.as_ptr()).finish() - } +type CB0 = unsafe extern "C" fn(*mut c_void) -> R; +unsafe extern "C" fn trampoline0 R>(user_data: *mut c_void) -> R { + let user_data = &mut *(user_data as *mut F); + user_data() } -impl Default for LearnCB { - fn default() -> Self { - Self(Box::new(Box::new(|_| {}))) - } +pub(crate) fn get_trampoline0 R>(_closure: &F) -> CB0 { + trampoline0:: } - -/// Storage for user provided callbacks to check whether a solver should terminate. -pub(crate) struct TermCB(pub(crate) Box c_int>>); -impl TermCB { - pub(crate) unsafe extern "C" fn exec_callback(data: *mut c_void) -> c_int { - let cb: &mut Box c_int> = &mut *(data as *mut Box c_int>); - cb() - } - pub(crate) fn new(f: impl FnMut() -> c_int + 'static) -> Self { - Self(Box::new(Box::new(f))) - } - pub(crate) fn as_ptr(&self) -> *mut c_void { - #[allow(clippy::borrowed_box)] - let x: &Box c_int> = &self.0; - x as *const _ as *mut c_void - } +type CB1 = unsafe extern "C" fn(*mut c_void, A) -> R; +unsafe extern "C" fn trampoline1 R>(user_data: *mut c_void, arg1: A) -> R { + let user_data = &mut *(user_data as *mut F); + user_data(arg1) } -impl fmt::Debug for TermCB { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("TermCB").field(&self.as_ptr()).finish() - } -} -impl Default for TermCB { - fn default() -> Self { - Self(Box::new(Box::new(|| 0))) - } +pub(crate) fn get_trampoline1 R>(_closure: &F) -> CB1 { + trampoline1:: } /// Iterator over the elements of a null-terminated i32 array