Skip to content

Commit

Permalink
[test-only] More genericity in tests (#171)
Browse files Browse the repository at this point in the history
* refactor: make circuit tests generic wrt curves

- Improve modularity by introducing generic `test_recursive_circuit_with` function in `src/circuit.rs`
- Refactor `test_recursive_circuit` to utilize the new function
- Implement type constraints for `test_recursive_circuit_with` function

* refactor: make bellperson tests generic in type of group

- Introduce `test_alloc_bit_with` function utilizing generic types
- Adapt existing `test_alloc_bit` function to use the new `test_alloc_bit_with` function with correct types

* refactor: make the nifs test generic in the type of group

* refactor: make the ivc tests generic in the type of curve

* refactor: simplify generics in tests

* make the keccak tests generic

* make the poseidon tests generic

* make the spartan tests generic
  • Loading branch information
huitseeker authored May 26, 2023
1 parent 58fc746 commit 54f758e
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 96 deletions.
25 changes: 17 additions & 8 deletions src/bellperson/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ pub mod solver;

#[cfg(test)]
mod tests {
use crate::bellperson::{
r1cs::{NovaShape, NovaWitness},
shape_cs::ShapeCS,
solver::SatisfyingAssignment,
use crate::{
bellperson::{
r1cs::{NovaShape, NovaWitness},
shape_cs::ShapeCS,
solver::SatisfyingAssignment,
},
traits::Group,
};
use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError};
use ff::PrimeField;
Expand Down Expand Up @@ -39,10 +42,10 @@ mod tests {
Ok(())
}

#[test]
fn test_alloc_bit() {
type G = pasta_curves::pallas::Point;

fn test_alloc_bit_with<G>()
where
G: Group,
{
// First create the shape
let mut cs: ShapeCS<G> = ShapeCS::new();
let _ = synthesize_alloc_bit(&mut cs);
Expand All @@ -56,4 +59,10 @@ mod tests {
// Make sure that this is satisfiable
assert!(shape.is_sat(&ck, &inst, &witness).is_ok());
}

#[test]
fn test_alloc_bit() {
type G = pasta_curves::pallas::Point;
test_alloc_bit_with::<G>();
}
}
52 changes: 34 additions & 18 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ impl<G: Group, SC: StepCircuit<G::Base>> Circuit<<G as Group>::Base>
mod tests {
use super::*;
use crate::bellperson::{shape_cs::ShapeCS, solver::SatisfyingAssignment};
type G1 = pasta_curves::pallas::Point;
type G2 = pasta_curves::vesta::Point;
type PastaG1 = pasta_curves::pallas::Point;
type PastaG2 = pasta_curves::vesta::Point;

use crate::constants::{BN_LIMB_WIDTH, BN_N_LIMBS};
use crate::{
Expand All @@ -383,39 +383,43 @@ mod tests {
traits::{circuit::TrivialTestCircuit, ROConstantsTrait},
};

#[test]
fn test_recursive_circuit() {
// In the following we use 1 to refer to the primary, and 2 to refer to the secondary circuit
let params1 = NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true);
let params2 = NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false);
let ro_consts1: ROConstantsCircuit<G2> = PoseidonConstantsCircuit::new();
let ro_consts2: ROConstantsCircuit<G1> = PoseidonConstantsCircuit::new();

// In the following we use 1 to refer to the primary, and 2 to refer to the secondary circuit
fn test_recursive_circuit_with<G1, G2>(
primary_params: NovaAugmentedCircuitParams,
secondary_params: NovaAugmentedCircuitParams,
ro_consts1: ROConstantsCircuit<G2>,
ro_consts2: ROConstantsCircuit<G1>,
num_constraints_primary: usize,
num_constraints_secondary: usize,
) where
G1: Group<Base = <G2 as Group>::Scalar>,
G2: Group<Base = <G1 as Group>::Scalar>,
{
// Initialize the shape and ck for the primary
let circuit1: NovaAugmentedCircuit<G2, TrivialTestCircuit<<G2 as Group>::Base>> =
NovaAugmentedCircuit::new(
params1.clone(),
primary_params.clone(),
None,
TrivialTestCircuit::default(),
ro_consts1.clone(),
);
let mut cs: ShapeCS<G1> = ShapeCS::new();
let _ = circuit1.synthesize(&mut cs);
let (shape1, ck1) = cs.r1cs_shape();
assert_eq!(cs.num_constraints(), 9815);
assert_eq!(cs.num_constraints(), num_constraints_primary);

// Initialize the shape and ck for the secondary
let circuit2: NovaAugmentedCircuit<G1, TrivialTestCircuit<<G1 as Group>::Base>> =
NovaAugmentedCircuit::new(
params2.clone(),
secondary_params.clone(),
None,
TrivialTestCircuit::default(),
ro_consts2.clone(),
);
let mut cs: ShapeCS<G2> = ShapeCS::new();
let _ = circuit2.synthesize(&mut cs);
let (shape2, ck2) = cs.r1cs_shape();
assert_eq!(cs.num_constraints(), 10347);
assert_eq!(cs.num_constraints(), num_constraints_secondary);

// Execute the base case for the primary
let zero1 = <<G2 as Group>::Base as Field>::ZERO;
Expand All @@ -431,7 +435,7 @@ mod tests {
);
let circuit1: NovaAugmentedCircuit<G2, TrivialTestCircuit<<G2 as Group>::Base>> =
NovaAugmentedCircuit::new(
params1,
primary_params,
Some(inputs1),
TrivialTestCircuit::default(),
ro_consts1,
Expand All @@ -453,16 +457,28 @@ mod tests {
Some(inst1),
None,
);
let circuit: NovaAugmentedCircuit<G1, TrivialTestCircuit<<G1 as Group>::Base>> =
let circuit2: NovaAugmentedCircuit<G1, TrivialTestCircuit<<G1 as Group>::Base>> =
NovaAugmentedCircuit::new(
params2,
secondary_params,
Some(inputs2),
TrivialTestCircuit::default(),
ro_consts2,
);
let _ = circuit.synthesize(&mut cs2);
let _ = circuit2.synthesize(&mut cs2);
let (inst2, witness2) = cs2.r1cs_instance_and_witness(&shape2, &ck2).unwrap();
// Make sure that it is satisfiable
assert!(shape2.is_sat(&ck2, &inst2, &witness2).is_ok());
}

#[test]
fn test_recursive_circuit() {
let params1 = NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true);
let params2 = NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false);
let ro_consts1: ROConstantsCircuit<PastaG2> = PoseidonConstantsCircuit::new();
let ro_consts2: ROConstantsCircuit<PastaG1> = PoseidonConstantsCircuit::new();

test_recursive_circuit_with::<PastaG1, PastaG2>(
params1, params2, ro_consts1, ro_consts2, 9815, 10347,
);
}
}
37 changes: 15 additions & 22 deletions src/gadgets/ecc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -975,16 +975,14 @@ mod tests {

#[test]
fn test_ecc_circuit_ops() {
test_ecc_circuit_ops_with::<pallas::Base, pallas::Scalar, pallas::Point, vesta::Point>();
test_ecc_circuit_ops_with::<vesta::Base, vesta::Scalar, vesta::Point, pallas::Point>();
test_ecc_circuit_ops_with::<pallas::Point, vesta::Point>();
test_ecc_circuit_ops_with::<vesta::Point, pallas::Point>();
}

fn test_ecc_circuit_ops_with<B, S, G1, G2>()
fn test_ecc_circuit_ops_with<G1, G2>()
where
B: PrimeField,
S: PrimeField,
G1: Group<Base = B, Scalar = S>,
G2: Group<Base = S, Scalar = B>,
G1: Group<Base = <G2 as Group>::Scalar>,
G2: Group<Base = <G1 as Group>::Scalar>,
{
// First create the shape
let mut cs: ShapeCS<G2> = ShapeCS::new();
Expand Down Expand Up @@ -1027,16 +1025,14 @@ mod tests {

#[test]
fn test_ecc_circuit_add_equal() {
test_ecc_circuit_add_equal_with::<pallas::Base, pallas::Scalar, pallas::Point, vesta::Point>();
test_ecc_circuit_add_equal_with::<vesta::Base, vesta::Scalar, vesta::Point, pallas::Point>();
test_ecc_circuit_add_equal_with::<pallas::Point, vesta::Point>();
test_ecc_circuit_add_equal_with::<vesta::Point, pallas::Point>();
}

fn test_ecc_circuit_add_equal_with<B, S, G1, G2>()
fn test_ecc_circuit_add_equal_with<G1, G2>()
where
B: PrimeField,
S: PrimeField,
G1: Group<Base = B, Scalar = S>,
G2: Group<Base = S, Scalar = B>,
G1: Group<Base = <G2 as Group>::Scalar>,
G2: Group<Base = <G1 as Group>::Scalar>,
{
// First create the shape
let mut cs: ShapeCS<G2> = ShapeCS::new();
Expand Down Expand Up @@ -1083,17 +1079,14 @@ mod tests {

#[test]
fn test_ecc_circuit_add_negation() {
test_ecc_circuit_add_negation_with::<pallas::Base, pallas::Scalar, pallas::Point, vesta::Point>(
);
test_ecc_circuit_add_negation_with::<vesta::Base, vesta::Scalar, vesta::Point, pallas::Point>();
test_ecc_circuit_add_negation_with::<pallas::Point, vesta::Point>();
test_ecc_circuit_add_negation_with::<vesta::Point, pallas::Point>();
}

fn test_ecc_circuit_add_negation_with<B, S, G1, G2>()
fn test_ecc_circuit_add_negation_with<G1, G2>()
where
B: PrimeField,
S: PrimeField,
G1: Group<Base = B, Scalar = S>,
G2: Group<Base = S, Scalar = B>,
G1: Group<Base = <G2 as Group>::Scalar>,
G2: Group<Base = <G1 as Group>::Scalar>,
{
// First create the shape
let mut cs: ShapeCS<G2> = ShapeCS::new();
Expand Down
Loading

0 comments on commit 54f758e

Please sign in to comment.