Skip to content

Commit

Permalink
Add convenience impls for common types (#137)
Browse files Browse the repository at this point in the history
* Work

* Tweak

* Fmt

* Work

* Format + typo fixes

* `no-std` fix
  • Loading branch information
Pratyush authored Jan 27, 2024
1 parent a124995 commit 4020fbc
Show file tree
Hide file tree
Showing 13 changed files with 341 additions and 121 deletions.
62 changes: 52 additions & 10 deletions src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ impl AllocationMode {

/// Specifies how variables of type `Self` should be allocated in a
/// `ConstraintSystem`.
pub trait AllocVar<V, F: Field>
where
Self: Sized,
V: ?Sized,
{
pub trait AllocVar<V: ?Sized, F: Field>: Sized {
/// Allocates a new variable of type `Self` in the `ConstraintSystem` `cs`.
/// The mode of allocation is decided by `mode`.
fn new_variable<T: Borrow<V>>(
Expand Down Expand Up @@ -92,10 +88,56 @@ impl<I, F: Field, A: AllocVar<I, F>> AllocVar<[I], F> for Vec<A> {
) -> Result<Self, SynthesisError> {
let ns = cs.into();
let cs = ns.cs();
let mut vec = Vec::new();
for value in f()?.borrow().iter() {
vec.push(A::new_variable(cs.clone(), || Ok(value), mode)?);
}
Ok(vec)
f().and_then(|v| {
v.borrow()
.iter()
.map(|e| A::new_variable(cs.clone(), || Ok(e), mode))
.collect()
})
}
}

/// Dummy impl for `()`.
impl<F: Field> AllocVar<(), F> for () {
fn new_variable<T: Borrow<()>>(
_cs: impl Into<Namespace<F>>,
_f: impl FnOnce() -> Result<T, SynthesisError>,
_mode: AllocationMode,
) -> Result<Self, SynthesisError> {
Ok(())
}
}

/// This blanket implementation just allocates variables in `Self`
/// element by element.
impl<I, F: Field, A: AllocVar<I, F>, const N: usize> AllocVar<[I; N], F> for [A; N] {
fn new_variable<T: Borrow<[I; N]>>(
cs: impl Into<Namespace<F>>,
f: impl FnOnce() -> Result<T, SynthesisError>,
mode: AllocationMode,
) -> Result<Self, SynthesisError> {
let ns = cs.into();
let cs = ns.cs();
f().map(|v| {
let v = v.borrow();
core::array::from_fn(|i| A::new_variable(cs.clone(), || Ok(&v[i]), mode).unwrap())
})
}
}

/// This blanket implementation just allocates variables in `Self`
/// element by element.
impl<I, F: Field, A: AllocVar<I, F>, const N: usize> AllocVar<[I], F> for [A; N] {
fn new_variable<T: Borrow<[I]>>(
cs: impl Into<Namespace<F>>,
f: impl FnOnce() -> Result<T, SynthesisError>,
mode: AllocationMode,
) -> Result<Self, SynthesisError> {
let ns = cs.into();
let cs = ns.cs();
f().map(|v| {
let v = v.borrow();
core::array::from_fn(|i| A::new_variable(cs.clone(), || Ok(&v[i]), mode).unwrap())
})
}
}
6 changes: 3 additions & 3 deletions src/boolean/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl<F: PrimeField> Boolean<F> {
let mut bits_iter = bits.iter().rev(); // Iterate in big-endian

// Runs of ones in r
let mut last_run = Boolean::constant(true);
let mut last_run = Boolean::TRUE;
let mut current_run = vec![];

let mut element_num_bits = 0;
Expand All @@ -57,12 +57,12 @@ impl<F: PrimeField> Boolean<F> {
}

if bits.len() > element_num_bits {
let mut or_result = Boolean::constant(false);
let mut or_result = Boolean::FALSE;
for should_be_zero in &bits[element_num_bits..] {
or_result |= should_be_zero;
let _ = bits_iter.next().unwrap();
}
or_result.enforce_equal(&Boolean::constant(false))?;
or_result.enforce_equal(&Boolean::FALSE)?;
}

for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) {
Expand Down
2 changes: 1 addition & 1 deletion src/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<F: Field> Boolean<F> {
/// let true_var = Boolean::<Fr>::TRUE;
/// let false_var = Boolean::<Fr>::FALSE;
///
/// true_var.enforce_equal(&Boolean::constant(true))?;
/// true_var.enforce_equal(&Boolean::TRUE)?;
/// false_var.enforce_equal(&Boolean::constant(false))?;
/// # Ok(())
/// # }
Expand Down
38 changes: 35 additions & 3 deletions src/cmp.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use ark_ff::Field;
use ark_ff::{Field, PrimeField};
use ark_relations::r1cs::SynthesisError;

use crate::{boolean::Boolean, R1CSVar};
use crate::{boolean::Boolean, eq::EqGadget, R1CSVar};

/// Specifies how to generate constraints for comparing two variables.
pub trait CmpGadget<F: Field>: R1CSVar<F> {
pub trait CmpGadget<F: Field>: R1CSVar<F> + EqGadget<F> {
fn is_gt(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
other.is_lt(self)
}
Expand All @@ -19,3 +19,35 @@ pub trait CmpGadget<F: Field>: R1CSVar<F> {
other.is_ge(self)
}
}

/// Mimics the behavior of `std::cmp::PartialOrd` for `()`.
impl<F: Field> CmpGadget<F> for () {
fn is_gt(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::FALSE)
}

fn is_ge(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::TRUE)
}

fn is_lt(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::FALSE)
}

fn is_le(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::TRUE)
}
}

/// Mimics the lexicographic comparison behavior of `std::cmp::PartialOrd` for `[T]`.
impl<T: CmpGadget<F>, F: PrimeField> CmpGadget<F> for [T] {
fn is_ge(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
let mut result = Boolean::TRUE;
let mut all_equal_so_far = Boolean::TRUE;
for (a, b) in self.iter().zip(other) {
all_equal_so_far &= a.is_eq(b)?;
result &= a.is_gt(b)? | &all_equal_so_far;
}
Ok(result)
}
}
54 changes: 54 additions & 0 deletions src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,60 @@ impl<'a, F: Field, T: 'a + ToBytesGadget<F>> ToBytesGadget<F> for &'a T {
}
}

impl<T: ToBytesGadget<F>, F: Field> ToBytesGadget<F> for [T] {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let mut bytes = Vec::new();
for elem in self {
let elem = elem.to_bytes_le()?;
bytes.extend_from_slice(&elem);
// Make sure that there's enough capacity to avoid reallocations.
bytes.reserve(elem.len() * (self.len() - 1));
}
Ok(bytes)
}

fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let mut bytes = Vec::new();
for elem in self {
let elem = elem.to_non_unique_bytes_le()?;
bytes.extend_from_slice(&elem);
// Make sure that there's enough capacity to avoid reallocations.
bytes.reserve(elem.len() * (self.len() - 1));
}
Ok(bytes)
}
}

impl<T: ToBytesGadget<F>, F: Field> ToBytesGadget<F> for Vec<T> {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_bytes_le()
}

fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_non_unique_bytes_le()
}
}

impl<T: ToBytesGadget<F>, F: Field, const N: usize> ToBytesGadget<F> for [T; N] {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_bytes_le()
}

fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_non_unique_bytes_le()
}
}

impl<F: Field> ToBytesGadget<F> for () {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
Ok(Vec::new())
}

fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
Ok(Vec::new())
}
}

/// Specifies how to convert a variable of type `Self` to variables of
/// type `FpVar<ConstraintF>`
pub trait ToConstraintFieldGadget<ConstraintF: ark_ff::PrimeField> {
Expand Down
109 changes: 100 additions & 9 deletions src/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub trait EqGadget<F: Field> {
should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.is_eq(&other)?
.conditional_enforce_equal(&Boolean::constant(true), should_enforce)
.conditional_enforce_equal(&Boolean::TRUE, should_enforce)
}

/// Enforce that `self` and `other` are equal.
Expand All @@ -46,7 +46,7 @@ pub trait EqGadget<F: Field> {
/// are encouraged to carefully analyze the efficiency and safety of these.
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn enforce_equal(&self, other: &Self) -> Result<(), SynthesisError> {
self.conditional_enforce_equal(other, &Boolean::constant(true))
self.conditional_enforce_equal(other, &Boolean::TRUE)
}

/// If `should_enforce == true`, enforce that `self` and `other` are *not*
Expand All @@ -65,7 +65,7 @@ pub trait EqGadget<F: Field> {
should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.is_neq(&other)?
.conditional_enforce_equal(&Boolean::constant(true), should_enforce)
.conditional_enforce_equal(&Boolean::TRUE, should_enforce)
}

/// Enforce that `self` and `other` are *not* equal.
Expand All @@ -78,20 +78,23 @@ pub trait EqGadget<F: Field> {
/// are encouraged to carefully analyze the efficiency and safety of these.
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn enforce_not_equal(&self, other: &Self) -> Result<(), SynthesisError> {
self.conditional_enforce_not_equal(other, &Boolean::constant(true))
self.conditional_enforce_not_equal(other, &Boolean::TRUE)
}
}

impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for [T] {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
assert_eq!(self.len(), other.len());
assert!(!self.is_empty());
let mut results = Vec::with_capacity(self.len());
for (a, b) in self.iter().zip(other) {
results.push(a.is_eq(b)?);
if self.is_empty() & other.is_empty() {
Ok(Boolean::TRUE)
} else {
let mut results = Vec::with_capacity(self.len());
for (a, b) in self.iter().zip(other) {
results.push(a.is_eq(b)?);
}
Boolean::kary_and(&results)
}
Boolean::kary_and(&results)
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
Expand Down Expand Up @@ -128,3 +131,91 @@ impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for [T] {
}
}
}

/// This blanket implementation just forwards to the impl on [`[T]`].
impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for Vec<T> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
self.as_slice().is_eq(other.as_slice())
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_equal(
&self,
other: &Self,
condition: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.as_slice()
.conditional_enforce_equal(other.as_slice(), condition)
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_not_equal(
&self,
other: &Self,
should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.as_slice()
.conditional_enforce_not_equal(other.as_slice(), should_enforce)
}
}

/// Dummy impl for `()`.
impl<F: Field> EqGadget<F> for () {
/// Output a `Boolean` value representing whether `self.value() ==
/// other.value()`.
#[inline]
fn is_eq(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
Ok(Boolean::TRUE)
}

/// If `should_enforce == true`, enforce that `self` and `other` are equal;
/// else, enforce a vacuously true statement.
///
/// This is a no-op as `self.is_eq(other)?` is always `true`.
#[tracing::instrument(target = "r1cs", skip(self, _other))]
fn conditional_enforce_equal(
&self,
_other: &Self,
_should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
Ok(())
}

/// Enforce that `self` and `other` are equal.
///
/// This does not generate any constraints as `self.is_eq(other)?` is always
/// `true`.
#[tracing::instrument(target = "r1cs", skip(self, _other))]
fn enforce_equal(&self, _other: &Self) -> Result<(), SynthesisError> {
Ok(())
}
}

/// This blanket implementation just forwards to the impl on [`[T]`].
impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField, const N: usize> EqGadget<F> for [T; N] {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
self.as_slice().is_eq(other.as_slice())
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_equal(
&self,
other: &Self,
condition: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.as_slice()
.conditional_enforce_equal(other.as_slice(), condition)
}

#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_not_equal(
&self,
other: &Self,
should_enforce: &Boolean<F>,
) -> Result<(), SynthesisError> {
self.as_slice()
.conditional_enforce_not_equal(other.as_slice(), should_enforce)
}
}
2 changes: 1 addition & 1 deletion src/fields/emulated_fp/allocated_field_var.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ impl<TargetF: PrimeField, BaseF: PrimeField> ToBytesGadget<BaseF>

let num_bits = TargetF::BigInt::NUM_LIMBS * 64;
assert!(bits.len() <= num_bits);
bits.resize_with(num_bits, || Boolean::constant(false));
bits.resize_with(num_bits, || Boolean::FALSE);

let bytes = bits.chunks(8).map(UInt8::from_bits_le).collect();
Ok(bytes)
Expand Down
4 changes: 2 additions & 2 deletions src/fields/fp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ impl<F: PrimeField> ToBytesGadget<F> for AllocatedFp<F> {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let num_bits = F::BigInt::NUM_LIMBS * 64;
let mut bits = self.to_bits_le()?;
let remainder = core::iter::repeat(Boolean::constant(false)).take(num_bits - bits.len());
let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len());
bits.extend(remainder);
let bytes = bits
.chunks(8)
Expand All @@ -568,7 +568,7 @@ impl<F: PrimeField> ToBytesGadget<F> for AllocatedFp<F> {
fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let num_bits = F::BigInt::NUM_LIMBS * 64;
let mut bits = self.to_non_unique_bits_le()?;
let remainder = core::iter::repeat(Boolean::constant(false)).take(num_bits - bits.len());
let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len());
bits.extend(remainder);
let bytes = bits
.chunks(8)
Expand Down
Loading

0 comments on commit 4020fbc

Please sign in to comment.