Skip to content

Commit

Permalink
Merge pull request #644 from zcash/remove-rotated-poly-cache
Browse files Browse the repository at this point in the history
halo2_proofs: Avoid caching rotated polynomials in `poly::Evaluator`
  • Loading branch information
str4d authored Sep 10, 2022
2 parents 1806b88 + d24f0fd commit 553584a
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 77 deletions.
116 changes: 116 additions & 0 deletions halo2_proofs/src/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,82 @@ impl<F: Field> Polynomial<F, LagrangeCoeff> {
_marker: PhantomData,
}
}

/// Gets the specified chunk of the rotated version of this polynomial.
///
/// Equivalent to:
/// ```ignore
/// self.rotate(rotation)
/// .chunks(chunk_size)
/// .nth(chunk_index)
/// .unwrap()
/// .to_vec()
/// ```
pub(crate) fn get_chunk_of_rotated(
&self,
rotation: Rotation,
chunk_size: usize,
chunk_index: usize,
) -> Vec<F> {
self.get_chunk_of_rotated_helper(
rotation.0 < 0,
rotation.0.unsigned_abs() as usize,
chunk_size,
chunk_index,
)
}
}

impl<F: Clone + Copy, B> Polynomial<F, B> {
pub(crate) fn get_chunk_of_rotated_helper(
&self,
rotation_is_negative: bool,
rotation_abs: usize,
chunk_size: usize,
chunk_index: usize,
) -> Vec<F> {
// Compute the lengths such that when applying the rotation, the first `mid`
// coefficients move to the end, and the last `k` coefficients move to the front.
// The coefficient previously at `mid` will be the first coefficient in the
// rotated polynomial, and the position from which chunk indexing begins.
#[allow(clippy::branches_sharing_code)]
let (mid, k) = if rotation_is_negative {
let k = rotation_abs;
assert!(k <= self.len());
let mid = self.len() - k;
(mid, k)
} else {
let mid = rotation_abs;
assert!(mid <= self.len());
let k = self.len() - mid;
(mid, k)
};

// Compute [chunk_start..chunk_end], the range of the chunk within the rotated
// polynomial.
let chunk_start = chunk_size * chunk_index;
let chunk_end = self.len().min(chunk_size * (chunk_index + 1));

if chunk_end < k {
// The chunk is entirely in the last `k` coefficients of the unrotated
// polynomial.
self.values[mid + chunk_start..mid + chunk_end].to_vec()
} else if chunk_start >= k {
// The chunk is entirely in the first `mid` coefficients of the unrotated
// polynomial.
self.values[chunk_start - k..chunk_end - k].to_vec()
} else {
// The chunk falls across the boundary between the last `k` and first `mid`
// coefficients of the unrotated polynomial. Splice the halves together.
let chunk = self.values[mid + chunk_start..]
.iter()
.chain(&self.values[..chunk_end - k])
.copied()
.collect::<Vec<_>>();
assert!(chunk.len() <= chunk_size);
chunk
}
}
}

impl<F: Field, B: Basis> Mul<F> for Polynomial<F, B> {
Expand Down Expand Up @@ -247,3 +323,43 @@ impl Rotation {
Rotation(1)
}
}

#[cfg(test)]
mod tests {
use ff::Field;
use pasta_curves::pallas;
use rand_core::OsRng;

use super::{EvaluationDomain, Polynomial, Rotation};

#[test]
fn test_get_chunk_of_rotated() {
let k = 11;
let domain = EvaluationDomain::<pallas::Base>::new(1, k);

// Create a random polynomial.
let mut poly = domain.empty_lagrange();
for coefficient in poly.iter_mut() {
*coefficient = pallas::Base::random(OsRng);
}

// Pick a chunk size that is guaranteed to not be a multiple of the polynomial
// length.
let chunk_size = 7;

for rotation in [
Rotation(-6),
Rotation::prev(),
Rotation::cur(),
Rotation::next(),
Rotation(12),
] {
for (chunk_index, chunk) in poly.rotate(rotation).chunks(chunk_size).enumerate() {
assert_eq!(
poly.get_chunk_of_rotated(rotation, chunk_size, chunk_index),
chunk
);
}
}
}
}
59 changes: 59 additions & 0 deletions halo2_proofs/src/poly/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,27 @@ impl<G: Group> EvaluationDomain<G> {
poly
}

/// Gets the specified chunk of the rotated version of this polynomial.
///
/// Equivalent to:
/// ```ignore
/// self.rotate_extended(poly, rotation)
/// .chunks(chunk_size)
/// .nth(chunk_index)
/// .unwrap()
/// .to_vec()
/// ```
pub(crate) fn get_chunk_of_rotated_extended(
&self,
poly: &Polynomial<G, ExtendedLagrangeCoeff>,
rotation: Rotation,
chunk_size: usize,
chunk_index: usize,
) -> Vec<G> {
let new_rotation = ((1 << (self.extended_k - self.k)) * rotation.0.abs()) as usize;
poly.get_chunk_of_rotated_helper(rotation.0 < 0, new_rotation, chunk_size, chunk_index)
}

/// This takes us from the extended evaluation domain and gets us the
/// quotient polynomial coefficients.
///
Expand Down Expand Up @@ -545,3 +566,41 @@ fn test_l_i() {
assert_eq!(eval_polynomial(&l[(8 - i) % 8][..], x), evaluations[7 - i]);
}
}

#[test]
fn test_get_chunk_of_rotated_extended() {
use pasta_curves::pallas;
use rand_core::OsRng;

let k = 11;
let domain = EvaluationDomain::<pallas::Base>::new(3, k);

// Create a random polynomial.
let mut poly = domain.empty_extended();
for coefficient in poly.iter_mut() {
*coefficient = pallas::Base::random(OsRng);
}

// Pick a chunk size that is guaranteed to not be a multiple of the polynomial
// length.
let chunk_size = 7;

for rotation in [
Rotation(-6),
Rotation::prev(),
Rotation::cur(),
Rotation::next(),
Rotation(12),
] {
for (chunk_index, chunk) in domain
.rotate_extended(&poly, rotation)
.chunks(chunk_size)
.enumerate()
{
assert_eq!(
domain.get_chunk_of_rotated_extended(&poly, rotation, chunk_size, chunk_index),
chunk
);
}
}
}
108 changes: 31 additions & 77 deletions halo2_proofs/src/poly/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,82 +138,30 @@ impl<E, F: Field, B: Basis> Evaluator<E, F, B> {
F: FieldExt,
B: BasisOps,
{
// Traverse `ast` to collect the used leaves.
fn collect_rotations<E: Copy, F: Field, B: Basis>(
ast: &Ast<E, F, B>,
) -> HashSet<AstLeaf<E, B>> {
match ast {
Ast::Poly(leaf) => vec![*leaf].into_iter().collect(),
Ast::Add(a, b) | Ast::Mul(AstMul(a, b)) => {
let lhs = collect_rotations(a);
let rhs = collect_rotations(b);
lhs.union(&rhs).cloned().collect()
}
Ast::Scale(a, _) => collect_rotations(a),
Ast::DistributePowers(terms, _) => terms
.iter()
.flat_map(|term| collect_rotations(term).into_iter())
.collect(),
Ast::LinearTerm(_) | Ast::ConstantTerm(_) => HashSet::default(),
}
}
let leaves = collect_rotations(ast);

// Produce a map from each leaf to the rotated polynomial it corresponds to, if
// any (or None if the leaf uses an unrotated polynomial).
let rotated: HashMap<_, _> = leaves
.iter()
.cloned()
.map(|leaf| {
(
leaf,
if leaf.rotation == Rotation::cur() {
// We can use the polynomial as-is for this leaf.
None
} else {
Some(B::rotate(domain, &self.polys[leaf.index], leaf.rotation))
},
)
})
.collect();

// We're working in a single basis, so all polynomials are the same length.
let poly_len = self.polys.first().unwrap().len();
let (chunk_size, num_chunks) = get_chunk_params(poly_len);

// Split each rotated and unrotated polynomial into chunks.
let chunks: Vec<HashMap<_, _>> = (0..num_chunks)
.map(|i| {
rotated
.iter()
.map(|(leaf, poly)| {
(
*leaf,
poly.as_ref()
.unwrap_or(&self.polys[leaf.index])
.chunks(chunk_size)
.nth(i)
.expect("num_chunks was calculated correctly"),
)
})
.collect()
})
.collect();
let (chunk_size, _num_chunks) = get_chunk_params(poly_len);

struct AstContext<'a, E, F: FieldExt, B: Basis> {
struct AstContext<'a, F: FieldExt, B: Basis> {
domain: &'a EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
leaves: &'a HashMap<AstLeaf<E, B>, &'a [F]>,
polys: &'a [Polynomial<F, B>],
}

fn recurse<E, F: FieldExt, B: BasisOps>(
ast: &Ast<E, F, B>,
ctx: &AstContext<'_, E, F, B>,
ctx: &AstContext<'_, F, B>,
) -> Vec<F> {
match ast {
Ast::Poly(leaf) => ctx.leaves.get(leaf).expect("We prepared this").to_vec(),
Ast::Poly(leaf) => B::get_chunk_of_rotated(
ctx.domain,
ctx.chunk_size,
ctx.chunk_index,
&ctx.polys[leaf.index],
leaf.rotation,
),
Ast::Add(a, b) => {
let mut lhs = recurse(a, ctx);
let rhs = recurse(b, ctx);
Expand Down Expand Up @@ -265,16 +213,14 @@ impl<E, F: Field, B: Basis> Evaluator<E, F, B> {
// polynomial.
let mut result = B::empty_poly(domain);
multicore::scope(|scope| {
for (chunk_index, (out, leaves)) in
result.chunks_mut(chunk_size).zip(chunks.iter()).enumerate()
{
for (chunk_index, out) in result.chunks_mut(chunk_size).enumerate() {
scope.spawn(move |_| {
let ctx = AstContext {
domain,
poly_len,
chunk_size,
chunk_index,
leaves,
polys: &self.polys,
};
out.copy_from_slice(&recurse(ast, &ctx));
});
Expand Down Expand Up @@ -511,11 +457,13 @@ pub(crate) trait BasisOps: Basis {
chunk_index: usize,
scalar: F,
) -> Vec<F>;
fn rotate<F: FieldExt>(
fn get_chunk_of_rotated<F: FieldExt>(
domain: &EvaluationDomain<F>,
chunk_size: usize,
chunk_index: usize,
poly: &Polynomial<F, Self>,
rotation: Rotation,
) -> Polynomial<F, Self>;
) -> Vec<F>;
}

impl BasisOps for Coeff {
Expand Down Expand Up @@ -558,11 +506,13 @@ impl BasisOps for Coeff {
chunk
}

fn rotate<F: FieldExt>(
fn get_chunk_of_rotated<F: FieldExt>(
_: &EvaluationDomain<F>,
_: usize,
_: usize,
_: &Polynomial<F, Self>,
_: Rotation,
) -> Polynomial<F, Self> {
) -> Vec<F> {
panic!("Can't rotate polynomials in the standard basis")
}
}
Expand Down Expand Up @@ -600,12 +550,14 @@ impl BasisOps for LagrangeCoeff {
.collect()
}

fn rotate<F: FieldExt>(
fn get_chunk_of_rotated<F: FieldExt>(
_: &EvaluationDomain<F>,
chunk_size: usize,
chunk_index: usize,
poly: &Polynomial<F, Self>,
rotation: Rotation,
) -> Polynomial<F, Self> {
poly.rotate(rotation)
) -> Vec<F> {
poly.get_chunk_of_rotated(rotation, chunk_size, chunk_index)
}
}

Expand Down Expand Up @@ -645,12 +597,14 @@ impl BasisOps for ExtendedLagrangeCoeff {
.collect()
}

fn rotate<F: FieldExt>(
fn get_chunk_of_rotated<F: FieldExt>(
domain: &EvaluationDomain<F>,
chunk_size: usize,
chunk_index: usize,
poly: &Polynomial<F, Self>,
rotation: Rotation,
) -> Polynomial<F, Self> {
domain.rotate_extended(poly, rotation)
) -> Vec<F> {
domain.get_chunk_of_rotated_extended(poly, rotation, chunk_size, chunk_index)
}
}

Expand Down

0 comments on commit 553584a

Please sign in to comment.