From 4020fbc22625621baa8125ede87abaeac3c1ca26 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Sat, 27 Jan 2024 03:58:26 -0500 Subject: [PATCH] Add convenience impls for common types (#137) * Work * Tweak * Fmt * Work * Format + typo fixes * `no-std` fix --- src/alloc.rs | 62 ++++++++-- src/boolean/cmp.rs | 6 +- src/boolean/mod.rs | 2 +- src/cmp.rs | 38 +++++- src/convert.rs | 54 +++++++++ src/eq.rs | 109 ++++++++++++++++-- src/fields/emulated_fp/allocated_field_var.rs | 2 +- src/fields/fp/mod.rs | 4 +- src/lib.rs | 59 +--------- src/poly/domain/mod.rs | 2 +- src/r1cs_var.rs | 86 ++++++++++++++ src/select.rs | 16 +-- src/uint/convert.rs | 22 ---- 13 files changed, 341 insertions(+), 121 deletions(-) create mode 100644 src/r1cs_var.rs diff --git a/src/alloc.rs b/src/alloc.rs index 55aa5d2a..ccce47f5 100644 --- a/src/alloc.rs +++ b/src/alloc.rs @@ -37,11 +37,7 @@ impl AllocationMode { /// Specifies how variables of type `Self` should be allocated in a /// `ConstraintSystem`. -pub trait AllocVar -where - Self: Sized, - V: ?Sized, -{ +pub trait AllocVar: Sized { /// Allocates a new variable of type `Self` in the `ConstraintSystem` `cs`. /// The mode of allocation is decided by `mode`. fn new_variable>( @@ -92,10 +88,56 @@ impl> AllocVar<[I], F> for Vec { ) -> Result { let ns = cs.into(); let cs = ns.cs(); - let mut vec = Vec::new(); - for value in f()?.borrow().iter() { - vec.push(A::new_variable(cs.clone(), || Ok(value), mode)?); - } - Ok(vec) + f().and_then(|v| { + v.borrow() + .iter() + .map(|e| A::new_variable(cs.clone(), || Ok(e), mode)) + .collect() + }) + } +} + +/// Dummy impl for `()`. +impl AllocVar<(), F> for () { + fn new_variable>( + _cs: impl Into>, + _f: impl FnOnce() -> Result, + _mode: AllocationMode, + ) -> Result { + Ok(()) + } +} + +/// This blanket implementation just allocates variables in `Self` +/// element by element. +impl, const N: usize> AllocVar<[I; N], F> for [A; N] { + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + f().map(|v| { + let v = v.borrow(); + core::array::from_fn(|i| A::new_variable(cs.clone(), || Ok(&v[i]), mode).unwrap()) + }) + } +} + +/// This blanket implementation just allocates variables in `Self` +/// element by element. +impl, const N: usize> AllocVar<[I], F> for [A; N] { + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result { + let ns = cs.into(); + let cs = ns.cs(); + f().map(|v| { + let v = v.borrow(); + core::array::from_fn(|i| A::new_variable(cs.clone(), || Ok(&v[i]), mode).unwrap()) + }) } } diff --git a/src/boolean/cmp.rs b/src/boolean/cmp.rs index a8c133a8..744279a0 100644 --- a/src/boolean/cmp.rs +++ b/src/boolean/cmp.rs @@ -48,7 +48,7 @@ impl Boolean { let mut bits_iter = bits.iter().rev(); // Iterate in big-endian // Runs of ones in r - let mut last_run = Boolean::constant(true); + let mut last_run = Boolean::TRUE; let mut current_run = vec![]; let mut element_num_bits = 0; @@ -57,12 +57,12 @@ impl Boolean { } if bits.len() > element_num_bits { - let mut or_result = Boolean::constant(false); + let mut or_result = Boolean::FALSE; for should_be_zero in &bits[element_num_bits..] { or_result |= should_be_zero; let _ = bits_iter.next().unwrap(); } - or_result.enforce_equal(&Boolean::constant(false))?; + or_result.enforce_equal(&Boolean::FALSE)?; } for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) { diff --git a/src/boolean/mod.rs b/src/boolean/mod.rs index f32c4a45..e4fab85c 100644 --- a/src/boolean/mod.rs +++ b/src/boolean/mod.rs @@ -100,7 +100,7 @@ impl Boolean { /// let true_var = Boolean::::TRUE; /// let false_var = Boolean::::FALSE; /// - /// true_var.enforce_equal(&Boolean::constant(true))?; + /// true_var.enforce_equal(&Boolean::TRUE)?; /// false_var.enforce_equal(&Boolean::constant(false))?; /// # Ok(()) /// # } diff --git a/src/cmp.rs b/src/cmp.rs index 3281195d..ad29a97e 100644 --- a/src/cmp.rs +++ b/src/cmp.rs @@ -1,10 +1,10 @@ -use ark_ff::Field; +use ark_ff::{Field, PrimeField}; use ark_relations::r1cs::SynthesisError; -use crate::{boolean::Boolean, R1CSVar}; +use crate::{boolean::Boolean, eq::EqGadget, R1CSVar}; /// Specifies how to generate constraints for comparing two variables. -pub trait CmpGadget: R1CSVar { +pub trait CmpGadget: R1CSVar + EqGadget { fn is_gt(&self, other: &Self) -> Result, SynthesisError> { other.is_lt(self) } @@ -19,3 +19,35 @@ pub trait CmpGadget: R1CSVar { other.is_ge(self) } } + +/// Mimics the behavior of `std::cmp::PartialOrd` for `()`. +impl CmpGadget for () { + fn is_gt(&self, _other: &Self) -> Result, SynthesisError> { + Ok(Boolean::FALSE) + } + + fn is_ge(&self, _other: &Self) -> Result, SynthesisError> { + Ok(Boolean::TRUE) + } + + fn is_lt(&self, _other: &Self) -> Result, SynthesisError> { + Ok(Boolean::FALSE) + } + + fn is_le(&self, _other: &Self) -> Result, SynthesisError> { + Ok(Boolean::TRUE) + } +} + +/// Mimics the lexicographic comparison behavior of `std::cmp::PartialOrd` for `[T]`. +impl, F: PrimeField> CmpGadget for [T] { + fn is_ge(&self, other: &Self) -> Result, SynthesisError> { + let mut result = Boolean::TRUE; + let mut all_equal_so_far = Boolean::TRUE; + for (a, b) in self.iter().zip(other) { + all_equal_so_far &= a.is_eq(b)?; + result &= a.is_gt(b)? | &all_equal_so_far; + } + Ok(result) + } +} diff --git a/src/convert.rs b/src/convert.rs index c6f6b1e7..41f87a3f 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -90,6 +90,60 @@ impl<'a, F: Field, T: 'a + ToBytesGadget> ToBytesGadget for &'a T { } } +impl, F: Field> ToBytesGadget for [T] { + fn to_bytes_le(&self) -> Result>, SynthesisError> { + let mut bytes = Vec::new(); + for elem in self { + let elem = elem.to_bytes_le()?; + bytes.extend_from_slice(&elem); + // Make sure that there's enough capacity to avoid reallocations. + bytes.reserve(elem.len() * (self.len() - 1)); + } + Ok(bytes) + } + + fn to_non_unique_bytes_le(&self) -> Result>, SynthesisError> { + let mut bytes = Vec::new(); + for elem in self { + let elem = elem.to_non_unique_bytes_le()?; + bytes.extend_from_slice(&elem); + // Make sure that there's enough capacity to avoid reallocations. + bytes.reserve(elem.len() * (self.len() - 1)); + } + Ok(bytes) + } +} + +impl, F: Field> ToBytesGadget for Vec { + fn to_bytes_le(&self) -> Result>, SynthesisError> { + self.as_slice().to_bytes_le() + } + + fn to_non_unique_bytes_le(&self) -> Result>, SynthesisError> { + self.as_slice().to_non_unique_bytes_le() + } +} + +impl, F: Field, const N: usize> ToBytesGadget for [T; N] { + fn to_bytes_le(&self) -> Result>, SynthesisError> { + self.as_slice().to_bytes_le() + } + + fn to_non_unique_bytes_le(&self) -> Result>, SynthesisError> { + self.as_slice().to_non_unique_bytes_le() + } +} + +impl ToBytesGadget for () { + fn to_bytes_le(&self) -> Result>, SynthesisError> { + Ok(Vec::new()) + } + + fn to_non_unique_bytes_le(&self) -> Result>, SynthesisError> { + Ok(Vec::new()) + } +} + /// Specifies how to convert a variable of type `Self` to variables of /// type `FpVar` pub trait ToConstraintFieldGadget { diff --git a/src/eq.rs b/src/eq.rs index 4f2c066b..1692edc9 100644 --- a/src/eq.rs +++ b/src/eq.rs @@ -33,7 +33,7 @@ pub trait EqGadget { should_enforce: &Boolean, ) -> Result<(), SynthesisError> { self.is_eq(&other)? - .conditional_enforce_equal(&Boolean::constant(true), should_enforce) + .conditional_enforce_equal(&Boolean::TRUE, should_enforce) } /// Enforce that `self` and `other` are equal. @@ -46,7 +46,7 @@ pub trait EqGadget { /// are encouraged to carefully analyze the efficiency and safety of these. #[tracing::instrument(target = "r1cs", skip(self, other))] fn enforce_equal(&self, other: &Self) -> Result<(), SynthesisError> { - self.conditional_enforce_equal(other, &Boolean::constant(true)) + self.conditional_enforce_equal(other, &Boolean::TRUE) } /// If `should_enforce == true`, enforce that `self` and `other` are *not* @@ -65,7 +65,7 @@ pub trait EqGadget { should_enforce: &Boolean, ) -> Result<(), SynthesisError> { self.is_neq(&other)? - .conditional_enforce_equal(&Boolean::constant(true), should_enforce) + .conditional_enforce_equal(&Boolean::TRUE, should_enforce) } /// Enforce that `self` and `other` are *not* equal. @@ -78,7 +78,7 @@ pub trait EqGadget { /// are encouraged to carefully analyze the efficiency and safety of these. #[tracing::instrument(target = "r1cs", skip(self, other))] fn enforce_not_equal(&self, other: &Self) -> Result<(), SynthesisError> { - self.conditional_enforce_not_equal(other, &Boolean::constant(true)) + self.conditional_enforce_not_equal(other, &Boolean::TRUE) } } @@ -86,12 +86,15 @@ impl + R1CSVar, F: PrimeField> EqGadget for [T] { #[tracing::instrument(target = "r1cs", skip(self, other))] fn is_eq(&self, other: &Self) -> Result, SynthesisError> { assert_eq!(self.len(), other.len()); - assert!(!self.is_empty()); - let mut results = Vec::with_capacity(self.len()); - for (a, b) in self.iter().zip(other) { - results.push(a.is_eq(b)?); + if self.is_empty() & other.is_empty() { + Ok(Boolean::TRUE) + } else { + let mut results = Vec::with_capacity(self.len()); + for (a, b) in self.iter().zip(other) { + results.push(a.is_eq(b)?); + } + Boolean::kary_and(&results) } - Boolean::kary_and(&results) } #[tracing::instrument(target = "r1cs", skip(self, other))] @@ -128,3 +131,91 @@ impl + R1CSVar, F: PrimeField> EqGadget for [T] { } } } + +/// This blanket implementation just forwards to the impl on [`[T]`]. +impl + R1CSVar, F: PrimeField> EqGadget for Vec { + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + self.as_slice().is_eq(other.as_slice()) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.as_slice() + .conditional_enforce_equal(other.as_slice(), condition) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_not_equal( + &self, + other: &Self, + should_enforce: &Boolean, + ) -> Result<(), SynthesisError> { + self.as_slice() + .conditional_enforce_not_equal(other.as_slice(), should_enforce) + } +} + +/// Dummy impl for `()`. +impl EqGadget for () { + /// Output a `Boolean` value representing whether `self.value() == + /// other.value()`. + #[inline] + fn is_eq(&self, _other: &Self) -> Result, SynthesisError> { + Ok(Boolean::TRUE) + } + + /// If `should_enforce == true`, enforce that `self` and `other` are equal; + /// else, enforce a vacuously true statement. + /// + /// This is a no-op as `self.is_eq(other)?` is always `true`. + #[tracing::instrument(target = "r1cs", skip(self, _other))] + fn conditional_enforce_equal( + &self, + _other: &Self, + _should_enforce: &Boolean, + ) -> Result<(), SynthesisError> { + Ok(()) + } + + /// Enforce that `self` and `other` are equal. + /// + /// This does not generate any constraints as `self.is_eq(other)?` is always + /// `true`. + #[tracing::instrument(target = "r1cs", skip(self, _other))] + fn enforce_equal(&self, _other: &Self) -> Result<(), SynthesisError> { + Ok(()) + } +} + +/// This blanket implementation just forwards to the impl on [`[T]`]. +impl + R1CSVar, F: PrimeField, const N: usize> EqGadget for [T; N] { + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + self.as_slice().is_eq(other.as_slice()) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.as_slice() + .conditional_enforce_equal(other.as_slice(), condition) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_not_equal( + &self, + other: &Self, + should_enforce: &Boolean, + ) -> Result<(), SynthesisError> { + self.as_slice() + .conditional_enforce_not_equal(other.as_slice(), should_enforce) + } +} diff --git a/src/fields/emulated_fp/allocated_field_var.rs b/src/fields/emulated_fp/allocated_field_var.rs index 095563a7..6905fe37 100644 --- a/src/fields/emulated_fp/allocated_field_var.rs +++ b/src/fields/emulated_fp/allocated_field_var.rs @@ -686,7 +686,7 @@ impl ToBytesGadget let num_bits = TargetF::BigInt::NUM_LIMBS * 64; assert!(bits.len() <= num_bits); - bits.resize_with(num_bits, || Boolean::constant(false)); + bits.resize_with(num_bits, || Boolean::FALSE); let bytes = bits.chunks(8).map(UInt8::from_bits_le).collect(); Ok(bytes) diff --git a/src/fields/fp/mod.rs b/src/fields/fp/mod.rs index eff7cf67..bc35c81b 100644 --- a/src/fields/fp/mod.rs +++ b/src/fields/fp/mod.rs @@ -555,7 +555,7 @@ impl ToBytesGadget for AllocatedFp { fn to_bytes_le(&self) -> Result>, SynthesisError> { let num_bits = F::BigInt::NUM_LIMBS * 64; let mut bits = self.to_bits_le()?; - let remainder = core::iter::repeat(Boolean::constant(false)).take(num_bits - bits.len()); + let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len()); bits.extend(remainder); let bytes = bits .chunks(8) @@ -568,7 +568,7 @@ impl ToBytesGadget for AllocatedFp { fn to_non_unique_bytes_le(&self) -> Result>, SynthesisError> { let num_bits = F::BigInt::NUM_LIMBS * 64; let mut bits = self.to_non_unique_bits_le()?; - let remainder = core::iter::repeat(Boolean::constant(false)).take(num_bits - bits.len()); + let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len()); bits.extend(remainder); let bytes = bits .chunks(8) diff --git a/src/lib.rs b/src/lib.rs index a2bdc104..7f3e5274 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,9 @@ pub mod macros; pub(crate) use ark_std::vec::Vec; -use ark_ff::Field; +#[doc(hidden)] +pub mod r1cs_var; +pub use r1cs_var::*; /// This module contains `Boolean`, an R1CS equivalent of the `bool` type. pub mod boolean; @@ -105,61 +107,6 @@ pub mod prelude { }; } -/// This trait describes some core functionality that is common to high-level -/// variables, such as `Boolean`s, `FieldVar`s, `GroupVar`s, etc. -pub trait R1CSVar { - /// The type of the "native" value that `Self` represents in the constraint - /// system. - type Value: core::fmt::Debug + Eq + Clone; - - /// Returns the underlying `ConstraintSystemRef`. - /// - /// If `self` is a constant value, then this *must* return - /// `ark_relations::r1cs::ConstraintSystemRef::None`. - fn cs(&self) -> ark_relations::r1cs::ConstraintSystemRef; - - /// Returns `true` if `self` is a circuit-generation-time constant. - fn is_constant(&self) -> bool { - self.cs().is_none() - } - - /// Returns the value that is assigned to `self` in the underlying - /// `ConstraintSystem`. - fn value(&self) -> Result; -} - -impl> R1CSVar for [T] { - type Value = Vec; - - fn cs(&self) -> ark_relations::r1cs::ConstraintSystemRef { - let mut result = ark_relations::r1cs::ConstraintSystemRef::None; - for var in self { - result = var.cs().or(result); - } - result - } - - fn value(&self) -> Result { - let mut result = Vec::new(); - for var in self { - result.push(var.value()?); - } - Ok(result) - } -} - -impl<'a, F: Field, T: 'a + R1CSVar> R1CSVar for &'a T { - type Value = T::Value; - - fn cs(&self) -> ark_relations::r1cs::ConstraintSystemRef { - (*self).cs() - } - - fn value(&self) -> Result { - (*self).value() - } -} - /// A utility trait to convert `Self` to `Result pub trait Assignment { /// Converts `self` to `Result`. diff --git a/src/poly/domain/mod.rs b/src/poly/domain/mod.rs index c30bd03e..ebb72127 100644 --- a/src/poly/domain/mod.rs +++ b/src/poly/domain/mod.rs @@ -44,7 +44,7 @@ impl Radix2DomainVar { impl EqGadget for Radix2DomainVar { fn is_eq(&self, other: &Self) -> Result, SynthesisError> { if self.gen != other.gen || self.dim != other.dim { - Ok(Boolean::constant(false)) + Ok(Boolean::FALSE) } else { self.offset.is_eq(&other.offset) } diff --git a/src/r1cs_var.rs b/src/r1cs_var.rs new file mode 100644 index 00000000..f99e9b2d --- /dev/null +++ b/src/r1cs_var.rs @@ -0,0 +1,86 @@ +use ark_ff::Field; +use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; +use ark_std::vec::Vec; + +/// This trait describes some core functionality that is common to high-level +/// variables, such as `Boolean`s, `FieldVar`s, `GroupVar`s, etc. +pub trait R1CSVar { + /// The type of the "native" value that `Self` represents in the constraint + /// system. + type Value: core::fmt::Debug + Eq + Clone; + + /// Returns the underlying `ConstraintSystemRef`. + /// + /// If `self` is a constant value, then this *must* return + /// `ark_relations::r1cs::ConstraintSystemRef::None`. + fn cs(&self) -> ConstraintSystemRef; + + /// Returns `true` if `self` is a circuit-generation-time constant. + fn is_constant(&self) -> bool { + self.cs().is_none() + } + + /// Returns the value that is assigned to `self` in the underlying + /// `ConstraintSystem`. + fn value(&self) -> Result; +} + +impl> R1CSVar for [T] { + type Value = Vec; + + fn cs(&self) -> ConstraintSystemRef { + let mut result = ConstraintSystemRef::None; + for var in self { + result = var.cs().or(result); + } + result + } + + fn value(&self) -> Result { + let mut result = Vec::new(); + for var in self { + result.push(var.value()?); + } + Ok(result) + } +} + +impl<'a, F: Field, T: 'a + R1CSVar> R1CSVar for &'a T { + type Value = T::Value; + + fn cs(&self) -> ConstraintSystemRef { + (*self).cs() + } + + fn value(&self) -> Result { + (*self).value() + } +} + +impl, const N: usize> R1CSVar for [T; N] { + type Value = [T::Value; N]; + + fn cs(&self) -> ConstraintSystemRef { + let mut result = ConstraintSystemRef::None; + for var in self { + result = var.cs().or(result); + } + result + } + + fn value(&self) -> Result { + Ok(core::array::from_fn(|i| self[i].value().unwrap())) + } +} + +impl R1CSVar for () { + type Value = (); + + fn cs(&self) -> ConstraintSystemRef { + ConstraintSystemRef::None + } + + fn value(&self) -> Result { + Ok(()) + } +} diff --git a/src/select.rs b/src/select.rs index 5d128a67..528fc90d 100644 --- a/src/select.rs +++ b/src/select.rs @@ -3,11 +3,7 @@ use ark_ff::Field; use ark_relations::r1cs::SynthesisError; use ark_std::vec::Vec; /// Generates constraints for selecting between one of two values. -pub trait CondSelectGadget -where - Self: Sized, - Self: Clone, -{ +pub trait CondSelectGadget: Sized + Clone { /// If `cond == &Boolean::TRUE`, then this returns `true_value`; else, /// returns `false_value`. /// @@ -68,10 +64,7 @@ where } /// Performs a lookup in a 4-element table using two bits. -pub trait TwoBitLookupGadget -where - Self: Sized, -{ +pub trait TwoBitLookupGadget: Sized { /// The type of values being looked up. type TableConstant; @@ -92,10 +85,7 @@ where /// Uses three bits to perform a lookup into a table, where the last bit /// conditionally negates the looked-up value. -pub trait ThreeBitCondNegLookupGadget -where - Self: Sized, -{ +pub trait ThreeBitCondNegLookupGadget: Sized { /// The type of values being looked up. type TableConstant; diff --git a/src/uint/convert.rs b/src/uint/convert.rs index e4411019..242c3ad8 100644 --- a/src/uint/convert.rs +++ b/src/uint/convert.rs @@ -171,28 +171,6 @@ impl ToBytesGadget } } -impl ToBytesGadget for [UInt] { - fn to_bytes_le(&self) -> Result>, SynthesisError> { - let mut bytes = Vec::with_capacity(self.len() * (N / 8)); - for elem in self { - bytes.extend_from_slice(&elem.to_bytes_le()?); - } - Ok(bytes) - } -} - -impl ToBytesGadget for Vec> { - fn to_bytes_le(&self) -> Result>, SynthesisError> { - self.as_slice().to_bytes_le() - } -} - -impl<'a, const N: usize, T: PrimUInt, F: Field> ToBytesGadget for &'a [UInt] { - fn to_bytes_le(&self) -> Result>, SynthesisError> { - (*self).to_bytes_le() - } -} - #[cfg(test)] mod tests { use super::*;