diff --git a/halo2_proofs/src/poly.rs b/halo2_proofs/src/poly.rs index d95fda66f6..35c37fd095 100644 --- a/halo2_proofs/src/poly.rs +++ b/halo2_proofs/src/poly.rs @@ -209,6 +209,82 @@ impl Polynomial { _marker: PhantomData, } } + + /// Gets the specified chunk of the rotated version of this polynomial. + /// + /// Equivalent to: + /// ```ignore + /// self.rotate(rotation) + /// .chunks(chunk_size) + /// .nth(chunk_index) + /// .unwrap() + /// .to_vec() + /// ``` + pub(crate) fn get_chunk_of_rotated( + &self, + rotation: Rotation, + chunk_size: usize, + chunk_index: usize, + ) -> Vec { + self.get_chunk_of_rotated_helper( + rotation.0 < 0, + rotation.0.unsigned_abs() as usize, + chunk_size, + chunk_index, + ) + } +} + +impl Polynomial { + pub(crate) fn get_chunk_of_rotated_helper( + &self, + rotation_is_negative: bool, + rotation_abs: usize, + chunk_size: usize, + chunk_index: usize, + ) -> Vec { + // Compute the lengths such that when applying the rotation, the first `mid` + // coefficients move to the end, and the last `k` coefficients move to the front. + // The coefficient previously at `mid` will be the first coefficient in the + // rotated polynomial, and the position from which chunk indexing begins. + #[allow(clippy::branches_sharing_code)] + let (mid, k) = if rotation_is_negative { + let k = rotation_abs; + assert!(k <= self.len()); + let mid = self.len() - k; + (mid, k) + } else { + let mid = rotation_abs; + assert!(mid <= self.len()); + let k = self.len() - mid; + (mid, k) + }; + + // Compute [chunk_start..chunk_end], the range of the chunk within the rotated + // polynomial. + let chunk_start = chunk_size * chunk_index; + let chunk_end = self.len().min(chunk_size * (chunk_index + 1)); + + if chunk_end < k { + // The chunk is entirely in the last `k` coefficients of the unrotated + // polynomial. + self.values[mid + chunk_start..mid + chunk_end].to_vec() + } else if chunk_start >= k { + // The chunk is entirely in the first `mid` coefficients of the unrotated + // polynomial. + self.values[chunk_start - k..chunk_end - k].to_vec() + } else { + // The chunk falls across the boundary between the last `k` and first `mid` + // coefficients of the unrotated polynomial. Splice the halves together. + let chunk = self.values[mid + chunk_start..] + .iter() + .chain(&self.values[..chunk_end - k]) + .copied() + .collect::>(); + assert!(chunk.len() <= chunk_size); + chunk + } + } } impl Mul for Polynomial { @@ -247,3 +323,43 @@ impl Rotation { Rotation(1) } } + +#[cfg(test)] +mod tests { + use ff::Field; + use pasta_curves::pallas; + use rand_core::OsRng; + + use super::{EvaluationDomain, Polynomial, Rotation}; + + #[test] + fn test_get_chunk_of_rotated() { + let k = 11; + let domain = EvaluationDomain::::new(1, k); + + // Create a random polynomial. + let mut poly = domain.empty_lagrange(); + for coefficient in poly.iter_mut() { + *coefficient = pallas::Base::random(OsRng); + } + + // Pick a chunk size that is guaranteed to not be a multiple of the polynomial + // length. + let chunk_size = 7; + + for rotation in [ + Rotation(-6), + Rotation::prev(), + Rotation::cur(), + Rotation::next(), + Rotation(12), + ] { + for (chunk_index, chunk) in poly.rotate(rotation).chunks(chunk_size).enumerate() { + assert_eq!( + poly.get_chunk_of_rotated(rotation, chunk_size, chunk_index), + chunk + ); + } + } + } +} diff --git a/halo2_proofs/src/poly/domain.rs b/halo2_proofs/src/poly/domain.rs index 708d72ee1e..68ccd5ae1a 100644 --- a/halo2_proofs/src/poly/domain.rs +++ b/halo2_proofs/src/poly/domain.rs @@ -272,6 +272,27 @@ impl EvaluationDomain { poly } + /// Gets the specified chunk of the rotated version of this polynomial. + /// + /// Equivalent to: + /// ```ignore + /// self.rotate_extended(poly, rotation) + /// .chunks(chunk_size) + /// .nth(chunk_index) + /// .unwrap() + /// .to_vec() + /// ``` + pub(crate) fn get_chunk_of_rotated_extended( + &self, + poly: &Polynomial, + rotation: Rotation, + chunk_size: usize, + chunk_index: usize, + ) -> Vec { + let new_rotation = ((1 << (self.extended_k - self.k)) * rotation.0.abs()) as usize; + poly.get_chunk_of_rotated_helper(rotation.0 < 0, new_rotation, chunk_size, chunk_index) + } + /// This takes us from the extended evaluation domain and gets us the /// quotient polynomial coefficients. /// @@ -545,3 +566,41 @@ fn test_l_i() { assert_eq!(eval_polynomial(&l[(8 - i) % 8][..], x), evaluations[7 - i]); } } + +#[test] +fn test_get_chunk_of_rotated_extended() { + use pasta_curves::pallas; + use rand_core::OsRng; + + let k = 11; + let domain = EvaluationDomain::::new(3, k); + + // Create a random polynomial. + let mut poly = domain.empty_extended(); + for coefficient in poly.iter_mut() { + *coefficient = pallas::Base::random(OsRng); + } + + // Pick a chunk size that is guaranteed to not be a multiple of the polynomial + // length. + let chunk_size = 7; + + for rotation in [ + Rotation(-6), + Rotation::prev(), + Rotation::cur(), + Rotation::next(), + Rotation(12), + ] { + for (chunk_index, chunk) in domain + .rotate_extended(&poly, rotation) + .chunks(chunk_size) + .enumerate() + { + assert_eq!( + domain.get_chunk_of_rotated_extended(&poly, rotation, chunk_size, chunk_index), + chunk + ); + } + } +} diff --git a/halo2_proofs/src/poly/evaluator.rs b/halo2_proofs/src/poly/evaluator.rs index d034d3c50c..99c613b6fb 100644 --- a/halo2_proofs/src/poly/evaluator.rs +++ b/halo2_proofs/src/poly/evaluator.rs @@ -138,82 +138,30 @@ impl Evaluator { F: FieldExt, B: BasisOps, { - // Traverse `ast` to collect the used leaves. - fn collect_rotations( - ast: &Ast, - ) -> HashSet> { - match ast { - Ast::Poly(leaf) => vec![*leaf].into_iter().collect(), - Ast::Add(a, b) | Ast::Mul(AstMul(a, b)) => { - let lhs = collect_rotations(a); - let rhs = collect_rotations(b); - lhs.union(&rhs).cloned().collect() - } - Ast::Scale(a, _) => collect_rotations(a), - Ast::DistributePowers(terms, _) => terms - .iter() - .flat_map(|term| collect_rotations(term).into_iter()) - .collect(), - Ast::LinearTerm(_) | Ast::ConstantTerm(_) => HashSet::default(), - } - } - let leaves = collect_rotations(ast); - - // Produce a map from each leaf to the rotated polynomial it corresponds to, if - // any (or None if the leaf uses an unrotated polynomial). - let rotated: HashMap<_, _> = leaves - .iter() - .cloned() - .map(|leaf| { - ( - leaf, - if leaf.rotation == Rotation::cur() { - // We can use the polynomial as-is for this leaf. - None - } else { - Some(B::rotate(domain, &self.polys[leaf.index], leaf.rotation)) - }, - ) - }) - .collect(); - // We're working in a single basis, so all polynomials are the same length. let poly_len = self.polys.first().unwrap().len(); - let (chunk_size, num_chunks) = get_chunk_params(poly_len); - - // Split each rotated and unrotated polynomial into chunks. - let chunks: Vec> = (0..num_chunks) - .map(|i| { - rotated - .iter() - .map(|(leaf, poly)| { - ( - *leaf, - poly.as_ref() - .unwrap_or(&self.polys[leaf.index]) - .chunks(chunk_size) - .nth(i) - .expect("num_chunks was calculated correctly"), - ) - }) - .collect() - }) - .collect(); + let (chunk_size, _num_chunks) = get_chunk_params(poly_len); - struct AstContext<'a, E, F: FieldExt, B: Basis> { + struct AstContext<'a, F: FieldExt, B: Basis> { domain: &'a EvaluationDomain, poly_len: usize, chunk_size: usize, chunk_index: usize, - leaves: &'a HashMap, &'a [F]>, + polys: &'a [Polynomial], } fn recurse( ast: &Ast, - ctx: &AstContext<'_, E, F, B>, + ctx: &AstContext<'_, F, B>, ) -> Vec { match ast { - Ast::Poly(leaf) => ctx.leaves.get(leaf).expect("We prepared this").to_vec(), + Ast::Poly(leaf) => B::get_chunk_of_rotated( + ctx.domain, + ctx.chunk_size, + ctx.chunk_index, + &ctx.polys[leaf.index], + leaf.rotation, + ), Ast::Add(a, b) => { let mut lhs = recurse(a, ctx); let rhs = recurse(b, ctx); @@ -265,16 +213,14 @@ impl Evaluator { // polynomial. let mut result = B::empty_poly(domain); multicore::scope(|scope| { - for (chunk_index, (out, leaves)) in - result.chunks_mut(chunk_size).zip(chunks.iter()).enumerate() - { + for (chunk_index, out) in result.chunks_mut(chunk_size).enumerate() { scope.spawn(move |_| { let ctx = AstContext { domain, poly_len, chunk_size, chunk_index, - leaves, + polys: &self.polys, }; out.copy_from_slice(&recurse(ast, &ctx)); }); @@ -511,11 +457,13 @@ pub(crate) trait BasisOps: Basis { chunk_index: usize, scalar: F, ) -> Vec; - fn rotate( + fn get_chunk_of_rotated( domain: &EvaluationDomain, + chunk_size: usize, + chunk_index: usize, poly: &Polynomial, rotation: Rotation, - ) -> Polynomial; + ) -> Vec; } impl BasisOps for Coeff { @@ -558,11 +506,13 @@ impl BasisOps for Coeff { chunk } - fn rotate( + fn get_chunk_of_rotated( _: &EvaluationDomain, + _: usize, + _: usize, _: &Polynomial, _: Rotation, - ) -> Polynomial { + ) -> Vec { panic!("Can't rotate polynomials in the standard basis") } } @@ -600,12 +550,14 @@ impl BasisOps for LagrangeCoeff { .collect() } - fn rotate( + fn get_chunk_of_rotated( _: &EvaluationDomain, + chunk_size: usize, + chunk_index: usize, poly: &Polynomial, rotation: Rotation, - ) -> Polynomial { - poly.rotate(rotation) + ) -> Vec { + poly.get_chunk_of_rotated(rotation, chunk_size, chunk_index) } } @@ -645,12 +597,14 @@ impl BasisOps for ExtendedLagrangeCoeff { .collect() } - fn rotate( + fn get_chunk_of_rotated( domain: &EvaluationDomain, + chunk_size: usize, + chunk_index: usize, poly: &Polynomial, rotation: Rotation, - ) -> Polynomial { - domain.rotate_extended(poly, rotation) + ) -> Vec { + domain.get_chunk_of_rotated_extended(poly, rotation, chunk_size, chunk_index) } }