Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use FFT to encode polynomials in eval form #385

Merged
merged 8 commits into from
Oct 19, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ and follow [semantic versioning](https://semver.org/) for our releases.
- [#302](https://github.com/EspressoSystems/jellyfish/pull/302) Followup APIs for non-native ECC circuit support.
- [#323](https://github.com/EspressoSystems/jellyfish/pull/323) Improve performance of range gate in ultra plonk.
- [#371](https://github.com/EspressoSystems/jellyfish/pull/371) VID disperse also return payload commitment
- [#385](https://github.com/EspressoSystems/jellyfish/pull/385) Use FFT to encode polynomials in eval form.

### Removed

Expand Down
223 changes: 175 additions & 48 deletions primitives/src/vid/advz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use ark_ff::{
fields::field_hashers::{DefaultFieldHasher, HashToField},
FftField, Field, PrimeField,
};
use ark_poly::{DenseUVPolynomial, EvaluationDomain};
use ark_poly::{DenseUVPolynomial, EvaluationDomain, Radix2EvaluationDomain};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Write};
use ark_std::{
borrow::Borrow,
Expand Down Expand Up @@ -65,14 +65,20 @@ pub type Advz<E, H> = GenericAdvz<
pub struct GenericAdvz<P, T, H, V>
where
P: PolynomialCommitmentScheme,
P::Evaluation: FftField,
{
payload_chunk_size: usize,
num_storage_nodes: usize,
ck: <P::SRS as StructuredReferenceString>::ProverParam,
vk: <P::SRS as StructuredReferenceString>::VerifierParam,
_phantom_t: PhantomData<T>, // needed for trait bounds
_phantom_h: PhantomData<H>, // needed for trait bounds
_phantom_v: PhantomData<V>, // needed for trait bounds
multi_open_domain: Radix2EvaluationDomain<P::Evaluation>,

// TODO might be able to eliminate this field and instead use
// `EvaluationDomain::reindex_by_subdomain()` on `multi_open_domain`
// but that method consumes `other` and its doc is unclear.
eval_domain: Radix2EvaluationDomain<P::Evaluation>,

_pd: (PhantomData<T>, PhantomData<H>, PhantomData<V>),
}

impl<P, T, H, V> GenericAdvz<P, T, H, V>
Expand All @@ -96,15 +102,36 @@ where
payload_chunk_size, num_storage_nodes
)));
}
let (ck, vk) = P::trim_fft_size(srs, payload_chunk_size).map_err(vid)?;
let (ck, vk) = P::trim_fft_size(srs, payload_chunk_size - 1).map_err(vid)?;
let multi_open_domain =
P::multi_open_rou_eval_domain(payload_chunk_size - 1, num_storage_nodes)
.map_err(vid)?;
let eval_domain = Radix2EvaluationDomain::new(payload_chunk_size).ok_or_else(|| {
VidError::Internal(anyhow::anyhow!(
"fail to construct doman of size {}",
payload_chunk_size
))
})?;

// TODO TEMPORARY: enforce power-of-2 chunk size
// Remove this restriction after we get KZG in eval form
// https://github.com/EspressoSystems/jellyfish/issues/339
if payload_chunk_size != eval_domain.size() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If payload_chunk_size is smaller. Could we simply padding zero here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The FFT pads with zero up to the next power of 2. This wouldn't be a problem except it screws up recovery. For a degree d polynomial I need next_power_of_2(d) points to interpolate, so we're essentially forced to use a power of 2 degree anyway. This PR is intended to be a temporary stop-gap, so I decided that it's best for now to simply enforce a power-of-2 degree at construction. This sanity check is merely a redundant backup check on the constructor.

return Err(VidError::Argument(format!(
"payload_chunk_size {} currently unsupported, round to {} instead",
payload_chunk_size,
eval_domain.size()
)));
}

Ok(Self {
payload_chunk_size,
num_storage_nodes,
ck,
vk,
_phantom_t: PhantomData,
_phantom_h: PhantomData,
_phantom_v: PhantomData,
multi_open_domain,
eval_domain,
_pd: Default::default(),
})
}
}
Expand Down Expand Up @@ -175,8 +202,14 @@ where
{
let mut hasher = H::new();
let elems_iter = bytes_to_field::<_, P::Evaluation>(payload);
for coeffs in elems_iter.chunks(self.payload_chunk_size).into_iter() {
let poly = DenseUVPolynomial::from_coefficients_vec(coeffs.collect());
for coeffs_iter in elems_iter.chunks(self.payload_chunk_size).into_iter() {
// TODO TEMPORARY: use FFT to encode polynomials in eval form
// Remove these FFTs after we get KZG in eval form
// https://github.com/EspressoSystems/jellyfish/issues/339
let mut coeffs: Vec<_> = coeffs_iter.collect();
self.eval_domain.fft_in_place(&mut coeffs);

let poly = DenseUVPolynomial::from_coefficients_vec(coeffs);
let commitment = P::commit(&self.ck, &poly).map_err(vid)?;
commitment
.serialize_uncompressed(&mut hasher)
Expand Down Expand Up @@ -242,17 +275,11 @@ where
let aggregate_eval =
polynomial_eval(share.evals.iter().map(FieldMultiplier), pseudorandom_scalar);

// prepare eval point for aggregate proof
// TODO(Gus) perf: don't re-compute domain elements: https://github.com/EspressoSystems/jellyfish/issues/313
let domain = P::multi_open_rou_eval_domain(self.payload_chunk_size, self.num_storage_nodes)
.map_err(vid)?;
let point = domain.element(share.index);

// verify aggregate proof
Ok(P::verify(
&self.vk,
&aggregate_poly_commit,
&point,
&self.multi_open_domain.element(share.index),
&aggregate_eval,
&share.aggregate_proof,
)
Expand All @@ -279,24 +306,35 @@ where
V::MembershipProof: Sync + Debug, /* TODO https://github.com/EspressoSystems/jellyfish/issues/253 */
V::Index: From<u64>,
{
/// Same as [`VidScheme::disperse`] except `payload` is a slice of
/// Same as [`VidScheme::disperse`] except `payload` iterates over
/// field elements.
pub fn disperse_from_elems<I>(&self, payload: I) -> VidResult<VidDisperse<Self>>
where
I: IntoIterator,
I::Item: Borrow<P::Evaluation>,
{
let domain = P::multi_open_rou_eval_domain(self.payload_chunk_size, self.num_storage_nodes)
.map_err(vid)?;

// partition payload into polynomial coefficients
// and count `elems_len` for later
let elems_iter = payload.into_iter().map(|elem| *elem.borrow());
let mut elems_len = 0;
let mut polys = Vec::new();
for coeffs_iter in elems_iter.chunks(self.payload_chunk_size).into_iter() {
chancharles92 marked this conversation as resolved.
Show resolved Hide resolved
let coeffs: Vec<_> = coeffs_iter.collect();
elems_len += coeffs.len();
// TODO TEMPORARY: use FFT to encode polynomials in eval form
// Remove these FFTs after we get KZG in eval form
// https://github.com/EspressoSystems/jellyfish/issues/339
let mut coeffs: Vec<_> = coeffs_iter.collect();
let pre_fft_len = coeffs.len();
self.eval_domain.fft_in_place(&mut coeffs);

// sanity check: the fft did not resize coeffs.
// If pre_fft_len != self.payload_chunk_size then we must be in the final chunk.
// In that case coeffs.len() could be anything, so there's nothing to sanity
// check.
if pre_fft_len == self.payload_chunk_size {
chancharles92 marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(coeffs.len(), pre_fft_len);
}

elems_len += pre_fft_len;
polys.push(DenseUVPolynomial::from_coefficients_vec(coeffs));
}

Expand All @@ -307,7 +345,8 @@ where

for poly in polys.iter() {
let poly_evals =
P::multi_open_rou_evals(poly, self.num_storage_nodes, &domain).map_err(vid)?;
P::multi_open_rou_evals(poly, self.num_storage_nodes, &self.multi_open_domain)
.map_err(vid)?;

for (storage_node_evals, poly_eval) in
all_storage_node_evals.iter_mut().zip(poly_evals)
Expand Down Expand Up @@ -372,9 +411,13 @@ where
let aggregate_poly =
polynomial_eval(polys.iter().map(PolynomialMultiplier), pseudorandom_scalar);

let aggregate_proofs =
P::multi_open_rou_proofs(&self.ck, &aggregate_poly, self.num_storage_nodes, &domain)
.map_err(vid)?;
let aggregate_proofs = P::multi_open_rou_proofs(
&self.ck,
&aggregate_poly,
self.num_storage_nodes,
&self.multi_open_domain,
)
.map_err(vid)?;

let shares = all_storage_node_evals
.into_iter()
Expand Down Expand Up @@ -438,15 +481,19 @@ where

let result_len = num_polys * self.payload_chunk_size;
let mut result = Vec::with_capacity(result_len);
let domain = P::multi_open_rou_eval_domain(self.payload_chunk_size, self.num_storage_nodes)
.map_err(vid)?;
for i in 0..num_polys {
let mut coeffs = reed_solomon_erasure_decode_rou(
shares.iter().map(|s| (s.index, s.evals[i])),
self.payload_chunk_size,
&domain,
&self.multi_open_domain,
)
.map_err(vid)?;

// TODO TEMPORARY: use FFT to encode polynomials in eval form
// Remove these FFTs after we get KZG in eval form
// https://github.com/EspressoSystems/jellyfish/issues/339
self.eval_domain.ifft_in_place(&mut coeffs);

result.append(&mut coeffs);
}
assert_eq!(result.len(), result_len);
Expand Down Expand Up @@ -577,9 +624,16 @@ where
mod tests {
use super::{VidError::Argument, *};

use crate::{merkle_tree::hasher::HasherNode, pcs::checked_fft_size};
use crate::{
merkle_tree::hasher::HasherNode,
pcs::{checked_fft_size, prelude::UnivariateUniversalParams},
};
use ark_bls12_381::Bls12_381;
use ark_std::{rand::RngCore, vec};
use ark_std::{
rand::{CryptoRng, RngCore},
vec,
};
use digest::{generic_array::ArrayLength, OutputSizeUser};
use sha2::Sha256;

#[test]
Expand Down Expand Up @@ -755,42 +809,115 @@ mod tests {
// corrupted index, out of bounds
{
let mut shares_bad_indices = shares;
let domain = UnivariateKzgPCS::<Bls12_381>::multi_open_rou_eval_domain(
advz.payload_chunk_size,
advz.num_storage_nodes,
)
.unwrap();
for i in 0..shares_bad_indices.len() {
shares_bad_indices[i].index += domain.size();
shares_bad_indices[i].index += advz.multi_open_domain.size();
advz.recover_payload(&shares_bad_indices, &common)
.expect_err("recover_payload should fail when indices are out of bounds");
}
}
}

fn prove_namespace_generic<E, H>()
where
E: Pairing,
H: Digest + DynDigest + Default + Clone + Write,
<<H as OutputSizeUser>::OutputSize as ArrayLength<u8>>::ArrayType: Copy,
{
// play with these items
let (payload_chunk_size, num_storage_nodes) = (4, 6);
let num_polys = 4;

// more items as a function of the above
let payload_elems_len = num_polys * payload_chunk_size;
let payload_bytes_len = payload_elems_len * modulus_byte_len::<E>();
let mut rng = jf_utils::test_rng();
let payload_bytes = init_random_bytes(payload_bytes_len, &mut rng);
let srs = init_srs(payload_elems_len, &mut rng);

let advz = Advz::<E, H>::new(payload_chunk_size, num_storage_nodes, srs).unwrap();
let d = advz.disperse(&payload_bytes).unwrap();

// TEST: verify "namespaces" (each namespace is a polynomial)
// This test is currently trivial: we simply repeat the commit computation.
// In the future there will be a proper API that can be tested meaningfully.

// encode payload as field elements, partition into polynomials, compute
// commitments, compare against VID common data
let elems_iter = bytes_to_field::<_, E::ScalarField>(payload_bytes);
for (coeffs_iter, poly_commit) in elems_iter
.chunks(payload_chunk_size)
.into_iter()
.zip(d.common.poly_commits.iter())
{
let mut coeffs: Vec<_> = coeffs_iter.collect();
advz.eval_domain.fft_in_place(&mut coeffs);

let poly = <UnivariateKzgPCS::<E> as PolynomialCommitmentScheme>::Polynomial::from_coefficients_vec(coeffs);
let my_poly_commit = UnivariateKzgPCS::<E>::commit(&advz.ck, &poly).unwrap();
assert_eq!(my_poly_commit, *poly_commit);
}

// compute payload commitment and verify
let commit = {
let mut hasher = H::new();
for poly_commit in d.common.poly_commits.iter() {
// TODO compiler bug? `as` should not be needed here!
(poly_commit as &<UnivariateKzgPCS<E> as PolynomialCommitmentScheme>::Commitment)
.serialize_uncompressed(&mut hasher)
.unwrap();
}
hasher.finalize()
};
assert_eq!(commit, d.commit);
}

#[test]
fn prove_namespace() {
prove_namespace_generic::<Bls12_381, Sha256>();
}

/// Routine initialization tasks.
///
/// Returns the following tuple:
/// 1. An initialized [`Advz`] instance.
/// 2. A `Vec<u8>` filled with random bytes.
fn avdz_init() -> (Advz<Bls12_381, Sha256>, Vec<u8>) {
let (payload_chunk_size, num_storage_nodes) = (3, 5);
let (payload_chunk_size, num_storage_nodes) = (4, 6);
let mut rng = jf_utils::test_rng();
let srs = UnivariateKzgPCS::<Bls12_381>::gen_srs_for_testing(
&mut rng,
checked_fft_size(payload_chunk_size).unwrap(),
)
.unwrap();
let srs = init_srs(payload_chunk_size, &mut rng);
let advz = Advz::new(payload_chunk_size, num_storage_nodes, srs).unwrap();

let mut bytes_random = vec![0u8; 4000];
rng.fill_bytes(&mut bytes_random);

let bytes_random = init_random_bytes(4000, &mut rng);
(advz, bytes_random)
}

/// Convenience wrapper to assert [`VidError::Argument`] return value.
fn assert_arg_err<T>(res: VidResult<T>, msg: &str) {
assert!(matches!(res, Err(Argument(_))), "{}", msg);
}

fn init_random_bytes<R>(len: usize, rng: &mut R) -> Vec<u8>
where
R: RngCore + CryptoRng,
{
let mut bytes_random = vec![0u8; len];
rng.fill_bytes(&mut bytes_random);
bytes_random
}

fn init_srs<E, R>(num_coeffs: usize, rng: &mut R) -> UnivariateUniversalParams<E>
where
E: Pairing,
R: RngCore + CryptoRng,
{
UnivariateKzgPCS::gen_srs_for_testing(rng, checked_fft_size(num_coeffs - 1).unwrap())
.unwrap()
}

fn modulus_byte_len<E>() -> usize
where
E: Pairing,
{
usize::try_from((<<UnivariateKzgPCS<Bls12_381> as PolynomialCommitmentScheme>::Evaluation as Field>::BasePrimeField
::MODULUS_BIT_SIZE - 7)/8 + 1).unwrap()
}
}
4 changes: 2 additions & 2 deletions primitives/tests/advz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ mod vid;
#[test]
fn round_trip() {
// play with these items
let vid_sizes = [(2, 3), (5, 9)];
let vid_sizes = [(2, 3), (8, 11)];
let byte_lens = [0, 1, 2, 16, 32, 47, 48, 49, 64, 100, 400];

// more items as a function of the above
let supported_degree = vid_sizes.iter().max_by_key(|v| v.0).unwrap().0;
let supported_degree = vid_sizes.iter().max_by_key(|v| v.0).unwrap().0 - 1;
let mut rng = jf_utils::test_rng();
let srs = UnivariateKzgPCS::<Bls12_381>::gen_srs_for_testing(
&mut rng,
Expand Down