diff --git a/src/ir/opt/mem/obliv.rs b/src/ir/opt/mem/obliv.rs index 2074047e0..c3560c85b 100644 --- a/src/ir/opt/mem/obliv.rs +++ b/src/ir/opt/mem/obliv.rs @@ -3,228 +3,148 @@ //! This module attempts to identify *oblivious* arrays: those that are only accessed at constant //! indices. These arrays can be replaced with tuples. Then, a tuple elimination pass can be run. //! -//! It operates in two passes: +//! It operates in a single IO (inputs->outputs) pass, that computes two maps: //! -//! 1. determine which arrays are oblivious -//! 2. replace oblivious arrays with tuples -//! -//! -//! ## Pass 1: Identifying oblivious arrays -//! -//! We maintain a set of non-oblivious arrays, initially empty. We traverse the whole computation -//! system, performing the following inferences: -//! -//! * If `a[i]` for non-constant `i`, then `a` and `a[i]` are not oblivious; -//! * If `a[i]`, `a` and `a[i]` are equi-oblivious -//! * If `a[i\v]` for non-constant `i`, then neither `a[i\v]` nor `a` are oblivious -//! * If `a[i\v]`, then `a[i\v]` and `a` are equi-oblivious -//! * If `ite(c,a,b)`, then `ite(c,a,b)`, `a`, and `b` are equi-oblivious -//! * If `a=b`, then `a` and `b` are equi-oblivious -//! -//! This procedure is iterated to fixpoint. -//! -//! Notice that we flag some *array* terms as non-oblivious, and we also flag their derived select -//! terms as non-oblivious. This makes it easy to see which selects should be replaced later. -//! -//! ### Sharing & Constant Arrays -//! -//! This pass is effective given the somewhat naive assumption that array terms in the term graph -//! can be separated into different "threads", which are not connected. Sometimes they are, -//! especially by constant arrays. -//! -//! For example, consider code like this: -//! -//! ```ignore -//! x = [0, 0, 0, 0] -//! y = [0, 0, 0, 0] -//! // oblivious modifications to x -//! // non-oblivious modifications to y -//! ``` -//! -//! In this situation, we would hope that x and its derived arrays will be identified as -//! "oblivious" while y will not. -//! -//! However, because of term sharing, the constant array [0,0,0,0] happens to be the root of both -//! x's and y's store chains. If the constant array is `c`, then the first store to x might be -//! `c[0\v1]` while the first store to y might be `c[i2\v2]`. The "store" rule for non-oblivious -//! analysis would say that `c` is non-oblivious (b/c of the second store) and therefore the whole -//! x store chain would b too... -//! -//! The problem isn't just with constants. If any non-oblivious stores branch off an otherwise -//! oblivious store chain, the same thing happens. -//! -//! Since constants are a pervasive problem, we special-case them, omitting them from the analysis. -//! -//! We probably want a better idea of what this pass does (and how to handle arrays) at some -//! point... -//! -//! ## Pass 2: Replacing oblivious arrays with term lists. -//! -//! In this pass, the goal is to -//! -//! * map array terms to tuple terms -//! * map array selections to tuple field gets -//! -//! In both cases we look at the non-oblivious array/select set to see whether to do the -//! replacement. +//! * `R`: the rewrite map; keys map to values of the same sort. This is the canonical rewrite map. +//! * `T`: the map from a term, to one whose sort has arrays replaced with tuples at the top of the sort tree +//! * if some select has a constant index and is against an entry of T, then: +//! * we add ((field i) T_ENTRY) to T for that select +//! * if the above has scalar sort, we add it to R //! +//! So, essentially, what's going on is that T maps each term t to an (approximate) analysis of t +//! that indicates which accesses can be perfectly resolved. -use super::super::visit::*; use crate::ir::term::extras::as_uint_constant; use crate::ir::term::*; -use log::debug; +use log::{debug, trace}; -struct NonOblivComputer { - not_obliv: TermSet, +#[derive(Default)] +struct OblivRewriter { + tups: TermMap, + terms: TermMap, } -impl NonOblivComputer { - fn mark(&mut self, a: &Term) -> bool { - if !a.is_const() && self.not_obliv.insert(a.clone()) { - debug!("Not obliv: {}", a); - true - } else { - false - } - } +fn suitable_const(t: &Term) -> bool { + t.is_const() && matches!(check(t), Sort::BitVector(_) | Sort::Field(_) | Sort::Bool) +} - fn bi_implicate(&mut self, a: &Term, b: &Term) -> bool { - if !a.is_const() && !b.is_const() { - match (self.not_obliv.contains(a), self.not_obliv.contains(b)) { - (false, true) => { - self.not_obliv.insert(a.clone()); - true - } - (true, false) => { - self.not_obliv.insert(b.clone()); - true - } - _ => false, - } - } else { - false - } +impl OblivRewriter { + fn get_t(&self, t: &Term) -> &Term { + self.tups.get(t).unwrap_or(self.terms.get(t).unwrap()) } - - fn new() -> Self { - Self { - not_obliv: TermSet::default(), - } + fn get(&self, t: &Term) -> &Term { + self.terms.get(t).unwrap() } -} - -impl ProgressAnalysisPass for NonOblivComputer { - fn visit(&mut self, term: &Term) -> bool { - match &term.op() { - Op::Store | Op::CStore => { - let a = &term.cs()[0]; - let i = &term.cs()[1]; - let v = &term.cs()[2]; - let mut progress = false; - if let Sort::Array(..) = check(v) { - // Imprecisely, mark v as non-obliv iff the array is. - progress = self.bi_implicate(term, v) || progress; - } - if let Op::Const(_) = i.op() { - progress = self.bi_implicate(term, a) || progress; - } else { - progress = self.mark(a) || progress; - progress = self.mark(term) || progress; - } - if let Sort::Array(..) = check(v) { - // Imprecisely, mark v as non-obliv iff the array is. - progress = self.bi_implicate(term, v) || progress; - } - progress - } - Op::Array(..) => { - let mut progress = false; - if !term.cs().is_empty() { - if let Sort::Array(..) = check(&term.cs()[0]) { - progress = self.bi_implicate(term, &term.cs()[0]) || progress; - for i in 0..term.cs().len() - 1 { - progress = - self.bi_implicate(&term.cs()[i], &term.cs()[i + 1]) || progress; - } - for i in (0..term.cs().len() - 1).rev() { - progress = - self.bi_implicate(&term.cs()[i], &term.cs()[i + 1]) || progress; + fn visit(&mut self, t: &Term) { + let (tup_opt, term_opt) = match t.op() { + Op::Const(v @ Value::Array(_)) => (Some(leaf_term(Op::Const(arr_val_to_tup(v)))), None), + Op::Array(_k, _v) => ( + Some(term( + Op::Tuple, + t.cs().iter().map(|c| self.get_t(c)).cloned().collect(), + )), + None, + ), + Op::Fill(_k, size) => ( + Some(term(Op::Tuple, vec![self.get_t(&t.cs()[0]).clone(); *size])), + None, + ), + Op::Store => { + let a = &t.cs()[0]; + let i = &t.cs()[1]; + let v = &t.cs()[2]; + ( + if let Some(aa) = self.tups.get(a) { + if suitable_const(i) { + debug!("simplify store {}", i); + Some(term![Op::Update(get_const(i)); aa.clone(), self.get_t(v).clone()]) + } else { + None } - progress = self.bi_implicate(term, &term.cs()[0]) || progress; - } - } - progress - } - Op::Fill(..) => { - let v = &term.cs()[0]; - if let Sort::Array(..) = check(v) { - self.bi_implicate(term, &term.cs()[0]) - } else { - false - } + } else { + None + }, + None, + ) } Op::Select => { - // Even though the selected value may not have array sort, we still flag it as - // non-oblivious so we know whether to replace it or not. - let a = &term.cs()[0]; - let i = &term.cs()[1]; - let mut progress = false; - if let Op::Const(_) = i.op() { - // pass - } else { - progress = self.mark(a) || progress; - progress = self.mark(term) || progress; - } - progress = self.bi_implicate(term, a) || progress; - progress - } - Op::Var(..) => { - if let Sort::Array(..) = check(term) { - self.mark(term) + let a = &t.cs()[0]; + let i = &t.cs()[1]; + if let Some(aa) = self.tups.get(a) { + if suitable_const(i) { + debug!("simplify select {}", i); + let tt = term![Op::Field(get_const(i)); aa.clone()]; + ( + Some(tt.clone()), + if check(&tt).is_scalar() { + Some(tt) + } else { + None + }, + ) + } else { + (None, None) + } } else { - false + (None, None) } } Op::Ite => { - let t = &term.cs()[1]; - let f = &term.cs()[2]; - if let Sort::Array(..) = check(t) { - let mut progress = self.bi_implicate(term, t); - progress = self.bi_implicate(t, f) || progress; - progress = self.bi_implicate(term, f) || progress; - progress - } else { - false - } + let cond = &t.cs()[0]; + let case_t = &t.cs()[1]; + let case_f = &t.cs()[2]; + ( + if let (Some(tt), Some(ff)) = (self.tups.get(case_t), self.tups.get(case_f)) { + Some(term![Op::Ite; self.get(cond).clone(), tt.clone(), ff.clone()]) + } else { + None + }, + None, + ) } Op::Eq => { - let a = &term.cs()[0]; - let b = &term.cs()[1]; - if let Sort::Array(..) = check(a) { - self.bi_implicate(a, b) - } else { - false - } + let a = &t.cs()[0]; + let b = &t.cs()[1]; + ( + None, + if let (Some(aa), Some(bb)) = (self.tups.get(a), self.tups.get(b)) { + Some(term![Op::Eq; aa.clone(), bb.clone()]) + } else { + None + }, + ) } - Op::Tuple => { - panic!("Tuple in obliv") - } - _ => false, + Op::Tuple => panic!("Tuple in obliv"), + _ => (None, None), + }; + if let Some(tup) = tup_opt { + trace!("Tuple rw: \n{}\nto\n{}", t, tup); + self.tups.insert(t.clone(), tup); } - } -} + let new_t = term_opt.unwrap_or_else(|| { + term( + t.op().clone(), + t.cs().iter().map(|c| self.get(c)).cloned().collect(), + ) + }); -struct Replacer { - /// The maximum size of arrays that will be replaced. - not_obliv: TermSet, + self.terms.insert(t.clone(), new_t); + } } -impl Replacer { - fn should_replace(&self, a: &Term) -> bool { - !self.not_obliv.contains(a) +/// Eliminate oblivious arrays. See module documentation. +pub fn elim_obliv(c: &mut Computation) { + let mut pass = OblivRewriter::default(); + for t in c.terms_postorder() { + pass.visit(&t); + } + for o in &mut c.outputs { + debug_assert!(check(o).is_scalar()); + *o = pass.get(o).clone(); } } + fn arr_val_to_tup(v: &Value) -> Value { match v { Value::Array(Array { @@ -240,13 +160,6 @@ fn arr_val_to_tup(v: &Value) -> Value { } } -fn term_arr_val_to_tup(a: Term) -> Term { - match &a.op() { - Op::Const(v @ Value::Array(..)) => leaf_term(Op::Const(arr_val_to_tup(v))), - _ => a, - } -} - #[track_caller] fn get_const(t: &Term) -> usize { as_uint_constant(t) @@ -255,108 +168,6 @@ fn get_const(t: &Term) -> usize { .expect("oversize") } -impl RewritePass for Replacer { - fn visit Vec>( - &mut self, - computation: &mut Computation, - orig: &Term, - rewritten_children: F, - ) -> Option { - //debug!("Visit {}", extras::Letified(orig.clone())); - let get_cs = || -> Vec { - rewritten_children() - .into_iter() - .map(term_arr_val_to_tup) - .collect() - }; - match &orig.op() { - Op::Var(name, Sort::Array(..)) => { - if self.should_replace(orig) { - let precomp = extras::array_to_tuple(orig); - let new_name = format!("{name}.tup"); - let new_sort = check(&precomp); - computation.extend_precomputation(new_name.clone(), precomp); - Some(leaf_term(Op::Var(new_name, new_sort))) - } else { - None - } - } - Op::Select => { - // we mark the selected term as non-obliv... - if self.should_replace(orig) { - let mut cs = get_cs(); - debug_assert_eq!(cs.len(), 2); - let k_const = get_const(&cs.pop().unwrap()); - Some(term(Op::Field(k_const), cs)) - } else { - None - } - } - Op::Store => { - if self.should_replace(orig) { - let mut cs = get_cs(); - debug_assert_eq!(cs.len(), 3); - let k_const = get_const(&cs.remove(1)); - Some(term(Op::Update(k_const), cs)) - } else { - None - } - } - Op::CStore => { - if self.should_replace(orig) { - let mut cs = get_cs(); - debug_assert_eq!(cs.len(), 4); - let cond = cs.remove(3); - let k_const = get_const(&cs.remove(1)); - let orig = cs[0].clone(); - Some(term![ITE; cond, term(Op::Update(k_const), cs), orig]) - } else { - None - } - } - Op::Array(..) => { - if self.should_replace(orig) { - Some(term(Op::Tuple, get_cs())) - } else { - None - } - } - Op::Fill(_, size) => { - if self.should_replace(orig) { - Some(term(Op::Tuple, vec![get_cs().pop().unwrap(); *size])) - } else { - None - } - } - Op::Ite => { - if self.should_replace(orig) { - Some(term(Op::Ite, get_cs())) - } else { - None - } - } - Op::Eq => { - if self.should_replace(&orig.cs()[0]) { - Some(term(Op::Eq, get_cs())) - } else { - None - } - } - _ => None, - } - } -} - -/// Eliminate oblivious arrays. See module documentation. -pub fn elim_obliv(t: &mut Computation) { - let mut prop_pass = NonOblivComputer::new(); - prop_pass.traverse(t); - let mut replace_pass = Replacer { - not_obliv: prop_pass.not_obliv, - }; - ::traverse_full(&mut replace_pass, t, false, false) -} - #[cfg(test)] mod test { use super::*; @@ -374,6 +185,12 @@ mod test { true } + fn count_selects(t: &Term) -> usize { + PostOrderIter::new(t.clone()) + .filter(|t| matches!(t.op(), Op::Select)) + .count() + } + #[test] fn obliv() { let z = term![Op::Const(Value::Array(Array::new( @@ -471,4 +288,263 @@ mod test { assert!(!array_free(&c.outputs[0])); assert!(array_free(&c.outputs[1])); } + + #[test] + fn linear_stores_branching_selects() { + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs ) (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (a0 (#a (mod 11) #f0 4 ())) + (a1 (store a0 #f0 #f1)) + (x0 (select a1 #f0)) + (x1 (select a1 #f1)) + (a2 (store a1 #f0 #f1)) + (x2 (select a2 #f2)) + (x3 (select a2 #f3)) + (a3 (store a2 #f1 #f1)) + (x4 (select a3 #f0)) + (x5 (select a3 #f1)) + ) + (+ x0 x1 x2 x3 x4 x5) + )) + ) + ", + ); + elim_obliv(&mut c); + assert_eq!(count_selects(&c.outputs[0]), 0); + } + + #[test] + fn linear_stores_branching_selects_partial() { + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs (i (mod 11))) (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (a0 (#a (mod 11) #f0 4 ())) + (a1 (store a0 #f0 #f1)) + (x0 (select a1 #f0)) + (x1 (select a1 #f1)) + (a2 (store a1 #f0 #f1)) + (x2 (select a2 #f2)) + (x3 (select a2 #f3)) + (a3 (store a2 i #f1)) + (x4 (select a3 #f0)) + (x5 (select a3 #f1)) + ) + (+ x0 x1 x2 x3 x4 x5) + )) + ) + ", + ); + elim_obliv(&mut c); + assert_eq!(count_selects(&c.outputs[0]), 2); + } + + #[test] + fn linear_stores_branching_selects_partial_2() { + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs (i (mod 11))) (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (a0 (#a (mod 11) #f0 4 ())) + (a1 (store a0 #f0 #f1)) + (x0 (select a1 #f0)) + (x1 (select a1 #f1)) + (a2 (store a1 i #f1)) + (x2 (select a2 #f2)) + (x3 (select a2 #f3)) + (a3 (store a2 #f0 #f1)) + (x4 (select a3 #f0)) + (x5 (select a3 #f1)) + ) + (+ x0 x1 x2 x3 x4 x5) + )) + ) + ", + ); + elim_obliv(&mut c); + assert_eq!(count_selects(&c.outputs[0]), 4); + } + + #[test] + fn nest_obliv() { + env_logger::try_init().ok(); + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs (i (mod 11))) (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (a0 (#l (mod 11) ((#l (mod 11) (#f1 #f0)) (#l (mod 11) (#f0 #f1))))) + (a1 (store a0 #f0 (store (select a0 #f0) #f1 #f1))) + (x0 (select (select a1 #f0) #f0)) + (x1 (select (select a1 #f1) #f0)) + (a2 (store a1 #f1 (store (select a1 #f1) #f1 #f1))) + (x2 (select (select a2 #f0) #f1)) + (x3 (select (select a2 #f1) #f1)) + (a3 (store a2 #f1 (store (select a2 #f1) #f0 #f1))) + (x4 (select (select a3 #f1) #f0)) + (x5 (select (select a3 #f0) #f1)) + ) + (+ x0 x1 x2 x3 x4 x5) + )) + ) + ", + ); + elim_obliv(&mut c); + assert_eq!(count_selects(&c.outputs[0]), 0); + } + + #[test] + fn nest_obliv_partial() { + env_logger::try_init().ok(); + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs (i (mod 11))) (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (a0 (#l (mod 11) ((#l (mod 11) (#f1 #f0)) (#l (mod 11) (#f0 #f1))))) + (a1 (store a0 #f0 (store (select a0 #f0) #f1 #f1))) + (x0 (select (select a1 #f0) #f0)) + (x1 (select (select a1 #f1) #f0)) + (a2 (store a1 i (store (select a1 i) #f1 #f1))) ; not elim + (x2 (select (select a2 #f0) #f1)) ; not elim (2) + (x3 (select (select a2 #f1) #f1)) ; not elim (2) + (a3 (store a2 #f1 (store (select a2 #f1) #f0 #f1))) ; not elim (dup) + (x4 (select (select a3 #f1) #f0)) ; not elim (2) + (x5 (select (select a3 #f0) #f1)) ; not elim (2) + ) + (+ x0 x1 x2 x3 x4 x5) + )) + ) + ", + ); + let before = count_selects(&c.outputs[0]); + elim_obliv(&mut c); + assert!(count_selects(&c.outputs[0]) < before); + } + + #[test] + fn nest_no_obliv() { + env_logger::try_init().ok(); + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs (i (mod 11))) (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (a0 (#l (mod 11) ((#l (mod 11) (#f1 #f0)) (#l (mod 11) (#f0 #f1))))) + (a1 (store a0 i (store (select a0 i) #f1 #f1))) + (x0 (select (select a1 #f0) #f0)) + (x1 (select (select a1 #f1) #f0)) + (a2 (store a1 #f0 (store (select a1 #f0) #f1 #f1))) ; not elim + (x2 (select (select a2 #f0) #f1)) ; not elim (2) + (x3 (select (select a2 #f1) #f1)) ; not elim (2) + (a3 (store a2 #f1 (store (select a2 #f1) #f0 #f1))) ; not elim (dup) + (x4 (select (select a3 #f1) #f0)) ; not elim (2) + (x5 (select (select a3 #f0) #f1)) ; not elim (2) + ) + (+ x0 x1 x2 x3 x4 x5) + )) + ) + ", + ); + let before = count_selects(&c.outputs[0]); + elim_obliv(&mut c); + assert_eq!(count_selects(&c.outputs[0]), before); + } + + #[test] + fn two_array_ptr_chase_eq_size() { + env_logger::try_init().ok(); + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) + (inputs (x0 (mod 11)) + (x1 (mod 11)) + (x2 (mod 11)) + (x3 (mod 11)) + (x4 (mod 11)) + (i0 (mod 11)) + (i1 (mod 11)) + (i2 (mod 11)) + (i3 (mod 11)) + ) + (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (ax (store (store (store (store (#a (mod 11) #f0 4 ()) #f0 x0) #f1 x1) #f2 x2) #f3 x3)) + (ai (store (store (store (store (#a (mod 11) #f0 4 ()) #f0 i0) #f1 i1) #f2 i2) #f3 i3)) + (xi0 (select ax (select ai #f0))) + (xi1 (select ax (select ai #f1))) + (xi2 (select ax (select ai #f2))) + (xi3 (select ax (select ai #f3))) + ) + (+ xi0 xi1 xi2 xi3) + )) + ) + ", + ); + elim_obliv(&mut c); + assert_eq!(count_selects(&c.outputs[0]), 4); + } + + #[test] + fn two_array_ptr_chase_ne_size() { + env_logger::try_init().ok(); + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) + (inputs (x0 (mod 11)) + (x1 (mod 11)) + (x2 (mod 11)) + (x3 (mod 11)) + (x4 (mod 11)) + (i0 (mod 11)) + (i1 (mod 11)) + (i2 (mod 11)) + ) + (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (ax (store (store (store (store (#a (mod 11) #f0 4 ()) #f0 x0) #f1 x1) #f2 x2) #f3 x3)) + (ai (store (store (store (#a (mod 11) #f0 4 ()) #f0 i0) #f1 i1) #f2 i2)) + (xi0 (select ax (select ai #f0))) + (xi1 (select ax (select ai #f1))) + (xi2 (select ax (select ai #f2))) + ) + (+ xi0 xi1 xi2) + )) + ) + ", + ); + elim_obliv(&mut c); + assert_eq!(count_selects(&c.outputs[0]), 3); + } } diff --git a/src/ir/opt/mem/ram/volatile.rs b/src/ir/opt/mem/ram/volatile.rs index 8f77a6eb8..b915390b9 100644 --- a/src/ir/opt/mem/ram/volatile.rs +++ b/src/ir/opt/mem/ram/volatile.rs @@ -1,6 +1,10 @@ //! A general-purpose RAM extractor use super::*; +use fxhash::FxHashMap as HashMap; +use fxhash::FxHashSet as HashSet; +use std::collections::BinaryHeap; + use log::trace; /// Graph of the *arrays* in the computation. @@ -198,6 +202,46 @@ impl Extactor { } } +/// Given a set of terms, return an ordering of them in post-order, but also with array selects on +/// A before stores to A. +fn array_order<'a>(terms: HashSet<&'a Term>) -> Vec<&'a Term> { + let mut parents: HashMap<&'a Term, HashSet<&'a Term>> = Default::default(); + let mut children: HashMap<&'a Term, HashSet<&'a Term>> = Default::default(); + for t in &terms { + parents.entry(t).or_default(); + children.entry(t).or_default(); + for c in t.cs() { + debug_assert!(terms.contains(c)); + parents.entry(c).or_default().insert(t); + children.entry(t).or_default().insert(c); + } + } + let mut output: Vec<&'a Term> = Default::default(); + // max-heap contains (is_select, term) pairs; so, selects go first. + let mut to_output: BinaryHeap<(bool, &'a Term)> = terms + .iter() + .filter(|t| t.cs().is_empty()) + .map(|t| (false, *t)) + .collect(); + let mut children_not_outputted: HashMap<&'a Term, usize> = children + .iter() + .map(|(term, children)| (*term, children.len())) + .collect(); + while let Some((_, output_me)) = to_output.pop() { + output.push(output_me); + for p in parents.get(&output_me).unwrap() { + let count = children_not_outputted.get_mut(p).unwrap(); + assert!(*count > 0); + *count -= 1; + if *count == 0 { + to_output.push((matches!(p.op(), Op::Select), *p)); + } + } + } + assert_eq!(output.len(), terms.len()); + output +} + impl RewritePass for Extactor { fn visit Vec>( &mut self, @@ -242,9 +286,25 @@ impl RewritePass for Extactor { match &t.op() { // Rewrite select's whose array is a RAM term Op::Select if self.graph.ram_terms.contains(&t.cs()[0]) => { - let ram_id = self.get_or_start(&t.cs()[0]); + let array = &t.cs()[0]; + let idx = &t.cs()[1]; + // If we're based on a leaf + let ram_id = if array_leaf(array) { + // check if that leaf has a RAM already + if let Some(id) = self.term_ram.get(array) { + *id + } else { + let id = self.start_ram_for_leaf(array); + + self.term_ram.insert(array.clone(), id); + id + } + } else { + // otherwise, assume that our parent has a RAM already + *self.term_ram.get(array).unwrap() + }; let ram = &mut self.rams[ram_id]; - let read_value = ram.new_read(t.cs()[1].clone(), computation, t.clone()); + let read_value = ram.new_read(idx.clone(), computation, t.clone()); self.read_terms.insert(t.clone(), read_value.clone()); Some(read_value) } @@ -252,6 +312,32 @@ impl RewritePass for Extactor { } } } + + fn traverse(&mut self, computation: &mut Computation) { + let terms: Vec = computation.terms_postorder().collect(); + let term_refs: HashSet<&Term> = terms.iter().collect(); + let mut cache = TermMap::::default(); + for top in array_order(term_refs) { + debug_assert!(!cache.contains_key(top)); + let new_t_opt = self.visit_cache(computation, top, &cache); + let new_t = new_t_opt.unwrap_or_else(|| { + term( + top.op().clone(), + top.cs() + .iter() + .map(|c| cache.get(c).unwrap()) + .cloned() + .collect(), + ) + }); + cache.insert(top.clone(), new_t); + } + computation.outputs = computation + .outputs + .iter() + .map(|o| cache.get(o).unwrap().clone()) + .collect(); + } } /// Find arrays which are RAMs (i.e., accessed with a linear sequences of @@ -517,6 +603,83 @@ mod test { assert_eq!(cs, cs2); } + #[test] + fn rom() { + let cs = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs ) (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (c_array (#a (mod 11) #f0 4 ())) + (x0 (select c_array #f0)) + (x1 (select c_array #f1)) + (x2 (select c_array #f2)) + (x3 (select c_array #f3)) + ) + (+ x0 x1 x2 x3) + )) + ) + ", + ); + let mut cs2 = cs.clone(); + cstore::parse(&mut cs2); + let field = FieldT::from(rug::Integer::from(11)); + let rams = extract(&mut cs2, AccessCfg::default_from_field(field.clone())); + extras::assert_all_vars_declared(&cs2); + assert_ne!(cs, cs2); + assert_eq!(1, rams.len()); + assert_eq!(4, rams[0].accesses.len()); + } + + #[test] + fn multi_arm_tree() { + let cs = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs ) (commitments)) + (precompute () () (#t )) + (set_default_modulus 11 + (let + ( + (a0 (#a (mod 11) #f0 4 ())) + (a1 (store a0 #f0 #f1)) + (x0 (select a1 #f0)) + (x1 (select a1 #f1)) + (a2 (store a1 #f0 #f1)) + (x2 (select a2 #f2)) + (x3 (select a2 #f3)) + (a3 (store a2 #f1 #f1)) + (x4 (select a3 #f0)) + (x5 (select a3 #f1)) + ) + (+ x0 x1 x2 x3 x4 x5) + )) + ) + ", + ); + let mut cs2 = cs.clone(); + cstore::parse(&mut cs2); + let field = FieldT::from(rug::Integer::from(11)); + let rams = extract(&mut cs2, AccessCfg::default_from_field(field.clone())); + extras::assert_all_vars_declared(&cs2); + assert_ne!(cs, cs2); + assert_eq!(1, rams.len()); + assert_eq!(9, rams[0].accesses.len()); + println!("{:?}", rams[0].accesses); + assert_eq!(bool_lit(true), rams[0].accesses[0].write.b); + assert_eq!(bool_lit(false), rams[0].accesses[1].write.b); + assert_eq!(bool_lit(false), rams[0].accesses[2].write.b); + assert_eq!(bool_lit(true), rams[0].accesses[3].write.b); + assert_eq!(bool_lit(false), rams[0].accesses[4].write.b); + assert_eq!(bool_lit(false), rams[0].accesses[5].write.b); + assert_eq!(bool_lit(true), rams[0].accesses[6].write.b); + assert_eq!(bool_lit(false), rams[0].accesses[7].write.b); + assert_eq!(bool_lit(false), rams[0].accesses[8].write.b); + } + #[cfg(feature = "poly")] #[test] fn length_4() { diff --git a/src/ir/term/extras.rs b/src/ir/term/extras.rs index a1319b719..2715bcb91 100644 --- a/src/ir/term/extras.rs +++ b/src/ir/term/extras.rs @@ -110,6 +110,7 @@ pub fn as_uint_constant(t: &Term) -> Option { match &t.op() { Op::Const(Value::BitVector(bv)) => Some(bv.uint().clone()), Op::Const(Value::Field(f)) => Some(f.i()), + Op::Const(Value::Bool(b)) => Some((*b).into()), _ => None, } }