Skip to content

Commit

Permalink
Resolve one level of dispatch for IPASIR callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
Dekker1 committed Jun 6, 2024
1 parent 6695c90 commit 507ba3b
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 150 deletions.
140 changes: 86 additions & 54 deletions crates/pindakaas-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,79 +91,109 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
quote!()
};

let term_callback = if opts.term_callback {
let (term_callback, term_drop) = if opts.term_callback {
let term_cb = match opts.term_callback_ident {
Some(x) => quote! { self. #x },
None => quote! { self.term_cb },
};
quote! {
impl crate::solver::TermCallback for #ident {
fn set_terminate_callback<F: FnMut() -> crate::solver::SlvTermSignal + 'static>(
&mut self,
cb: Option<F>,
) {
if let Some(mut cb) = cb {
#term_cb = crate::solver::libloading::TermCB::new(move || -> std::ffi::c_int {
match cb() {
crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0),
crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1),
(
quote! {
impl crate::solver::TermCallback for #ident {
fn set_terminate_callback<F: FnMut() -> crate::solver::SlvTermSignal + 'static>(
&mut self,
cb: Option<F>,
) {
if let Some(mut cb) = cb {
let mut wrapped_cb = move || -> std::ffi::c_int {
match cb() {
crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0),
crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1),
}
};
let trampoline = crate::solver::libloading::get_trampoline0(&wrapped_cb);
let layout = std::alloc::Layout::for_value(&wrapped_cb);
let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void;
if layout.size() != 0 {
// Otherwise nothing was leaked.
#term_cb = Some((data, layout));
}
});

unsafe {
#krate::ipasir_set_terminate(
#ptr,
#term_cb .as_ptr(),
Some(crate::solver::libloading::TermCB::exec_callback),
)
unsafe {
#krate::ipasir_set_terminate(
#ptr,
data,
Some(trampoline),
)
}
} else {
if let Some((ptr, layout)) = #term_cb .take() {
unsafe { std::alloc::dealloc(ptr as *mut _, layout) };
}
unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) }
}
} else {
#term_cb = Default::default();
unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) }
}
}
}
}
},
quote! {
if let Some((ptr, layout)) = #term_cb .take() {
unsafe { std::alloc::dealloc(ptr as *mut _, layout) };
}
},
)
} else {
quote!()
(quote!(), quote!())
};

let learn_callback = if opts.learn_callback {
let (learn_callback, learn_drop) = if opts.learn_callback {
let learn_cb = match opts.learn_callback_ident {
Some(x) => quote! { self. #x },
None => quote! { self.learn_cb },
};
quote! {
impl crate::solver::LearnCallback for #ident {
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = crate::Lit>) + 'static>(
&mut self,
cb: Option<F>,
) {
const MAX_LEN: std::ffi::c_int = 512;
if let Some(mut cb) = cb {
#learn_cb = crate::solver::libloading::LearnCB::new(move |clause: *const i32| {
let mut iter = crate::solver::libloading::ExplIter(clause)
.map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap()));
cb(&mut iter)
});

unsafe {
#krate::ipasir_set_learn(
#ptr,
#learn_cb .as_ptr(),
MAX_LEN,
Some(crate::solver::libloading::LearnCB::exec_callback),
)
(
quote! {
impl crate::solver::LearnCallback for #ident {
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = crate::Lit>) + 'static>(
&mut self,
cb: Option<F>,
) {
const MAX_LEN: std::ffi::c_int = 512;
if let Some(mut cb) = cb {
let mut wrapped_cb = move |clause: *const i32| {
let mut iter = crate::solver::libloading::ExplIter(clause)
.map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap()));
cb(&mut iter)
};
let trampoline = crate::solver::libloading::get_trampoline1(&wrapped_cb);
let layout = std::alloc::Layout::for_value(&wrapped_cb);
let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void;
if layout.size() != 0 {
// Otherwise nothing was leaked.
#learn_cb = Some((data, layout));
}
unsafe {
#krate::ipasir_set_learn(
#ptr,
data,
MAX_LEN,
Some(trampoline),
)
}
} else {
if let Some((ptr, layout)) = #learn_cb .take() {
unsafe { std::alloc::dealloc(ptr as *mut _, layout) };
}
unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) }
}
} else {
#learn_cb = Default::default();
unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) }
}
}
}
}
},
quote! {
if let Some((ptr, layout)) = #learn_cb .take() {
unsafe { std::alloc::dealloc(ptr as *mut _, layout) };
}
},
)
} else {
quote!()
(quote!(), quote!())
};

let sol_ident = format_ident!("{}Sol", ident);
Expand Down Expand Up @@ -356,6 +386,8 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
quote! {
impl Drop for #ident {
fn drop(&mut self) {
#learn_drop
#term_drop
unsafe { #krate::ipasir_release( #ptr ) }
}
}
Expand Down
25 changes: 13 additions & 12 deletions crates/pindakaas/src/solver/cadical.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
use std::{ffi::CString, fmt};
use std::{
alloc::Layout,
ffi::{c_void, CString},
fmt,
};

use pindakaas_cadical::{ccadical_copy, ccadical_phase, ccadical_unphase};
use pindakaas_derive::IpasirSolver;

use super::VarFactory;
use crate::{
solver::libloading::{LearnCB, TermCB},
Lit,
};
use crate::Lit;

#[derive(IpasirSolver)]
#[ipasir(krate = pindakaas_cadical, assumptions, learn_callback, term_callback, ipasir_up)]
pub struct Cadical {
/// The raw pointer to the Cadical solver.
ptr: *mut std::ffi::c_void,
ptr: *mut c_void,
/// The variable factory for this solver.
vars: VarFactory,
/// The callback used when a clause is learned.
learn_cb: LearnCB,
learn_cb: Option<(*mut c_void, Layout)>,
/// The callback used to check whether the solver should terminate.
term_cb: TermCB,
term_cb: Option<(*mut c_void, Layout)>,

#[cfg(feature = "ipasir-up")]
/// The external propagator called by the solver
Expand All @@ -31,8 +32,8 @@ impl Default for Cadical {
Self {
ptr: unsafe { pindakaas_cadical::ipasir_init() },
vars: VarFactory::default(),
learn_cb: LearnCB::default(),
term_cb: TermCB::default(),
learn_cb: None,
term_cb: None,
#[cfg(feature = "ipasir-up")]
prop: None,
}
Expand All @@ -45,8 +46,8 @@ impl Clone for Cadical {
Self {
ptr,
vars: self.vars,
learn_cb: LearnCB::default(),
term_cb: TermCB::default(),
learn_cb: None,
term_cb: None,
#[cfg(feature = "ipasir-up")]
prop: None,
}
Expand Down
13 changes: 7 additions & 6 deletions crates/pindakaas/src/solver/intel_sat.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
use std::{alloc::Layout, ffi::c_void};

use pindakaas_derive::IpasirSolver;

use super::VarFactory;
use crate::solver::libloading::{LearnCB, TermCB};

#[derive(Debug, IpasirSolver)]
#[ipasir(krate = pindakaas_intel_sat, assumptions, learn_callback, term_callback)]
pub struct IntelSat {
/// The raw pointer to the Intel SAT solver.
ptr: *mut std::ffi::c_void,
ptr: *mut c_void,
/// The variable factory for this solver.
vars: VarFactory,
/// The callback used when a clause is learned.
learn_cb: LearnCB,
learn_cb: Option<(*mut c_void, Layout)>,
/// The callback used to check whether the solver should terminate.
term_cb: TermCB,
term_cb: Option<(*mut c_void, Layout)>,
}

impl Default for IntelSat {
fn default() -> Self {
Self {
ptr: unsafe { pindakaas_intel_sat::ipasir_init() },
vars: VarFactory::default(),
term_cb: TermCB::default(),
learn_cb: LearnCB::default(),
term_cb: None,
learn_cb: None,
}
}
}
Expand Down
Loading

0 comments on commit 507ba3b

Please sign in to comment.