diff --git a/src/bellperson/mod.rs b/src/bellperson/mod.rs index 6686fa25..0e0d8c85 100644 --- a/src/bellperson/mod.rs +++ b/src/bellperson/mod.rs @@ -8,10 +8,13 @@ pub mod solver; #[cfg(test)] mod tests { - use crate::bellperson::{ - r1cs::{NovaShape, NovaWitness}, - shape_cs::ShapeCS, - solver::SatisfyingAssignment, + use crate::{ + bellperson::{ + r1cs::{NovaShape, NovaWitness}, + shape_cs::ShapeCS, + solver::SatisfyingAssignment, + }, + traits::Group, }; use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; use ff::PrimeField; @@ -39,10 +42,10 @@ mod tests { Ok(()) } - #[test] - fn test_alloc_bit() { - type G = pasta_curves::pallas::Point; - + fn test_alloc_bit_with() + where + G: Group, + { // First create the shape let mut cs: ShapeCS = ShapeCS::new(); let _ = synthesize_alloc_bit(&mut cs); @@ -56,4 +59,10 @@ mod tests { // Make sure that this is satisfiable assert!(shape.is_sat(&ck, &inst, &witness).is_ok()); } + + #[test] + fn test_alloc_bit() { + type G = pasta_curves::pallas::Point; + test_alloc_bit_with::(); + } } diff --git a/src/circuit.rs b/src/circuit.rs index 614e61c7..d9ef5907 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -372,8 +372,8 @@ impl> Circuit<::Base> mod tests { use super::*; use crate::bellperson::{shape_cs::ShapeCS, solver::SatisfyingAssignment}; - type G1 = pasta_curves::pallas::Point; - type G2 = pasta_curves::vesta::Point; + type PastaG1 = pasta_curves::pallas::Point; + type PastaG2 = pasta_curves::vesta::Point; use crate::constants::{BN_LIMB_WIDTH, BN_N_LIMBS}; use crate::{ @@ -383,18 +383,22 @@ mod tests { traits::{circuit::TrivialTestCircuit, ROConstantsTrait}, }; - #[test] - fn test_recursive_circuit() { - // In the following we use 1 to refer to the primary, and 2 to refer to the secondary circuit - let params1 = NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); - let params2 = NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); - let ro_consts1: ROConstantsCircuit = PoseidonConstantsCircuit::new(); - let ro_consts2: ROConstantsCircuit = PoseidonConstantsCircuit::new(); - + // In the following we use 1 to refer to the primary, and 2 to refer to the secondary circuit + fn test_recursive_circuit_with( + primary_params: NovaAugmentedCircuitParams, + secondary_params: NovaAugmentedCircuitParams, + ro_consts1: ROConstantsCircuit, + ro_consts2: ROConstantsCircuit, + num_constraints_primary: usize, + num_constraints_secondary: usize, + ) where + G1: Group::Scalar>, + G2: Group::Scalar>, + { // Initialize the shape and ck for the primary let circuit1: NovaAugmentedCircuit::Base>> = NovaAugmentedCircuit::new( - params1.clone(), + primary_params.clone(), None, TrivialTestCircuit::default(), ro_consts1.clone(), @@ -402,12 +406,12 @@ mod tests { let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit1.synthesize(&mut cs); let (shape1, ck1) = cs.r1cs_shape(); - assert_eq!(cs.num_constraints(), 9815); + assert_eq!(cs.num_constraints(), num_constraints_primary); // Initialize the shape and ck for the secondary let circuit2: NovaAugmentedCircuit::Base>> = NovaAugmentedCircuit::new( - params2.clone(), + secondary_params.clone(), None, TrivialTestCircuit::default(), ro_consts2.clone(), @@ -415,7 +419,7 @@ mod tests { let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit2.synthesize(&mut cs); let (shape2, ck2) = cs.r1cs_shape(); - assert_eq!(cs.num_constraints(), 10347); + assert_eq!(cs.num_constraints(), num_constraints_secondary); // Execute the base case for the primary let zero1 = <::Base as Field>::ZERO; @@ -431,7 +435,7 @@ mod tests { ); let circuit1: NovaAugmentedCircuit::Base>> = NovaAugmentedCircuit::new( - params1, + primary_params, Some(inputs1), TrivialTestCircuit::default(), ro_consts1, @@ -453,16 +457,28 @@ mod tests { Some(inst1), None, ); - let circuit: NovaAugmentedCircuit::Base>> = + let circuit2: NovaAugmentedCircuit::Base>> = NovaAugmentedCircuit::new( - params2, + secondary_params, Some(inputs2), TrivialTestCircuit::default(), ro_consts2, ); - let _ = circuit.synthesize(&mut cs2); + let _ = circuit2.synthesize(&mut cs2); let (inst2, witness2) = cs2.r1cs_instance_and_witness(&shape2, &ck2).unwrap(); // Make sure that it is satisfiable assert!(shape2.is_sat(&ck2, &inst2, &witness2).is_ok()); } + + #[test] + fn test_recursive_circuit() { + let params1 = NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); + let params2 = NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); + let ro_consts1: ROConstantsCircuit = PoseidonConstantsCircuit::new(); + let ro_consts2: ROConstantsCircuit = PoseidonConstantsCircuit::new(); + + test_recursive_circuit_with::( + params1, params2, ro_consts1, ro_consts2, 9815, 10347, + ); + } } diff --git a/src/gadgets/ecc.rs b/src/gadgets/ecc.rs index 09478269..104a0270 100644 --- a/src/gadgets/ecc.rs +++ b/src/gadgets/ecc.rs @@ -975,16 +975,14 @@ mod tests { #[test] fn test_ecc_circuit_ops() { - test_ecc_circuit_ops_with::(); - test_ecc_circuit_ops_with::(); + test_ecc_circuit_ops_with::(); + test_ecc_circuit_ops_with::(); } - fn test_ecc_circuit_ops_with() + fn test_ecc_circuit_ops_with() where - B: PrimeField, - S: PrimeField, - G1: Group, - G2: Group, + G1: Group::Scalar>, + G2: Group::Scalar>, { // First create the shape let mut cs: ShapeCS = ShapeCS::new(); @@ -1027,16 +1025,14 @@ mod tests { #[test] fn test_ecc_circuit_add_equal() { - test_ecc_circuit_add_equal_with::(); - test_ecc_circuit_add_equal_with::(); + test_ecc_circuit_add_equal_with::(); + test_ecc_circuit_add_equal_with::(); } - fn test_ecc_circuit_add_equal_with() + fn test_ecc_circuit_add_equal_with() where - B: PrimeField, - S: PrimeField, - G1: Group, - G2: Group, + G1: Group::Scalar>, + G2: Group::Scalar>, { // First create the shape let mut cs: ShapeCS = ShapeCS::new(); @@ -1083,17 +1079,14 @@ mod tests { #[test] fn test_ecc_circuit_add_negation() { - test_ecc_circuit_add_negation_with::( - ); - test_ecc_circuit_add_negation_with::(); + test_ecc_circuit_add_negation_with::(); + test_ecc_circuit_add_negation_with::(); } - fn test_ecc_circuit_add_negation_with() + fn test_ecc_circuit_add_negation_with() where - B: PrimeField, - S: PrimeField, - G1: Group, - G2: Group, + G1: Group::Scalar>, + G2: Group::Scalar>, { // First create the shape let mut cs: ShapeCS = ShapeCS::new(); diff --git a/src/lib.rs b/src/lib.rs index 8ee6854f..c4165173 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -787,13 +787,16 @@ fn compute_digest(o: &T) -> G::Scalar { #[cfg(test)] mod tests { + use crate::provider::pedersen::CommitmentKeyExtTrait; + use super::*; - type G1 = pasta_curves::pallas::Point; - type G2 = pasta_curves::vesta::Point; - type EE1 = provider::ipa_pc::EvaluationEngine; - type EE2 = provider::ipa_pc::EvaluationEngine; - type S1 = spartan::RelaxedR1CSSNARK; - type S2 = spartan::RelaxedR1CSSNARK; + type EE1 = provider::ipa_pc::EvaluationEngine; + type EE2 = provider::ipa_pc::EvaluationEngine; + type S1 = spartan::RelaxedR1CSSNARK>; + type S2 = spartan::RelaxedR1CSSNARK>; + type S1Prime = spartan::pp::RelaxedR1CSSNARK>; + type S2Prime = spartan::pp::RelaxedR1CSSNARK>; + use ::bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::marker::PhantomData; use ff::PrimeField; @@ -848,15 +851,21 @@ mod tests { } } - #[test] - fn test_ivc_trivial() { + fn test_ivc_trivial_with() + where + G1: Group::Scalar>, + G2: Group::Scalar>, + { + let test_circuit1 = TrivialTestCircuit::<::Scalar>::default(); + let test_circuit2 = TrivialTestCircuit::<::Scalar>::default(); + // produce public parameters let pp = PublicParams::< G1, G2, TrivialTestCircuit<::Scalar>, TrivialTestCircuit<::Scalar>, - >::setup(TrivialTestCircuit::default(), TrivialTestCircuit::default()); + >::setup(test_circuit1.clone(), test_circuit2.clone()); let num_steps = 1; @@ -864,8 +873,8 @@ mod tests { let res = RecursiveSNARK::prove_step( &pp, None, - TrivialTestCircuit::default(), - TrivialTestCircuit::default(), + test_circuit1, + test_circuit2, vec![::Scalar::ZERO], vec![::Scalar::ZERO], ); @@ -883,7 +892,17 @@ mod tests { } #[test] - fn test_ivc_nontrivial() { + fn test_ivc_trivial() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + test_ivc_trivial_with::(); + } + + fn test_ivc_nontrivial_with() + where + G1: Group::Scalar>, + G2: Group::Scalar>, + { let circuit_primary = TrivialTestCircuit::default(); let circuit_secondary = CubicCircuit::default(); @@ -950,14 +969,30 @@ mod tests { assert_eq!(zn_primary, vec![::Scalar::ONE]); let mut zn_secondary_direct = vec![::Scalar::ZERO]; for _i in 0..num_steps { - zn_secondary_direct = CubicCircuit::default().output(&zn_secondary_direct); + zn_secondary_direct = circuit_secondary.clone().output(&zn_secondary_direct); } assert_eq!(zn_secondary, zn_secondary_direct); assert_eq!(zn_secondary, vec![::Scalar::from(2460515u64)]); } #[test] - fn test_ivc_nontrivial_with_compression() { + fn test_ivc_nontrivial() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + test_ivc_nontrivial_with::(); + } + + fn test_ivc_nontrivial_with_compression_with() + where + G1: Group::Scalar>, + G2: Group::Scalar>, + // this is due to the reliance on CommitmentKeyExtTrait as a bound in ipa_pc + >::CommitmentKey: + CommitmentKeyExtTrait::CE>, + >::CommitmentKey: + CommitmentKeyExtTrait::CE>, + { let circuit_primary = TrivialTestCircuit::default(); let circuit_secondary = CubicCircuit::default(); @@ -1012,16 +1047,16 @@ mod tests { assert_eq!(zn_primary, vec![::Scalar::ONE]); let mut zn_secondary_direct = vec![::Scalar::ZERO]; for _i in 0..num_steps { - zn_secondary_direct = CubicCircuit::default().output(&zn_secondary_direct); + zn_secondary_direct = circuit_secondary.clone().output(&zn_secondary_direct); } assert_eq!(zn_secondary, zn_secondary_direct); assert_eq!(zn_secondary, vec![::Scalar::from(2460515u64)]); // produce the prover and verifier keys for compressed snark - let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); + let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); // produce a compressed SNARK - let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &pk, &recursive_snark); + let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &pk, &recursive_snark); assert!(res.is_ok()); let compressed_snark = res.unwrap(); @@ -1036,7 +1071,23 @@ mod tests { } #[test] - fn test_ivc_nontrivial_with_spark_compression() { + fn test_ivc_nontrivial_with_compression() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + test_ivc_nontrivial_with_compression_with::(); + } + + fn test_ivc_nontrivial_with_spark_compression_with() + where + G1: Group::Scalar>, + G2: Group::Scalar>, + // this is due to the reliance on CommitmentKeyExtTrait as a bound in ipa_pc + >::CommitmentKey: + CommitmentKeyExtTrait::CE>, + >::CommitmentKey: + CommitmentKeyExtTrait::CE>, + { let circuit_primary = TrivialTestCircuit::default(); let circuit_secondary = CubicCircuit::default(); @@ -1097,14 +1148,13 @@ mod tests { assert_eq!(zn_secondary, vec![::Scalar::from(2460515u64)]); // run the compressed snark with Spark compiler - type S1Prime = spartan::pp::RelaxedR1CSSNARK; - type S2Prime = spartan::pp::RelaxedR1CSSNARK; // produce the prover and verifier keys for compressed snark - let (pk, vk) = CompressedSNARK::<_, _, _, _, S1Prime, S2Prime>::setup(&pp).unwrap(); + let (pk, vk) = CompressedSNARK::<_, _, _, _, S1Prime, S2Prime>::setup(&pp).unwrap(); // produce a compressed SNARK - let res = CompressedSNARK::<_, _, _, _, S1Prime, S2Prime>::prove(&pp, &pk, &recursive_snark); + let res = + CompressedSNARK::<_, _, _, _, S1Prime, S2Prime>::prove(&pp, &pk, &recursive_snark); assert!(res.is_ok()); let compressed_snark = res.unwrap(); @@ -1119,7 +1169,23 @@ mod tests { } #[test] - fn test_ivc_nondet_with_compression() { + fn test_ivc_nontrivial_with_spark_compression() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + test_ivc_nontrivial_with_spark_compression_with::(); + } + + fn test_ivc_nondet_with_compression_with() + where + G1: Group::Scalar>, + G2: Group::Scalar>, + // this is due to the reliance on CommitmentKeyExtTrait as a bound in ipa_pc + >::CommitmentKey: + CommitmentKeyExtTrait::CE>, + >::CommitmentKey: + CommitmentKeyExtTrait::CE>, + { // y is a non-deterministic advice representing the fifth root of the input at a step. #[derive(Clone, Debug)] struct FifthRootCheckingCircuit { @@ -1252,10 +1318,10 @@ mod tests { assert!(res.is_ok()); // produce the prover and verifier keys for compressed snark - let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); + let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); // produce a compressed SNARK - let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &pk, &recursive_snark); + let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &pk, &recursive_snark); assert!(res.is_ok()); let compressed_snark = res.unwrap(); @@ -1265,7 +1331,18 @@ mod tests { } #[test] - fn test_ivc_base() { + fn test_ivc_nondet_with_compression() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + test_ivc_nondet_with_compression_with::(); + } + + fn test_ivc_base_with() + where + G1: Group::Scalar>, + G2: Group::Scalar>, + { // produce public parameters let pp = PublicParams::< G1, @@ -1302,4 +1379,12 @@ mod tests { assert_eq!(zn_primary, vec![::Scalar::ONE]); assert_eq!(zn_secondary, vec![::Scalar::from(5u64)]); } + + #[test] + fn test_ivc_base() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + test_ivc_base_with::(); + } } diff --git a/src/nifs.rs b/src/nifs.rs index 0993349d..2187a5d9 100644 --- a/src/nifs.rs +++ b/src/nifs.rs @@ -162,8 +162,10 @@ mod tests { Ok(()) } - #[test] - fn test_tiny_r1cs_bellperson() { + fn test_tiny_r1cs_bellperson_with() + where + G: Group, + { use crate::bellperson::{ r1cs::{NovaShape, NovaWitness}, shape_cs::ShapeCS, @@ -179,7 +181,7 @@ mod tests { // Now get the instance and assignment for one instance let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); - let _ = synthesize_tiny_r1cs_bellperson(&mut cs, Some(S::from(5))); + let _ = synthesize_tiny_r1cs_bellperson(&mut cs, Some(G::Scalar::from(5))); let (U1, W1) = cs.r1cs_instance_and_witness(&shape, &ck).unwrap(); // Make sure that the first instance is satisfiable @@ -187,7 +189,7 @@ mod tests { // Now get the instance and assignment for second instance let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); - let _ = synthesize_tiny_r1cs_bellperson(&mut cs, Some(S::from(135))); + let _ = synthesize_tiny_r1cs_bellperson(&mut cs, Some(G::Scalar::from(135))); let (U2, W2) = cs.r1cs_instance_and_witness(&shape, &ck).unwrap(); // Make sure that the second instance is satisfiable @@ -206,8 +208,13 @@ mod tests { ); } + #[test] + fn test_tiny_r1cs_bellperson() { + test_tiny_r1cs_bellperson_with::(); + } + #[allow(clippy::too_many_arguments)] - fn execute_sequence( + fn execute_sequence( ck: &CommitmentKey, ro_consts: &<::RO as ROTrait<::Base, ::Scalar>>::Constants, pp_digest: &::Scalar, @@ -216,7 +223,9 @@ mod tests { W1: &R1CSWitness, U2: &R1CSInstance, W2: &R1CSWitness, - ) { + ) where + G: Group, + { // produce a default running instance let mut r_W = RelaxedR1CSWitness::default(shape); let mut r_U = RelaxedR1CSInstance::default(ck, shape); diff --git a/src/provider/keccak.rs b/src/provider/keccak.rs index 2629e7c9..a4daf189 100644 --- a/src/provider/keccak.rs +++ b/src/provider/keccak.rs @@ -101,10 +101,7 @@ mod tests { use ff::PrimeField; use sha3::{Digest, Keccak256}; - type G = pasta_curves::pallas::Point; - - #[test] - fn test_keccak_transcript() { + fn test_keccak_transcript_with() { let mut transcript: Keccak256Transcript = Keccak256Transcript::new(b"test"); // two scalars @@ -136,6 +133,12 @@ mod tests { ); } + #[test] + fn test_keccak_transcript() { + type G = pasta_curves::pallas::Point; + test_keccak_transcript_with::() + } + #[test] fn test_keccak_example() { let mut hasher = Keccak256::new(); diff --git a/src/provider/poseidon.rs b/src/provider/poseidon.rs index eee606d8..2445984f 100644 --- a/src/provider/poseidon.rs +++ b/src/provider/poseidon.rs @@ -201,27 +201,31 @@ where #[cfg(test)] mod tests { use super::*; - type S = pasta_curves::pallas::Scalar; - type B = pasta_curves::vesta::Scalar; - type G = pasta_curves::pallas::Point; use crate::{ bellperson::solver::SatisfyingAssignment, constants::NUM_CHALLENGE_BITS, - gadgets::utils::le_bits_to_num, + gadgets::utils::le_bits_to_num, traits::Group, }; use ff::Field; use rand::rngs::OsRng; - #[test] - fn test_poseidon_ro() { + fn test_poseidon_ro_with() + where + // we can print the field elements we get from G's Base & Scalar fields, + // and compare their byte representations + <::Base as PrimeField>::Repr: std::fmt::Debug, + <::Scalar as PrimeField>::Repr: std::fmt::Debug, + <::Base as PrimeField>::Repr: PartialEq<<::Scalar as PrimeField>::Repr>, + { // Check that the number computed inside the circuit is equal to the number computed outside the circuit let mut csprng: OsRng = OsRng; - let constants = PoseidonConstantsCircuit::new(); + let constants = PoseidonConstantsCircuit::::new(); let num_absorbs = 32; - let mut ro: PoseidonRO = PoseidonRO::new(constants.clone(), num_absorbs); - let mut ro_gadget: PoseidonROCircuit = PoseidonROCircuit::new(constants, num_absorbs); + let mut ro: PoseidonRO = PoseidonRO::new(constants.clone(), num_absorbs); + let mut ro_gadget: PoseidonROCircuit = + PoseidonROCircuit::new(constants, num_absorbs); let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); for i in 0..num_absorbs { - let num = S::random(&mut csprng); + let num = G::Scalar::random(&mut csprng); ro.absorb(num); let num_gadget = AllocatedNum::alloc(cs.namespace(|| format!("data {i}")), || Ok(num)).unwrap(); @@ -235,4 +239,11 @@ mod tests { let num2 = le_bits_to_num(&mut cs, num2_bits).unwrap(); assert_eq!(num.to_repr(), num2.get_value().unwrap().to_repr()); } + + #[test] + fn test_poseidon_ro() { + type G = pasta_curves::pallas::Point; + + test_poseidon_ro_with::() + } } diff --git a/src/spartan/pp.rs b/src/spartan/pp.rs index 08dc6f7d..70c4c71f 100644 --- a/src/spartan/pp.rs +++ b/src/spartan/pp.rs @@ -2191,8 +2191,6 @@ impl, C: StepCircuit; use ::bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::marker::PhantomData; use ff::PrimeField; @@ -2248,6 +2246,13 @@ mod tests { #[test] fn test_spartan_snark() { + type G = pasta_curves::pallas::Point; + type EE = crate::provider::ipa_pc::EvaluationEngine; + + test_spartan_snark_with::(); + } + + fn test_spartan_snark_with>() { let circuit = CubicCircuit::default(); // produce keys