From ada1f0c12f0d973665dbe6c75a407d31b2f5ca26 Mon Sep 17 00:00:00 2001 From: "Jip J. Dekker" Date: Wed, 5 Jun 2024 10:57:37 +1000 Subject: [PATCH] Store solver callbacks in the solver struct --- crates/pindakaas-derive/src/lib.rs | 44 +++++---- crates/pindakaas/src/solver.rs | 15 +-- crates/pindakaas/src/solver/cadical.rs | 64 ++++--------- crates/pindakaas/src/solver/intel_sat.rs | 9 ++ crates/pindakaas/src/solver/libloading.rs | 108 +++++++++++++++++----- 5 files changed, 142 insertions(+), 98 deletions(-) diff --git a/crates/pindakaas-derive/src/lib.rs b/crates/pindakaas-derive/src/lib.rs index 46efffffa8..f5add49da5 100644 --- a/crates/pindakaas-derive/src/lib.rs +++ b/crates/pindakaas-derive/src/lib.rs @@ -18,8 +18,12 @@ struct IpasirOpts { #[darling(default)] learn_callback: bool, #[darling(default)] + learn_callback_ident: Option, + #[darling(default)] term_callback: bool, #[darling(default)] + term_callback_ident: Option, + #[darling(default)] ipasir_up: bool, #[darling(default = "default_true")] has_default: bool, @@ -88,30 +92,35 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream { }; let term_callback = if opts.term_callback { + let term_cb = match opts.term_callback_ident { + Some(x) => quote! { self. #x }, + None => quote! { self.term_cb }, + }; quote! { impl crate::solver::TermCallback for #ident { - fn set_terminate_callback crate::solver::SlvTermSignal>( + fn set_terminate_callback crate::solver::SlvTermSignal + 'static>( &mut self, cb: Option, ) { if let Some(mut cb) = cb { - let mut wrapped_cb = move || -> std::ffi::c_int { + #term_cb = crate::solver::libloading::TermCB::new(move || -> std::ffi::c_int { match cb() { crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0), crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1), } - }; - let trampoline = crate::solver::libloading::get_trampoline0(&wrapped_cb); - // WARNING: Any data in the callback now exists forever - let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void; + }); + + // let trampoline = crate::solver::libloading::get_pin_trampoline0(&#term_cb .0); + // let data = unsafe { #term_cb.0 .as_mut().get_unchecked_mut() } as *mut _ as *mut std::ffi::c_void; unsafe { #krate::ipasir_set_terminate( #ptr, - data, - Some(trampoline), + #term_cb .as_ptr(), + Some(crate::solver::libloading::TermCB::exec_callback), ) } } else { + #term_cb = Default::default(); unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) } } } @@ -122,31 +131,34 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream { }; let learn_callback = if opts.learn_callback { + let learn_cb = match opts.learn_callback_ident { + Some(x) => quote! { self. #x }, + None => quote! { self.learn_cb }, + }; quote! { impl crate::solver::LearnCallback for #ident { - fn set_learn_callback)>( + fn set_learn_callback) + 'static>( &mut self, cb: Option, ) { const MAX_LEN: std::ffi::c_int = 512; if let Some(mut cb) = cb { - let mut wrapped_cb = move |clause: *const i32| { + #learn_cb = crate::solver::libloading::LearnCB::new(move |clause: *const i32| { let mut iter = crate::solver::libloading::ExplIter(clause) .map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap())); cb(&mut iter) - }; - let trampoline = crate::solver::libloading::get_trampoline1(&wrapped_cb); - // WARNING: Any data in the callback now exists forever - let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void; + }); + unsafe { #krate::ipasir_set_learn( #ptr, - data, + #learn_cb .as_ptr(), MAX_LEN, - Some(trampoline), + Some(crate::solver::libloading::LearnCB::exec_callback), ) } } else { + #learn_cb = Default::default(); unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) } } } diff --git a/crates/pindakaas/src/solver.rs b/crates/pindakaas/src/solver.rs index f39e00b78a..ce0570e1a1 100644 --- a/crates/pindakaas/src/solver.rs +++ b/crates/pindakaas/src/solver.rs @@ -72,11 +72,10 @@ pub trait LearnCallback: Solver { /// /// Subsequent calls to this method override the previously set /// callback function. - /// - /// For IPASIR connected through C, the callback and any objects contained - /// within it might be leaked to satisfy the FFI requirements. Note that - /// [`Drop`] implementations might not be called on these objects. - fn set_learn_callback)>(&mut self, cb: Option); + fn set_learn_callback) + 'static>( + &mut self, + cb: Option, + ); } pub trait TermCallback: Solver { @@ -91,11 +90,7 @@ pub trait TermCallback: Solver { /// /// Subsequent calls to this method override the previously set /// callback function. - /// - /// For IPASIR connected through C, the callback and any objects contained - /// within it might be leaked to satisfy the FFI requirements. Note that - /// [`Drop`] implementations might not be called on these objects. - fn set_terminate_callback SlvTermSignal>(&mut self, cb: Option); + fn set_terminate_callback SlvTermSignal + 'static>(&mut self, cb: Option); } #[cfg(feature = "ipasir-up")] diff --git a/crates/pindakaas/src/solver/cadical.rs b/crates/pindakaas/src/solver/cadical.rs index 1a642b1801..bdacc3a53f 100644 --- a/crates/pindakaas/src/solver/cadical.rs +++ b/crates/pindakaas/src/solver/cadical.rs @@ -4,14 +4,25 @@ use pindakaas_cadical::{ccadical_copy, ccadical_phase, ccadical_unphase}; use pindakaas_derive::IpasirSolver; use super::VarFactory; -use crate::Lit; +use crate::{ + solver::libloading::{LearnCB, TermCB}, + Lit, +}; #[derive(IpasirSolver)] #[ipasir(krate = pindakaas_cadical, assumptions, learn_callback, term_callback, ipasir_up)] pub struct Cadical { + /// The raw pointer to the Cadical solver. ptr: *mut std::ffi::c_void, + /// The variable factory for this solver. vars: VarFactory, + /// The callback used when a clause is learned. + learn_cb: LearnCB, + /// The callback used to check whether the solver should terminate. + term_cb: TermCB, + #[cfg(feature = "ipasir-up")] + /// The external propagator called by the solver prop: Option>, } @@ -20,6 +31,8 @@ impl Default for Cadical { Self { ptr: unsafe { pindakaas_cadical::ipasir_init() }, vars: VarFactory::default(), + learn_cb: LearnCB::default(), + term_cb: TermCB::default(), #[cfg(feature = "ipasir-up")] prop: None, } @@ -32,6 +45,8 @@ impl Clone for Cadical { Self { ptr, vars: self.vars, + learn_cb: LearnCB::default(), + term_cb: TermCB::default(), #[cfg(feature = "ipasir-up")] prop: None, } @@ -77,7 +92,7 @@ mod tests { use super::*; use crate::{ linear::LimitComp, - solver::{LearnCallback, SlvTermSignal, SolveResult, Solver, TermCallback}, + solver::{SolveResult, Solver}, CardinalityOne, ClauseDatabase, Encoder, PairwiseEncoder, Valuation, }; @@ -114,51 +129,6 @@ mod tests { }); } - #[test] - fn test_cadical_cb_no_drop() { - let mut slv = Cadical::default(); - - let a = slv.new_var().into(); - let b = slv.new_var().into(); - PairwiseEncoder::default() - .encode( - &mut slv, - &CardinalityOne { - lits: vec![a, b], - cmp: LimitComp::Equal, - }, - ) - .unwrap(); - - struct NoDrop(i32); - impl NoDrop { - fn seen(&mut self) { - self.0 += 1; - eprintln!("seen {}", self.0); - } - } - impl Drop for NoDrop { - fn drop(&mut self) { - panic!("I have been dropped {}", self.0); - } - } - - { - let mut nodrop = NoDrop(0); - slv.set_terminate_callback(Some(move || { - nodrop.seen(); - SlvTermSignal::Continue - })); - } - { - let mut nodrop = NoDrop(0); - slv.set_learn_callback(Some(move |_: &mut dyn Iterator| { - nodrop.seen(); - })); - } - assert_eq!(slv.solve(|_| {}), SolveResult::Sat); - } - #[cfg(feature = "ipasir-up")] #[test] fn test_ipasir_up() { diff --git a/crates/pindakaas/src/solver/intel_sat.rs b/crates/pindakaas/src/solver/intel_sat.rs index c6c3a55ba6..47b68b95a8 100644 --- a/crates/pindakaas/src/solver/intel_sat.rs +++ b/crates/pindakaas/src/solver/intel_sat.rs @@ -1,12 +1,19 @@ use pindakaas_derive::IpasirSolver; use super::VarFactory; +use crate::solver::libloading::{LearnCB, TermCB}; #[derive(Debug, IpasirSolver)] #[ipasir(krate = pindakaas_intel_sat, assumptions, learn_callback, term_callback)] pub struct IntelSat { + /// The raw pointer to the Intel SAT solver. ptr: *mut std::ffi::c_void, + /// The variable factory for this solver. vars: VarFactory, + /// The callback used when a clause is learned. + learn_cb: LearnCB, + /// The callback used to check whether the solver should terminate. + term_cb: TermCB, } impl Default for IntelSat { @@ -14,6 +21,8 @@ impl Default for IntelSat { Self { ptr: unsafe { pindakaas_intel_sat::ipasir_init() }, vars: VarFactory::default(), + term_cb: TermCB::default(), + learn_cb: LearnCB::default(), } } } diff --git a/crates/pindakaas/src/solver/libloading.rs b/crates/pindakaas/src/solver/libloading.rs index cb68cde4e5..3f008bf464 100644 --- a/crates/pindakaas/src/solver/libloading.rs +++ b/crates/pindakaas/src/solver/libloading.rs @@ -2,6 +2,7 @@ use std::{any::Any, collections::VecDeque}; use std::{ ffi::{c_char, c_int, c_void, CStr}, + fmt, num::NonZeroI32, ptr, }; @@ -85,7 +86,9 @@ impl IpasirLibrary { pub fn new_solver(&self) -> IpasirSolver<'_> { IpasirSolver { slv: (self.ipasir_init_sym().unwrap())(), - next_var: VarFactory::default(), + vars: VarFactory::default(), + learn_cb: LearnCB::default(), + term_cb: TermCB::default(), signature_fn: self.ipasir_signature_sym().unwrap(), release_fn: self.ipasir_release_sym().unwrap(), add_fn: self.ipasir_add_sym().unwrap(), @@ -120,8 +123,16 @@ impl TryFrom for IpasirLibrary { #[derive(Debug)] pub struct IpasirSolver<'lib> { + /// The raw pointer to the Intel SAT solver. slv: *mut c_void, - next_var: VarFactory, + /// The variable factory for this solver. + vars: VarFactory, + + /// The callback used when a clause is learned. + learn_cb: LearnCB, + /// The callback used to check whether the solver should terminate. + term_cb: TermCB, + signature_fn: Symbol<'lib, extern "C" fn() -> *const c_char>, release_fn: Symbol<'lib, extern "C" fn(*mut c_void)>, add_fn: Symbol<'lib, extern "C" fn(*mut c_void, i32)>, @@ -152,7 +163,7 @@ impl<'lib> Drop for IpasirSolver<'lib> { impl<'lib> ClauseDatabase for IpasirSolver<'lib> { fn new_var(&mut self) -> Var { - self.next_var.next().expect("variable pool exhaused") + self.vars.next().expect("variable pool exhaused") } fn add_clause>(&mut self, clause: I) -> Result { @@ -279,53 +290,100 @@ impl FailedAssumtions for IpasirFailed<'_> { } impl<'lib> TermCallback for IpasirSolver<'lib> { - fn set_terminate_callback SlvTermSignal>(&mut self, cb: Option) { + fn set_terminate_callback SlvTermSignal + 'static>(&mut self, cb: Option) { if let Some(mut cb) = cb { - let mut wrapped_cb = || -> c_int { + self.term_cb = TermCB::new(move || -> c_int { match cb() { SlvTermSignal::Continue => c_int::from(0), SlvTermSignal::Terminate => c_int::from(1), } - }; - let data = &mut wrapped_cb as *mut _ as *mut c_void; - (self.set_terminate_fn)(self.slv, data, Some(get_trampoline0(&wrapped_cb))); + }); + + (self.set_terminate_fn)(self.slv, self.term_cb.as_ptr(), Some(TermCB::exec_callback)); } else { + self.term_cb = TermCB::default(); (self.set_terminate_fn)(self.slv, ptr::null_mut(), None); } } } impl<'lib> LearnCallback for IpasirSolver<'lib> { - fn set_learn_callback)>(&mut self, cb: Option) { + fn set_learn_callback) + 'static>( + &mut self, + cb: Option, + ) { const MAX_LEN: std::ffi::c_int = 512; if let Some(mut cb) = cb { - let mut wrapped_cb = |clause: *const i32| { + self.learn_cb = LearnCB::new(move |clause: *const i32| { let mut iter = ExplIter(clause).map(|i: i32| Lit(NonZeroI32::new(i).unwrap())); cb(&mut iter) - }; - let data = &mut wrapped_cb as *mut _ as *mut c_void; - (self.set_learn_fn)(self.slv, data, MAX_LEN, Some(get_trampoline1(&wrapped_cb))); + }); + (self.set_learn_fn)( + self.slv, + self.learn_cb.as_ptr(), + MAX_LEN, + Some(LearnCB::exec_callback), + ); } else { + self.learn_cb = LearnCB::default(); (self.set_learn_fn)(self.slv, ptr::null_mut(), MAX_LEN, None); } } } -type CB0 = unsafe extern "C" fn(*mut c_void) -> R; -unsafe extern "C" fn trampoline0 R>(user_data: *mut c_void) -> R { - let user_data = &mut *(user_data as *mut F); - user_data() + +/// Storage for user provided callbacks when a new clause is learned. +pub(crate) struct LearnCB(pub(crate) Box>); +impl LearnCB { + pub(crate) unsafe extern "C" fn exec_callback(data: *mut c_void, clause: *const c_int) { + let cb: &mut Box = + &mut *(data as *mut Box); + cb(clause) + } + pub(crate) fn new(f: impl FnMut(*const c_int) + 'static) -> Self { + Self(Box::new(Box::new(f))) + } + pub(crate) fn as_ptr(&self) -> *mut c_void { + let x: &Box = &*self.0; + x as *const _ as *mut c_void + } } -pub(crate) fn get_trampoline0 R>(_closure: &F) -> CB0 { - trampoline0:: +impl fmt::Debug for LearnCB { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("LearnCB").field(&self.as_ptr()).finish() + } } -type CB1 = unsafe extern "C" fn(*mut c_void, A) -> R; -unsafe extern "C" fn trampoline1 R>(user_data: *mut c_void, arg1: A) -> R { - let user_data = &mut *(user_data as *mut F); - user_data(arg1) +impl Default for LearnCB { + fn default() -> Self { + Self(Box::new(Box::new(|_| {}))) + } } -pub(crate) fn get_trampoline1 R>(_closure: &F) -> CB1 { - trampoline1:: + +/// Storage for user provided callbacks to check whether a solver should terminate. +pub(crate) struct TermCB(pub(crate) Box c_int>>); +impl TermCB { + pub(crate) unsafe extern "C" fn exec_callback(data: *mut c_void) -> c_int { + let cb: &mut Box c_int> = &mut *(data as *mut Box c_int>); + cb() + } + pub(crate) fn new(f: impl FnMut() -> c_int + 'static) -> Self { + Self(Box::new(Box::new(f))) + } + pub(crate) fn as_ptr(&self) -> *mut c_void { + let x: &Box c_int> = &*self.0; + x as *const _ as *mut c_void + } } +impl fmt::Debug for TermCB { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("TermCB").field(&self.as_ptr()).finish() + } +} +impl Default for TermCB { + fn default() -> Self { + Self(Box::new(Box::new(|| 0))) + } +} + /// Iterator over the elements of a null-terminated i32 array #[derive(Debug, Clone, Copy)] pub(crate) struct ExplIter(pub(crate) *const i32);