From d1637addab540e232fd9698aa83dfcc95ab5763c Mon Sep 17 00:00:00 2001 From: Alex Xiong Date: Fri, 29 Jul 2022 21:53:59 +0800 Subject: [PATCH] no_std compliance and refactor SNARK trait (#87) * update Snark -> UniversalSNARK trait (#80) * update Snark -> UniversalSNARK trait * enable CI on PR targetting cap-rollup branch * address Zhenfei's comment * Restoring no_std compliance (#85) * restore no_std on jf-* * remove HashMap and HashSet for no_std * fix bench.rs, add Display to TaggedBlobError * more no_std fix * put rayon to feature=parallel * use hashbrown for HashMap, update es-commons * simplify rayon-accelerated code * update CHANGELOG --- .github/workflows/build.yml | 10 +- CHANGELOG.md | 2 + plonk/Cargo.toml | 33 +- plonk/benches/bench.rs | 9 +- plonk/examples/proof_of_exp.rs | 2 +- plonk/src/circuit/basic.rs | 41 +-- plonk/src/circuit/customized/ecc/mod.rs | 28 +- .../ultraplonk/plonk_verifier/gadgets.rs | 2 +- .../ultraplonk/plonk_verifier/mod.rs | 2 +- plonk/src/errors.rs | 3 +- plonk/src/lib.rs | 1 + plonk/src/par_utils.rs | 30 ++ plonk/src/proof_system/batch_arg.rs | 22 +- plonk/src/proof_system/mod.rs | 50 +-- plonk/src/proof_system/prover.rs | 181 ++++++----- plonk/src/proof_system/snark.rs | 294 +++++++++--------- plonk/src/proof_system/structs.rs | 43 ++- plonk/src/testing_apis.rs | 5 +- primitives/Cargo.toml | 31 +- primitives/src/elgamal.rs | 57 ++-- primitives/src/merkle_tree.rs | 2 +- primitives/src/signatures/bls.rs | 2 +- primitives/src/signatures/schnorr.rs | 2 +- rescue/Cargo.toml | 25 +- rescue/src/lib.rs | 1 + scripts/check_no_std.sh | 9 + utilities/Cargo.toml | 26 +- utilities/src/serialize.rs | 11 +- 28 files changed, 509 insertions(+), 415 deletions(-) create mode 100644 plonk/src/par_utils.rs create mode 100755 scripts/check_no_std.sh diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4feaa2e0f..9d21f69f1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,8 +7,9 @@ on: pull_request: branches: - main + - cap-rollup schedule: - - cron: '0 0 * * 1' + - cron: "0 0 * * 1" workflow_dispatch: jobs: @@ -64,9 +65,12 @@ jobs: - name: Check Ignored Tests run: cargo test --no-run -- --ignored + - name: Check no_std compilation + run: cargo test --no-run --no-default-features + - name: Test - run: bash ./scripts/run_tests.sh - + run: bash ./scripts/run_tests.sh + - name: Example run: cargo run --release --example proof_of_exp diff --git a/CHANGELOG.md b/CHANGELOG.md index ac4b00037..4061cd334 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ ## Pending - Splitting polynomials are masked to ensure zero-knowledge of Plonk (#76) +- Refactored `UniversalSNARK` trait (#80, #87) +- Restore `no_std` compliance (#85, #87) ## v0.1.2 diff --git a/plonk/Cargo.toml b/plonk/Cargo.toml index 2d5d52414..70fbe18cd 100644 --- a/plonk/Cargo.toml +++ b/plonk/Cargo.toml @@ -11,17 +11,17 @@ jf-utils = { path = "../utilities" } jf-rescue = { path = "../rescue" } ark-std = { version = "0.3.0", default-features = false } -ark-serialize = { version = "0.3.0", default-features = false } -ark-ff = { version = "0.3.0", default-features = false, features = ["asm", "parallel"] } -ark-ec = { version = "0.3.0", default-features = false, features = ["parallel"] } -ark-poly = { version = "0.3.0", default-features = false, features = ["parallel"] } -ark-bn254 = { version = "0.3.0", default-features = false, features = ["curve"] } -ark-bls12-377 = { git = "https://github.com/arkworks-rs/curves", features = ["curve"], rev = "677b4ae751a274037880ede86e9b6f30f62635af" } -ark-bls12-381 = { version = "0.3.0", default-features = false, features = ["curve"] } +ark-serialize = "0.3.0" +ark-ff = { version = "0.3.0", features = [ "asm" ] } +ark-ec = "0.3.0" +ark-poly = "0.3.0" +ark-bn254 = "0.3.0" +ark-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } +ark-bls12-381 = "0.3.0" ark-bw6-761 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } merlin = { version = "3.0.0", default-features = false } -rayon = { version = "1.5.0", default-features = false } +rayon = { version = "1.5.0", optional = true } itertools = { version = "0.10.1", default-features = false } downcast-rs = { version = "1.2.0", default-features = false } serde = { version = "1.0", default-features = false, features = ["derive"] } @@ -30,8 +30,8 @@ derivative = { version = "2", features = ["use_core"] } num-bigint = { version = "0.4", default-features = false} rand_chacha = { version = "0.3.1" } sha3 = "^0.10" -espresso-systems-common = { git = "https://github.com/espressosystems/espresso-systems-common", tag = "0.1.1" } - +espresso-systems-common = { git = "https://github.com/espressosystems/espresso-systems-common", branch = "main" } +hashbrown = "0.12.3" [dependencies.ark-poly-commit] git = "https://github.com/arkworks-rs/poly-commit/" @@ -40,10 +40,10 @@ default-features = false [dev-dependencies] bincode = "1.0" -ark-ed-on-bls12-381 = { version = "0.3.0", default-features = false } +ark-ed-on-bls12-381 = "0.3.0" ark-ed-on-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } -ark-ed-on-bls12-381-bandersnatch = { git = "https://github.com/arkworks-rs/curves", default-features = false, rev = "677b4ae751a274037880ede86e9b6f30f62635af" } -ark-ed-on-bn254 = { version = "0.3.0", default-features = false } +ark-ed-on-bls12-381-bandersnatch = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } +ark-ed-on-bn254 = "0.3.0" hex = "^0.4.3" # Benchmarks @@ -53,6 +53,7 @@ path = "benches/bench.rs" harness = false [features] -std = [] -# exposing apis for testing purpose -test_apis = [] +default = [ "parallel" ] +std = [ "ark-std/std", "ark-serialize/std", "ark-ff/std", "ark-ec/std", "ark-poly/std"] +test_apis = [] # exposing apis for testing purpose +parallel = [ "ark-ff/parallel", "ark-ec/parallel", "ark-poly/parallel", "rayon" ] diff --git a/plonk/benches/bench.rs b/plonk/benches/bench.rs index 3ffb6b806..a2588eefd 100644 --- a/plonk/benches/bench.rs +++ b/plonk/benches/bench.rs @@ -16,10 +16,11 @@ use ark_ff::PrimeField; use jf_plonk::{ circuit::{Circuit, PlonkCircuit}, errors::PlonkError, - proof_system::{PlonkKzgSnark, Snark}, + proof_system::{PlonkKzgSnark, UniversalSNARK}, transcript::StandardTranscript, PlonkType, }; +use std::time::Instant; const NUM_REPETITIONS: usize = 10; const NUM_GATES_LARGE: usize = 32768; @@ -54,7 +55,7 @@ macro_rules! plonk_prove_bench { let (pk, _) = PlonkKzgSnark::<$bench_curve>::preprocess(&srs, &cs).unwrap(); - let start = ark_std::time::Instant::now(); + let start = Instant::now(); for _ in 0..NUM_REPETITIONS { let _ = PlonkKzgSnark::<$bench_curve>::prove::<_, _, StandardTranscript>( @@ -97,7 +98,7 @@ macro_rules! plonk_verify_bench { PlonkKzgSnark::<$bench_curve>::prove::<_, _, StandardTranscript>(rng, &cs, &pk, None) .unwrap(); - let start = ark_std::time::Instant::now(); + let start = Instant::now(); for _ in 0..NUM_REPETITIONS { let _ = @@ -144,7 +145,7 @@ macro_rules! plonk_batch_verify_bench { let public_inputs_ref = vec![&pub_input[..]; $num_proofs]; let proofs_ref = vec![&proof; $num_proofs]; - let start = ark_std::time::Instant::now(); + let start = Instant::now(); for _ in 0..NUM_REPETITIONS { let _ = PlonkKzgSnark::<$bench_curve>::batch_verify::( diff --git a/plonk/examples/proof_of_exp.rs b/plonk/examples/proof_of_exp.rs index c86fd6376..484e8bee7 100644 --- a/plonk/examples/proof_of_exp.rs +++ b/plonk/examples/proof_of_exp.rs @@ -22,7 +22,7 @@ use ark_std::{rand::SeedableRng, UniformRand}; use jf_plonk::{ circuit::{customized::ecc::Point, Arithmetization, Circuit, PlonkCircuit}, errors::PlonkError, - proof_system::{PlonkKzgSnark, Snark}, + proof_system::{PlonkKzgSnark, UniversalSNARK}, transcript::StandardTranscript, }; use jf_utils::fr_to_fq; diff --git a/plonk/src/circuit/basic.rs b/plonk/src/circuit/basic.rs index c7c72bab8..771632532 100644 --- a/plonk/src/circuit/basic.rs +++ b/plonk/src/circuit/basic.rs @@ -10,21 +10,16 @@ use crate::{ circuit::{gates::*, SortedLookupVecAndPolys}, constants::{compute_coset_representatives, GATE_WIDTH, N_MUL_SELECTORS}, errors::{CircuitError::*, PlonkError}, + par_utils::parallelizable_slice_iter, MergeableCircuitType, PlonkType, }; use ark_ff::{FftField, PrimeField}; use ark_poly::{ domain::Radix2EvaluationDomain, univariate::DensePolynomial, EvaluationDomain, UVPolynomial, }; -use ark_std::{ - boxed::Box, - cmp::max, - collections::{HashMap, HashSet}, - format, - string::ToString, - vec, - vec::Vec, -}; +use ark_std::{boxed::Box, cmp::max, format, string::ToString, vec, vec::Vec}; +use hashbrown::{HashMap, HashSet}; +#[cfg(feature = "parallel")] use rayon::prelude::*; /// The wire type identifier for range gates. @@ -1100,12 +1095,9 @@ where .into()); } // order: (lc, mul, hash, o, c, ecc) as specified in spec - let selector_polys: Vec<_> = self - .all_selectors() - .par_iter() + let selector_polys = parallelizable_slice_iter(&self.all_selectors()) .map(|selector| DensePolynomial::from_coefficients_vec(domain.ifft(selector))) .collect(); - Ok(selector_polys) } @@ -1116,14 +1108,16 @@ where let domain = &self.eval_domain; let n = domain.size(); let extended_perm = self.compute_extended_permutation()?; - let extended_perm_polys: Vec> = (0..self.num_wire_types) - .into_par_iter() - .map(|i| { - DensePolynomial::from_coefficients_vec( - domain.ifft(&extended_perm[i * n..(i + 1) * n]), - ) - }) - .collect(); + + let extended_perm_polys: Vec> = + parallelizable_slice_iter(&(0..self.num_wire_types).collect::>()) // current par_utils only support slice iterator, not range iterator. + .map(|i| { + DensePolynomial::from_coefficients_vec( + domain.ifft(&extended_perm[i * n..(i + 1) * n]), + ) + }) + .collect(); + Ok(extended_perm_polys) } @@ -1167,9 +1161,7 @@ where .into()); } let witness = &self.witness; - let wire_polys: Vec<_> = self - .wire_variables - .par_iter() + let wire_polys: Vec> = parallelizable_slice_iter(&self.wire_variables) .take(self.num_wire_types()) .map(|wire_vars| { let mut wire_vec: Vec = wire_vars.iter().map(|&var| witness[var]).collect(); @@ -1177,6 +1169,7 @@ where DensePolynomial::from_coefficients_vec(wire_vec) }) .collect(); + assert_eq!(wire_polys.len(), self.num_wire_types()); Ok(wire_polys) } diff --git a/plonk/src/circuit/customized/ecc/mod.rs b/plonk/src/circuit/customized/ecc/mod.rs index 1c741c4c0..e57ff2c23 100644 --- a/plonk/src/circuit/customized/ecc/mod.rs +++ b/plonk/src/circuit/customized/ecc/mod.rs @@ -572,15 +572,25 @@ fn compute_base_points( // base3 = (3*B, 3*4*B, ..., 3*4^(l-1)*B) let mut bases3 = vec![b]; - rayon::join( - || { - rayon::join( - || fill_bases(&mut bases1, len).ok(), - || fill_bases(&mut bases2, len).ok(), - ) - }, - || fill_bases(&mut bases3, len).ok(), - ); + #[cfg(feature = "parallel")] + { + rayon::join( + || { + rayon::join( + || fill_bases(&mut bases1, len).ok(), + || fill_bases(&mut bases2, len).ok(), + ) + }, + || fill_bases(&mut bases3, len).ok(), + ); + } + + #[cfg(not(feature = "parallel"))] + { + fill_bases(&mut bases1, len).ok(); + fill_bases(&mut bases2, len).ok(); + fill_bases(&mut bases3, len).ok(); + } // converting GroupAffine -> Points here. // Cannot do it earlier: in `fill_bases` we need to do `double` diff --git a/plonk/src/circuit/customized/ultraplonk/plonk_verifier/gadgets.rs b/plonk/src/circuit/customized/ultraplonk/plonk_verifier/gadgets.rs index f6242852e..b9f663316 100644 --- a/plonk/src/circuit/customized/ultraplonk/plonk_verifier/gadgets.rs +++ b/plonk/src/circuit/customized/ultraplonk/plonk_verifier/gadgets.rs @@ -462,7 +462,7 @@ mod test { circuit::Circuit, proof_system::{ batch_arg::{new_mergeable_circuit_for_test, BatchArgument}, - PlonkKzgSnark, + PlonkKzgSnark, UniversalSNARK, }, transcript::{PlonkTranscript, RescueTranscript}, }; diff --git a/plonk/src/circuit/customized/ultraplonk/plonk_verifier/mod.rs b/plonk/src/circuit/customized/ultraplonk/plonk_verifier/mod.rs index 8b271be9e..a5dd655aa 100644 --- a/plonk/src/circuit/customized/ultraplonk/plonk_verifier/mod.rs +++ b/plonk/src/circuit/customized/ultraplonk/plonk_verifier/mod.rs @@ -321,7 +321,7 @@ mod test { proof_system::{ batch_arg::{new_mergeable_circuit_for_test, BatchArgument}, structs::BatchProof, - PlonkKzgSnark, + PlonkKzgSnark, UniversalSNARK, }, transcript::{PlonkTranscript, RescueTranscript}, }; diff --git a/plonk/src/errors.rs b/plonk/src/errors.rs index 95e57ae3e..68b311079 100644 --- a/plonk/src/errors.rs +++ b/plonk/src/errors.rs @@ -44,8 +44,7 @@ pub enum PlonkError { PublicInputsDoNotMatch, } -#[cfg(feature = "std")] -impl std::error::Error for PlonkError {} +impl ark_std::error::Error for PlonkError {} impl From for PlonkError { fn from(e: ark_poly_commit::Error) -> Self { diff --git a/plonk/src/lib.rs b/plonk/src/lib.rs index 4db941c7f..02172b72e 100644 --- a/plonk/src/lib.rs +++ b/plonk/src/lib.rs @@ -21,6 +21,7 @@ extern crate derivative; pub mod circuit; pub mod constants; pub mod errors; +pub(crate) mod par_utils; pub mod proof_system; pub mod transcript; diff --git a/plonk/src/par_utils.rs b/plonk/src/par_utils.rs new file mode 100644 index 000000000..885b6ffd1 --- /dev/null +++ b/plonk/src/par_utils.rs @@ -0,0 +1,30 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Utilities for parallel code. + +/// this function helps with slice iterator creation that optionally use +/// `par_iter()` when feature flag `parallel` is on. +/// +/// # Usage +/// let v = [1, 2, 3, 4, 5]; +/// let sum = parallelizable_slice_iter(&v).sum(); +/// +/// // the above code is a shorthand for (thus equivalent to) +/// #[cfg(feature = "parallel")] +/// let sum = v.par_iter().sum(); +/// #[cfg(not(feature = "parallel"))] +/// let sum = v.iter().sum(); +#[cfg(feature = "parallel")] +pub(crate) fn parallelizable_slice_iter(data: &[T]) -> rayon::slice::Iter { + use rayon::iter::IntoParallelIterator; + data.into_par_iter() +} + +#[cfg(not(feature = "parallel"))] +pub(crate) fn parallelizable_slice_iter(data: &[T]) -> ark_std::slice::Iter { + use ark_std::iter::IntoIterator; + data.iter() +} diff --git a/plonk/src/proof_system/batch_arg.rs b/plonk/src/proof_system/batch_arg.rs index 6e90f42da..aae96371b 100644 --- a/plonk/src/proof_system/batch_arg.rs +++ b/plonk/src/proof_system/batch_arg.rs @@ -11,7 +11,7 @@ use crate::{ proof_system::{ structs::{BatchProof, OpenKey, ProvingKey, ScalarsAndBases, UniversalSrs, VerifyingKey}, verifier::Verifier, - PlonkKzgSnark, + PlonkKzgSnark, UniversalSNARK, }, transcript::PlonkTranscript, MergeableCircuitType, @@ -30,19 +30,19 @@ use jf_rescue::RescueParameter; use jf_utils::multi_pairing; /// A batching argument. -pub struct BatchArgument<'a, E: PairingEngine>(PhantomData<&'a E>); +pub struct BatchArgument(PhantomData); /// A circuit instance that consists of the corresponding proving /// key/verification key/circuit. #[derive(Clone)] -pub struct Instance<'a, E: PairingEngine> { +pub struct Instance { // TODO: considering giving instance an ID - prove_key: ProvingKey<'a, E>, // the verification key can be obtained inside the proving key. + prove_key: ProvingKey, // the verification key can be obtained inside the proving key. circuit: PlonkCircuit, _circuit_type: MergeableCircuitType, } -impl<'a, E: PairingEngine> Instance<'a, E> { +impl Instance { /// Get verification key by reference. pub fn verify_key_ref(&self) -> &VerifyingKey { &self.prove_key.vk @@ -54,7 +54,7 @@ impl<'a, E: PairingEngine> Instance<'a, E> { } } -impl<'a, E, F, P> BatchArgument<'a, E> +impl BatchArgument where E: PairingEngine>, F: RescueParameter + SWToTEConParam, @@ -62,10 +62,10 @@ where { /// Setup the circuit and the proving key for a (mergeable) instance. pub fn setup_instance( - srs: &'a UniversalSrs, + srs: &UniversalSrs, mut circuit: PlonkCircuit, circuit_type: MergeableCircuitType, - ) -> Result, PlonkError> { + ) -> Result, PlonkError> { circuit.finalize_for_mergeable_circuit(circuit_type)?; let (prove_key, _) = PlonkKzgSnark::preprocess(srs, &circuit)?; Ok(Instance { @@ -78,8 +78,8 @@ where /// Prove satisfiability of multiple instances in a batch. pub fn batch_prove( prng: &mut R, - instances_type_a: &[Instance<'a, E>], - instances_type_b: &[Instance<'a, E>], + instances_type_a: &[Instance], + instances_type_b: &[Instance], ) -> Result, PlonkError> where R: CryptoRng + RngCore, @@ -175,7 +175,7 @@ where } } -impl<'a, E> BatchArgument<'a, E> +impl BatchArgument where E: PairingEngine, { diff --git a/plonk/src/proof_system/mod.rs b/plonk/src/proof_system/mod.rs index 8dde62a91..90622ddbf 100644 --- a/plonk/src/proof_system/mod.rs +++ b/plonk/src/proof_system/mod.rs @@ -5,9 +5,11 @@ // along with the Jellyfish library. If not, see . //! Interfaces for Plonk-based proof systems -use crate::{circuit::Arithmetization, errors::PlonkError}; +use crate::circuit::Arithmetization; use ark_ec::PairingEngine; use ark_std::{ + error::Error, + fmt::Debug, rand::{CryptoRng, RngCore}, vec::Vec, }; @@ -19,8 +21,10 @@ pub(crate) mod verifier; use crate::transcript::PlonkTranscript; pub use snark::PlonkKzgSnark; -/// An interface for SNARKs. -pub trait Snark { +// TODO: (alex) should we name it `PlonkishSNARK` instead? since we use +// `PlonkTranscript` on prove and verify. +/// An interface for SNARKs with universal setup. +pub trait UniversalSNARK { /// The SNARK proof computed by the prover. type Proof: Clone; @@ -32,18 +36,28 @@ pub trait Snark { /// specific circuit. type VerifyingKey: Clone; - // TODO: (alex) add back when `trait PolynomialCommitment` is implemented for - // KZG10, and the following can be compiled so that the Snark trait can be - // generic over prime field F. - // pub type UniversalSrs = >>::UniversalParams; - // - // /// Compute the proving/verifying keys from the circuit `circuit`. - // fn preprocess>( - // &self, - // srs: &UniversalSrs, - // circuit: &C, - // ) -> Result<(Self::ProvingKey, Self::VerifyingKey), PlonkError>; + /// Universal Structured Reference String from `universal_setup`, used for + /// all subsequent circuit-specific preprocessing + type UniversalSRS: Clone + Debug; + + /// SNARK related error + type Error: 'static + Error; + + /// Generate the universal SRS for the argument system. + /// This setup is for trusted party to run, and mostly only used for + /// testing purpose. In practice, a MPC flavor of the setup will be carried + /// out to have higher assurance on the "toxic waste"/trapdoor being thrown + /// away to ensure soundness of the argument system. + fn universal_setup( + max_degree: usize, + rng: &mut R, + ) -> Result; + + /// Circuit-specific preprocessing to compute the proving/verifying keys. + fn preprocess>( + srs: &Self::UniversalSRS, + circuit: &C, + ) -> Result<(Self::ProvingKey, Self::VerifyingKey), Self::Error>; /// Compute a SNARK proof of a circuit `circuit`, using the corresponding /// proving key `prove_key`. The witness used to @@ -55,11 +69,11 @@ pub trait Snark { /// resulting proof without any check on the data. It does not incur any /// additional cost in proof size or prove time. fn prove( - prng: &mut R, + rng: &mut R, circuit: &C, prove_key: &Self::ProvingKey, extra_transcript_init_msg: Option>, - ) -> Result + ) -> Result where C: Arithmetization, R: CryptoRng + RngCore, @@ -74,5 +88,5 @@ pub trait Snark { public_input: &[E::Fr], proof: &Self::Proof, extra_transcript_init_msg: Option>, - ) -> Result<(), PlonkError>; + ) -> Result<(), Self::Error>; } diff --git a/plonk/src/proof_system/prover.rs b/plonk/src/proof_system/prover.rs index 4974cdea6..b38171b7b 100644 --- a/plonk/src/proof_system/prover.rs +++ b/plonk/src/proof_system/prover.rs @@ -14,6 +14,7 @@ use crate::{ circuit::Arithmetization, constants::{domain_size_ratio, GATE_WIDTH}, errors::{PlonkError, SnarkError::*}, + par_utils::parallelizable_slice_iter, proof_system::structs::CommitKey, }; use ark_ec::PairingEngine; @@ -23,7 +24,7 @@ use ark_poly::{ Radix2EvaluationDomain, UVPolynomial, }; use ark_poly_commit::{ - kzg10::{Commitment, Randomness, KZG10}, + kzg10::{Commitment, Powers, Randomness, KZG10}, PCRandomness, }; use ark_std::{ @@ -32,6 +33,7 @@ use ark_std::{ vec, vec::Vec, }; +#[cfg(feature = "parallel")] use rayon::prelude::*; type CommitmentsAndPolys = ( @@ -191,14 +193,10 @@ impl Prover { online_oracles: &Oracles, num_wire_types: usize, ) -> ProofEvaluations { - let wires_evals: Vec = online_oracles - .wire_polys - .par_iter() + let wires_evals: Vec = parallelizable_slice_iter(&online_oracles.wire_polys) .map(|poly| poly.evaluate(&challenges.zeta)) .collect(); - let wire_sigma_evals: Vec = pk - .sigmas - .par_iter() + let wire_sigma_evals: Vec = parallelizable_slice_iter(&pk.sigmas) .take(num_wire_types - 1) .map(|poly| poly.evaluate(&challenges.zeta)) .collect(); @@ -458,8 +456,7 @@ impl Prover { ck: &CommitKey, polys: &[DensePolynomial], ) -> Result>, PlonkError> { - let poly_comms = polys - .par_iter() + let poly_comms = parallelizable_slice_iter(polys) .map(|poly| Self::commit_polynomial(ck, poly)) .collect::, _>>()?; Ok(poly_comms) @@ -471,7 +468,9 @@ impl Prover { ck: &CommitKey, poly: &DensePolynomial, ) -> Result, PlonkError> { - let (poly_comm, _) = KZG10::commit(ck, poly, None, None).map_err(PlonkError::PcsError)?; + let powers: Powers<'_, E> = ck.into(); + let (poly_comm, _) = + KZG10::commit(&powers, poly, None, None).map_err(PlonkError::PcsError)?; Ok(poly_comm) } @@ -539,22 +538,17 @@ impl Prover { let lookup_flag = pk.plookup_pk.is_some(); // Compute coset evaluations. - let selectors_coset_fft: Vec> = pk - .selectors - .par_iter() + let selectors_coset_fft: Vec> = parallelizable_slice_iter(&pk.selectors) .map(|poly| self.quot_domain.coset_fft(poly.coeffs())) .collect(); - let sigmas_coset_fft: Vec> = pk - .sigmas - .par_iter() + let sigmas_coset_fft: Vec> = parallelizable_slice_iter(&pk.sigmas) .map(|poly| self.quot_domain.coset_fft(poly.coeffs())) .collect(); + let wire_polys_coset_fft: Vec> = + parallelizable_slice_iter(&oracles.wire_polys) + .map(|poly| self.quot_domain.coset_fft(poly.coeffs())) + .collect(); - let wire_polys_coset_fft: Vec> = oracles - .wire_polys - .par_iter() - .map(|poly| self.quot_domain.coset_fft(poly.coeffs())) - .collect(); // TODO: (binyi) we can also compute below in parallel with // `wire_polys_coset_fft`. let prod_perm_poly_coset_fft = @@ -583,12 +577,10 @@ impl Prover { let key_table_coset_fft = self .quot_domain .coset_fft(pk.plookup_pk.as_ref().unwrap().key_table_poly.coeffs()); // safe unwrap - let h_coset_ffts: Vec> = oracles - .plookup_oracles - .h_polys - .par_iter() - .map(|poly| self.quot_domain.coset_fft(poly.coeffs())) - .collect(); + let h_coset_ffts: Vec> = + parallelizable_slice_iter(&oracles.plookup_oracles.h_polys) + .map(|poly| self.quot_domain.coset_fft(poly.coeffs())) + .collect(); let prod_lookup_poly_coset_fft = self .quot_domain .coset_fft(oracles.plookup_oracles.prod_lookup_poly.coeffs()); @@ -605,58 +597,63 @@ impl Prover { }; // Compute coset evaluations of the quotient polynomial. - let quot_poly_coset_evals: Vec = (0..m) - .into_par_iter() - .map(|i| { - let w: Vec = (0..num_wire_types) - .map(|j| wire_polys_coset_fft[j][i]) - .collect(); - let w_next: Vec = (0..num_wire_types) - .map(|j| wire_polys_coset_fft[j][(i + domain_size_ratio) % m]) - .collect(); - - let t_circ = Self::compute_quotient_circuit_contribution( - i, - &w, - &pub_input_poly_coset_fft[i], - &selectors_coset_fft, - ); - let (t_perm_1, t_perm_2) = Self::compute_quotient_copy_constraint_contribution( - i, - self.quot_domain.element(i) * E::Fr::multiplicative_generator(), - pk, - &w, - &prod_perm_poly_coset_fft[i], - &prod_perm_poly_coset_fft[(i + domain_size_ratio) % m], - challenges, - &sigmas_coset_fft, - ); - let mut t1 = t_circ + t_perm_1; - let mut t2 = t_perm_2; - - // add Plookup-related terms - if lookup_flag { - let (t_lookup_1, t_lookup_2) = self.compute_quotient_plookup_contribution( + let quot_poly_coset_evals: Vec = + parallelizable_slice_iter(&(0..m).collect::>()) + .map(|&i| { + let w: Vec = (0..num_wire_types) + .map(|j| wire_polys_coset_fft[j][i]) + .collect(); + let w_next: Vec = (0..num_wire_types) + .map(|j| wire_polys_coset_fft[j][(i + domain_size_ratio) % m]) + .collect(); + + let t_circ = Self::compute_quotient_circuit_contribution( i, - self.quot_domain.element(i) * E::Fr::multiplicative_generator(), - pk, &w, - &w_next, - h_coset_ffts.as_ref().unwrap(), - prod_lookup_poly_coset_fft.as_ref().unwrap(), - range_table_coset_fft.as_ref().unwrap(), - key_table_coset_fft.as_ref().unwrap(), - selectors_coset_fft.last().unwrap(), // TODO: add a method to extract q_lookup_coset_fft - table_dom_sep_coset_fft.as_ref().unwrap(), - q_dom_sep_coset_fft.as_ref().unwrap(), - challenges, + &pub_input_poly_coset_fft[i], + &selectors_coset_fft, ); - t1 += t_lookup_1; - t2 += t_lookup_2; - } - t1 * z_h_inv[i % domain_size_ratio] + t2 - }) - .collect(); + let (t_perm_1, t_perm_2) = + Self::compute_quotient_copy_constraint_contribution( + i, + self.quot_domain.element(i) * E::Fr::multiplicative_generator(), + pk, + &w, + &prod_perm_poly_coset_fft[i], + &prod_perm_poly_coset_fft[(i + domain_size_ratio) % m], + challenges, + &sigmas_coset_fft, + ); + let mut t1 = t_circ + t_perm_1; + let mut t2 = t_perm_2; + + // add Plookup-related terms + if lookup_flag { + let (t_lookup_1, t_lookup_2) = self + .compute_quotient_plookup_contribution( + i, + self.quot_domain.element(i) * E::Fr::multiplicative_generator(), + pk, + &w, + &w_next, + h_coset_ffts.as_ref().unwrap(), + prod_lookup_poly_coset_fft.as_ref().unwrap(), + range_table_coset_fft.as_ref().unwrap(), + key_table_coset_fft.as_ref().unwrap(), + selectors_coset_fft.last().unwrap(), /* TODO: add a method + * to extract + * q_lookup_coset_fft */ + table_dom_sep_coset_fft.as_ref().unwrap(), + q_dom_sep_coset_fft.as_ref().unwrap(), + challenges, + ); + t1 += t_lookup_1; + t2 += t_lookup_2; + } + t1 * z_h_inv[i % domain_size_ratio] + t2 + }) + .collect(); + for (a, b) in quot_poly_coset_evals_sum .iter_mut() .zip(quot_poly_coset_evals.iter()) @@ -917,20 +914,20 @@ impl Prover { let n = self.domain.size(); // compute the splitting polynomials t'_i(X) s.t. t(X) = // \sum_{i=0}^{num_wire_types} X^{i*(n+2)} * t'_i(X) - let mut split_quot_polys: Vec> = (0..num_wire_types) - .into_par_iter() - .map(|i| { - let end = if i < num_wire_types - 1 { - (i + 1) * (n + 2) - } else { - quot_poly.degree() + 1 - }; - // Degree-(n+1) polynomial has n + 2 coefficients. - DensePolynomial::::from_coefficients_slice( - "_poly.coeffs[i * (n + 2)..end], - ) - }) - .collect(); + let mut split_quot_polys: Vec> = + parallelizable_slice_iter(&(0..num_wire_types).collect::>()) + .map(|&i| { + let end = if i < num_wire_types - 1 { + (i + 1) * (n + 2) + } else { + quot_poly.degree() + 1 + }; + // Degree-(n+1) polynomial has n + 2 coefficients. + DensePolynomial::::from_coefficients_slice( + "_poly.coeffs[i * (n + 2)..end], + ) + }) + .collect(); // mask splitting polynomials t_i(X), for i in {0..num_wire_types}. // t_i(X) = t'_i(X) - b_last_i + b_now_i * X^(n+2) @@ -1108,7 +1105,9 @@ impl Prover { #[inline] fn mul_poly(poly: &DensePolynomial, coeff: &E::Fr) -> DensePolynomial { DensePolynomial::::from_coefficients_vec( - poly.coeffs.par_iter().map(|c| *coeff * c).collect(), + parallelizable_slice_iter(&poly.coeffs) + .map(|c| *coeff * c) + .collect(), ) } } diff --git a/plonk/src/proof_system/snark.rs b/plonk/src/proof_system/snark.rs index 44f9798b5..7b694662f 100644 --- a/plonk/src/proof_system/snark.rs +++ b/plonk/src/proof_system/snark.rs @@ -12,12 +12,13 @@ use super::{ PlookupVerifyingKey, Proof, ProvingKey, VerifyingKey, }, verifier::Verifier, - Snark, + UniversalSNARK, }; use crate::{ circuit::{customized::ecc::SWToTEConParam, Arithmetization}, constants::{compute_coset_representatives, EXTRA_TRANSCRIPT_MSG_LABEL}, errors::{PlonkError, SnarkError::ParameterError}, + par_utils::parallelizable_slice_iter, proof_system::structs::UniversalSrs, transcript::*, }; @@ -34,12 +35,13 @@ use ark_std::{ vec::Vec, }; use jf_rescue::RescueParameter; +#[cfg(feature = "parallel")] use rayon::prelude::*; /// A Plonk instantiated with KZG PCS -pub struct PlonkKzgSnark<'a, E: PairingEngine>(PhantomData<&'a E>); +pub struct PlonkKzgSnark(PhantomData); -impl<'a, E, F, P> PlonkKzgSnark<'a, E> +impl PlonkKzgSnark where E: PairingEngine>, F: RescueParameter + SWToTEConParam, @@ -51,137 +53,11 @@ where Self(PhantomData) } - /// Generate the universal SRS for the argument system. - /// This setup is for trusted party to run, and mostly only used for - /// testing purpose. In practice, a MPC flavor of the setup will be carried - /// out to have higher assurance on the "toxic waste"/trapdoor being thrown - /// away to ensure soundness of the argument system. - pub fn universal_setup( - max_degree: usize, - rng: &mut R, - ) -> Result, PlonkError> { - let srs = KZG10::>::setup(max_degree, false, rng)?; - Ok(UniversalSrs(srs)) - } - - // TODO: (alex) move back to Snark trait when `trait PolynomialCommitment` is - // implemented for KZG10 - /// Input a circuit and the SRS, precompute the proving key and verification - /// key. - pub fn preprocess>( - srs: &'a UniversalSrs, - circuit: &C, - ) -> Result<(ProvingKey<'a, E>, VerifyingKey), PlonkError> { - // Make sure the SRS can support the circuit (with hiding degree of 2 for zk) - let domain_size = circuit.eval_domain_size()?; - let srs_size = circuit.srs_size()?; - let num_inputs = circuit.num_inputs(); - if srs.0.max_degree() < circuit.srs_size()? { - return Err(PlonkError::IndexTooLarge); - } - // 1. Compute selector and permutation polynomials. - let selectors_polys = circuit.compute_selector_polynomials()?; - let sigma_polys = circuit.compute_extended_permutation_polynomials()?; - - // Compute Plookup proving key if support lookup. - let plookup_pk = if circuit.support_lookup() { - let range_table_poly = circuit.compute_range_table_polynomial()?; - let key_table_poly = circuit.compute_key_table_polynomial()?; - let table_dom_sep_poly = circuit.compute_table_dom_sep_polynomial()?; - let q_dom_sep_poly = circuit.compute_q_dom_sep_polynomial()?; - Some(PlookupProvingKey { - range_table_poly, - key_table_poly, - table_dom_sep_poly, - q_dom_sep_poly, - }) - } else { - None - }; - - // 2. Compute VerifyingKey - let (commit_key, open_key) = trim(&srs.0, srs_size); - let selector_comms: Vec<_> = selectors_polys - .par_iter() - .map(|poly| { - let (comm, _) = KZG10::commit(&commit_key, poly, None, None)?; - Ok(comm) - }) - .collect::, PlonkError>>()? - .into_iter() - .collect(); - let sigma_comms: Vec<_> = sigma_polys - .par_iter() - .map(|poly| { - let (comm, _) = KZG10::commit(&commit_key, poly, None, None)?; - Ok(comm) - }) - .collect::, PlonkError>>()? - .into_iter() - .collect(); - // Compute Plookup verifying key if support lookup. - let plookup_vk = match circuit.support_lookup() { - false => None, - true => Some(PlookupVerifyingKey { - range_table_comm: KZG10::commit( - &commit_key, - &plookup_pk.as_ref().unwrap().range_table_poly, - None, - None, - )? - .0, - key_table_comm: KZG10::commit( - &commit_key, - &plookup_pk.as_ref().unwrap().key_table_poly, - None, - None, - )? - .0, - table_dom_sep_comm: KZG10::commit( - &commit_key, - &plookup_pk.as_ref().unwrap().table_dom_sep_poly, - None, - None, - )? - .0, - q_dom_sep_comm: KZG10::commit( - &commit_key, - &plookup_pk.as_ref().unwrap().q_dom_sep_poly, - None, - None, - )? - .0, - }), - }; - - let vk = VerifyingKey { - domain_size, - num_inputs, - selector_comms, - sigma_comms, - k: compute_coset_representatives(circuit.num_wire_types(), Some(domain_size)), - open_key, - plookup_vk, - is_merged: false, - }; - - // Compute ProvingKey (which includes the VerifyingKey) - let pk = ProvingKey { - sigmas: sigma_polys, - selectors: selectors_polys, - commit_key, - vk: vk.clone(), - plookup_pk, - }; - - Ok((pk, vk)) - } - /// Generate an aggregated Plonk proof for multiple instances. pub fn batch_prove( prng: &mut R, circuits: &[&C], - prove_keys: &[&ProvingKey<'a, E>], + prove_keys: &[&ProvingKey], ) -> Result, PlonkError> where C: Arithmetization, @@ -247,11 +123,10 @@ where ); } - let pcs_infos = verify_keys - .par_iter() - .zip(proofs.par_iter()) - .zip(public_inputs.par_iter()) - .zip(extra_transcript_init_msgs.par_iter()) + let pcs_infos = parallelizable_slice_iter(verify_keys) + .zip(parallelizable_slice_iter(proofs)) + .zip(parallelizable_slice_iter(public_inputs)) + .zip(parallelizable_slice_iter(extra_transcript_init_msgs)) .map(|(((&vk, &proof), &pub_input), extra_msg)| { let verifier = Verifier::new(vk.domain_size)?; verifier.prepare_pcs_info::( @@ -284,7 +159,7 @@ where fn batch_prove_internal( prng: &mut R, circuits: &[&C], - prove_keys: &[&ProvingKey<'a, E>], + prove_keys: &[&ProvingKey], extra_transcript_init_msg: Option>, ) -> Result<(BatchProof, Vec>, Challenges), PlonkError> where @@ -539,17 +414,135 @@ where } } -impl<'a, E, F, P> Snark for PlonkKzgSnark<'a, E> +impl UniversalSNARK for PlonkKzgSnark where E: PairingEngine>, F: RescueParameter + SWToTEConParam, P: SWModelParameters + Clone, { type Proof = Proof; + type ProvingKey = ProvingKey; + type VerifyingKey = VerifyingKey; + type UniversalSRS = UniversalSrs; + type Error = PlonkError; - type ProvingKey = ProvingKey<'a, E>; + fn universal_setup( + max_degree: usize, + rng: &mut R, + ) -> Result { + let srs = KZG10::>::setup(max_degree, false, rng)?; + Ok(UniversalSrs(srs)) + } - type VerifyingKey = VerifyingKey; + /// Input a circuit and the SRS, precompute the proving key and verification + /// key. + fn preprocess>( + srs: &Self::UniversalSRS, + circuit: &C, + ) -> Result<(Self::ProvingKey, Self::VerifyingKey), Self::Error> { + // Make sure the SRS can support the circuit (with hiding degree of 2 for zk) + let domain_size = circuit.eval_domain_size()?; + let srs_size = circuit.srs_size()?; + let num_inputs = circuit.num_inputs(); + if srs.0.max_degree() < circuit.srs_size()? { + return Err(PlonkError::IndexTooLarge); + } + // 1. Compute selector and permutation polynomials. + let selectors_polys = circuit.compute_selector_polynomials()?; + let sigma_polys = circuit.compute_extended_permutation_polynomials()?; + + // Compute Plookup proving key if support lookup. + let plookup_pk = if circuit.support_lookup() { + let range_table_poly = circuit.compute_range_table_polynomial()?; + let key_table_poly = circuit.compute_key_table_polynomial()?; + let table_dom_sep_poly = circuit.compute_table_dom_sep_polynomial()?; + let q_dom_sep_poly = circuit.compute_q_dom_sep_polynomial()?; + Some(PlookupProvingKey { + range_table_poly, + key_table_poly, + table_dom_sep_poly, + q_dom_sep_poly, + }) + } else { + None + }; + + // 2. Compute VerifyingKey + let (commit_key, open_key) = trim(&srs.0, srs_size); + let selector_comms = parallelizable_slice_iter(&selectors_polys) + .map(|poly| { + let (comm, _) = KZG10::commit(&commit_key, poly, None, None)?; + Ok(comm) + }) + .collect::, PlonkError>>()? + .into_iter() + .collect(); + let sigma_comms = parallelizable_slice_iter(&sigma_polys) + .map(|poly| { + let (comm, _) = KZG10::commit(&commit_key, poly, None, None)?; + Ok(comm) + }) + .collect::, PlonkError>>()? + .into_iter() + .collect(); + + // Compute Plookup verifying key if support lookup. + let plookup_vk = match circuit.support_lookup() { + false => None, + true => Some(PlookupVerifyingKey { + range_table_comm: KZG10::commit( + &commit_key, + &plookup_pk.as_ref().unwrap().range_table_poly, + None, + None, + )? + .0, + key_table_comm: KZG10::commit( + &commit_key, + &plookup_pk.as_ref().unwrap().key_table_poly, + None, + None, + )? + .0, + table_dom_sep_comm: KZG10::commit( + &commit_key, + &plookup_pk.as_ref().unwrap().table_dom_sep_poly, + None, + None, + )? + .0, + q_dom_sep_comm: KZG10::commit( + &commit_key, + &plookup_pk.as_ref().unwrap().q_dom_sep_poly, + None, + None, + )? + .0, + }), + }; + + let vk = VerifyingKey { + domain_size, + num_inputs, + selector_comms, + sigma_comms, + k: compute_coset_representatives(circuit.num_wire_types(), Some(domain_size)), + open_key, + plookup_vk, + is_merged: false, + }; + + // Compute ProvingKey (which includes the VerifyingKey) + let pk = ProvingKey { + sigmas: sigma_polys, + selectors: selectors_polys, + commit_key: commit_key.into(), + vk: vk.clone(), + plookup_pk, + }; + + Ok((pk, vk)) + } /// Compute a Plonk proof. /// Refer to Sec 8.4 of @@ -557,18 +550,18 @@ where /// `circuit` and `prove_key` has to be consistent (with the same evaluation /// domain etc.), otherwise return error. fn prove( - prng: &mut R, + rng: &mut R, circuit: &C, prove_key: &Self::ProvingKey, extra_transcript_init_msg: Option>, - ) -> Result + ) -> Result where C: Arithmetization, R: CryptoRng + RngCore, T: PlonkTranscript, { let (batch_proof, ..) = Self::batch_prove_internal::<_, _, T>( - prng, + rng, &[circuit], &[prove_key], extra_transcript_init_msg, @@ -589,7 +582,7 @@ where public_input: &[E::Fr], proof: &Self::Proof, extra_transcript_init_msg: Option>, - ) -> Result<(), PlonkError> + ) -> Result<(), Self::Error> where T: PlonkTranscript, { @@ -613,7 +606,7 @@ pub mod test { eval_merged_lookup_witness, eval_merged_table, Challenges, Oracles, Proof, ProvingKey, UniversalSrs, VerifyingKey, }, - PlonkKzgSnark, Snark, + PlonkKzgSnark, UniversalSNARK, }, transcript::{ rescue::RescueTranscript, solidity::SolidityTranscript, standard::StandardTranscript, @@ -631,7 +624,7 @@ pub mod test { univariate::DensePolynomial, EvaluationDomain, Polynomial, Radix2EvaluationDomain, UVPolynomial, }; - use ark_poly_commit::kzg10::{Commitment, KZG10}; + use ark_poly_commit::kzg10::{Commitment, Powers, KZG10}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{ convert::TryInto, @@ -777,20 +770,23 @@ pub mod test { .iter() .zip(vk.selector_comms.iter()) .for_each(|(p, &p_comm)| { - let (expected_comm, _) = KZG10::commit(&pk.commit_key, p, None, None).unwrap(); + let powers: Powers<'_, E> = (&pk.commit_key).into(); + let (expected_comm, _) = KZG10::commit(&powers, p, None, None).unwrap(); assert_eq!(expected_comm, p_comm); }); sigmas .iter() .zip(vk.sigma_comms.iter()) .for_each(|(p, &p_comm)| { - let (expected_comm, _) = KZG10::commit(&pk.commit_key, p, None, None).unwrap(); + let powers: Powers<'_, E> = (&pk.commit_key).into(); + let (expected_comm, _) = KZG10::commit(&powers, p, None, None).unwrap(); assert_eq!(expected_comm, p_comm); }); // check plookup verification key if plonk_type == PlonkType::UltraPlonk { + let powers: Powers<'_, E> = (&pk.commit_key).into(); let (expected_comm, _) = KZG10::commit( - &pk.commit_key, + &powers, &pk.plookup_pk.as_ref().unwrap().range_table_poly, None, None, @@ -802,7 +798,7 @@ pub mod test { ); let (expected_comm, _) = KZG10::commit( - &pk.commit_key, + &powers, &pk.plookup_pk.as_ref().unwrap().key_table_poly, None, None, diff --git a/plonk/src/proof_system/structs.rs b/plonk/src/proof_system/structs.rs index 4849d8558..556a83cbd 100644 --- a/plonk/src/proof_system/structs.rs +++ b/plonk/src/proof_system/structs.rs @@ -30,14 +30,14 @@ use ark_poly::univariate::DensePolynomial; use ark_poly_commit::kzg10::{Commitment, Powers, UniversalParams, VerifierKey}; use ark_serialize::*; use ark_std::{ - collections::HashMap, convert::{TryFrom, TryInto}, format, string::ToString, vec, vec::Vec, }; -use espresso_systems_common::jellyfish as tag; +use espresso_systems_common::jellyfish::tag; +use hashbrown::HashMap; use jf_rescue::RescueParameter; use jf_utils::{field_switching, fq_to_fr, fr_to_fq, tagged_blob}; @@ -52,7 +52,29 @@ impl UniversalSrs { } } -pub(crate) type CommitKey<'a, E> = Powers<'a, E>; +#[derive(Debug, Clone, PartialEq, CanonicalSerialize, CanonicalDeserialize)] +pub(crate) struct CommitKey { + pub(crate) powers_of_g: Vec, +} + +impl From> for CommitKey { + fn from(powers: Powers<'_, E>) -> Self { + Self { + powers_of_g: powers.powers_of_g.to_vec(), + } + } +} + +impl<'a, E: PairingEngine> From<&'a CommitKey> for Powers<'a, E> { + fn from(ck: &'a CommitKey) -> Self { + Self { + // Copy-on-write ensure reference passing as smart pointer for read-only access + powers_of_g: ark_std::borrow::Cow::Borrowed(&ck.powers_of_g), + // didn't use hiding variant of KZG, thus leave it empty + powers_of_gamma_g: ark_std::borrow::Cow::Owned(Vec::new()), + } + } +} /// Key for verifying PCS opening proof (alias to kzg10::VerifierKey). pub type OpenKey = VerifierKey; @@ -551,7 +573,7 @@ impl PlookupEvaluations { /// Preprocessed prover parameters used to compute Plonk proofs for a certain /// circuit. #[derive(Debug, Clone, PartialEq, CanonicalSerialize, CanonicalDeserialize)] -pub struct ProvingKey<'a, E: PairingEngine> { +pub struct ProvingKey { /// Extended permutation (sigma) polynomials. pub(crate) sigmas: Vec>, @@ -559,7 +581,7 @@ pub struct ProvingKey<'a, E: PairingEngine> { pub(crate) selectors: Vec>, // KZG PCS committing key. - pub(crate) commit_key: CommitKey<'a, E>, + pub(crate) commit_key: CommitKey, /// The verifying key. It is used by prover to initialize transcripts. pub vk: VerifyingKey, @@ -585,7 +607,7 @@ pub struct PlookupProvingKey { pub(crate) q_dom_sep_poly: DensePolynomial, } -impl<'a, E: PairingEngine> ProvingKey<'a, E> { +impl ProvingKey { /// The size of the evaluation domain. Should be a power of two. pub(crate) fn domain_size(&self) -> usize { self.vk.domain_size @@ -884,19 +906,20 @@ impl ScalarsAndBases { let entry_scalar = self.base_scalar_map.entry(base).or_insert_with(E::Fr::zero); *entry_scalar += scalar; } + /// Add a list of scalars and bases into self, where each scalar is /// multiplied by a constant c. pub(crate) fn merge(&mut self, c: E::Fr, scalars_and_bases: &Self) { - for (&base, scalar) in &scalars_and_bases.base_scalar_map { - self.push(c * scalar, base); + for (base, scalar) in &scalars_and_bases.base_scalar_map { + self.push(c * scalar, *base); } } /// Compute the multi-scalar multiplication. pub(crate) fn multi_scalar_mul(&self) -> E::G1Projective { let mut bases = vec![]; let mut scalars = vec![]; - for (&base, scalar) in &self.base_scalar_map { - bases.push(base); + for (base, scalar) in &self.base_scalar_map { + bases.push(*base); scalars.push(scalar.into_repr()); } VariableBaseMSM::multi_scalar_mul(&bases, &scalars) diff --git a/plonk/src/testing_apis.rs b/plonk/src/testing_apis.rs index 65f67cf00..dcce3f893 100644 --- a/plonk/src/testing_apis.rs +++ b/plonk/src/testing_apis.rs @@ -9,6 +9,8 @@ //! The functions and structs in this file should not be used for other //! purposes. +#![allow(missing_docs)] + use crate::{ circuit::customized::ecc::SWToTEConParam, errors::PlonkError, @@ -22,7 +24,8 @@ use ark_ec::{short_weierstrass_jacobian::GroupAffine, PairingEngine, SWModelPara use ark_ff::Field; use ark_poly::Radix2EvaluationDomain; use ark_poly_commit::kzg10::Commitment; -use ark_std::{collections::HashMap, vec::Vec}; +use ark_std::vec::Vec; +use hashbrown::HashMap; use jf_rescue::RescueParameter; /// A wrapper of crate::proof_system::structs::Challenges diff --git a/primitives/Cargo.toml b/primitives/Cargo.toml index c7750ffbf..6b40e0c03 100644 --- a/primitives/Cargo.toml +++ b/primitives/Cargo.toml @@ -9,16 +9,16 @@ license = "MIT" [dependencies] # ark -ark-ff = { version = "0.3.0", default-features = false } +ark-ff = "0.3.0" ark-std = { version = "0.3.0", default-features = false } -ark-ec = { version = "0.3.0", default-features = false } -ark-serialize = { version = "0.3.0", default-features = false } +ark-ec = "0.3.0" +ark-serialize = "0.3.0" # ark curves -ark-bls12-381 = { version = "0.3.0", default-features = false, features = ["curve"] } -ark-bls12-377 = { git = "https://github.com/arkworks-rs/curves", default-features = false, features = ["curve"], rev = "677b4ae751a274037880ede86e9b6f30f62635af" } -ark-ed-on-bls12-377 = { git = "https://github.com/arkworks-rs/curves", default-features = false, rev = "677b4ae751a274037880ede86e9b6f30f62635af"} -ark-ed-on-bls12-381 = { version = "0.3.0", default-features = false } +ark-bls12-381 = "0.3.0" +ark-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } +ark-ed-on-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af"} +ark-ed-on-bls12-381 = "0.3.0" # jellyfish jf-plonk = { path = "../plonk" } @@ -26,10 +26,10 @@ jf-rescue = { path = "../rescue" } jf-utils = { path = "../utilities" } # others -rayon = { version = "1.5.0", default-features = false } +rayon = { version = "1.5.0", optional = true } zeroize = { version = "1.3", default-features = false } itertools = { version = "0.10.1", default-features = false, features = [ "use_alloc" ] } -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", default-features = false, features = ["derive"] } generic-array = { version = "^0.14", default-features = false } crypto_box = { version = "0.7.1", default-features = false, features = [ "u64_backend", "alloc" ] } displaydoc = { version = "0.2.3", default-features = false } @@ -37,23 +37,24 @@ derivative = { version = "2", features = ["use_core"] } rand_chacha = { version = "0.3.1", default-features = false } sha2 = { version = "0.10.1", default-features = false } digest = { version = "0.10.1", default-features = false } -espresso-systems-common = { git = "https://github.com/espressosystems/espresso-systems-common", tag = "0.1.1" } +espresso-systems-common = { git = "https://github.com/espressosystems/espresso-systems-common", branch = "main" } [dev-dependencies] -rand_chacha = "^0.3" bincode = "1.0" quickcheck = "1.0.0" criterion = "0.3.1" # ark curves -ark-ed-on-bls12-381-bandersnatch = { git = "https://github.com/arkworks-rs/curves", default-features = false, rev = "677b4ae751a274037880ede86e9b6f30f62635af" } -ark-ed-on-bn254 = { version = "0.3.0", default-features = false } -ark-bn254 = { version = "0.3.0", default-features = false, features = ["curve"] } -ark-bw6-761 = { git = "https://github.com/arkworks-rs/curves", default-features = false, rev = "677b4ae751a274037880ede86e9b6f30f62635af" } +ark-ed-on-bls12-381-bandersnatch = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } +ark-ed-on-bn254 = "0.3.0" +ark-bn254 = "0.3.0" +ark-bw6-761 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } [[bench]] name = "merkle_path" harness = false [features] +default = [ "parallel" ] std = [] +parallel = [ "jf-plonk/parallel", "rayon" ] diff --git a/primitives/src/elgamal.rs b/primitives/src/elgamal.rs index 5208b267d..03f7a90aa 100644 --- a/primitives/src/elgamal.rs +++ b/primitives/src/elgamal.rs @@ -26,10 +26,8 @@ use ark_std::{ }; use jf_rescue::{Permutation, RescueParameter, RescueVector, PRP, STATE_SIZE}; use jf_utils::pad_with_zeros; -use rayon::{ - iter::{IndexedParallelIterator, ParallelIterator}, - prelude::ParallelSliceMut, -}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use zeroize::Zeroize; // ===================================================== @@ -309,28 +307,37 @@ where // temporarily append dummy padding element pad_with_zeros(&mut output, STATE_SIZE); - output - .par_chunks_exact_mut(STATE_SIZE) - .enumerate() - .for_each(|(i, output_chunk)| { - let stream_chunk = prp.prp_with_round_keys( - &round_keys, - &RescueVector::from(&[ - nonce.add(F::from(i as u64)), - F::zero(), - F::zero(), - F::zero(), - ]), - ); - for (output_elem, stream_elem) in - output_chunk.iter_mut().zip(stream_chunk.elems().iter()) - { - match direction { - Direction::Encrypt => output_elem.add_assign(stream_elem), - Direction::Decrypt => output_elem.sub_assign(stream_elem), - } + let round_fn = |(idx, output_chunk): (usize, &mut [F])| { + let stream_chunk = prp.prp_with_round_keys( + &round_keys, + &RescueVector::from(&[ + nonce.add(F::from(idx as u64)), + F::zero(), + F::zero(), + F::zero(), + ]), + ); + for (output_elem, stream_elem) in output_chunk.iter_mut().zip(stream_chunk.elems().iter()) { + match direction { + Direction::Encrypt => output_elem.add_assign(stream_elem), + Direction::Decrypt => output_elem.sub_assign(stream_elem), } - }); + } + }; + #[cfg(feature = "parallel")] + { + output + .par_chunks_exact_mut(STATE_SIZE) + .enumerate() + .for_each(round_fn); + } + #[cfg(not(feature = "parallel"))] + { + output + .chunks_exact_mut(STATE_SIZE) + .enumerate() + .for_each(round_fn); + } // remove dummy padding elements output.truncate(data.len()); output diff --git a/primitives/src/merkle_tree.rs b/primitives/src/merkle_tree.rs index a1211ed12..138db2d26 100644 --- a/primitives/src/merkle_tree.rs +++ b/primitives/src/merkle_tree.rs @@ -29,7 +29,7 @@ use ark_std::{ vec::Vec, }; use core::{convert::TryFrom, fmt::Debug}; -use espresso_systems_common::jellyfish as tag; +use espresso_systems_common::jellyfish::tag; use jf_rescue::{Permutation, RescueParameter}; use jf_utils::tagged_blob; use serde::{Deserialize, Serialize}; diff --git a/primitives/src/signatures/bls.rs b/primitives/src/signatures/bls.rs index 986a17de6..6c008cdc4 100644 --- a/primitives/src/signatures/bls.rs +++ b/primitives/src/signatures/bls.rs @@ -19,7 +19,7 @@ use ark_std::{ One, UniformRand, }; use core::marker::PhantomData; -use espresso_systems_common::jellyfish as tag; +use espresso_systems_common::jellyfish::tag; use jf_utils::{multi_pairing, tagged_blob}; /// BLS signature scheme. diff --git a/primitives/src/signatures/schnorr.rs b/primitives/src/signatures/schnorr.rs index bb543925e..3d7203834 100644 --- a/primitives/src/signatures/schnorr.rs +++ b/primitives/src/signatures/schnorr.rs @@ -23,7 +23,7 @@ use ark_std::{ string::ToString, vec, }; -use espresso_systems_common::jellyfish as tag; +use espresso_systems_common::jellyfish::tag; use jf_rescue::{Permutation, RescueParameter}; use jf_utils::{fq_to_fr, fq_to_fr_with_mask, fr_to_fq, tagged_blob}; use zeroize::Zeroize; diff --git a/rescue/Cargo.toml b/rescue/Cargo.toml index e591a854a..d6f72c09a 100644 --- a/rescue/Cargo.toml +++ b/rescue/Cargo.toml @@ -9,27 +9,26 @@ license = "MIT" [dependencies] # ark -ark-ff = { version = "0.3.0", default-features = false } +ark-ff = "0.3.0" ark-std = { version = "0.3.0", default-features = false } -ark-ec = { version = "0.3.0", default-features = false } -ark-serialize = { version = "0.3.0", default-features = false } +ark-ec = "0.3.0" +ark-serialize = "0.3.0" # ark cruves -ark-ed-on-bls12-377 = { git = "https://github.com/arkworks-rs/curves", default-features = false, rev = "677b4ae751a274037880ede86e9b6f30f62635af"} -ark-ed-on-bls12-381 = { version = "0.3.0", default-features = false } -ark-ed-on-bn254 = { version = "0.3.0", default-features = false } -ark-bls12-377 = { git = "https://github.com/arkworks-rs/curves", default-features = false, features = ["curve"], rev = "677b4ae751a274037880ede86e9b6f30f62635af" } -ark-bls12-381 = { version = "0.3.0", default-features = false, features = ["curve"] } -ark-bn254 = { version = "0.3.0", default-features = false, features = ["curve"] } -ark-bw6-761 = { git = "https://github.com/arkworks-rs/curves", default-features = false, rev = "677b4ae751a274037880ede86e9b6f30f62635af" } +ark-ed-on-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af"} +ark-ed-on-bls12-381 = "0.3.0" +ark-ed-on-bn254 = "0.3.0" +ark-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } +ark-bls12-381 = "0.3.0" +ark-bn254 = "0.3.0" +ark-bw6-761 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } # jellyfish jf-utils = { path = "../utilities" } # others -rayon = { version = "1.5.0", default-features = false } zeroize = { version = "1.3", default-features = false } -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", default-features = false, features = ["derive"] } generic-array = { version = "^0.14", default-features = false } displaydoc = { version = "0.2.3", default-features = false } derivative = { version = "2", features = ["use_core"] } @@ -40,6 +39,6 @@ bincode = "1.0" quickcheck = "1.0.0" criterion = "0.3.1" - [features] +default = [] std = [] diff --git a/rescue/src/lib.rs b/rescue/src/lib.rs index f14c2c456..113e325da 100644 --- a/rescue/src/lib.rs +++ b/rescue/src/lib.rs @@ -5,6 +5,7 @@ // along with the Jellyfish library. If not, see . #![deny(missing_docs)] +#![cfg_attr(not(feature = "std"), no_std)] //! This module implements Rescue hash function over the following fields //! - bls12_377 base field //! - ed_on_bls12_377 base field diff --git a/scripts/check_no_std.sh b/scripts/check_no_std.sh new file mode 100755 index 000000000..b51398774 --- /dev/null +++ b/scripts/check_no_std.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +set -x + +cargo-nono check --no-default-features --package jf-utils-derive +cargo-nono check --no-default-features --package jf-utils +cargo-nono check --no-default-features --package jf-rescue +cargo-nono check --no-default-features --package jf-primitives +cargo-nono check --no-default-features --package jf-plonk diff --git a/utilities/Cargo.toml b/utilities/Cargo.toml index ddf33fbc7..9765ac1a3 100644 --- a/utilities/Cargo.toml +++ b/utilities/Cargo.toml @@ -11,27 +11,29 @@ jf-utils-derive = { path = "../utilities_derive" } tagged-base64 = { git = "https://github.com/EspressoSystems/tagged-base64", tag = "0.2.0" } ark-std = { version = "0.3.0", default-features = false } -ark-ff = { version = "0.3.0", default-features = false, features = ["asm", "parallel"] } -ark-ec = { version = "0.3.0", default-features = false, features = ["parallel"] } +ark-ff = { version = "0.3.0", default-features = false, features = [ "asm" ] } +ark-ec = { version = "0.3.0", default-features = false } ark-serialize = { version = "0.3.0", default-features = false } -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", default-features = false, features = ["derive"] } anyhow = { version = "^1.0", default-features = false } -snafu = { version = "0.7", features = ["backtraces"] } +displaydoc = { version = "0.2.3", default-features = false } sha2 = { version = "0.10.1", default-features = false } digest = { version = "0.10.1", default-features = false } [dev-dependencies] -ark-ed-on-bn254 = { version = "0.3.0", default-features = false } -ark-ed-on-bls12-377 = { git = "https://github.com/arkworks-rs/curves", default-features = false, rev = "677b4ae751a274037880ede86e9b6f30f62635af" } -ark-ed-on-bls12-381 = { version = "0.3.0", default-features = false } -ark-ed-on-bls12-381-bandersnatch = { git = "https://github.com/arkworks-rs/curves", default-features = false, rev = "677b4ae751a274037880ede86e9b6f30f62635af" } -ark-bn254 = { version = "0.3.0", default-features = false, features = ["curve"] } +ark-ed-on-bn254 = "0.3.0" +ark-ed-on-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } +ark-ed-on-bls12-381 = "0.3.0" +ark-ed-on-bls12-381-bandersnatch = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } +ark-bn254 = "0.3.0" ark-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" } -ark-bls12-381 = { version = "0.3.0", default-features = false, features = ["curve"] } -ark-serialize = { version = "0.3.0", default-features = false, features = ["derive"] } +ark-bls12-381 = "0.3.0" +ark-serialize = { version = "0.3.0", features = ["derive"] } serde_json = "1.0" [features] -std = [] +default = [ "parallel" ] +std = [ "ark-ff/std", "ark-std/std", "ark-ec/std", "ark-serialize/std" ] +parallel = [ "ark-ff/parallel", "ark-std/parallel", "ark-ec/parallel" ] diff --git a/utilities/src/serialize.rs b/utilities/src/serialize.rs index d10638fe6..7e64a7b13 100644 --- a/utilities/src/serialize.rs +++ b/utilities/src/serialize.rs @@ -8,8 +8,8 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{marker::PhantomData, string::String, vec::Vec}; +use displaydoc::Display; use serde::{Deserialize, Serialize}; -use snafu::Snafu; use tagged_base64::{TaggedBase64, Tb64Error}; /// A helper for converting CanonicalSerde bytes to standard Serde bytes. @@ -116,12 +116,11 @@ impl From<&T> for TaggedB } } -#[derive(Debug, Snafu)] +#[derive(Debug, Display)] pub enum TaggedBlobError { - Base64Error { - #[snafu(source(false))] - source: Tb64Error, - }, + /// TaggedBase64 parsing failure + Base64Error { source: Tb64Error }, + /// CanonicalSerialize failure DeserializationError { source: ark_serialize::SerializationError, },