Skip to content

Commit

Permalink
no_std compliance and refactor SNARK trait (#87)
Browse files Browse the repository at this point in the history
* update Snark -> UniversalSNARK trait (#80)

* update Snark -> UniversalSNARK trait

* enable CI on PR targetting cap-rollup branch

* address Zhenfei's comment

* Restoring no_std compliance (#85)

* restore no_std on jf-*

* remove HashMap and HashSet for no_std

* fix bench.rs, add Display to TaggedBlobError

* more no_std fix

* put rayon to feature=parallel

* use hashbrown for HashMap, update es-commons

* simplify rayon-accelerated code

* update CHANGELOG
  • Loading branch information
alxiong authored Jul 29, 2022
1 parent 706eb8d commit d1637ad
Show file tree
Hide file tree
Showing 28 changed files with 509 additions and 415 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ on:
pull_request:
branches:
- main
- cap-rollup
schedule:
- cron: '0 0 * * 1'
- cron: "0 0 * * 1"
workflow_dispatch:

jobs:
Expand Down Expand Up @@ -64,9 +65,12 @@ jobs:
- name: Check Ignored Tests
run: cargo test --no-run -- --ignored

- name: Check no_std compilation
run: cargo test --no-run --no-default-features

- name: Test
run: bash ./scripts/run_tests.sh
run: bash ./scripts/run_tests.sh

- name: Example
run: cargo run --release --example proof_of_exp

Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## Pending

- Splitting polynomials are masked to ensure zero-knowledge of Plonk (#76)
- Refactored `UniversalSNARK` trait (#80, #87)
- Restore `no_std` compliance (#85, #87)

## v0.1.2

Expand Down
33 changes: 17 additions & 16 deletions plonk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ jf-utils = { path = "../utilities" }
jf-rescue = { path = "../rescue" }

ark-std = { version = "0.3.0", default-features = false }
ark-serialize = { version = "0.3.0", default-features = false }
ark-ff = { version = "0.3.0", default-features = false, features = ["asm", "parallel"] }
ark-ec = { version = "0.3.0", default-features = false, features = ["parallel"] }
ark-poly = { version = "0.3.0", default-features = false, features = ["parallel"] }
ark-bn254 = { version = "0.3.0", default-features = false, features = ["curve"] }
ark-bls12-377 = { git = "https://github.com/arkworks-rs/curves", features = ["curve"], rev = "677b4ae751a274037880ede86e9b6f30f62635af" }
ark-bls12-381 = { version = "0.3.0", default-features = false, features = ["curve"] }
ark-serialize = "0.3.0"
ark-ff = { version = "0.3.0", features = [ "asm" ] }
ark-ec = "0.3.0"
ark-poly = "0.3.0"
ark-bn254 = "0.3.0"
ark-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" }
ark-bls12-381 = "0.3.0"
ark-bw6-761 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" }

merlin = { version = "3.0.0", default-features = false }
rayon = { version = "1.5.0", default-features = false }
rayon = { version = "1.5.0", optional = true }
itertools = { version = "0.10.1", default-features = false }
downcast-rs = { version = "1.2.0", default-features = false }
serde = { version = "1.0", default-features = false, features = ["derive"] }
Expand All @@ -30,8 +30,8 @@ derivative = { version = "2", features = ["use_core"] }
num-bigint = { version = "0.4", default-features = false}
rand_chacha = { version = "0.3.1" }
sha3 = "^0.10"
espresso-systems-common = { git = "https://github.com/espressosystems/espresso-systems-common", tag = "0.1.1" }

espresso-systems-common = { git = "https://github.com/espressosystems/espresso-systems-common", branch = "main" }
hashbrown = "0.12.3"

[dependencies.ark-poly-commit]
git = "https://github.com/arkworks-rs/poly-commit/"
Expand All @@ -40,10 +40,10 @@ default-features = false

[dev-dependencies]
bincode = "1.0"
ark-ed-on-bls12-381 = { version = "0.3.0", default-features = false }
ark-ed-on-bls12-381 = "0.3.0"
ark-ed-on-bls12-377 = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" }
ark-ed-on-bls12-381-bandersnatch = { git = "https://github.com/arkworks-rs/curves", default-features = false, rev = "677b4ae751a274037880ede86e9b6f30f62635af" }
ark-ed-on-bn254 = { version = "0.3.0", default-features = false }
ark-ed-on-bls12-381-bandersnatch = { git = "https://github.com/arkworks-rs/curves", rev = "677b4ae751a274037880ede86e9b6f30f62635af" }
ark-ed-on-bn254 = "0.3.0"
hex = "^0.4.3"

# Benchmarks
Expand All @@ -53,6 +53,7 @@ path = "benches/bench.rs"
harness = false

[features]
std = []
# exposing apis for testing purpose
test_apis = []
default = [ "parallel" ]
std = [ "ark-std/std", "ark-serialize/std", "ark-ff/std", "ark-ec/std", "ark-poly/std"]
test_apis = [] # exposing apis for testing purpose
parallel = [ "ark-ff/parallel", "ark-ec/parallel", "ark-poly/parallel", "rayon" ]
9 changes: 5 additions & 4 deletions plonk/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ use ark_ff::PrimeField;
use jf_plonk::{
circuit::{Circuit, PlonkCircuit},
errors::PlonkError,
proof_system::{PlonkKzgSnark, Snark},
proof_system::{PlonkKzgSnark, UniversalSNARK},
transcript::StandardTranscript,
PlonkType,
};
use std::time::Instant;

const NUM_REPETITIONS: usize = 10;
const NUM_GATES_LARGE: usize = 32768;
Expand Down Expand Up @@ -54,7 +55,7 @@ macro_rules! plonk_prove_bench {

let (pk, _) = PlonkKzgSnark::<$bench_curve>::preprocess(&srs, &cs).unwrap();

let start = ark_std::time::Instant::now();
let start = Instant::now();

for _ in 0..NUM_REPETITIONS {
let _ = PlonkKzgSnark::<$bench_curve>::prove::<_, _, StandardTranscript>(
Expand Down Expand Up @@ -97,7 +98,7 @@ macro_rules! plonk_verify_bench {
PlonkKzgSnark::<$bench_curve>::prove::<_, _, StandardTranscript>(rng, &cs, &pk, None)
.unwrap();

let start = ark_std::time::Instant::now();
let start = Instant::now();

for _ in 0..NUM_REPETITIONS {
let _ =
Expand Down Expand Up @@ -144,7 +145,7 @@ macro_rules! plonk_batch_verify_bench {
let public_inputs_ref = vec![&pub_input[..]; $num_proofs];
let proofs_ref = vec![&proof; $num_proofs];

let start = ark_std::time::Instant::now();
let start = Instant::now();

for _ in 0..NUM_REPETITIONS {
let _ = PlonkKzgSnark::<$bench_curve>::batch_verify::<StandardTranscript>(
Expand Down
2 changes: 1 addition & 1 deletion plonk/examples/proof_of_exp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use ark_std::{rand::SeedableRng, UniformRand};
use jf_plonk::{
circuit::{customized::ecc::Point, Arithmetization, Circuit, PlonkCircuit},
errors::PlonkError,
proof_system::{PlonkKzgSnark, Snark},
proof_system::{PlonkKzgSnark, UniversalSNARK},
transcript::StandardTranscript,
};
use jf_utils::fr_to_fq;
Expand Down
41 changes: 17 additions & 24 deletions plonk/src/circuit/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,16 @@ use crate::{
circuit::{gates::*, SortedLookupVecAndPolys},
constants::{compute_coset_representatives, GATE_WIDTH, N_MUL_SELECTORS},
errors::{CircuitError::*, PlonkError},
par_utils::parallelizable_slice_iter,
MergeableCircuitType, PlonkType,
};
use ark_ff::{FftField, PrimeField};
use ark_poly::{
domain::Radix2EvaluationDomain, univariate::DensePolynomial, EvaluationDomain, UVPolynomial,
};
use ark_std::{
boxed::Box,
cmp::max,
collections::{HashMap, HashSet},
format,
string::ToString,
vec,
vec::Vec,
};
use ark_std::{boxed::Box, cmp::max, format, string::ToString, vec, vec::Vec};
use hashbrown::{HashMap, HashSet};
#[cfg(feature = "parallel")]
use rayon::prelude::*;

/// The wire type identifier for range gates.
Expand Down Expand Up @@ -1100,12 +1095,9 @@ where
.into());
}
// order: (lc, mul, hash, o, c, ecc) as specified in spec
let selector_polys: Vec<_> = self
.all_selectors()
.par_iter()
let selector_polys = parallelizable_slice_iter(&self.all_selectors())
.map(|selector| DensePolynomial::from_coefficients_vec(domain.ifft(selector)))
.collect();

Ok(selector_polys)
}

Expand All @@ -1116,14 +1108,16 @@ where
let domain = &self.eval_domain;
let n = domain.size();
let extended_perm = self.compute_extended_permutation()?;
let extended_perm_polys: Vec<DensePolynomial<F>> = (0..self.num_wire_types)
.into_par_iter()
.map(|i| {
DensePolynomial::from_coefficients_vec(
domain.ifft(&extended_perm[i * n..(i + 1) * n]),
)
})
.collect();

let extended_perm_polys: Vec<DensePolynomial<F>> =
parallelizable_slice_iter(&(0..self.num_wire_types).collect::<Vec<_>>()) // current par_utils only support slice iterator, not range iterator.
.map(|i| {
DensePolynomial::from_coefficients_vec(
domain.ifft(&extended_perm[i * n..(i + 1) * n]),
)
})
.collect();

Ok(extended_perm_polys)
}

Expand Down Expand Up @@ -1167,16 +1161,15 @@ where
.into());
}
let witness = &self.witness;
let wire_polys: Vec<_> = self
.wire_variables
.par_iter()
let wire_polys: Vec<DensePolynomial<F>> = parallelizable_slice_iter(&self.wire_variables)
.take(self.num_wire_types())
.map(|wire_vars| {
let mut wire_vec: Vec<F> = wire_vars.iter().map(|&var| witness[var]).collect();
domain.ifft_in_place(&mut wire_vec);
DensePolynomial::from_coefficients_vec(wire_vec)
})
.collect();

assert_eq!(wire_polys.len(), self.num_wire_types());
Ok(wire_polys)
}
Expand Down
28 changes: 19 additions & 9 deletions plonk/src/circuit/customized/ecc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,15 +572,25 @@ fn compute_base_points<E: AffineCurve + Group>(
// base3 = (3*B, 3*4*B, ..., 3*4^(l-1)*B)
let mut bases3 = vec![b];

rayon::join(
|| {
rayon::join(
|| fill_bases(&mut bases1, len).ok(),
|| fill_bases(&mut bases2, len).ok(),
)
},
|| fill_bases(&mut bases3, len).ok(),
);
#[cfg(feature = "parallel")]
{
rayon::join(
|| {
rayon::join(
|| fill_bases(&mut bases1, len).ok(),
|| fill_bases(&mut bases2, len).ok(),
)
},
|| fill_bases(&mut bases3, len).ok(),
);
}

#[cfg(not(feature = "parallel"))]
{
fill_bases(&mut bases1, len).ok();
fill_bases(&mut bases2, len).ok();
fill_bases(&mut bases3, len).ok();
}

// converting GroupAffine -> Points here.
// Cannot do it earlier: in `fill_bases` we need to do `double`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ mod test {
circuit::Circuit,
proof_system::{
batch_arg::{new_mergeable_circuit_for_test, BatchArgument},
PlonkKzgSnark,
PlonkKzgSnark, UniversalSNARK,
},
transcript::{PlonkTranscript, RescueTranscript},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ mod test {
proof_system::{
batch_arg::{new_mergeable_circuit_for_test, BatchArgument},
structs::BatchProof,
PlonkKzgSnark,
PlonkKzgSnark, UniversalSNARK,
},
transcript::{PlonkTranscript, RescueTranscript},
};
Expand Down
3 changes: 1 addition & 2 deletions plonk/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ pub enum PlonkError {
PublicInputsDoNotMatch,
}

#[cfg(feature = "std")]
impl std::error::Error for PlonkError {}
impl ark_std::error::Error for PlonkError {}

impl From<ark_poly_commit::Error> for PlonkError {
fn from(e: ark_poly_commit::Error) -> Self {
Expand Down
1 change: 1 addition & 0 deletions plonk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ extern crate derivative;
pub mod circuit;
pub mod constants;
pub mod errors;
pub(crate) mod par_utils;
pub mod proof_system;
pub mod transcript;

Expand Down
30 changes: 30 additions & 0 deletions plonk/src/par_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2022 Espresso Systems (espressosys.com)
// This file is part of the Jellyfish library.
// You should have received a copy of the MIT License
// along with the Jellyfish library. If not, see <https://mit-license.org/>.

//! Utilities for parallel code.
/// this function helps with slice iterator creation that optionally use
/// `par_iter()` when feature flag `parallel` is on.
///
/// # Usage
/// let v = [1, 2, 3, 4, 5];
/// let sum = parallelizable_slice_iter(&v).sum();
///
/// // the above code is a shorthand for (thus equivalent to)
/// #[cfg(feature = "parallel")]
/// let sum = v.par_iter().sum();
/// #[cfg(not(feature = "parallel"))]
/// let sum = v.iter().sum();
#[cfg(feature = "parallel")]
pub(crate) fn parallelizable_slice_iter<T: Sync>(data: &[T]) -> rayon::slice::Iter<T> {
use rayon::iter::IntoParallelIterator;
data.into_par_iter()
}

#[cfg(not(feature = "parallel"))]
pub(crate) fn parallelizable_slice_iter<T>(data: &[T]) -> ark_std::slice::Iter<T> {
use ark_std::iter::IntoIterator;
data.iter()
}
Loading

0 comments on commit d1637ad

Please sign in to comment.