diff --git a/Cargo.toml b/Cargo.toml index 8c5c90cb..a3510d3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ wasm-bindgen = { version = "0.2.84" } [dev-dependencies] serde_json = { version = "1.0.81" } clap = { version = "4.2.1", features = ["derive"] } +rayon = "1.7.0" [lib] name = "rsdd" @@ -81,3 +82,7 @@ path = "examples/semantic_top_down_experiment.rs" [[example]] name = "marginal_map_experiment" path = "examples/marginal_map_experiment.rs" + +[[example]] +name = "parallel_semantic" +path = "examples/parallel_semantic.rs" diff --git a/examples/parallel_semantic.rs b/examples/parallel_semantic.rs new file mode 100644 index 00000000..cbd7a543 --- /dev/null +++ b/examples/parallel_semantic.rs @@ -0,0 +1,207 @@ +use std::{ + fs, + time::{Duration, Instant}, +}; + +use clap::Parser; +use rayon::prelude::*; +use rsdd::{ + builder::{parallel::SemanticBddBuilder, BottomUpBuilder}, + constants::primes, + repr::{create_semantic_hash_map, Cnf, DDNNFPtr}, +}; + +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + /// input CNF in DIMACS form + #[clap(short, long, value_parser)] + file: String, + + /// number of splits to perform on the data + #[clap(short, long, value_parser, default_value_t = 4)] + num_splits: usize, + + /// use multithreading! + #[clap(short, long, value_parser)] + thread: bool, +} + +fn split_cnf(cnf: &Cnf, num_splits: usize) -> Vec { + let chunk_size = cnf.clauses().len() / num_splits + + (if cnf.clauses().len() % num_splits == 0 { + 0 + } else { + 1 + }); + + cnf.clauses().chunks(chunk_size).map(Cnf::new).collect() +} + +fn single_threaded(cnf: &Cnf, num_splits: usize) { + let num_splits = std::cmp::min(num_splits, cnf.clauses().len()); + + let num_vars = cnf.num_vars(); + let map = create_semantic_hash_map(num_vars); + let order = cnf.min_fill_order(); + + let builders: Vec<_> = (0..num_splits) + .map(|_| { + SemanticBddBuilder::<{ primes::U64_LARGEST }>::new_with_map(order.clone(), map.clone()) + }) + .collect(); + + let mut ptrs = Vec::new(); + + let mut timings = Vec::new(); + + for (i, subcnf) in split_cnf(cnf, num_splits).iter().enumerate() { + let start = Instant::now(); + ptrs.push(builders[i].compile_cnf(subcnf)); + let end = start.elapsed(); + timings.push(end); + } + + let compile_duration: Duration = timings.iter().sum(); + let compile_max = timings.iter().max().unwrap(); + + println!("DONE COMPILING: {:.2}s", compile_duration.as_secs_f64()); + println!("MAX COMPILATION: {:.2}s", compile_max.as_secs_f64()); + + let start = Instant::now(); + + let builder = &builders[0]; + let mut ptr = ptrs[0]; + + for i in 1..ptrs.len() { + let new_ptr = builder.merge_from(&builders[i], &[ptrs[i]])[0]; + ptr = builder.and(ptr, new_ptr); + } + + let merge_duration = start.elapsed(); + + let st_builder = + SemanticBddBuilder::<{ primes::U64_LARGEST }>::new_with_map(order.clone(), map.clone()); + + let start = Instant::now(); + let st_ptr = st_builder.compile_cnf(cnf); + let single_duration = start.elapsed(); + + let wmc = ptr.wmc(&order, &map); + let st_wmc = st_ptr.wmc(&order, &map); + + println!("=== TIMING ==="); + println!( + "Compile: {:.2}s (Total {:.2}s), Merge: {:.2}s", + compile_max.as_secs_f64(), + compile_duration.as_secs_f64(), + merge_duration.as_secs_f64() + ); + println!("Single-threaded: {:.2}s", single_duration.as_secs_f64()); + println!( + "Speedup ratio: {:.2}x", + single_duration.as_secs_f64() / (compile_max.as_secs_f64() + merge_duration.as_secs_f64()) + ); + if wmc != st_wmc { + println!( + "BROKEN. Not equal WMC; single: {}, merge: {}", + st_wmc.value(), + wmc.value() + ); + } +} + +fn multi_threaded(cnf: &Cnf, num_splits: usize) { + let num_splits = std::cmp::min(num_splits, cnf.clauses().len()); + + let num_vars = cnf.num_vars(); + let map = create_semantic_hash_map(num_vars); + let order = cnf.min_fill_order(); + + let builders: Vec<_> = split_cnf(cnf, num_splits) + .into_par_iter() + .map(|subcnf| { + ( + SemanticBddBuilder::<{ primes::U64_LARGEST }>::new_with_map( + order.clone(), + map.clone(), + ), + subcnf, + ) + }) + .collect(); + + let start = Instant::now(); + + let ptrs: Vec<_> = builders + .par_iter() + .map(|(builder, cnf)| builder.compile_cnf(cnf)) + .collect(); + + let compile_duration: Duration = start.elapsed(); + + println!("DONE COMPILING: {:.2}s", compile_duration.as_secs_f64()); + + let mut merge_ds = 0.0; + let mut merge_and = 0.0; + + let builder = &builders[0].0; + let mut ptr = ptrs[0]; + + for i in 1..ptrs.len() { + let start = Instant::now(); + let new_ptr = builder.merge_from(&builders[i].0, &[ptrs[i]])[0]; + merge_ds += start.elapsed().as_secs_f64(); + + let start = Instant::now(); + ptr = builder.and(ptr, new_ptr); + merge_and += start.elapsed().as_secs_f64(); + println!("completed one AND; total time spent: {:.2}s", merge_and); + } + + let merge_duration = merge_ds + merge_and; + + let st_builder = + SemanticBddBuilder::<{ primes::U64_LARGEST }>::new_with_map(order.clone(), map.clone()); + + let start = Instant::now(); + let st_ptr = st_builder.compile_cnf(cnf); + let single_duration = start.elapsed(); + + let wmc = ptr.wmc(&order, &map); + let st_wmc = st_ptr.wmc(&order, &map); + + println!("=== TIMING ==="); + println!( + "Compile: {:.2}s, Merge: {:.2}s (ds: {:.2}s, ands: {:.2}s)", + compile_duration.as_secs_f64(), + merge_duration, + merge_ds, + merge_and + ); + println!("Single-threaded: {:.2}s", single_duration.as_secs_f64()); + println!( + "Speedup ratio: {:.2}x", + single_duration.as_secs_f64() / (compile_duration.as_secs_f64() + merge_duration) + ); + if wmc != st_wmc { + println!( + "BROKEN. Not equal WMC; single: {}, merge: {}", + st_wmc.value(), + wmc.value() + ); + } +} + +fn main() { + let args = Args::parse(); + + let cnf_input = fs::read_to_string(args.file).expect("Should have been able to read the file"); + + let cnf = Cnf::from_dimacs(&cnf_input); + + match args.thread { + false => single_threaded(&cnf, args.num_splits), + true => multi_threaded(&cnf, args.num_splits), + } +} diff --git a/src/backing_store/bump_table.rs b/src/backing_store/bump_table.rs index e0893bb9..1f0a4018 100644 --- a/src/backing_store/bump_table.rs +++ b/src/backing_store/bump_table.rs @@ -84,6 +84,7 @@ fn propagate<'a, T: Clone>( /// Implements a mutable vector-backed robin-hood linear probing hash table, /// whose keys are given by BDD pointers. +#[derive(Debug)] pub struct BackedRobinhoodTable<'a, T> where T: Hash + PartialEq + Clone, diff --git a/src/builder/bdd/stats.rs b/src/builder/bdd/stats.rs index 81112e54..c582f24f 100644 --- a/src/builder/bdd/stats.rs +++ b/src/builder/bdd/stats.rs @@ -1,5 +1,6 @@ /// An auxiliary data structure for tracking statistics about BDD manager /// performance (for fine-tuning) +#[derive(Debug)] pub struct BddBuilderStats { /// For now, always track the number of recursive calls. In the future, /// this should probably be gated behind a debug build (since I suspect diff --git a/src/builder/cache/all_app.rs b/src/builder/cache/all_app.rs index e8e9173c..f94dc929 100644 --- a/src/builder/cache/all_app.rs +++ b/src/builder/cache/all_app.rs @@ -8,6 +8,7 @@ use rustc_hash::FxHashMap; /// An Ite structure, assumed to be in standard form. /// The top-level data structure that caches applications +#[derive(Debug)] pub struct AllIteTable { /// a vector of applications, indexed by the top label of the first pointer. table: FxHashMap<(T, T, T), T>, @@ -54,6 +55,14 @@ impl<'a, T: DDNNFPtr<'a>> AllIteTable { table: FxHashMap::default(), } } + + pub fn iter(&self) -> impl Iterator + '_ { + self.table.iter() + } + + pub fn insert_directly(&mut self, k: (T, T, T), v: T) { + self.table.insert(k, v); + } } impl<'a, T: DDNNFPtr<'a>> Default for AllIteTable { diff --git a/src/builder/mod.rs b/src/builder/mod.rs index 3c61b61d..3f7a7b0c 100644 --- a/src/builder/mod.rs +++ b/src/builder/mod.rs @@ -5,6 +5,7 @@ pub mod cache; pub mod bdd; pub mod decision_nnf; +pub mod parallel; pub mod sdd; use crate::{ @@ -31,6 +32,7 @@ pub trait BottomUpBuilder<'a, Ptr> { fn or(&'a self, a: Ptr, b: Ptr) -> Ptr { self.negate(self.and(self.negate(a), self.negate(b))) } + fn negate(&'a self, f: Ptr) -> Ptr; /// if f then g else h diff --git a/src/builder/parallel/mod.rs b/src/builder/parallel/mod.rs new file mode 100644 index 00000000..94a0c4f2 --- /dev/null +++ b/src/builder/parallel/mod.rs @@ -0,0 +1,3 @@ +mod semantic_bdd; + +pub use self::semantic_bdd::*; diff --git a/src/builder/parallel/semantic_bdd.rs b/src/builder/parallel/semantic_bdd.rs new file mode 100644 index 00000000..7e888a5a --- /dev/null +++ b/src/builder/parallel/semantic_bdd.rs @@ -0,0 +1,545 @@ +use crate::{ + builder::{ + bdd::BddBuilderStats, + cache::{AllIteTable, Ite, IteTable}, + BottomUpBuilder, + }, + repr::{ + create_semantic_hash_map, Cnf, DDNNFPtr, SemanticBddNode, SemanticBddPtr, VarLabel, + VarOrder, WmcParams, + }, + util::semirings::{FiniteField, Semiring}, +}; +use std::{cmp::Ordering, collections::HashMap, sync::RwLock}; + +#[derive(Debug)] +pub struct SemanticBddBuilder<'a, const P: u128> { + compute_table: RwLock, SemanticBddNode

>>, + apply_table: RwLock>>, + stats: RwLock, + order: RwLock, + map: WmcParams>, +} + +impl<'a, const P: u128> BottomUpBuilder<'a, SemanticBddPtr<'a, P>> for SemanticBddBuilder<'a, P> { + fn true_ptr(&self) -> SemanticBddPtr<'a, P> { + SemanticBddPtr::PtrTrue + } + + fn false_ptr(&self) -> SemanticBddPtr<'a, P> { + SemanticBddPtr::PtrFalse + } + + fn var(&'a self, label: VarLabel, polarity: bool) -> SemanticBddPtr<'a, P> { + let bdd = + SemanticBddNode::new_from_builder(label, FiniteField::zero(), FiniteField::one(), self); + let r = self.get_or_insert(bdd); + if polarity { + r + } else { + r.neg() + } + } + + fn eq(&'a self, a: SemanticBddPtr<'a, P>, b: SemanticBddPtr<'a, P>) -> bool { + a.semantic_hash() == b.semantic_hash() + } + + fn and(&'a self, a: SemanticBddPtr<'a, P>, b: SemanticBddPtr<'a, P>) -> SemanticBddPtr<'a, P> { + self.ite(a, b, SemanticBddPtr::false_ptr()) + } + + fn negate(&'a self, f: SemanticBddPtr<'a, P>) -> SemanticBddPtr<'a, P> { + f.neg() + } + + fn ite( + &'a self, + f: SemanticBddPtr<'a, P>, + g: SemanticBddPtr<'a, P>, + h: SemanticBddPtr<'a, P>, + ) -> SemanticBddPtr<'a, P> { + self.ite_helper(f, g, h) + } + + fn iff(&'a self, a: SemanticBddPtr<'a, P>, b: SemanticBddPtr<'a, P>) -> SemanticBddPtr<'a, P> { + self.ite(a, b, b.neg()) + } + + fn xor(&'a self, a: SemanticBddPtr<'a, P>, b: SemanticBddPtr<'a, P>) -> SemanticBddPtr<'a, P> { + self.ite(a, b.neg(), b) + } + + fn exists(&'a self, bdd: SemanticBddPtr<'a, P>, lbl: VarLabel) -> SemanticBddPtr<'a, P> { + let v1 = self.condition(bdd, lbl, true); + let v2 = self.condition(bdd, lbl, false); + self.or(v1, v2) + } + + /// Compute the Boolean function `f | var = value` + fn condition( + &'a self, + bdd: SemanticBddPtr<'a, P>, + lbl: VarLabel, + value: bool, + ) -> SemanticBddPtr<'a, P> { + self.cond_with_alloc(bdd, lbl, value, &mut Vec::new()) + } + + fn compile_cnf(&'a self, cnf: &Cnf) -> SemanticBddPtr<'a, P> { + let mut cvec: Vec> = Vec::with_capacity(cnf.clauses().len()); + if cnf.clauses().is_empty() { + return SemanticBddPtr::true_ptr(); + } + // check if there is an empty clause -- if so, UNSAT + if cnf.clauses().iter().any(|x| x.is_empty()) { + return SemanticBddPtr::false_ptr(); + } + + // sort the clauses based on a best-effort bottom-up ordering of clauses + let mut cnf_sorted = cnf.clauses().to_vec(); + cnf_sorted.sort_by(|c1, c2| { + // order the clause with the first-most variable last + let fst1 = c1 + .iter() + .max_by(|l1, l2| { + if self.less_than(l1.label(), l2.label()) { + Ordering::Less + } else { + Ordering::Equal + } + }) + .unwrap(); + let fst2 = c2 + .iter() + .max_by(|l1, l2| { + if self.less_than(l1.label(), l2.label()) { + Ordering::Less + } else { + Ordering::Equal + } + }) + .unwrap(); + if self.less_than(fst1.label(), fst2.label()) { + Ordering::Less + } else { + Ordering::Equal + } + }); + + for lit_vec in cnf_sorted.iter() { + let (vlabel, val) = (lit_vec[0].label(), lit_vec[0].polarity()); + let mut bdd = self.var(vlabel, val); + for lit in lit_vec { + let (vlabel, val) = (lit.label(), lit.polarity()); + let var = self.var(vlabel, val); + bdd = self.or(bdd, var); + } + cvec.push(bdd); + } + // now cvec has a list of all the clauses; collapse it down + let r = self.collapse_clauses(&cvec); + match r { + None => SemanticBddPtr::true_ptr(), + Some(x) => x, + } + } +} + +impl<'a, const P: u128> SemanticBddBuilder<'a, P> { + pub fn new(order: VarOrder) -> SemanticBddBuilder<'a, P> { + let map = create_semantic_hash_map(order.num_vars()); + SemanticBddBuilder { + compute_table: RwLock::new(HashMap::default()), + order: RwLock::new(order), + apply_table: RwLock::new(AllIteTable::default()), + stats: RwLock::new(BddBuilderStats::new()), + map, + } + } + + pub fn new_with_map( + order: VarOrder, + map: WmcParams>, + ) -> SemanticBddBuilder<'a, P> { + SemanticBddBuilder { + compute_table: RwLock::new(HashMap::default()), + order: RwLock::new(order), + apply_table: RwLock::new(AllIteTable::default()), + stats: RwLock::new(BddBuilderStats::new()), + map, + } + } + + pub fn map(&'a self) -> &WmcParams> { + &self.map + } + + pub fn deref_semantic_node( + &'a self, + semantic_hash: &FiniteField

, + ) -> Option> { + if let Some(bdd) = self.compute_table.read().unwrap().get(semantic_hash) { + return Some(bdd.clone()); + } + + // check negated hash + let semantic_hash = semantic_hash.negate(); + if let Some(bdd) = self.compute_table.read().unwrap().get(&semantic_hash) { + return Some(bdd.clone()); + } + None + } + + pub fn deref_semantic_hash(&'a self, hash: &FiniteField

) -> SemanticBddPtr<'a, P> { + self.check_cached_hash_and_neg(hash) + .unwrap_or_else(|| panic!("Could not find item for hash: {}.", hash)) + } + + fn less_than(&self, a: VarLabel, b: VarLabel) -> bool { + self.order.read().unwrap().lt(a, b) + } + + fn ite_helper( + &'a self, + f: SemanticBddPtr<'a, P>, + g: SemanticBddPtr<'a, P>, + h: SemanticBddPtr<'a, P>, + ) -> SemanticBddPtr<'a, P> { + self.stats.write().unwrap().num_recursive_calls += 1; + let o = |a: SemanticBddPtr

, b: SemanticBddPtr

| match (a, b) { + (SemanticBddPtr::PtrTrue, _) | (SemanticBddPtr::PtrFalse, _) => true, + (_, SemanticBddPtr::PtrTrue) | (_, SemanticBddPtr::PtrFalse) => false, + ( + SemanticBddPtr::Reg(ff_a, _) | SemanticBddPtr::Compl(ff_a, _), + SemanticBddPtr::Reg(ff_b, _) | SemanticBddPtr::Compl(ff_b, _), + ) => { + let node_a = self.deref_semantic_node(&ff_a).unwrap(); + let node_b = self.deref_semantic_node(&ff_b).unwrap(); + self.less_than(node_a.var(), node_b.var()) + } + }; + + let ite = Ite::new(o, f, g, h); + + if let Ite::IteConst(f) = ite { + return f; + } + + let hash = self.apply_table.read().unwrap().hash(&ite); + if let Some(v) = self.apply_table.read().unwrap().get(ite, hash) { + return v; + } + + // ok the work! + // find the first essential variable for f, g, or h + let lbl = self.order.read().unwrap().first_essential(&f, &g, &h); + let fx = self.condition_essential(f, lbl, true); + let gx = self.condition_essential(g, lbl, true); + let hx = self.condition_essential(h, lbl, true); + let fxn = self.condition_essential(f, lbl, false); + let gxn = self.condition_essential(g, lbl, false); + let hxn = self.condition_essential(h, lbl, false); + let t = self.ite(fx, gx, hx); + let f = self.ite(fxn, gxn, hxn); + + if t == f { + return t; + }; + + // now we have a new BDD + let node = + SemanticBddNode::new_from_builder(lbl, f.semantic_hash(), t.semantic_hash(), self); + let r = self.get_or_insert(node); + self.apply_table.write().unwrap().insert(ite, r, hash); + r + } + + // condition a BDD *only* if the top variable is `v`; used in `ite` + fn condition_essential( + &'a self, + f: SemanticBddPtr<'a, P>, + lbl: VarLabel, + v: bool, + ) -> SemanticBddPtr<'a, P> { + match f { + SemanticBddPtr::PtrTrue | SemanticBddPtr::PtrFalse => f, + SemanticBddPtr::Reg(semantic_hash, _) | SemanticBddPtr::Compl(semantic_hash, _) => { + let node = self.deref_semantic_node(&semantic_hash).unwrap(); + if node.var() != lbl { + return f; + } + let r = if v { node.high(self) } else { node.low(self) }; + if f.is_neg() { + r.neg() + } else { + r + } + } + } + } + + fn cond_with_alloc( + &'a self, + bdd: SemanticBddPtr<'a, P>, + lbl: VarLabel, + value: bool, + alloc: &mut Vec>, + ) -> SemanticBddPtr<'a, P> { + self.stats.write().unwrap().num_recursive_calls += 1; + match bdd { + SemanticBddPtr::PtrTrue | SemanticBddPtr::PtrFalse => bdd, + SemanticBddPtr::Reg(semantic_hash, _) | SemanticBddPtr::Compl(semantic_hash, _) => { + let node = self.deref_semantic_node(&semantic_hash).unwrap(); + if self.order.read().unwrap().lt(lbl, node.var()) { + // we passed the variable in the order, we will never find it + return bdd; + } + + if node.var() == lbl { + let r = if value { + node.high(self) + } else { + node.low(self) + }; + return if bdd.is_neg() { r.neg() } else { r }; + } + + // check cache + match bdd.scratch::() { + None => (), + Some(v) => { + return if bdd.is_neg() { + alloc[v].neg() + } else { + alloc[v] + } + } + }; + + // recurse on the children + let l = self.cond_with_alloc(node.low(self), lbl, value, alloc); + let h = self.cond_with_alloc(node.high(self), lbl, value, alloc); + + if l == h { + // reduce the BDD -- two children identical + if bdd.is_neg() { + return l.neg(); + } else { + return l; + }; + }; + let res = if l != node.low(self) || h != node.high(self) { + // cache and return the new BDD + let new_bdd = SemanticBddNode::new_from_builder( + node.var(), + l.semantic_hash(), + h.semantic_hash(), + self, + ); + let r = self.get_or_insert(new_bdd); + if bdd.is_neg() { + r.neg() + } else { + r + } + } else { + // nothing changed + bdd + }; + + let idx = if bdd.is_neg() { + alloc.push(res.neg()); + alloc.len() - 1 + } else { + alloc.push(res); + alloc.len() - 1 + }; + bdd.set_scratch(idx); + res + } + } + } + + fn collapse_clauses(&'a self, vec: &[SemanticBddPtr<'a, P>]) -> Option> { + if vec.is_empty() { + None + } else if vec.len() == 1 { + Some(vec[0]) + } else { + let (l, r) = vec.split_at(vec.len() / 2); + let sub_l = self.collapse_clauses(l); + let sub_r = self.collapse_clauses(r); + match (sub_l, sub_r) { + (None, None) => None, + (Some(v), None) | (None, Some(v)) => Some(v), + (Some(l), Some(r)) => Some(self.and(l, r)), + } + } + } + + fn get_bdd_ptr(&'a self, semantic_hash: &FiniteField

) -> Option> { + match semantic_hash.value() { + 0 => Some(SemanticBddPtr::PtrFalse), + 1 => Some(SemanticBddPtr::PtrTrue), + _ => { + if self + .compute_table + .read() + .unwrap() + .get(semantic_hash) + .is_some() + { + return Some(SemanticBddPtr::Reg(*semantic_hash, self)); + } + None + } + } + } + + fn check_cached_hash_and_neg( + &'a self, + semantic_hash: &FiniteField

, + ) -> Option> { + // check regular hash + if let Some(bdd) = self.get_bdd_ptr(semantic_hash) { + return Some(bdd); + } + + // check negated hash + let semantic_hash = semantic_hash.negate(); + if let Some(bdd) = self.get_bdd_ptr(&semantic_hash) { + return Some(bdd.neg()); + } + None + } + + // Normalizes and fetches a node from the store + fn get_or_insert(&'a self, bdd: SemanticBddNode

) -> SemanticBddPtr<'a, P> { + if let Some(ptr) = self.check_cached_hash_and_neg(&bdd.semantic_hash()) { + return ptr; + } + + let semantic_hash = bdd.semantic_hash(); + + self.compute_table + .write() + .unwrap() + .insert(semantic_hash, bdd); + + SemanticBddPtr::Reg(semantic_hash, self) + } + + pub fn merge_from<'b>( + &'a self, + other: &'b Self, + ptrs: &[SemanticBddPtr<'b, P>], + ) -> Vec> { + let new_roots = ptrs + .iter() + .map(|ptr| match *ptr { + SemanticBddPtr::PtrTrue => SemanticBddPtr::PtrTrue, + SemanticBddPtr::PtrFalse => SemanticBddPtr::PtrFalse, + SemanticBddPtr::Reg(node, _) => SemanticBddPtr::Reg(node, self), + SemanticBddPtr::Compl(node, _) => SemanticBddPtr::Compl(node, self), + }) + .collect(); + + for (k, v) in other.compute_table.read().unwrap().iter() { + self.compute_table.write().unwrap().insert(*k, v.clone()); + } + + for (k, v) in other.apply_table.read().unwrap().iter() { + self.apply_table.write().unwrap().insert_directly(*k, *v); + } + + new_roots + } +} + +unsafe impl Send for SemanticBddBuilder<'_, P> {} +unsafe impl Sync for SemanticBddBuilder<'_, P> {} + +#[cfg(test)] +mod test { + use crate::{ + builder::BottomUpBuilder, + constants::primes, + repr::{VarLabel, VarOrder}, + }; + + use super::SemanticBddBuilder; + + #[test] + fn trivial_semantic_builder() { + let order = VarOrder::linear_order(2); + let builder: SemanticBddBuilder<'_, { primes::U64_LARGEST }> = + SemanticBddBuilder::new(order); + + println!("{:?}", builder.map()); + + let v1 = builder.var(VarLabel::new(0), true); + let v2 = builder.var(VarLabel::new(1), true); + let r1 = builder.or(v1, v2); + let r2 = builder.and(r1, v1); + + assert!(builder.eq(v1, r2), "Not eq:\n {:?}\n{:?}", v1, r2); + } + + #[test] + fn e2e_merge() { + let order = VarOrder::linear_order(2); + let builder: SemanticBddBuilder<'_, { primes::U64_LARGEST }> = + SemanticBddBuilder::new(order); + + println!("{:?}", builder.map()); + + let v1 = builder.var(VarLabel::new(0), true); + let v2 = builder.var(VarLabel::new(1), true); + let r1 = builder.or(v1, v2); + let r2 = builder.and(r1, v1); + + let order = VarOrder::linear_order(2); + let builder2: SemanticBddBuilder<'_, { primes::U64_LARGEST }> = + SemanticBddBuilder::new_with_map(order, builder.map().clone()); + + let v3 = builder2.var(VarLabel::new(0), true); + let v4 = builder2.var(VarLabel::new(1), true); + let r3 = builder2.and(v3, v4); + let r4 = builder2.and(r3, v3); + + // this should always be true... + assert!( + builder.eq(v1, r2), + "Invariant, pre-merge: Not eq:\n {:?}\n{:?}", + v1, + r2 + ); + assert!( + builder2.eq(r3, r4), + "Invariant, pre-merge: Not eq:\n {:?}\n{:?}", + r3, + r4 + ); + + println!("starting merge..."); + + let res = builder.merge_from(&builder2, &[r3, r4]); + + println!("merge done..."); + + // and still be true *after* the merge + assert!( + builder.eq(v1, r2), + "Invariant, post-merge: Not eq:\n {:?}\n{:?}", + v1, + r2 + ); + assert!( + builder.eq(res[0], res[1]), + "Invariant, post-merge: Not eq:\n {:?}\n{:?}", + res[0], + res[1] + ); + } +} diff --git a/src/repr/mod.rs b/src/repr/mod.rs index c6b8190e..c10c7822 100644 --- a/src/repr/mod.rs +++ b/src/repr/mod.rs @@ -11,6 +11,7 @@ mod dtree; mod logical_expr; mod model; mod sdd; +mod semantic_bdd; mod unit_prop; mod var_label; mod var_order; @@ -24,7 +25,7 @@ pub use self::dtree::*; pub use self::logical_expr::*; pub use self::model::*; pub use self::sdd::*; -pub use self::sdd::*; +pub use self::semantic_bdd::*; pub use self::unit_prop::*; pub use self::var_label::*; pub use self::var_order::*; diff --git a/src/repr/semantic_bdd.rs b/src/repr/semantic_bdd.rs new file mode 100644 index 00000000..87dd6c7e --- /dev/null +++ b/src/repr/semantic_bdd.rs @@ -0,0 +1,428 @@ +use std::{ + any::Any, + fmt::Debug, + hash::{Hash, Hasher}, + sync::RwLock, +}; + +use crate::{ + builder::parallel::SemanticBddBuilder, + repr::var_label::VarSet, + util::semirings::{FiniteField, Semiring}, +}; + +use super::{ + ddnnf::{DDNNFPtr, DDNNF}, + var_label::VarLabel, + var_order::{PartialVariableOrder, VarOrder}, +}; + +use SemanticBddPtr::*; + +#[derive(Clone, Copy)] +pub enum SemanticBddPtr<'a, const P: u128> { + PtrTrue, + PtrFalse, + Reg(FiniteField

, &'a SemanticBddBuilder<'a, P>), + Compl(FiniteField

, &'a SemanticBddBuilder<'a, P>), +} + +impl<'a, const P: u128> SemanticBddPtr<'a, P> { + /// Gets the scratch value stored in `&self` + /// + /// Panics if not node. + pub fn scratch(&self) -> Option { + match self { + Compl(semantic_hash, builder) | Reg(semantic_hash, builder) => { + let n = builder.deref_semantic_node(semantic_hash).unwrap(); + if self.is_scratch_cleared() { + return None; + } + let x = n + .data + .read() + .unwrap() + .as_ref() + .unwrap() + .as_ref() + .downcast_ref::() + .cloned(); + + x + } + PtrTrue => None, + PtrFalse => None, + } + } + + /// Set the scratch in this node to the value `v`. + /// + /// Panics if not a node. + /// + /// Invariant: values stored in `set_scratch` must not outlive + /// the provided allocator `alloc` (i.e., calling `scratch` + /// involves dereferencing a pointer stored in `alloc`) + pub fn set_scratch(&self, v: T) { + match self { + Compl(semantic_hash, builder) | Reg(semantic_hash, builder) => { + let n = builder.deref_semantic_node(semantic_hash).unwrap(); + *n.data.write().unwrap() = Some(Box::new(v)); + } + _ => panic!("attempting to store scratch on constant"), + } + } + + /// Traverses the BDD and clears all scratch memory (sets it equal to 0) + pub fn clear_scratch(&self) { + match &self { + Compl(semantic_hash, builder) | Reg(semantic_hash, builder) => { + let n = builder.deref_semantic_node(semantic_hash).unwrap(); + if n.data.read().unwrap().is_some() { + *n.data.write().unwrap() = None; + n.low(builder).clear_scratch(); + n.high(builder).clear_scratch(); + } + } + PtrTrue | PtrFalse => (), + } + } + + /// true if the scratch is current cleared + pub fn is_scratch_cleared(&self) -> bool { + match self { + Compl(semantic_hash, builder) | Reg(semantic_hash, builder) => { + let n = builder.deref_semantic_node(semantic_hash).unwrap(); + let x = n.data.read().unwrap().is_none(); + x + } + PtrTrue => true, + PtrFalse => true, + } + } +} + +type DDNNFCache = (Option, Option); + +impl<'a, const P: u128> DDNNFPtr<'a> for SemanticBddPtr<'a, P> { + type Order = VarOrder; + + fn fold) -> T>(&self, _o: &VarOrder, f: F) -> T + where + T: 'static, + { + debug_assert!(self.is_scratch_cleared()); + fn bottomup_pass_h) -> T, const P: u128>( + ptr: SemanticBddPtr

, + f: &F, + ) -> T + where + T: 'static, + { + match ptr { + PtrTrue => f(DDNNF::True), + PtrFalse => f(DDNNF::False), + Compl(semantic_hash, builder) | Reg(semantic_hash, builder) => { + let node = builder.deref_semantic_node(&semantic_hash).unwrap(); + // inside the cache, store a (compl, non_compl) pair corresponding to the + // complemented and uncomplemented pass over this node + + // helper performs actual fold-and-cache work + let bottomup_helper = |cached| { + let (l, h) = if ptr.is_neg() { + (node.low(builder).neg(), node.high(builder).neg()) + } else { + (node.low(builder), node.high(builder)) + }; + + let low_v = bottomup_pass_h(l, f); + let high_v = bottomup_pass_h(h, f); + let top = node.var(); + + let lit_high = f(DDNNF::Lit(top, true)); + let lit_low = f(DDNNF::Lit(top, false)); + + let and_low = f(DDNNF::And(lit_low, low_v)); + let and_high = f(DDNNF::And(lit_high, high_v)); + + // in a BDD, each decision only depends on the topvar + let mut varset = VarSet::new(); + varset.insert(top); + + let or_v = f(DDNNF::Or(and_low, and_high, varset)); + + // cache and return or_v + if ptr.is_neg() { + ptr.set_scratch::>((Some(or_v), cached)); + } else { + ptr.set_scratch::>((cached, Some(or_v))); + } + or_v + }; + + match ptr.scratch::>() { + // first, check if cached; explicit arms here for clarity + Some((Some(l), Some(h))) => { + if ptr.is_neg() { + l + } else { + h + } + } + Some((Some(v), None)) if ptr.is_neg() => v, + Some((None, Some(v))) if !ptr.is_neg() => v, + // no cached value found, compute it + Some((None, cached)) | Some((cached, None)) => bottomup_helper(cached), + None => bottomup_helper(None), + } + } + } + } + + let r = bottomup_pass_h(*self, &f); + self.clear_scratch(); + r + } + + fn neg(&self) -> Self { + match self { + PtrTrue => PtrFalse, + PtrFalse => PtrTrue, + Reg(n, b) => Compl(*n, b), + Compl(n, b) => Reg(*n, b), + } + } + + fn false_ptr() -> Self { + PtrFalse + } + + fn true_ptr() -> Self { + PtrTrue + } + + fn is_neg(&self) -> bool { + match &self { + Compl(_, _) => true, + Reg(_, _) | PtrTrue | PtrFalse => false, + } + } + + fn is_true(&self) -> bool { + match &self { + Compl(_, _) | Reg(_, _) | PtrFalse => false, + PtrTrue => true, + } + } + + fn is_false(&self) -> bool { + match &self { + Compl(_, _) | Reg(_, _) | PtrTrue => false, + PtrFalse => true, + } + } + + fn count_nodes(&self) -> usize { + debug_assert!(self.is_scratch_cleared()); + + let mut count = 0; + self.count_h(&mut count); + self.clear_scratch(); + count + } +} + +impl<'a, const P: u128> SemanticBddPtr<'a, P> { + pub fn semantic_hash(&self) -> FiniteField

{ + match self { + PtrTrue => FiniteField::one(), + PtrFalse => FiniteField::zero(), + Reg(node, _) => *node, + Compl(node, _) => node.negate(), + } + } + + fn count_h(self, count: &mut usize) { + match self { + PtrTrue | PtrFalse => (), + Compl(semantic_hash, builder) | Reg(semantic_hash, builder) => { + match self.scratch::() { + Some(_) => (), + None => { + // found a new node + *count += 1; + self.set_scratch::(0); + let n = builder.deref_semantic_node(&semantic_hash).unwrap(); + n.low(builder).count_h(count); + n.high(builder).count_h(count); + } + } + } + } + } +} + +impl<'a, const P: u128> PartialEq for SemanticBddPtr<'a, P> { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Compl(l0, _), Self::Compl(r0, _)) => l0 == r0, + (Self::Reg(l0, _), Self::Reg(r0, _)) => l0 == r0, + _ => core::mem::discriminant(self) == core::mem::discriminant(other), + } + } +} + +impl<'a, const P: u128> Eq for SemanticBddPtr<'a, P> {} + +impl<'a, const P: u128> Hash for SemanticBddPtr<'a, P> { + fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + match self { + Compl(n, _) | Reg(n, _) => n.hash(state), + _ => (), + } + } +} + +impl<'a, const P: u128> Debug for SemanticBddPtr<'a, P> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::PtrTrue => write!(f, "PtrTrue"), + Self::PtrFalse => write!(f, "PtrFalse"), + Self::Reg(arg0, _) => f.debug_tuple("Reg").field(arg0).finish(), + Self::Compl(arg0, _) => f.debug_tuple("Compl").field(arg0).finish(), + } + } +} + +impl<'a, const P: u128> PartialVariableOrder for SemanticBddPtr<'a, P> { + fn var(&self) -> Option { + match self { + PtrTrue | PtrFalse => None, + Compl(semantic_hash, builder) | Reg(semantic_hash, builder) => { + let n = builder.deref_semantic_node(semantic_hash).unwrap(); + Some(n.var()) + } + } + } +} + +#[derive(Debug)] +pub struct SemanticBddNode { + var: VarLabel, + hash: FiniteField

, + low_hash: FiniteField

, + high_hash: FiniteField

, + /// scratch space used for caching data during traversals; ignored during + /// equality checking and hashing + data: RwLock>>, +} + +impl SemanticBddNode

{ + pub fn new( + var: VarLabel, + hash: FiniteField

, + low_hash: FiniteField

, + high_hash: FiniteField

, + ) -> SemanticBddNode

{ + SemanticBddNode { + var, + hash, + low_hash, + high_hash, + data: RwLock::new(None), + } + } + + pub fn new_from_builder<'a>( + var: VarLabel, + low_hash: FiniteField

, + high_hash: FiniteField

, + builder: &'a SemanticBddBuilder<'a, P>, + ) -> SemanticBddNode

{ + let (low_w, high_w) = builder.map().var_weight(var); + let hash = low_hash * (*low_w) + high_hash * (*high_w); + + SemanticBddNode { + var, + hash, + low_hash, + high_hash, + data: RwLock::new(None), + } + } + + pub fn semantic_hash(&self) -> FiniteField

{ + self.hash + } + + pub fn low<'a>(&self, builder: &'a SemanticBddBuilder<'a, P>) -> SemanticBddPtr<'a, P> { + builder.deref_semantic_hash(&self.low_hash) + } + + pub fn high<'a>(&self, builder: &'a SemanticBddBuilder<'a, P>) -> SemanticBddPtr<'a, P> { + builder.deref_semantic_hash(&self.high_hash) + } + + pub fn var(&self) -> VarLabel { + self.var + } +} + +impl PartialEq for SemanticBddNode

{ + fn eq(&self, other: &Self) -> bool { + self.var == other.var + && self.hash == other.hash + && self.low_hash == other.low_hash + && self.high_hash == other.high_hash + } +} + +impl Hash for SemanticBddNode

{ + fn hash(&self, state: &mut H) { + self.var.hash(state); + self.hash.hash(state); + self.low_hash.hash(state); + self.high_hash.hash(state); + } +} + +impl Eq for SemanticBddNode

{} + +impl PartialOrd for SemanticBddNode

{ + fn partial_cmp(&self, other: &Self) -> Option { + match self.var.partial_cmp(&other.var) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + match self.low_hash.partial_cmp(&other.low_hash) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + match self.high_hash.partial_cmp(&other.high_hash) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + Some(core::cmp::Ordering::Equal) + } +} + +impl Clone for SemanticBddNode

{ + fn clone(&self) -> Self { + Self { + var: self.var, + hash: self.hash, + low_hash: self.low_hash, + high_hash: self.high_hash, + data: RwLock::new(None), + } + } +} + +impl Ord for SemanticBddNode

{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).unwrap() + } +} + +unsafe impl Send for SemanticBddNode

{} +unsafe impl Sync for SemanticBddNode

{} diff --git a/src/util/semirings/finitefield.rs b/src/util/semirings/finitefield.rs index 03f5c83e..1884410d 100644 --- a/src/util/semirings/finitefield.rs +++ b/src/util/semirings/finitefield.rs @@ -4,7 +4,7 @@ use core::fmt::Debug; /// a finite-field abstraction. The parameter `p` is the size of the field. use std::{fmt::Display, ops}; -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct FiniteField { v: u128, } diff --git a/tests/test.rs b/tests/test.rs index 3f646237..310ec255 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1194,3 +1194,106 @@ mod test_dnnf_builder { } } } + +#[cfg(test)] +mod test_parallel_semantic_builder { + use rsdd::{ + builder::{ + bdd::RobddBuilder, cache::AllIteTable, parallel::SemanticBddBuilder, BottomUpBuilder, + }, + constants::primes, + repr::{create_semantic_hash_map, BddPtr, Cnf, DDNNFPtr, VarOrder}, + }; + + quickcheck! { + fn test_semantic_and_robdd_agree_on_wmc(c: Cnf) -> bool { + let order = VarOrder::linear_order(c.num_vars()); + + let robdd_builder = RobddBuilder::>::new(order.clone()); + let robdd_cnf = robdd_builder.compile_cnf(&c); + + let semantic_builder = SemanticBddBuilder::<{ primes::U64_LARGEST }>::new(order.clone()); + let semantic_cnf = semantic_builder.compile_cnf(&c); + + let params = semantic_builder.map(); + + let robdd_wmc = robdd_cnf.wmc(&order, params); + let semantic_wmc = semantic_cnf.wmc(&order, params); + + let eps = robdd_wmc == semantic_wmc; + if !eps { + println!("error on input {}: std wmc: {}, sem wmc: {}", + c, robdd_wmc, semantic_wmc); + } + eps + } + } + + quickcheck! { + fn arbitrary_merge_maintains_wmc(c1: Cnf, c2: Cnf) -> bool { + let num_vars = std::cmp::max(c1.num_vars(), c2.num_vars()); + let map = create_semantic_hash_map(num_vars); + let order = VarOrder::linear_order(num_vars); + + let builder_1 = SemanticBddBuilder::<{ primes::U64_LARGEST }>::new_with_map(order.clone(), map.clone()); + let cnf_1 = builder_1.compile_cnf(&c1); + let cnf_1_wmc = cnf_1.wmc(&order, &map); + + let builder_2 = SemanticBddBuilder::<{ primes::U64_LARGEST }>::new_with_map(order.clone(), map.clone()); + let cnf_2 = builder_2.compile_cnf(&c2); + let cnf_2_wmc = cnf_2.wmc(&order, &map); + + builder_1.merge_from(&builder_2, &[]); + + let new_cnf_1_wmc = cnf_1.wmc(&order, &map); + + let cnf_2 = builder_1.compile_cnf(&c2); + let new_cnf_2_wmc = cnf_2.wmc(&order, &map); + + + let eq_1 = cnf_1_wmc == new_cnf_1_wmc; + if !eq_1 { + println!("error on input {}: pre-merge wmc: {}, post-merge wmc: {}", + c1, cnf_1_wmc, new_cnf_1_wmc); + } + + let eq_2 = cnf_2_wmc == new_cnf_2_wmc; + if !eq_2 { + println!("error on input {}: pre-merge (b2) wmc: {}, post-merge (b1) wmc: {}", + c2, cnf_2_wmc, new_cnf_2_wmc); + } + + eq_1 && eq_2 + } + } + + quickcheck! { + fn arbitrary_merge_maintains_pointers(c1: Cnf) -> bool { + let num_vars = c1.num_vars(); + let map = create_semantic_hash_map(num_vars); + let order = VarOrder::linear_order(num_vars); + + let builder_1 = SemanticBddBuilder::<{ primes::U64_LARGEST }>::new_with_map(order.clone(), map.clone()); + + let builder_2 = SemanticBddBuilder::<{ primes::U64_LARGEST }>::new_with_map(order.clone(), map.clone()); + let cnf = builder_2.compile_cnf(&c1); + let cnf_wmc = cnf.wmc(&order, &map); + let cnf_str = format!("{:?}", cnf); + + let new_cnf = builder_1.merge_from(&builder_2, &[cnf])[0]; + let new_cnf_str = format!("{:?}", cnf); + + let new_cnf_wmc = new_cnf.wmc(&order, &map); + + let eq = cnf_wmc == new_cnf_wmc; + if !eq { + println!("error on input {}: pre-merge (b2) wmc: {}, post-merge (b1) wmc: {}", c1 , cnf_wmc, new_cnf_wmc); + println!("old: {}", cnf_str); + println!("new: {}", new_cnf_str); + println!("order: {}, map: {:?}", order, map); + } + + eq + } + } +}