diff --git a/.gitignore b/.gitignore index c8e9e48..923e69d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ target Cargo.lock + +# IDE config files .vscode +.idea \ No newline at end of file diff --git a/crates/bellpepper-core/src/util_cs/mod.rs b/crates/bellpepper-core/src/util_cs/mod.rs index aff1f1d..ce316c3 100644 --- a/crates/bellpepper-core/src/util_cs/mod.rs +++ b/crates/bellpepper-core/src/util_cs/mod.rs @@ -1,6 +1,7 @@ -use crate::LinearCombination; use ff::PrimeField; +use crate::LinearCombination; + pub mod test_cs; pub type Constraint = ( diff --git a/crates/bellpepper/src/util_cs/mod.rs b/crates/bellpepper/src/util_cs/mod.rs index 6773180..fe16084 100644 --- a/crates/bellpepper/src/util_cs/mod.rs +++ b/crates/bellpepper/src/util_cs/mod.rs @@ -1,5 +1,7 @@ +pub use bellpepper_core::{Comparable, Constraint}; + pub mod bench_cs; pub mod metric_cs; +pub mod shape_cs; +pub mod test_shape_cs; pub mod witness_cs; - -pub use bellpepper_core::{Comparable, Constraint}; diff --git a/crates/bellpepper/src/util_cs/shape_cs.rs b/crates/bellpepper/src/util_cs/shape_cs.rs new file mode 100644 index 0000000..349609e --- /dev/null +++ b/crates/bellpepper/src/util_cs/shape_cs.rs @@ -0,0 +1,106 @@ +//! Support for generating R1CS shape. + +use bellpepper_core::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; +use ff::PrimeField; + +/// `ShapeCSConstraint` represent a constraint in a `ShapeCS`. +pub type ShapeCSConstraint = ( + LinearCombination, + LinearCombination, + LinearCombination, +); + +/// `ShapeCS` is a `ConstraintSystem` for creating `R1CSShape`s for a circuit. +#[derive(Debug)] +pub struct ShapeCS { + /// All constraints added to the `ShapeCS`. + pub constraints: Vec>, + inputs: usize, + aux: usize, +} + +impl ShapeCS { + /// Create a new, default `ShapeCS`, + pub fn new() -> Self { + Self::default() + } + + /// Returns the number of constraints defined for this `ShapeCS`. + pub fn num_constraints(&self) -> usize { + self.constraints.len() + } + + /// Returns the number of inputs defined for this `ShapeCS`. + pub fn num_inputs(&self) -> usize { + self.inputs + } + + /// Returns the number of aux inputs defined for this `ShapeCS`. + pub fn num_aux(&self) -> usize { + self.aux + } +} + +impl Default for ShapeCS { + fn default() -> Self { + Self { + constraints: vec![], + inputs: 1, + aux: 0, + } + } +} + +impl ConstraintSystem for ShapeCS { + type Root = Self; + + fn alloc(&mut self, _annotation: A, _f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.aux += 1; + + Ok(Variable::new_unchecked(Index::Aux(self.aux - 1))) + } + + fn alloc_input(&mut self, _annotation: A, _f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.inputs += 1; + + Ok(Variable::new_unchecked(Index::Input(self.inputs - 1))) + } + + fn enforce(&mut self, _annotation: A, a: LA, b: LB, c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination, + { + let a = a(LinearCombination::zero()); + let b = b(LinearCombination::zero()); + let c = c(LinearCombination::zero()); + + self.constraints.push((a, b, c)); + } + + fn push_namespace(&mut self, _name_fn: N) + where + NR: Into, + N: FnOnce() -> NR, + { + } + + fn pop_namespace(&mut self) {} + + fn get_root(&mut self) -> &mut Self::Root { + self + } +} diff --git a/crates/bellpepper/src/util_cs/test_shape_cs.rs b/crates/bellpepper/src/util_cs/test_shape_cs.rs new file mode 100644 index 0000000..415f0f6 --- /dev/null +++ b/crates/bellpepper/src/util_cs/test_shape_cs.rs @@ -0,0 +1,314 @@ +//! Support for generating R1CS shape using bellpepper. +//! `TestShapeCS` implements a superset of `ShapeCS`, adding non-trivial +//! namespace support for use in testing. + +use core::fmt::Write; +use std::{ + cmp::Ordering, + collections::{BTreeMap, HashMap}, +}; + +use bellpepper_core::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; +use ff::PrimeField; + +#[derive(Clone, Copy)] +struct OrderedVariable(Variable); + +#[derive(Debug)] +enum NamedObject { + Constraint(usize), + Var(Variable), + Namespace, +} + +impl Eq for OrderedVariable {} +impl PartialEq for OrderedVariable { + fn eq(&self, other: &Self) -> bool { + match (self.0.get_unchecked(), other.0.get_unchecked()) { + (Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => { + a == b + } + _ => false, + } + } +} +impl PartialOrd for OrderedVariable { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for OrderedVariable { + fn cmp(&self, other: &Self) -> Ordering { + match (self.0.get_unchecked(), other.0.get_unchecked()) { + (Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => { + a.cmp(b) + } + (Index::Input(_), Index::Aux(_)) => Ordering::Less, + (Index::Aux(_), Index::Input(_)) => Ordering::Greater, + } + } +} + +/// `TestShapeCSConstraint` represent a constraint in a `ShapeCS`. +pub type TestShapeCSConstraint = ( + LinearCombination, + LinearCombination, + LinearCombination, + String, +); + +#[derive(Debug)] +/// `TestShapeCS` is a `ConstraintSystem` for creating `R1CSShape`s for a circuit. +pub struct TestShapeCS { + named_objects: HashMap, + current_namespace: Vec, + /// All constraints added to the `TestShapeCS`. + pub constraints: Vec>, + inputs: Vec, + aux: Vec, +} + +fn proc_lc( + terms: &LinearCombination, +) -> BTreeMap { + let mut map = BTreeMap::new(); + for (var, &coeff) in terms.iter() { + map.entry(OrderedVariable(var)) + .or_insert_with(|| Scalar::ZERO) + .add_assign(&coeff); + } + + // Remove terms that have a zero coefficient to normalize + let mut to_remove = vec![]; + for (var, coeff) in map.iter() { + if coeff.is_zero().into() { + to_remove.push(*var) + } + } + + for var in to_remove { + map.remove(&var); + } + + map +} + +impl TestShapeCS { + #[allow(unused)] + /// Create a new, default `TestShapeCS`, + pub fn new() -> Self { + Self::default() + } + + /// Returns the number of constraints defined for this `TestShapeCS`. + pub fn num_constraints(&self) -> usize { + self.constraints.len() + } + + /// Returns the number of inputs defined for this `TestShapeCS`. + pub fn num_inputs(&self) -> usize { + self.inputs.len() + } + + /// Returns the number of aux inputs defined for this `TestShapeCS`. + pub fn num_aux(&self) -> usize { + self.aux.len() + } + + /// Print all public inputs, aux inputs, and constraint names. + #[allow(dead_code)] + pub fn pretty_print_list(&self) -> Vec { + let mut result = Vec::new(); + + for input in &self.inputs { + result.push(format!("INPUT {input}")); + } + for aux in &self.aux { + result.push(format!("AUX {aux}")); + } + + for (_a, _b, _c, name) in &self.constraints { + result.push(name.to_string()); + } + + result + } + + /// Print all iputs and a detailed representation of each constraint. + #[allow(dead_code)] + pub fn pretty_print(&self) -> String { + let mut s = String::new(); + + for input in &self.inputs { + writeln!(s, "INPUT {}", &input).unwrap() + } + + let negone = -Scalar::ONE; + + let powers_of_two = (0..Scalar::NUM_BITS) + .map(|i| Scalar::from(2u64).pow_vartime([u64::from(i)])) + .collect::>(); + + let pp = |s: &mut String, lc: &LinearCombination| { + s.push('('); + let mut is_first = true; + for (var, coeff) in proc_lc::(lc) { + if coeff == negone { + s.push_str(" - ") + } else if !is_first { + s.push_str(" + ") + } + is_first = false; + + if coeff != Scalar::ONE && coeff != negone { + for (i, x) in powers_of_two.iter().enumerate() { + if x == &coeff { + write!(s, "2^{i} . ").unwrap(); + break; + } + } + + write!(s, "{coeff:?} . ").unwrap() + } + + match var.0.get_unchecked() { + Index::Input(i) => { + write!(s, "`I{}`", &self.inputs[i]).unwrap(); + } + Index::Aux(i) => { + write!(s, "`A{}`", &self.aux[i]).unwrap(); + } + } + } + if is_first { + // Nothing was visited, print 0. + s.push('0'); + } + s.push(')'); + }; + + for (a, b, c, name) in &self.constraints { + s.push('\n'); + + write!(s, "{name}: ").unwrap(); + pp(&mut s, a); + write!(s, " * ").unwrap(); + pp(&mut s, b); + s.push_str(" = "); + pp(&mut s, c); + } + + s.push('\n'); + + s + } + + /// Associate `NamedObject` with `path`. + /// `path` must not already have an associated object. + fn set_named_obj(&mut self, path: String, to: NamedObject) { + assert!( + !self.named_objects.contains_key(&path), + "tried to create object at existing path: {path}" + ); + + self.named_objects.insert(path, to); + } +} + +impl Default for TestShapeCS { + fn default() -> Self { + let mut map = HashMap::new(); + map.insert("ONE".into(), NamedObject::Var(Self::one())); + Self { + named_objects: map, + current_namespace: vec![], + constraints: vec![], + inputs: vec![String::from("ONE")], + aux: vec![], + } + } +} + +impl ConstraintSystem for TestShapeCS { + type Root = Self; + + fn alloc(&mut self, annotation: A, _f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + let path = compute_path(&self.current_namespace, &annotation().into()); + self.aux.push(path); + + Ok(Variable::new_unchecked(Index::Aux(self.aux.len() - 1))) + } + + fn alloc_input(&mut self, annotation: A, _f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + let path = compute_path(&self.current_namespace, &annotation().into()); + self.inputs.push(path); + + Ok(Variable::new_unchecked(Index::Input(self.inputs.len() - 1))) + } + + fn enforce(&mut self, annotation: A, a: LA, b: LB, c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination, + { + let path = compute_path(&self.current_namespace, &annotation().into()); + let index = self.constraints.len(); + self.set_named_obj(path.clone(), NamedObject::Constraint(index)); + + let a = a(LinearCombination::zero()); + let b = b(LinearCombination::zero()); + let c = c(LinearCombination::zero()); + + self.constraints.push((a, b, c, path)); + } + + fn push_namespace(&mut self, name_fn: N) + where + NR: Into, + N: FnOnce() -> NR, + { + let name = name_fn().into(); + let path = compute_path(&self.current_namespace, &name); + self.set_named_obj(path, NamedObject::Namespace); + self.current_namespace.push(name); + } + + fn pop_namespace(&mut self) { + assert!(self.current_namespace.pop().is_some()); + } + + fn get_root(&mut self) -> &mut Self::Root { + self + } +} + +fn compute_path(ns: &[String], this: &str) -> String { + assert!(!this.contains('/'), "'/' is not allowed in names"); + + let mut name = String::new(); + + let mut needs_separation = false; + for ns in ns.iter().chain(Some(this.to_string()).iter()) { + if needs_separation { + name += "/"; + } + + name += ns; + needs_separation = true; + } + + name +}