Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a problem where callbacks given to IPASIR solver where dropped early #54

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions crates/pindakaas-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
cb: Option<F>,
) {
if let Some(mut cb) = cb {
let mut wrapped_cb = || -> std::ffi::c_int {
let mut wrapped_cb = move || -> std::ffi::c_int {
match cb() {
crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0),
crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1),
Expand All @@ -109,6 +109,8 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
Some(crate::solver::libloading::get_trampoline0(&wrapped_cb)),
)
}
// WARNING: Any data in the callback now exists forever
std::mem::forget(wrapped_cb);
} else {
unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) }
}
Expand All @@ -128,7 +130,7 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
) {
const MAX_LEN: std::ffi::c_int = 512;
if let Some(mut cb) = cb {
let mut wrapped_cb = |clause: *const i32| {
let mut wrapped_cb = move |clause: *const i32| {
let mut iter = crate::solver::libloading::ExplIter(clause)
.map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap()));
cb(&mut iter)
Expand All @@ -142,6 +144,8 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
Some(crate::solver::libloading::get_trampoline1(&wrapped_cb)),
)
}
// WARNING: Any data in the callback now exists forever
std::mem::forget(wrapped_cb);
} else {
unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) }
}
Expand Down
17 changes: 16 additions & 1 deletion crates/pindakaas/src/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,14 @@ pub trait LearnCallback: Solver {
/// Set a callback function used to extract learned clauses up to a given
/// length from the solver.
///
/// WARNING: Subsequent calls to this method override the previously set
/// # Warning
///
/// Subsequent calls to this method override the previously set
/// callback function.
///
/// For IPASIR connected through C, the callback and any objects contained
/// within it might be leaked to satisfy the FFI requirements. Note that
/// [`Drop`] implementations might not be called on these objects.
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = Lit>)>(&mut self, cb: Option<F>);
}

Expand All @@ -80,6 +86,15 @@ pub trait TermCallback: Solver {
/// The solver will periodically call this function and check its return value
/// during the search. Subsequent calls to this method override the previously
/// set callback function.
///
/// # Warning
///
/// Subsequent calls to this method override the previously set
/// callback function.
///
/// For IPASIR connected through C, the callback and any objects contained
/// within it might be leaked to satisfy the FFI requirements. Note that
/// [`Drop`] implementations might not be called on these objects.
fn set_terminate_callback<F: FnMut() -> SlvTermSignal>(&mut self, cb: Option<F>);
}

Expand Down
47 changes: 46 additions & 1 deletion crates/pindakaas/src/solver/cadical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mod tests {
use super::*;
use crate::{
linear::LimitComp,
solver::{SolveResult, Solver},
solver::{LearnCallback, SlvTermSignal, SolveResult, Solver, TermCallback},
CardinalityOne, ClauseDatabase, Encoder, PairwiseEncoder, Valuation,
};

Expand Down Expand Up @@ -102,6 +102,51 @@ mod tests {
});
}

#[test]
fn test_cadical_cb_no_drop() {
let mut slv = Cadical::default();

let a = slv.new_var().into();
let b = slv.new_var().into();
PairwiseEncoder::default()
.encode(
&mut slv,
&CardinalityOne {
lits: vec![a, b],
cmp: LimitComp::Equal,
},
)
.unwrap();

struct NoDrop(i32);
impl NoDrop {
fn seen(&mut self) {
self.0 += 1;
eprintln!("seen {}", self.0);
}
}
impl Drop for NoDrop {
fn drop(&mut self) {
panic!("I have been dropped {}", self.0);
}
}

{
let mut nodrop = NoDrop(0);
slv.set_terminate_callback(Some(move || {
nodrop.seen();
SlvTermSignal::Continue
}));
}
{
let mut nodrop = NoDrop(0);
slv.set_learn_callback(Some(move |_: &mut dyn Iterator<Item = Lit>| {
nodrop.seen();
}));
}
assert_eq!(slv.solve(|_| {}), SolveResult::Sat);
}

#[cfg(feature = "ipasir-up")]
#[test]
fn test_ipasir_up() {
Expand Down