diff --git a/crates/pindakaas-derive/src/lib.rs b/crates/pindakaas-derive/src/lib.rs index 74e3582555..a4f28fe136 100644 --- a/crates/pindakaas-derive/src/lib.rs +++ b/crates/pindakaas-derive/src/lib.rs @@ -95,7 +95,7 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream { cb: Option, ) { if let Some(mut cb) = cb { - let mut wrapped_cb = || -> std::ffi::c_int { + let mut wrapped_cb = move || -> std::ffi::c_int { match cb() { crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0), crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1), @@ -109,6 +109,8 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream { Some(crate::solver::libloading::get_trampoline0(&wrapped_cb)), ) } + // WARNING: Any data in the callback now exists forever + std::mem::forget(wrapped_cb); } else { unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) } } @@ -128,7 +130,7 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream { ) { const MAX_LEN: std::ffi::c_int = 512; if let Some(mut cb) = cb { - let mut wrapped_cb = |clause: *const i32| { + let mut wrapped_cb = move |clause: *const i32| { let mut iter = crate::solver::libloading::ExplIter(clause) .map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap())); cb(&mut iter) @@ -142,6 +144,8 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream { Some(crate::solver::libloading::get_trampoline1(&wrapped_cb)), ) } + // WARNING: Any data in the callback now exists forever + std::mem::forget(wrapped_cb); } else { unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) } } diff --git a/crates/pindakaas/src/solver.rs b/crates/pindakaas/src/solver.rs index 32c990c012..f39e00b78a 100644 --- a/crates/pindakaas/src/solver.rs +++ b/crates/pindakaas/src/solver.rs @@ -68,8 +68,14 @@ pub trait LearnCallback: Solver { /// Set a callback function used to extract learned clauses up to a given /// length from the solver. /// - /// WARNING: Subsequent calls to this method override the previously set + /// # Warning + /// + /// Subsequent calls to this method override the previously set /// callback function. + /// + /// For IPASIR connected through C, the callback and any objects contained + /// within it might be leaked to satisfy the FFI requirements. Note that + /// [`Drop`] implementations might not be called on these objects. fn set_learn_callback)>(&mut self, cb: Option); } @@ -80,6 +86,15 @@ pub trait TermCallback: Solver { /// The solver will periodically call this function and check its return value /// during the search. Subsequent calls to this method override the previously /// set callback function. + /// + /// # Warning + /// + /// Subsequent calls to this method override the previously set + /// callback function. + /// + /// For IPASIR connected through C, the callback and any objects contained + /// within it might be leaked to satisfy the FFI requirements. Note that + /// [`Drop`] implementations might not be called on these objects. fn set_terminate_callback SlvTermSignal>(&mut self, cb: Option); } diff --git a/crates/pindakaas/src/solver/cadical.rs b/crates/pindakaas/src/solver/cadical.rs index 7029ffcec4..5df12a2339 100644 --- a/crates/pindakaas/src/solver/cadical.rs +++ b/crates/pindakaas/src/solver/cadical.rs @@ -65,7 +65,7 @@ mod tests { use super::*; use crate::{ linear::LimitComp, - solver::{SolveResult, Solver}, + solver::{LearnCallback, SlvTermSignal, SolveResult, Solver, TermCallback}, CardinalityOne, ClauseDatabase, Encoder, PairwiseEncoder, Valuation, }; @@ -102,6 +102,51 @@ mod tests { }); } + #[test] + fn test_cadical_cb_no_drop() { + let mut slv = Cadical::default(); + + let a = slv.new_var().into(); + let b = slv.new_var().into(); + PairwiseEncoder::default() + .encode( + &mut slv, + &CardinalityOne { + lits: vec![a, b], + cmp: LimitComp::Equal, + }, + ) + .unwrap(); + + struct NoDrop(i32); + impl NoDrop { + fn seen(&mut self) { + self.0 += 1; + eprintln!("seen {}", self.0); + } + } + impl Drop for NoDrop { + fn drop(&mut self) { + panic!("I have been dropped {}", self.0); + } + } + + { + let mut nodrop = NoDrop(0); + slv.set_terminate_callback(Some(move || { + nodrop.seen(); + SlvTermSignal::Continue + })); + } + { + let mut nodrop = NoDrop(0); + slv.set_learn_callback(Some(move |_: &mut dyn Iterator| { + nodrop.seen(); + })); + } + assert_eq!(slv.solve(|_| {}), SolveResult::Sat); + } + #[cfg(feature = "ipasir-up")] #[test] fn test_ipasir_up() {