Skip to content

Commit

Permalink
Gadgets for comparing to constant (#117)
Browse files Browse the repository at this point in the history
* naive implementation

* Unified test methods.

* comments.
  • Loading branch information
mrain authored Aug 31, 2022
1 parent bcd92b2 commit 851c937
Showing 1 changed file with 208 additions and 100 deletions.
308 changes: 208 additions & 100 deletions relation/src/gadgets/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,93 @@ impl<F: PrimeField> PlonkCircuit<F> {
let c = self.is_lt_internal(a, b)?;
self.logic_neg(c)
}

/// Returns a `BoolVar` indicating whether the variable `a` is less than a
/// given constant `val`.
pub fn is_lt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
where
F: PrimeField,
{
self.check_var_bound(a)?;
let b = self.create_constant_variable(val)?;
self.is_lt(a, b)
}

/// Returns a `BoolVar` indicating whether the variable `a` is less than or
/// equal to a given constant `val`.
pub fn is_leq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
where
F: PrimeField,
{
self.check_var_bound(a)?;
let b = self.create_constant_variable(val)?;
self.is_leq(a, b)
}

/// Returns a `BoolVar` indicating whether the variable `a` is greater than
/// a given constant `val`.
pub fn is_gt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
where
F: PrimeField,
{
self.check_var_bound(a)?;
self.is_gt_constant_internal(a, &val)
}

/// Returns a `BoolVar` indicating whether the variable `a` is greater than
/// or equal a given constant `val`.
pub fn is_geq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
where
F: PrimeField,
{
self.check_var_bound(a)?;
let b = self.create_constant_variable(val)?;
self.is_geq(a, b)
}

/// Enforce the variable `a` to be less than a
/// given constant `val`.
pub fn enforce_lt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
where
F: PrimeField,
{
self.check_var_bound(a)?;
let b = self.create_constant_variable(val)?;
self.enforce_lt(a, b)
}

/// Enforce the variable `a` to be less than or
/// equal to a given constant `val`.
pub fn enforce_leq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
where
F: PrimeField,
{
self.check_var_bound(a)?;
let b = self.create_constant_variable(val)?;
self.enforce_leq(a, b)
}

/// Enforce the variable `a` to be greater than
/// a given constant `val`.
pub fn enforce_gt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
where
F: PrimeField,
{
self.check_var_bound(a)?;
let b = self.create_constant_variable(val)?;
self.enforce_gt(a, b)
}

/// Enforce the variable `a` to be greater than
/// or equal a given constant `val`.
pub fn enforce_geq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
where
F: PrimeField,
{
self.check_var_bound(a)?;
let b = self.create_constant_variable(val)?;
self.enforce_geq(a, b)
}
}

/// Private helper functions for comparison gate
Expand Down Expand Up @@ -171,12 +258,14 @@ impl<F: PrimeField> PlonkCircuit<F> {

#[cfg(test)]
mod test {
use crate::{errors::CircuitError, Circuit, PlonkCircuit};
use crate::{errors::CircuitError, BoolVar, Circuit, PlonkCircuit};
use ark_bls12_377::Fq as Fq377;
use ark_ed_on_bls12_377::Fq as FqEd377;
use ark_ed_on_bls12_381::Fq as FqEd381;
use ark_ed_on_bn254::Fq as FqEd254;
use ark_ff::PrimeField;
use ark_std::cmp::Ordering;
use itertools::multizip;

#[test]
fn test_cmp_gates() -> Result<(), CircuitError> {
Expand All @@ -199,116 +288,135 @@ mod test {
F::from(F::modulus_minus_one_div_two()).mul(F::from(2u32)),
),
];
list.iter()
.try_for_each(|(a, b)| -> Result<(), CircuitError> {
test_is_le(a, b)?;
test_is_leq(a, b)?;
test_is_ge(a, b)?;
test_is_geq(a, b)?;
test_enforce_le(a, b)?;
test_enforce_leq(a, b)?;
test_enforce_ge(a, b)?;
test_enforce_geq(a, b)?;
test_is_le(b, a)?;
test_is_leq(b, a)?;
test_is_ge(b, a)?;
test_is_geq(b, a)?;
test_enforce_le(b, a)?;
test_enforce_leq(b, a)?;
test_enforce_ge(b, a)?;
test_enforce_geq(b, a)
})
multizip((
list,
[Ordering::Less, Ordering::Greater],
[false, true],
[false, true],
)).into_iter()
.try_for_each(
|((a, b), ordering, should_also_check_equality,
is_b_constant)|
-> Result<(), CircuitError> {
test_enforce_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
test_enforce_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)?;
test_is_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
test_is_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)
},
)
}

fn test_is_le<F: PrimeField>(a: &F, b: &F) -> Result<(), CircuitError> {
fn test_is_cmp_helper<F: PrimeField>(
a: &F,
b: &F,
ordering: Ordering,
should_also_check_equality: bool,
is_b_constant: bool,
) -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let expected_result = if a < b { F::one() } else { F::zero() };
let a = circuit.create_variable(*a)?;
let b = circuit.create_variable(*b)?;

let c = circuit.is_lt(a, b)?;
assert!(circuit.witness(c.into())?.eq(&expected_result));
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
Ok(())
}
fn test_is_leq<F: PrimeField>(a: &F, b: &F) -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let expected_result = if a <= b { F::one() } else { F::zero() };
let a = circuit.create_variable(*a)?;
let b = circuit.create_variable(*b)?;

let c = circuit.is_leq(a, b)?;
assert!(circuit.witness(c.into())?.eq(&expected_result));
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
Ok(())
}
fn test_is_ge<F: PrimeField>(a: &F, b: &F) -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let expected_result = if a > b { F::one() } else { F::zero() };
let a = circuit.create_variable(*a)?;
let b = circuit.create_variable(*b)?;

let c = circuit.is_gt(a, b)?;
assert!(circuit.witness(c.into())?.eq(&expected_result));
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
Ok(())
}
fn test_is_geq<F: PrimeField>(a: &F, b: &F) -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let expected_result = if a >= b { F::one() } else { F::zero() };
let a = circuit.create_variable(*a)?;
let b = circuit.create_variable(*b)?;

let c = circuit.is_geq(a, b)?;
assert!(circuit.witness(c.into())?.eq(&expected_result));
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
Ok(())
}
fn test_enforce_le<F: PrimeField>(a: &F, b: &F) -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let expected_result = a < b;
let a = circuit.create_variable(*a)?;
let b = circuit.create_variable(*b)?;
circuit.enforce_lt(a, b)?;
if expected_result {
assert!(circuit.check_circuit_satisfiability(&[]).is_ok())
let expected_result = if a.cmp(b) == ordering
|| (a.cmp(b) == Ordering::Equal && should_also_check_equality)
{
F::one()
} else {
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
}
Ok(())
}
fn test_enforce_leq<F: PrimeField>(a: &F, b: &F) -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let expected_result = a <= b;
F::zero()
};
let a = circuit.create_variable(*a)?;
let b = circuit.create_variable(*b)?;
circuit.enforce_leq(a, b)?;
if expected_result {
assert!(circuit.check_circuit_satisfiability(&[]).is_ok())
let c: BoolVar = if is_b_constant {
match ordering {
Ordering::Less => {
if should_also_check_equality {
circuit.is_leq_constant(a, *b)?
} else {
circuit.is_lt_constant(a, *b)?
}
},
Ordering::Greater => {
if should_also_check_equality {
circuit.is_geq_constant(a, *b)?
} else {
circuit.is_gt_constant(a, *b)?
}
},
// Equality test will be handled elsewhere, comparison gate test will not enter here
Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
}
} else {
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
}
let b = circuit.create_variable(*b)?;
match ordering {
Ordering::Less => {
if should_also_check_equality {
circuit.is_leq(a, b)?
} else {
circuit.is_lt(a, b)?
}
},
Ordering::Greater => {
if should_also_check_equality {
circuit.is_geq(a, b)?
} else {
circuit.is_gt(a, b)?
}
},
// Equality test will be handled elsewhere, comparison gate test will not enter here
Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
}
};
assert!(circuit.witness(c.into())?.eq(&expected_result));
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
Ok(())
}
fn test_enforce_ge<F: PrimeField>(a: &F, b: &F) -> Result<(), CircuitError> {
fn test_enforce_cmp_helper<F: PrimeField>(
a: &F,
b: &F,
ordering: Ordering,
should_also_check_equality: bool,
is_b_constant: bool,
) -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let expected_result = a > b;
let expected_result =
a.cmp(b) == ordering || (a.cmp(b) == Ordering::Equal && should_also_check_equality);
let a = circuit.create_variable(*a)?;
let b = circuit.create_variable(*b)?;
circuit.enforce_gt(a, b)?;
if expected_result {
assert!(circuit.check_circuit_satisfiability(&[]).is_ok())
if is_b_constant {
match ordering {
Ordering::Less => {
if should_also_check_equality {
circuit.enforce_leq_constant(a, *b)?
} else {
circuit.enforce_lt_constant(a, *b)?
}
},
Ordering::Greater => {
if should_also_check_equality {
circuit.enforce_geq_constant(a, *b)?
} else {
circuit.enforce_gt_constant(a, *b)?
}
},
// Equality test will be handled elsewhere, comparison gate test will not enter here
Ordering::Equal => (),
}
} else {
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
}
Ok(())
}
fn test_enforce_geq<F: PrimeField>(a: &F, b: &F) -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let expected_result = a >= b;
let a = circuit.create_variable(*a)?;
let b = circuit.create_variable(*b)?;
circuit.enforce_geq(a, b)?;
let b = circuit.create_variable(*b)?;
match ordering {
Ordering::Less => {
if should_also_check_equality {
circuit.enforce_leq(a, b)?
} else {
circuit.enforce_lt(a, b)?
}
},
Ordering::Greater => {
if should_also_check_equality {
circuit.enforce_geq(a, b)?
} else {
circuit.enforce_gt(a, b)?
}
},
// Equality test will be handled elsewhere, comparison gate test will not enter here
Ordering::Equal => (),
}
};
if expected_result {
assert!(circuit.check_circuit_satisfiability(&[]).is_ok())
} else {
Expand Down

0 comments on commit 851c937

Please sign in to comment.