diff --git a/CHANGELOG.md b/CHANGELOG.md index 970fa8033..e5bad3316 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ and follow [semantic versioning](https://semver.org/) for our releases. - [#341](https://github.com/EspressoSystems/jellyfish/pull/341) Port VDF from another repo - [#343](https://github.com/EspressoSystems/jellyfish/pull/343) Rescue parameter for `ark_bn254::Fq` - [#362](https://github.com/EspressoSystems/jellyfish/pull/362) Derive Eq, Hash at a bunch of places +- [#381](https://github.com/EspressoSystems/jellyfish/pull/381) VID take iterator instead of slice ### Changed diff --git a/Cargo.toml b/Cargo.toml index b21d6bb22..2d92ff1e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,5 @@ [workspace] -members = [ - "plonk", - "primitives", - "relation", - "utilities", -] +members = ["plonk", "primitives", "relation", "utilities"] [workspace.package] version = "0.4.0-pre.0" @@ -15,3 +10,6 @@ rust-version = "1.64.0" homepage = "https://github.com/EspressoSystems/jellyfish" documentation = "https://jellyfish.docs.espressosys.com" repository = "https://github.com/EspressoSystems/jellyfish" + +[workspace.dependencies] +itertools = { version = "0.10.1", default-features = false } diff --git a/flake.lock b/flake.lock index 262ffc9d5..93aea596e 100644 --- a/flake.lock +++ b/flake.lock @@ -3,11 +3,11 @@ "flake-compat": { "flake": false, "locked": { - "lastModified": 1673956053, - "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", "owner": "edolstra", "repo": "flake-compat", - "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", "type": "github" }, "original": { @@ -37,11 +37,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1689068808, - "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=", + "lastModified": 1694529238, + "narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=", "owner": "numtide", "repo": "flake-utils", - "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4", + "rev": "ff7b65b44d01cf9ba6a71320833626af21126384", "type": "github" }, "original": { @@ -109,11 +109,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1690031011, - "narHash": "sha256-kzK0P4Smt7CL53YCdZCBbt9uBFFhE0iNvCki20etAf4=", + "lastModified": 1696375444, + "narHash": "sha256-Sv0ICt/pXfpnFhTGYTsX6lUr1SljnuXWejYTI2ZqHa4=", "owner": "nixos", "repo": "nixpkgs", - "rev": "12303c652b881435065a98729eb7278313041e49", + "rev": "81e8f48ebdecf07aab321182011b067aafc78896", "type": "github" }, "original": { @@ -166,11 +166,11 @@ "nixpkgs-stable": "nixpkgs-stable" }, "locked": { - "lastModified": 1689668210, - "narHash": "sha256-XAATwDkaUxH958yXLs1lcEOmU6pSEIkatY3qjqk8X0E=", + "lastModified": 1696516544, + "narHash": "sha256-8rKE8Je6twTNFRTGF63P9mE3lZIq917RAicdc4XJO80=", "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "eb433bff05b285258be76513add6f6c57b441775", + "rev": "66c352d33e0907239e4a69416334f64af2c685cc", "type": "github" }, "original": { @@ -194,11 +194,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1689992383, - "narHash": "sha256-x/MSjx2aA9aJqQA3fUHxiH0l8uG+1vxnkRNkqAZHQ2U=", + "lastModified": 1696558324, + "narHash": "sha256-TnnP4LGwDB8ZGE7h2n4nA9Faee8xPkMdNcyrzJ57cbw=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "97fcdd4793778cf8e4f9007079cb9d2b836d7ea9", + "rev": "fdb37574a04df04aaa8cf7708f94a9309caebe2b", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 8a8960d40..c659ef0d4 100644 --- a/flake.nix +++ b/flake.nix @@ -70,7 +70,7 @@ buildInputs = [ argbash openssl - pkgconfig + pkg-config git stableToolchain diff --git a/plonk/Cargo.toml b/plonk/Cargo.toml index 4672cfed4..a6f6aaba0 100644 --- a/plonk/Cargo.toml +++ b/plonk/Cargo.toml @@ -11,7 +11,7 @@ rust-version = { workspace = true } [dependencies] ark-ec = "0.4.0" -ark-ff = { version = "0.4.0", features = [ "asm" ] } +ark-ff = { version = "0.4.0", features = ["asm"] } ark-poly = "0.4.0" ark-serialize = "0.4.0" ark-std = { version = "0.4.0", default-features = false } @@ -21,7 +21,7 @@ downcast-rs = { version = "1.2.0", default-features = false } dyn-clone = "^1.0" espresso-systems-common = { git = "https://github.com/espressosystems/espresso-systems-common", tag = "0.4.0" } hashbrown = "0.13.2" -itertools = { version = "0.10.1", default-features = false } +itertools = { workspace = true } jf-primitives = { path = "../primitives", default-features = false } jf-relation = { path = "../relation", default-features = false } jf-utils = { path = "../utilities" } @@ -52,14 +52,30 @@ harness = false [features] default = ["parallel"] std = [ - "ark-std/std", "ark-serialize/std", "ark-ff/std", "ark-ec/std", "ark-poly/std", - "downcast-rs/std", "itertools/use_std", "jf-primitives/std", "jf-relation/std", - "jf-utils/std", "num-bigint/std", "rand_chacha/std", "sha3/std" + "ark-std/std", + "ark-serialize/std", + "ark-ff/std", + "ark-ec/std", + "ark-poly/std", + "downcast-rs/std", + "itertools/use_std", + "jf-primitives/std", + "jf-relation/std", + "jf-utils/std", + "num-bigint/std", + "rand_chacha/std", + "sha3/std", ] test_apis = [] # exposing apis for testing purpose -parallel = ["ark-ff/parallel", "ark-ec/parallel", "ark-poly/parallel", - "jf-utils/parallel", "jf-relation/parallel", "jf-primitives/parallel", - "dep:rayon" ] +parallel = [ + "ark-ff/parallel", + "ark-ec/parallel", + "ark-poly/parallel", + "jf-utils/parallel", + "jf-relation/parallel", + "jf-primitives/parallel", + "dep:rayon", +] test-srs = [] [[example]] diff --git a/plonk/src/circuit/plonk_verifier/gadgets.rs b/plonk/src/circuit/plonk_verifier/gadgets.rs index 987874c08..8d00151c4 100644 --- a/plonk/src/circuit/plonk_verifier/gadgets.rs +++ b/plonk/src/circuit/plonk_verifier/gadgets.rs @@ -177,7 +177,7 @@ where } // ensure all the buffer has been consumed if v_and_uv_basis.next().is_some() { - return Err(PlonkError::IteratorOutOfRange)?; + Err(PlonkError::IteratorOutOfRange)?; } Ok(result) } diff --git a/plonk/src/circuit/plonk_verifier/poly.rs b/plonk/src/circuit/plonk_verifier/poly.rs index 7eec2fabb..5835b4db5 100644 --- a/plonk/src/circuit/plonk_verifier/poly.rs +++ b/plonk/src/circuit/plonk_verifier/poly.rs @@ -340,7 +340,7 @@ where let pi = public_inputs[0]; for &pi_i in public_inputs.iter().skip(1) { if pi != pi_i { - return Err(PlonkError::PublicInputsDoNotMatch)?; + Err(PlonkError::PublicInputsDoNotMatch)?; } } @@ -462,7 +462,7 @@ where } // ensure all the buffer has been consumed if alpha_bases_elem_var.next().is_some() { - return Err(PlonkError::IteratorOutOfRange)?; + Err(PlonkError::IteratorOutOfRange)?; } // ===================================================== // second statement @@ -690,7 +690,7 @@ where // ensure all the buffer has been consumed if alpha_bases_elem_var.next().is_some() { - return Err(PlonkError::IteratorOutOfRange)?; + Err(PlonkError::IteratorOutOfRange)?; } // ============================================ // Add splitted quotient commitments diff --git a/primitives/Cargo.toml b/primitives/Cargo.toml index 72fc66cd3..6fd804592 100644 --- a/primitives/Cargo.toml +++ b/primitives/Cargo.toml @@ -36,9 +36,7 @@ digest = { version = "0.10.1", default-features = false, features = ["alloc"] } displaydoc = { version = "0.2.3", default-features = false } espresso-systems-common = { git = "https://github.com/espressosystems/espresso-systems-common", tag = "0.4.0" } hashbrown = "0.13.1" -itertools = { version = "0.10.1", default-features = false, features = [ - "use_alloc", -] } +itertools = { workspace = true, features = ["use_alloc"] } jf-relation = { path = "../relation", default-features = false } jf-utils = { path = "../utilities" } merlin = { version = "3.0.0", default-features = false } diff --git a/primitives/src/merkle_tree/hasher.rs b/primitives/src/merkle_tree/hasher.rs index 93b9882d2..7a8df6ee1 100644 --- a/primitives/src/merkle_tree/hasher.rs +++ b/primitives/src/merkle_tree/hasher.rs @@ -35,6 +35,10 @@ //! Use [`GenericHasherMerkleTree`] if you prefer to specify your own `Arity` //! and node [`Index`] types. +// clippy is freaking out about `HasherNode` and this is the only thing I +// could do to stop it +#![allow(clippy::incorrect_partial_ord_impl_on_ord_type)] + use crate::errors::PrimitivesError; use super::{append_only::MerkleTree, DigestAlgorithm, Element, Index}; diff --git a/primitives/src/vid/advz.rs b/primitives/src/vid/advz.rs index b5b811b05..9000dbc33 100644 --- a/primitives/src/vid/advz.rs +++ b/primitives/src/vid/advz.rs @@ -21,7 +21,7 @@ use anyhow::anyhow; use ark_ec::{pairing::Pairing, AffineRepr}; use ark_ff::{ fields::field_hashers::{DefaultFieldHasher, HashToField}, - FftField, Field, + FftField, Field, PrimeField, }; use ark_poly::{DenseUVPolynomial, EvaluationDomain}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Write}; @@ -37,7 +37,8 @@ use ark_std::{ }; use derivative::Derivative; use digest::{crypto_common::Output, Digest, DynDigest}; -use jf_utils::{bytes_from_field_elements, bytes_to_field_elements, canonical}; +use itertools::Itertools; +use jf_utils::{bytes_to_field, canonical, field_to_bytes}; use serde::{Deserialize, Serialize}; /// The [ADVZ VID scheme](https://eprint.iacr.org/2021/1500), a concrete impl for [`VidScheme`]. @@ -140,6 +141,7 @@ where #[serde(with = "canonical")] poly_commits: Vec, all_evals_digest: V::NodeValue, + elems_len: usize, } // We take great pains to maintain abstraction by relying only on traits and not @@ -147,10 +149,13 @@ where // 1,2: `Polynomial` is univariate: domain (`Point`) same field as range // (`Evaluation'). 3,4: `Commitment` is (convertible to/from) an elliptic curve // group in affine form. 5: `H` is a hasher +// +// `PrimeField` needed only because `bytes_to_field` needs it. +// Otherwise we could relax to `FftField`. impl VidScheme for GenericAdvz where P: UnivariatePCS::Evaluation>, - P::Evaluation: FftField, + P::Evaluation: PrimeField, P::Polynomial: DenseUVPolynomial, // 2 P::Commitment: From + AsRef, // 3 T: AffineRepr, // 4 @@ -163,26 +168,29 @@ where type Share = Share; type Common = Common; - fn commit_only(&self, payload: &[u8]) -> VidResult { + fn commit_only(&self, payload: I) -> VidResult + where + I: IntoIterator, + I::Item: Borrow, + { let mut hasher = H::new(); - - // TODO perf: DenseUVPolynomial::from_coefficients_slice copies the slice. - // We could avoid unnecessary mem copies if bytes_to_field_elements returned - // Vec> - let elems = bytes_to_field_elements(payload); - for coeffs in elems.chunks(self.payload_chunk_size) { - let poly = DenseUVPolynomial::from_coefficients_slice(coeffs); + let elems_iter = bytes_to_field::<_, P::Evaluation>(payload); + for coeffs in elems_iter.chunks(self.payload_chunk_size).into_iter() { + let poly = DenseUVPolynomial::from_coefficients_vec(coeffs.collect()); let commitment = P::commit(&self.ck, &poly).map_err(vid)?; commitment .serialize_uncompressed(&mut hasher) .map_err(vid)?; } - Ok(hasher.finalize()) } - fn disperse(&self, payload: &[u8]) -> VidResult> { - self.disperse_from_elems(&bytes_to_field_elements(payload)) + fn disperse(&self, payload: I) -> VidResult> + where + I: IntoIterator, + I::Item: Borrow, + { + self.disperse_from_elems(bytes_to_field::<_, P::Evaluation>(payload)) } fn verify_share( @@ -254,16 +262,15 @@ where } fn recover_payload(&self, shares: &[Self::Share], common: &Self::Common) -> VidResult> { - Ok(bytes_from_field_elements( - self.recover_elems(shares, common)?, - )) + // TODO can we avoid collect() here? + Ok(field_to_bytes(self.recover_elems(shares, common)?).collect()) } } impl GenericAdvz where P: UnivariatePCS::Evaluation>, - P::Evaluation: FftField, + P::Evaluation: PrimeField, P::Polynomial: DenseUVPolynomial, P::Commitment: From + AsRef, T: AffineRepr, @@ -274,25 +281,29 @@ where { /// Same as [`VidScheme::disperse`] except `payload` is a slice of /// field elements. - pub fn disperse_from_elems(&self, payload: &[P::Evaluation]) -> VidResult> { - let num_polys = if payload.is_empty() { - 0 - } else { - (payload.len() - 1) / self.payload_chunk_size + 1 - }; + pub fn disperse_from_elems(&self, payload: I) -> VidResult> + where + I: IntoIterator, + I::Item: Borrow, + { let domain = P::multi_open_rou_eval_domain(self.payload_chunk_size, self.num_storage_nodes) .map_err(vid)?; // partition payload into polynomial coefficients - let polys: Vec = payload - .chunks(self.payload_chunk_size) - .map(DenseUVPolynomial::from_coefficients_slice) - .collect(); + // and count `elems_len` for later + let elems_iter = payload.into_iter().map(|elem| *elem.borrow()); + let mut elems_len = 0; + let mut polys = Vec::new(); + for coeffs_iter in elems_iter.chunks(self.payload_chunk_size).into_iter() { + let coeffs: Vec<_> = coeffs_iter.collect(); + elems_len += coeffs.len(); + polys.push(DenseUVPolynomial::from_coefficients_vec(coeffs)); + } // evaluate polynomials let all_storage_node_evals = { let mut all_storage_node_evals = - vec![Vec::with_capacity(num_polys); self.num_storage_nodes]; + vec![Vec::with_capacity(polys.len()); self.num_storage_nodes]; for poly in polys.iter() { let poly_evals = @@ -308,7 +319,7 @@ where // sanity checks assert_eq!(all_storage_node_evals.len(), self.num_storage_nodes); for storage_node_evals in all_storage_node_evals.iter() { - assert_eq!(storage_node_evals.len(), num_polys); + assert_eq!(storage_node_evals.len(), polys.len()); } all_storage_node_evals @@ -338,6 +349,7 @@ where .collect::>() .map_err(vid)?, all_evals_digest: all_evals_commit.commitment().digest(), + elems_len, }; let commit = { @@ -394,7 +406,7 @@ where pub fn recover_elems( &self, shares: &[::Share], - _common: &::Common, + common: &::Common, ) -> VidResult> { if shares.len() < self.payload_chunk_size { return Err(VidError::Argument(format!( @@ -438,6 +450,7 @@ where result.append(&mut coeffs); } assert_eq!(result.len(), result_len); + result.truncate(common.elems_len); Ok(result) } diff --git a/primitives/src/vid/mod.rs b/primitives/src/vid/mod.rs index d2fa3cb51..00d1a7a40 100644 --- a/primitives/src/vid/mod.rs +++ b/primitives/src/vid/mod.rs @@ -7,8 +7,9 @@ //! Trait and implementation for a Verifiable Information Retrieval (VID). /// See section 1.3--1.4 for intro to VID semantics. use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use ark_std::{error::Error, fmt::Debug, string::String, vec::Vec}; +use ark_std::{borrow::Borrow, error::Error, fmt::Debug, hash::Hash, string::String, vec::Vec}; use displaydoc::Display; +use serde::{Deserialize, Serialize}; pub mod advz; @@ -52,11 +53,17 @@ pub trait VidScheme { /// Common data sent to all storage nodes. type Common: CanonicalSerialize + CanonicalDeserialize + Clone + Eq + PartialEq + Sync; // TODO https://github.com/EspressoSystems/jellyfish/issues/253 - /// Compute a payload commitment. - fn commit_only(&self, payload: &[u8]) -> VidResult; + /// Compute a payload commitment + fn commit_only(&self, payload: I) -> VidResult + where + I: IntoIterator, + I::Item: Borrow; /// Compute shares to send to the storage nodes - fn disperse(&self, payload: &[u8]) -> VidResult>; + fn disperse(&self, payload: I) -> VidResult> + where + I: IntoIterator, + I::Item: Borrow; /// Verify a share. Used by both storage node and retrieval client. /// Why is return type a nested `Result`? See @@ -78,6 +85,18 @@ pub trait VidScheme { /// /// # Why the `?Sized` bound? /// Rust hates you: +#[derive(Derivative, Deserialize, Serialize)] +#[serde(bound = "V::Share: Serialize + for<'a> Deserialize<'a>, + V::Common: Serialize + for<'a> Deserialize<'a>, + V::Commit: Serialize + for<'a> Deserialize<'a>,")] +// Somehow these bizarre bounds suffice for downstream derivations +#[derivative( + Clone(bound = ""), + Debug(bound = "V::Share: Debug, V::Common: Debug, V::Commit: Debug"), + Eq(bound = ""), + Hash(bound = "V::Share: Hash, V::Common: Hash, V::Commit: Hash"), + PartialEq(bound = "") +)] pub struct VidDisperse { /// VID disperse shares to send to the storage nodes. pub shares: Vec, diff --git a/utilities/Cargo.toml b/utilities/Cargo.toml index 6aba0e3ee..9844e0f6a 100644 --- a/utilities/Cargo.toml +++ b/utilities/Cargo.toml @@ -9,7 +9,7 @@ rust-version = { workspace = true } [dependencies] ark-ec = { version = "0.4.0", default-features = false } -ark-ff = { version = "0.4.0", default-features = false, features = [ "asm" ] } +ark-ff = { version = "0.4.0", default-features = false, features = ["asm"] } ark-serialize = { version = "0.4.0", default-features = false } ark-std = { version = "0.4.0", default-features = false } digest = { version = "0.10.1", default-features = false } @@ -28,5 +28,18 @@ ark-ed-on-bn254 = "0.4.0" [features] default = [] -std = ["ark-ff/std", "ark-std/std", "ark-ec/std", "ark-serialize/std", "digest/std", "serde/std", "sha2/std"] -parallel = ["ark-ff/parallel", "ark-std/parallel", "ark-ec/parallel", "dep:rayon"] +std = [ + "ark-ff/std", + "ark-std/std", + "ark-ec/std", + "ark-serialize/std", + "digest/std", + "serde/std", + "sha2/std", +] +parallel = [ + "ark-ff/parallel", + "ark-std/parallel", + "ark-ec/parallel", + "dep:rayon", +] diff --git a/utilities/src/conversion.rs b/utilities/src/conversion.rs index a6d73bc0f..377901059 100644 --- a/utilities/src/conversion.rs +++ b/utilities/src/conversion.rs @@ -9,9 +9,10 @@ use ark_ff::{BigInteger, Field, PrimeField}; use ark_std::{ borrow::Borrow, cmp::min, - iter::{once, repeat}, - mem, - vec::Vec, + iter::{once, repeat, Peekable, Take}, + marker::PhantomData, + mem, vec, + vec::{IntoIter, Vec}, }; use sha2::{Digest, Sha512}; @@ -289,6 +290,262 @@ fn compile_time_checks() -> (usize, usize, usize) { (primefield_bytes_len, extension_degree, field_bytes_len) } +/// Deterministic, infallible, invertible iterator adaptor to convert from +/// arbitrary bytes to field elements. +/// +/// # TODO doc test +/// +/// # How it works +/// +/// Returns an iterator over [`PrimeField`] items defined as follows: +/// - For each call to `next()`: +/// - Consume P-1 items from `bytes` where P is the field characteristic byte +/// length. (Consume all remaining B items from `bytes` if B < P-1.) +/// - Convert the consumed bytes into a [`PrimeField`] via +/// [`from_le_bytes_mod_order`]. Reduction modulo the field characteristic +/// is guaranteed not to occur because we consumed at most P-1 bytes. +/// - Return the resulting [`PrimeField`] item. +/// - The returned iterator has an additional item that encodes the number of +/// input items consumed in order to produce the final output item. +/// - If `bytes` is empty then result is empty. +/// +/// # Panics +/// +/// Panics only under conditions that should be checkable at compile time: +/// +/// - The [`PrimeField`] modulus bit length is too small to hold a `u64`. +/// - The [`PrimeField`] byte length is too large to fit inside a `usize`. +/// +/// If any of the above conditions holds then this function *always* panics. +pub fn bytes_to_field(bytes: I) -> impl Iterator +where + F: PrimeField, + I: IntoIterator, + I::Item: Borrow, +{ + BytesToField::new(bytes.into_iter()) +} + +/// Deterministic, infallible inverse of [`bytes_to_field`]. +/// +/// This function is not invertible because [`bytes_to_field`] is not onto. +/// +/// ## Panics +/// +/// Panics under the conditions listed at [`bytes_to_field`]. +pub fn field_to_bytes(elems: I) -> impl Iterator +where + F: PrimeField, + I: IntoIterator, + I::Item: Borrow, +{ + FieldToBytes::new(elems.into_iter()) +} + +struct BytesToField +where + I: Iterator, +{ + bytes_iter: Peekable, + final_byte_len: Option, + done: bool, + new: bool, + _phantom: PhantomData, + primefield_bytes_len: usize, +} + +impl BytesToField +where + I: Iterator, +{ + fn new(iter: I) -> Self { + let (primefield_bytes_len, ..) = compile_time_checks::(); + Self { + bytes_iter: iter.peekable(), + final_byte_len: None, + done: false, + new: true, + _phantom: PhantomData, + primefield_bytes_len, + } + } +} + +impl Iterator for BytesToField +where + I: Iterator, + I::Item: Borrow, + F: PrimeField, +{ + type Item = F; + + fn next(&mut self) -> Option { + if self.done { + // we don't support iterators that return `Some` after returning `None` + return None; + } + + if let Some(len) = self.final_byte_len { + // iterator is done. final field elem encodes length. + self.done = true; + return Some(F::from(len as u64)); + } + + if self.new && self.bytes_iter.peek().is_none() { + // zero-length iterator + self.done = true; + return None; + } + + // TODO const generics: use [u8; primefield_bytes_len] + let mut field_elem_bytes = vec![0u8; self.primefield_bytes_len]; + for (i, b) in field_elem_bytes.iter_mut().enumerate() { + if let Some(byte) = self.bytes_iter.next() { + *b = *byte.borrow(); + } else { + self.final_byte_len = Some(i); + break; + } + } + Some(F::from_le_bytes_mod_order(&field_elem_bytes)) + } +} + +struct FieldToBytes { + elems_iter: I, + state: FieldToBytesState, + primefield_bytes_len: usize, +} + +enum FieldToBytesState { + New, + Typical { + bytes_iter: Take>, + next_elem: F, + next_next_elem: F, + }, + Final { + bytes_iter: Take>, + }, +} + +impl FieldToBytes { + fn new(elems_iter: I) -> Self { + let (primefield_bytes_len, ..) = compile_time_checks::(); + Self { + elems_iter, + state: FieldToBytesState::New, + primefield_bytes_len, + } + } + + fn elem_to_usize(elem: F) -> usize { + usize::try_from(u64::from_le_bytes( + elem.into_bigint().to_bytes_le()[..mem::size_of::()] + .try_into() + .expect("conversion from [u8] to u64 should succeed"), + )) + .expect("result len conversion from u64 to usize should succeed") + } + + fn elem_to_bytes_iter(elem: F) -> IntoIter { + elem.into_bigint().to_bytes_le().into_iter() + } +} + +impl Iterator for FieldToBytes +where + I: Iterator, + I::Item: Borrow, + F: PrimeField, +{ + type Item = u8; + + fn next(&mut self) -> Option { + use FieldToBytesState::{Final, New, Typical}; + match &mut self.state { + New => { + let cur_elem = if let Some(elem) = self.elems_iter.next() { + *elem.borrow() + } else { + // length-0 iterator + // move to `Final` state with an empty iterator + self.state = Final { + bytes_iter: Vec::new().into_iter().take(0), + }; + return None; + }; + + let bytes_iter = Self::elem_to_bytes_iter(cur_elem); + + let next_elem = if let Some(elem) = self.elems_iter.next() { + *elem.borrow() + } else { + // length-1 iterator: we never produced this + // move to `Final` state with primefield_bytes_len bytes from the sole elem + let mut bytes_iter = bytes_iter.take(self.primefield_bytes_len); + let ret = bytes_iter.next(); + self.state = Final { bytes_iter }; + return ret; + }; + + let next_next_elem = if let Some(elem) = self.elems_iter.next() { + *elem.borrow() + } else { + // length-2 iterator + let final_byte_len = Self::elem_to_usize(next_elem); + let mut bytes_iter = bytes_iter.take(final_byte_len); + let ret = bytes_iter.next(); + self.state = Final { bytes_iter }; + return ret; + }; + + // length >2 iterator + let mut bytes_iter = bytes_iter.take(self.primefield_bytes_len); + let ret = bytes_iter.next(); + self.state = Typical { + bytes_iter, + next_elem, + next_next_elem, + }; + ret + }, + Typical { + bytes_iter, + next_elem, + next_next_elem, + } => { + let ret = bytes_iter.next(); + if ret.is_some() { + return ret; + } + + let bytes_iter = Self::elem_to_bytes_iter(*next_elem); + + if let Some(elem) = self.elems_iter.next() { + // advance to the next field element + let mut bytes_iter = bytes_iter.take(self.primefield_bytes_len); + let ret = bytes_iter.next(); + self.state = Typical { + bytes_iter, + next_elem: *next_next_elem, + next_next_elem: *elem.borrow(), + }; + return ret; + } + + // done + let final_byte_len = Self::elem_to_usize(*next_next_elem); + let mut bytes_iter = bytes_iter.take(final_byte_len); + let ret = bytes_iter.next(); + self.state = Final { bytes_iter }; + ret + }, + Final { bytes_iter } => bytes_iter.next(), + } + } +} + #[cfg(test)] mod tests { use crate::test_rng; @@ -300,6 +557,7 @@ mod tests { use ark_ed_on_bls12_377::{EdwardsConfig as Param377, Fr as Fr377}; use ark_ed_on_bls12_381::{EdwardsConfig as Param381, Fr as Fr381}; use ark_ed_on_bn254::{EdwardsConfig as Param254, Fr as Fr254}; + use ark_ff::{Field, PrimeField}; use ark_std::{rand::RngCore, UniformRand}; #[test] @@ -363,6 +621,65 @@ mod tests { } } + fn bytes_field_elems_iter() { + // copied from bytes_field_elems() + + let lengths = [0, 1, 2, 16, 31, 32, 33, 48, 65, 100, 200, 5000]; + let trailing_zeros_lengths = [0, 1, 2, 5, 50]; + + let max_len = *lengths.iter().max().unwrap(); + let max_trailing_zeros_len = *trailing_zeros_lengths.iter().max().unwrap(); + let mut bytes = Vec::with_capacity(max_len + max_trailing_zeros_len); + let mut elems: Vec = Vec::with_capacity(max_len); + let mut rng = test_rng(); + + for len in lengths { + for trailing_zeros_len in trailing_zeros_lengths { + // fill bytes with random bytes and trailing zeros + bytes.resize(len + trailing_zeros_len, 0); + rng.fill_bytes(&mut bytes[..len]); + bytes[len..].fill(0); + + // debug + // println!("byte_len: {}, trailing_zeros: {}", len, trailing_zeros_len); + // println!("bytes: {:?}", bytes); + // let encoded: Vec = bytes_to_field(bytes.iter()).collect(); + // println!("encoded: {:?}", encoded); + // let result: Vec<_> = bytes_from_field(encoded).collect(); + // println!("result: {:?}", result); + + // round trip: bytes as Iterator, elems as Iterator + let result_clone: Vec<_> = + field_to_bytes(bytes_to_field::<_, F>(bytes.clone())).collect(); + assert_eq!(result_clone, bytes); + + // round trip: bytes as Iterator, elems as Iterator + let encoded: Vec<_> = bytes_to_field::<_, F>(bytes.iter()).collect(); + let result_borrow: Vec<_> = field_to_bytes::<_, F>(encoded.iter()).collect(); + assert_eq!(result_borrow, bytes); + } + + // test infallibility of bytes_from_field + // with random field elements + elems.resize(len, F::zero()); + elems.iter_mut().for_each(|e| *e = F::rand(&mut rng)); + let _: Vec = field_to_bytes::<_, F>(elems.iter()).collect(); + } + + // empty input -> empty output + let bytes = Vec::new(); + assert!(bytes.iter().next().is_none()); + let mut elems_iter = bytes_to_field::<_, F>(bytes.iter()); + assert!(elems_iter.next().is_none()); + + // smallest non-empty input -> 2-item output + let bytes = [42u8; 1]; + let mut elems_iter = bytes_to_field::<_, F>(bytes.iter()); + assert_eq!(elems_iter.next().unwrap(), F::from(42u64)); + assert_eq!(elems_iter.next().unwrap(), F::from(1u64)); + assert!(elems_iter.next().is_none()); + } + #[test] fn test_bytes_field_elems() { bytes_field_elems::(); @@ -372,4 +689,11 @@ mod tests { bytes_field_elems::(); bytes_field_elems::(); } + + #[test] + fn test_bytes_field_elems_iter() { + bytes_field_elems_iter::(); + bytes_field_elems_iter::(); + bytes_field_elems_iter::(); + } }