diff --git a/Cargo.toml b/Cargo.toml index b7878ae843..526f9d841c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,12 @@ members = [ "halo2_gadgets", "halo2_proofs", ] + +[profile.bench] +opt-level = 3 +debug = false +debug-assertions = false +overflow-checks = false +lto = true +incremental = false +codegen-units = 1 diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index 1bb805e37e..216849ca1f 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -36,7 +36,7 @@ group = "0.11" rand = "0.8" rand_core = { version = "0.6", default-features = false } blake2b_simd = "1" -pairing = { git = 'https://github.com/appliedzkp/pairing', package = "pairing_bn256" } +pairing = { git = 'https://github.com/Brechtpd/pairing', branch = "msm", package = "pairing_bn256" } subtle = "2.3" cfg-if = "0.1" @@ -57,6 +57,7 @@ criterion = "0.3" gumdrop = "0.8" proptest = "1" rand_core = { version = "0.6", default-features = false, features = ["getrandom"] } +ark-std = { version = "0.3" } [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] getrandom = { version = "0.2", features = ["js"] } @@ -68,6 +69,8 @@ gadget-traces = ["backtrace"] sanity-checks = [] shplonk = [] gwc = [] +asm = ["pairing/asm"] +prefetch = ["pairing/prefetch"] [lib] bench = false diff --git a/halo2_proofs/benches/plonk.rs b/halo2_proofs/benches/plonk.rs index ef2c0e76a1..69e2fbedb2 100644 --- a/halo2_proofs/benches/plonk.rs +++ b/halo2_proofs/benches/plonk.rs @@ -255,16 +255,16 @@ fn criterion_benchmark(c: &mut Criterion) { ParamsVerifier, ProvingKey, ) { - let params: Params = Params::::unsafe_setup::(k); + let mut params: Params = Params::::unsafe_setup::(k); let params_verifier: ParamsVerifier = params.verifier(0).unwrap(); let empty_circuit: MyCircuit = MyCircuit { a: None, k }; - let vk = keygen_vk(¶ms, &empty_circuit).expect("keygen_vk should not fail"); + let vk = keygen_vk(&mut params, &empty_circuit).expect("keygen_vk should not fail"); let pk = keygen_pk(¶ms, vk, &empty_circuit).expect("keygen_pk should not fail"); (params, params_verifier, pk) } - fn prover(k: u32, params: &Params, pk: &ProvingKey) -> Vec { + fn prover(k: u32, params: &mut Params, pk: &ProvingKey) -> Vec { let rng = OsRng; let circuit: MyCircuit = MyCircuit { @@ -304,7 +304,7 @@ fn criterion_benchmark(c: &mut Criterion) { BenchmarkId::from_parameter(k), &(k, ¶ms, &pk), |b, &(k, params, pk)| { - b.iter(|| prover(k, params, pk)); + b.iter(|| prover(k, &mut params.clone(), pk)); }, ); } @@ -312,8 +312,8 @@ fn criterion_benchmark(c: &mut Criterion) { let mut verifier_group = c.benchmark_group("plonk-verifier"); for k in k_range { - let (params, params_verifier, pk) = keygen(k); - let proof = prover(k, ¶ms, &pk); + let (mut params, params_verifier, pk) = keygen(k); + let proof = prover(k, &mut params, &pk); verifier_group.bench_with_input( BenchmarkId::from_parameter(k), diff --git a/halo2_proofs/examples/simple-example-2.rs b/halo2_proofs/examples/simple-example-2.rs index 59e88c4953..d55f4e4a51 100644 --- a/halo2_proofs/examples/simple-example-2.rs +++ b/halo2_proofs/examples/simple-example-2.rs @@ -245,11 +245,11 @@ fn main() { let empty_circuit: MyCircuit = MyCircuit { a: None, k }; // Initialize the polynomial commitment parameters - let params: Params = Params::::unsafe_setup::(k); + let mut params: Params = Params::::unsafe_setup::(k); let params_verifier: ParamsVerifier = params.verifier(public_inputs_size).unwrap(); // Initialize the proving key - let vk = keygen_vk(¶ms, &empty_circuit).expect("keygen_vk should not fail"); + let vk = keygen_vk(&mut params, &empty_circuit).expect("keygen_vk should not fail"); let pk = keygen_pk(¶ms, vk, &empty_circuit).expect("keygen_pk should not fail"); let circuit: MyCircuit = MyCircuit { @@ -263,7 +263,7 @@ fn main() { use std::time::Instant; let _dur = Instant::now(); - create_proof(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) + create_proof(&mut params, &pk, &[circuit], &[&[]], OsRng, &mut transcript) .expect("proof generation should not fail"); println!("proving period: {:?}", _dur.elapsed()); diff --git a/halo2_proofs/src/arithmetic_msm.rs b/halo2_proofs/src/arithmetic_msm.rs new file mode 100644 index 0000000000..b03765f6a0 --- /dev/null +++ b/halo2_proofs/src/arithmetic_msm.rs @@ -0,0 +1,1032 @@ +//! This module implements a fast method for multi-scalar multiplications. +//! +//! Generally it works like pippenger with a couple of tricks to make if faster. +//! +//! - First the coefficients are split into two parts (using the endomorphism). This +//! reduces the number of rounds by half, but doubles the number of points per round. +//! This is faster because half the rounds also means only needing to add all bucket +//! results together half the number of times. +//! +//! - The coefficients are then sorted in buckets. Instead of using +//! the binary representation to do this, a signed digit representation is +//! used instead (WNAF). Unfortunately this doesn't directly reduce the number of additions +//! in a bucket, but it does reduce the number of buckets in half, which halves the +//! work required to accumulate the results of the buckets. +//! +//! - We then need to add all the points in each bucket together. To do this +//! the affine addition formulas are used. If the points are linearly independent the +//! incomplete version of the formula can be used which is quite a bit faster than +//! the full one because some checks can be skipped. +//! The affine formula is only fast if a lot of independent points can be added +//! together. This is because to get the actual result of an addition an inversion is +//! needed which is very expensive, but it's cheap when batched inversion can be used. +//! So the idea is to add a lot of pairs of points together using a single batched inversion. +//! We then have the results of all those additions, and can do a new batch of additions on those +//! results. This process is repeated as many times as needed until all additions for each bucket +//! are done. To do this efficiently we first build up an addition tree that sets everything +//! up correctly per round. We then process each addition tree per round. + +use core::slice; +use std::{ + convert::TryInto, + env::var, + mem::{self, size_of}, + time::Instant, +}; + +use super::multicore; +pub use ff::Field; +use group::{ + ff::{BatchInvert, PrimeField}, + prime::PrimeCurveAffine, + Group as _, +}; + +pub use pairing::arithmetic::*; + +fn num_bits(value: usize) -> usize { + (0usize.leading_zeros() - value.leading_zeros()) as usize +} + +fn div_up(a: usize, b: usize) -> usize { + (a + (b - 1)) / b +} + +fn get_wnaf_size_bits(num_bits: usize, w: usize) -> usize { + div_up(num_bits, w) +} + +fn get_wnaf_size(w: usize) -> usize { + get_wnaf_size_bits(div_up(C::Scalar::NUM_BITS as usize, 2), w) +} + +fn get_num_rounds(c: usize) -> usize { + get_wnaf_size::(c + 1) +} + +fn get_num_buckets(c: usize) -> usize { + (1 << c) + 1 +} + +fn get_max_tree_size(num_points: usize, c: usize) -> usize { + num_points * 2 + get_num_buckets(c) +} + +fn get_num_tree_levels(num_points: usize) -> usize { + 1 + num_bits(num_points - 1) +} + +/// Returns the signed digit representation of value with the specified window size. +/// The result is written to the wnaf slice with the specified stride. +fn get_wnaf(value: u128, w: usize, num_rounds: usize, wnaf: &mut [u32], stride: usize) { + fn get_bits_at(v: u128, pos: usize, num: usize) -> usize { + ((v >> pos) & ((1 << num) - 1)) as usize + } + + let mut borrow = 0; + let max = 1 << (w - 1); + for idx in 0..num_rounds { + let b = get_bits_at(value, idx * w, w) + borrow; + if b >= max { + // Set the highest bit to 1 to represent a negative value. + // This way the lower bits directly represent the bucket index. + wnaf[idx * stride] = (0x80000000 | ((1 << w) - b)) as u32; + borrow = 1; + } else { + wnaf[idx * stride] = b as u32; + borrow = 0; + } + } + assert_eq!(borrow, 0); +} + +/// Returns the best bucket width for the given number of points. +fn get_best_c(num_points: usize) -> usize { + if num_points >= 262144 { + 15 + } else if num_points >= 65536 { + 12 + } else if num_points >= 16384 { + 11 + } else if num_points >= 8192 { + 10 + } else if num_points >= 1024 { + 9 + } else { + 7 + } +} + +/// MultiExp +#[derive(Clone, Debug, Default)] +pub struct MultiExp { + /// The bases + bases: Vec, +} + +/// MultiExp context object +#[derive(Clone, Debug, Default)] +pub struct MultiExpContext { + /// Memory to store the points in the addition tree + points: Vec, + /// Memory to store wnafs + wnafs: Vec, + /// Memory split up between rounds + rounds: SharedRoundData, +} + +/// SharedRoundData +#[derive(Clone, Debug, Default)] +struct SharedRoundData { + /// Memory to store bucket sizes + bucket_sizes: Vec, + /// Memory to store bucket offsets + bucket_offsets: Vec, + /// Memory to store the point data + point_data: Vec, + /// Memory to store the output indices + output_indices: Vec, + /// Memory to store the base positions (on the first level) + base_positions: Vec, + /// Memory to store the scatter maps + scatter_map: Vec, +} + +/// RoundData +#[derive(Debug, Default)] +struct RoundData<'a> { + /// Number of levels in the addition tree + pub num_levels: usize, + /// The length of each level in the addition tree + pub level_sizes: Vec, + /// The offset to each level in the addition tree + pub level_offset: Vec, + /// The size of each bucket + pub bucket_sizes: &'a mut [usize], + /// The offset of each bucket + pub bucket_offsets: &'a mut [usize], + /// The point to use for each coefficient + pub point_data: &'a mut [u32], + /// The output index in the point array for each pair addition + pub output_indices: &'a mut [u32], + /// The point to use on the first level in the addition tree + pub base_positions: &'a mut [u32], + /// List of points that are scattered to the addition tree + pub scatter_map: &'a mut [ScatterData], + /// The length of scatter_map + pub scatter_map_len: usize, +} + +/// ScatterData +#[derive(Default, Debug, Clone)] +struct ScatterData { + /// The position in the addition tree to store the point + pub position: u32, + /// The point to write + pub point_data: u32, +} + +impl MultiExp { + /// Create a new MultiExp instance with the specified bases + pub fn new(bases: &[C]) -> Self { + let mut endo_bases = vec![C::identity(); bases.len() * 2]; + + // Generate the endomorphism bases + let num_threads = multicore::current_num_threads(); + multicore::scope(|scope| { + let num_points_per_thread = div_up(bases.len(), num_threads); + for (endo_bases, bases) in endo_bases + .chunks_mut(num_points_per_thread * 2) + .zip(bases.chunks(num_points_per_thread)) + { + scope.spawn(move |_| { + for (idx, base) in bases.iter().enumerate() { + endo_bases[idx * 2] = *base; + endo_bases[idx * 2 + 1] = C::get_endomorphism_base(base); + } + }); + } + }); + + Self { bases: endo_bases } + } + + /// Performs a multi-exponentiation operation. + /// Set complete to true if the bases are not guaranteed linearly independent. + pub fn evaluate( + &self, + ctx: &mut MultiExpContext, + coeffs: &[C::Scalar], + complete: bool, + ) -> C::Curve { + self.evaluate_with(ctx, coeffs, complete, get_best_c(coeffs.len())) + } + + /// Performs a multi-exponentiation operation with the given bucket width. + /// Set complete to true if the bases are not guaranteed linearly independent. + pub fn evaluate_with( + &self, + ctx: &mut MultiExpContext, + coeffs: &[C::Scalar], + complete: bool, + c: usize, + ) -> C::Curve { + assert!(coeffs.len() * 2 <= self.bases.len()); + assert!(c >= 4); + + // Allocate more memory if required + ctx.allocate(coeffs.len(), c); + + // Get the data for each round + let mut rounds = ctx.rounds.get_rounds::(coeffs.len(), c); + + // Get the bases for the coefficients + let bases = &self.bases[..coeffs.len() * 2]; + + let num_threads = multicore::current_num_threads(); + let start = start_measure( + format!("msm {} ({}) ({} threads)", coeffs.len(), c, num_threads), + false, + ); + if coeffs.len() >= 16 { + let num_points = coeffs.len() * 2; + let w = c + 1; + let num_rounds = get_num_rounds::(c); + + // Prepare WNAFs of all coefficients for all rounds + calculate_wnafs::(coeffs, &mut ctx.wnafs, c); + // Sort WNAFs into buckets for all rounds + sort::(&mut ctx.wnafs[0..num_rounds * num_points], &mut rounds, c); + // Calculate addition trees for all rounds + create_addition_trees(&mut rounds); + + // Now process each round individually + let mut partials = vec![C::Curve::identity(); num_rounds]; + for (round, acc) in rounds.iter().zip(partials.iter_mut()) { + // Scatter the odd points in the odd length buckets to the addition tree + do_point_scatter::(round, bases, &mut ctx.points); + // Do all bucket additions + do_batch_additions::(round, bases, &mut ctx.points, complete); + // Get the final result of the round + *acc = accumulate_buckets::(round, &mut ctx.points, c); + } + + // Accumulate round results + let res = + partials + .iter() + .rev() + .skip(1) + .fold(partials[num_rounds - 1], |acc, partial| { + let mut res = acc; + for _ in 0..w { + res = res.double(); + } + res + partial + }); + stop_measure(start); + + res + } else { + // Just do a naive msm + let mut acc = C::Curve::identity(); + for (idx, coeff) in coeffs.iter().enumerate() { + // Skip over endomorphism bases + acc += bases[idx * 2] * coeff; + } + stop_measure(start); + acc + } + } +} + +impl MultiExpContext { + /// Allocate memory for the evalution + pub fn allocate(&mut self, num_points: usize, c: usize) { + let num_points = num_points * 2; + let num_buckets = get_num_buckets(c); + let num_rounds = get_num_rounds::(c); + let tree_size = get_max_tree_size(num_points, c); + let num_points_total = num_rounds * num_points; + let num_buckets_total = num_rounds * num_buckets; + let tree_size_total = num_rounds * tree_size; + + // Allocate memory when necessary + if self.points.len() < tree_size { + self.points.resize(tree_size, C::identity()); + } + if self.wnafs.len() < num_points_total { + self.wnafs.resize(num_points_total, 0u32); + } + if self.rounds.bucket_sizes.len() < num_buckets_total { + self.rounds.bucket_sizes.resize(num_buckets_total, 0usize); + } + if self.rounds.bucket_offsets.len() < num_buckets_total { + self.rounds.bucket_offsets.resize(num_buckets_total, 0usize); + } + if self.rounds.point_data.len() < num_points_total { + self.rounds.point_data.resize(num_points_total, 0u32); + } + if self.rounds.output_indices.len() < tree_size_total / 2 { + self.rounds.output_indices.resize(tree_size_total / 2, 0u32); + } + if self.rounds.base_positions.len() < num_points_total { + self.rounds.base_positions.resize(num_points_total, 0u32); + } + if self.rounds.scatter_map.len() < num_buckets_total { + self.rounds + .scatter_map + .resize(num_buckets_total, ScatterData::default()); + } + } +} + +impl SharedRoundData { + fn get_rounds(&mut self, num_points: usize, c: usize) -> Vec { + let num_points = num_points * 2; + let num_buckets = get_num_buckets(c); + let num_rounds = get_num_rounds::(c); + let tree_size = num_points * 2 + num_buckets; + + let mut bucket_sizes_rest = self.bucket_sizes.as_mut_slice(); + let mut bucket_offsets_rest = self.bucket_offsets.as_mut_slice(); + let mut point_data_rest = self.point_data.as_mut_slice(); + let mut output_indices_rest = self.output_indices.as_mut_slice(); + let mut base_positions_rest = self.base_positions.as_mut_slice(); + let mut scatter_map_rest = self.scatter_map.as_mut_slice(); + + // Use the allocated memory above to init the memory used for each round. + // This way the we don't need to reallocate memory for each msm with + // a different configuration (different number of points or different bucket width) + let mut rounds: Vec = Vec::with_capacity(num_rounds); + for _ in 0..num_rounds { + let (bucket_sizes, rest) = bucket_sizes_rest.split_at_mut(num_buckets); + bucket_sizes_rest = rest; + let (bucket_offsets, rest) = bucket_offsets_rest.split_at_mut(num_buckets); + bucket_offsets_rest = rest; + let (point_data, rest) = point_data_rest.split_at_mut(num_points); + point_data_rest = rest; + let (output_indices, rest) = output_indices_rest.split_at_mut(tree_size / 2); + output_indices_rest = rest; + let (base_positions, rest) = base_positions_rest.split_at_mut(num_points); + base_positions_rest = rest; + let (scatter_map, rest) = scatter_map_rest.split_at_mut(num_buckets); + scatter_map_rest = rest; + + rounds.push(RoundData { + num_levels: 0, + level_sizes: vec![], + level_offset: vec![], + bucket_sizes, + bucket_offsets, + point_data, + output_indices, + base_positions, + scatter_map, + scatter_map_len: 0, + }); + } + rounds + } +} + +#[derive(Clone, Copy)] +struct ThreadBox(*mut T, usize); +#[allow(unsafe_code)] +unsafe impl Send for ThreadBox {} +#[allow(unsafe_code)] +unsafe impl Sync for ThreadBox {} + +/// Wraps a mutable slice so it can be passed into a thread without +/// hard to fix borrow checks caused by difficult data access patterns. +impl ThreadBox { + fn wrap(data: &mut [T]) -> Self { + Self(data.as_mut_ptr(), data.len()) + } + + fn unwrap(&mut self) -> &mut [T] { + #[allow(unsafe_code)] + unsafe { + slice::from_raw_parts_mut(self.0, self.1) + } + } +} + +fn calculate_wnafs(coeffs: &[C::Scalar], wnafs: &mut [u32], c: usize) { + let num_threads = multicore::current_num_threads(); + let num_points = coeffs.len() * 2; + let num_rounds = get_num_rounds::(c); + let w = c + 1; + + let start = start_measure("calculate wnafs".to_string(), false); + let mut wnafs_box = ThreadBox::wrap(wnafs); + let chunk_size = div_up(coeffs.len(), num_threads); + multicore::scope(|scope| { + for (thread_idx, coeffs) in coeffs.chunks(chunk_size).enumerate() { + scope.spawn(move |_| { + let wnafs = &mut wnafs_box.unwrap()[thread_idx * chunk_size * 2..]; + for (idx, coeff) in coeffs.iter().enumerate() { + let (p0, p1) = C::get_endomorphism_scalars(coeff); + get_wnaf(p0, w, num_rounds, &mut wnafs[idx * 2..], num_points); + get_wnaf(p1, w, num_rounds, &mut wnafs[idx * 2 + 1..], num_points); + } + }); + } + }); + stop_measure(start); +} + +fn radix_sort(wnafs: &mut [u32], round: &mut RoundData) { + let bucket_sizes = &mut round.bucket_sizes; + let bucket_offsets = &mut round.bucket_offsets; + + // Calculate bucket sizes, first resetting all sizes to 0 + bucket_sizes.fill_with(|| 0); + for wnaf in wnafs.iter() { + bucket_sizes[(wnaf & 0x7FFFFFFF) as usize] += 1; + } + + // Calculate bucket offsets + let mut offset = 0; + let mut max_bucket_size = 0; + bucket_offsets[0] = offset; + offset += bucket_sizes[0]; + for (bucket_offset, bucket_size) in bucket_offsets + .iter_mut() + .skip(1) + .zip(bucket_sizes.iter().skip(1)) + { + *bucket_offset = offset; + offset += bucket_size; + max_bucket_size = max_bucket_size.max(*bucket_size); + } + // Number of levels we need in our addition tree + round.num_levels = get_num_tree_levels(max_bucket_size); + + // Fill in point data grouped in buckets + let point_data = &mut round.point_data; + for (idx, wnaf) in wnafs.iter().enumerate() { + let bucket_idx = (wnaf & 0x7FFFFFFF) as usize; + point_data[bucket_offsets[bucket_idx]] = (wnaf & 0x80000000) | (idx as u32); + bucket_offsets[bucket_idx] += 1; + } +} + +/// Sorts the points so they are grouped per bucket +fn sort(wnafs: &mut [u32], rounds: &mut [RoundData], c: usize) { + let num_rounds = get_num_rounds::(c); + let num_points = wnafs.len() / num_rounds; + + // Sort per bucket for each round separately + let start = start_measure("radix sort".to_string(), false); + multicore::scope(|scope| { + for (round, wnafs) in rounds.chunks_mut(1).zip(wnafs.chunks_mut(num_points)) { + scope.spawn(move |_| { + radix_sort(wnafs, &mut round[0]); + }); + } + }); + stop_measure(start); +} + +/// Creates the addition tree. +/// When PREPROCESS is false we just calculate the size of each level. +/// All points in a bucket need to be added to each other. Because the affine formulas +/// are used we need to add points together in pairs. So we have to make sure that +/// on each level we have an even number of points for each level. Odd points are +/// added to lower levels where the length of the addition results is odd (which then +/// makes the length even). +fn process_addition_tree(round: &mut RoundData) { + let num_levels = round.num_levels; + let bucket_sizes = &round.bucket_sizes; + let point_data = &round.point_data; + + let mut level_sizes = vec![0usize; num_levels]; + let mut level_offset = vec![0usize; num_levels]; + let output_indices = &mut round.output_indices; + let scatter_map = &mut round.scatter_map; + let base_positions = &mut round.base_positions; + let mut point_idx = bucket_sizes[0]; + + if !PREPROCESS { + // Set the offsets to the different levels in the tree + level_offset[0] = 0; + for idx in 1..level_offset.len() { + level_offset[idx] = level_offset[idx - 1] + round.level_sizes[idx - 1]; + } + } + + // The level where all bucket results will be stored + let bucket_level = num_levels - 1; + + // Run over all buckets + for bucket_size in bucket_sizes.iter().skip(1) { + let mut size = *bucket_size; + if size == 0 { + level_sizes[bucket_level] += 1; + } else if size == 1 { + if !PREPROCESS { + scatter_map[round.scatter_map_len] = ScatterData { + position: (level_offset[bucket_level] + level_sizes[bucket_level]) as u32, + point_data: point_data[point_idx], + }; + round.scatter_map_len += 1; + point_idx += 1; + } + level_sizes[bucket_level] += 1; + } else { + #[derive(Clone, Copy, PartialEq)] + enum State { + Even, + OddPoint(usize), + OddResult(usize), + } + let mut state = State::Even; + let num_levels_bucket = get_num_tree_levels(size); + + let mut start_level_size = level_sizes[0]; + for level in 0..num_levels_bucket - 1 { + let is_level_odd = size & 1; + let first_level = level == 0; + let last_level = level == num_levels_bucket - 2; + + // If this level has odd size we have to handle it + if is_level_odd == 1 { + // If we already have a point saved from a previous odd level, use it + // to make the current level even + if state != State::Even { + if !PREPROCESS { + let pos = (level_offset[level] + level_sizes[level]) as u32; + match state { + State::OddPoint(point_idx) => { + scatter_map[round.scatter_map_len] = ScatterData { + position: pos, + point_data: point_data[point_idx], + }; + round.scatter_map_len += 1; + } + State::OddResult(output_idx) => { + output_indices[output_idx] = pos; + } + _ => unreachable!(), + }; + } + level_sizes[level] += 1; + size += 1; + state = State::Even; + } else { + // Not odd yet, so the state is now odd + // Store the point we have to add later + if !PREPROCESS { + if first_level { + state = State::OddPoint(point_idx + size - 1); + } else { + state = State::OddResult( + (level_offset[level] + level_sizes[level] + size) >> 1, + ); + } + } else { + // Just mark it as odd, we won't use the actual value anywhere + state = State::OddPoint(0); + } + size -= 1; + } + } + + // Write initial points on the first level + if first_level { + if !PREPROCESS { + // Just write all points (except the odd size one) + let pos = level_offset[level] + level_sizes[level]; + base_positions[pos..pos + size] + .copy_from_slice(&point_data[point_idx..point_idx + size]); + point_idx += size + is_level_odd; + } + level_sizes[level] += size; + } + + // Write output indices + // If the next level would be odd, we have to make it even + // by writing the last result of this level to the next level that is odd + // (unless we are writing the final result to the bucket level) + let next_level_size = size >> 1; + let next_level_odd = next_level_size & 1 == 1; + let redirect = + if next_level_odd && state == State::Even && level < num_levels_bucket - 2 { + 1usize + } else { + 0usize + }; + // An addition works on two points and has one result, so this takes only half the size + let sub_level_offset = (level_offset[level] + start_level_size) >> 1; + // Cache the start position of the next level + start_level_size = level_sizes[level + 1]; + if !PREPROCESS { + // Write the destination positions of the addition results in the tree + let dst_pos = level_offset[level + 1] + level_sizes[level + 1]; + for (idx, output_index) in output_indices + [sub_level_offset..sub_level_offset + next_level_size] + .iter_mut() + .enumerate() + { + *output_index = (dst_pos + idx) as u32; + } + } + if last_level { + // The result of the last addition for this bucket is written + // to the last level (so all bucket results are nicely after each other). + // Overwrite the output locations of the last result here. + if !PREPROCESS { + output_indices[sub_level_offset] = + (level_offset[bucket_level] + level_sizes[bucket_level]) as u32; + } + level_sizes[bucket_level] += 1; + } else { + // Update the sizes + level_sizes[level + 1] += next_level_size - redirect; + size -= redirect; + // We have to redirect the last result to a lower level + if redirect == 1 { + state = State::OddResult(sub_level_offset + next_level_size - 1); + } + } + + // We added pairs of points together so the next level has half the size + size >>= 1; + } + } + } + + // Store the tree level data + round.level_sizes = level_sizes; + round.level_offset = level_offset; +} + +/// The affine formula is only efficient for independent point additions +/// (using the result of the addition requires an inversion which needs to be avoided as much as possible). +/// And so we try to add as many points together on each level of the tree, writing the result of the addition +/// to a lower level. Each level thus contains independent point additions, with only requiring a single inversion +/// per level in the tree. +fn create_addition_trees(rounds: &mut [RoundData]) { + let start = start_measure("create addition trees".to_string(), false); + multicore::scope(|scope| { + for round in rounds.chunks_mut(1) { + scope.spawn(move |_| { + // Collect tree levels sizes + process_addition_tree::(&mut round[0]); + // Construct the tree + process_addition_tree::(&mut round[0]); + }); + } + }); + stop_measure(start); +} + +/// Here we write the odd points in odd length buckets (the other points are loaded on the fly). +/// This will do random reads AND random writes, which is normally terrible for performance. +/// Luckily this doesn't really matter because we only have to write at most num_buckets points. +fn do_point_scatter(round: &RoundData, bases: &[C], points: &mut [C]) { + let num_threads = multicore::current_num_threads(); + let scatter_map = &round.scatter_map[..round.scatter_map_len]; + let mut points_box = ThreadBox::wrap(points); + let start = start_measure("point scatter".to_string(), false); + if !scatter_map.is_empty() { + multicore::scope(|scope| { + let num_copies_per_thread = div_up(scatter_map.len(), num_threads); + for scatter_map in scatter_map.chunks(num_copies_per_thread) { + scope.spawn(move |_| { + let points = points_box.unwrap(); + for scatter_data in scatter_map.iter() { + let target_idx = scatter_data.position as usize; + let negate = scatter_data.point_data & 0x80000000 != 0; + let base_idx = (scatter_data.point_data & 0x7FFFFFFF) as usize; + if negate { + points[target_idx] = bases[base_idx].neg(); + } else { + points[target_idx] = bases[base_idx]; + } + } + }); + } + }); + } + stop_measure(start); +} + +/// Finally do all additions using the addition tree we've setup. +fn do_batch_additions( + round: &RoundData, + bases: &[C], + points: &mut [C], + complete: bool, +) { + let num_threads = multicore::current_num_threads(); + + let num_levels = round.num_levels; + let level_counter = &round.level_sizes; + let level_offset = &round.level_offset; + let output_indices = &round.output_indices; + let base_positions = &round.base_positions; + let mut points_box = ThreadBox::wrap(points); + + let start = start_measure("batch additions".to_string(), false); + for i in 0..num_levels - 1 { + let start = level_offset[i]; + let num_points = level_counter[i]; + multicore::scope(|scope| { + // We have to make sure we have an even amount here so we don't split within a pair + let num_points_per_thread = div_up(num_points / 2, num_threads) * 2; + for thread_idx in 0..num_threads { + scope.spawn(move |_| { + let points = points_box.unwrap(); + + let thread_start = thread_idx * num_points_per_thread; + let mut thread_num_points = num_points_per_thread; + + if thread_start < num_points { + if thread_start + thread_num_points > num_points { + thread_num_points = num_points - thread_start; + } + + let points = &mut points[(start + thread_start)..]; + let output_indices = &output_indices[(start + thread_start) / 2..]; + let offset = start + thread_start; + if i == 0 { + let base_positions = &base_positions[(start + thread_start)..]; + if complete { + C::batch_add::( + points, + output_indices, + thread_num_points, + offset, + bases, + base_positions, + ); + } else { + C::batch_add::( + points, + output_indices, + thread_num_points, + offset, + bases, + base_positions, + ); + } + } else { + #[allow(collapsible-else-if)] + if complete { + C::batch_add::( + points, + output_indices, + thread_num_points, + offset, + &[], + &[], + ); + } else { + C::batch_add::( + points, + output_indices, + thread_num_points, + offset, + &[], + &[], + ); + } + } + } + }); + } + }); + } + stop_measure(start); +} + +/// Accumulate all bucket results to get the result of the round +fn accumulate_buckets(round: &RoundData, points: &mut [C], c: usize) -> C::Curve { + let num_threads = multicore::current_num_threads(); + let num_buckets = get_num_buckets(c); + + let num_levels = round.num_levels; + let bucket_sizes = &round.bucket_sizes; + let level_offset = &round.level_offset; + + let start_time = start_measure("accumulate buckets".to_string(), false); + let start = level_offset[num_levels - 1]; + let buckets = &mut points[start..(start + num_buckets)]; + let mut results: Vec = vec![C::Curve::identity(); num_threads]; + multicore::scope(|scope| { + let chunk_size = num_buckets / num_threads; + for (thread_idx, ((bucket_sizes, buckets), result)) in bucket_sizes[1..] + .chunks(chunk_size) + .zip(buckets[..].chunks_mut(chunk_size)) + .zip(results.chunks_mut(1)) + .enumerate() + { + scope.spawn(move |_| { + // Accumulate all bucket results + let num_buckets_thread = bucket_sizes.len(); + let mut acc = C::Curve::identity(); + let mut running_sum = C::Curve::identity(); + for b in (0..num_buckets_thread).rev() { + if bucket_sizes[b] > 0 { + running_sum = running_sum + buckets[b]; + } + acc = acc + &running_sum; + } + + // Each thread started at a different bucket location + // so correct for that here + let bucket_start = thread_idx * chunk_size; + let num_bits = num_bits(bucket_start); + let mut accumulator = C::Curve::identity(); + for idx in (0..num_bits).rev() { + accumulator = accumulator.double(); + if (bucket_start >> idx) & 1 != 0 { + accumulator += running_sum; + } + } + acc += accumulator; + + // Store the result + result[0] = acc; + }); + } + }); + stop_measure(start_time); + + // Add the results of all threads together + results + .iter() + .fold(C::Curve::identity(), |acc, result| acc + result) +} + +use crate::{ + arithmetic::{best_multiexp, parallelize}, + env_value, start_measure, stop_measure, +}; +#[cfg(test)] +use pairing::bn256::Fr as Fp; +use pairing::bn256::{self as bn256, Fq, G1Affine, G1}; + +#[cfg(test)] +fn get_random_data(n: usize) -> (Vec, Vec) { + let mut bases = vec![pairing::bn256::G1Affine::identity(); n]; + parallelize(&mut bases, |bases, _| { + let mut rng = rand::thread_rng(); + let base_rnd = pairing::bn256::G1Affine::random(&mut rng); + for base in bases { + if INDEPENDENT { + *base = pairing::bn256::G1Affine::random(&mut rng); + } else { + *base = base_rnd; + } + } + }); + + let mut coeffs = vec![Fp::zero(); n]; + parallelize(&mut coeffs, |coeffs, _| { + for coeff in coeffs { + *coeff = Fp::rand(); + } + }); + + (bases, coeffs) +} + +#[test] +fn test_multiexp_simple() { + let n = 1 << env_value("K", 15); + + let (bases, coeffs) = get_random_data::(n); + + let res_base = best_multiexp(&coeffs, &bases); + let res_base_affine: pairing::bn256::G1Affine = res_base.into(); + + let msm = MultiExp::new(&bases); + let mut ctx = MultiExpContext::default(); + let res = msm.evaluate(&mut ctx, &coeffs, false); + let res_affine: pairing::bn256::G1Affine = res.into(); + + assert_eq!(res_base_affine, res_affine); +} + +#[test] +fn test_multiexp_complete_simple() { + let n = 1 << env_value("K", 15); + + let (bases, coeffs) = get_random_data::(n); + + let res_base = best_multiexp(&coeffs, &bases); + let res_base_affine: pairing::bn256::G1Affine = res_base.into(); + + let msm = MultiExp::new(&bases); + let mut ctx = MultiExpContext::default(); + let res = msm.evaluate(&mut ctx, &coeffs, true); + let res_affine: pairing::bn256::G1Affine = res.into(); + + assert_eq!(res_base_affine, res_affine); +} + +#[test] +fn test_multiexp_small() { + let n = 5; + + let (bases, coeffs) = get_random_data::(n); + + let res_base = best_multiexp(&coeffs, &bases); + let res_base_affine: pairing::bn256::G1Affine = res_base.into(); + + let msm = MultiExp::new(&bases); + let mut ctx = MultiExpContext::default(); + let res = msm.evaluate(&mut ctx, &coeffs, false); + let res_affine: pairing::bn256::G1Affine = res.into(); + + assert_eq!(res_base_affine, res_affine); +} + +#[test] +#[ignore] +fn test_multiexp_bench() { + let min_k = 10; + let max_k = 20; + let n = 1 << max_k; + let (bases, coeffs) = get_random_data::(n); + let msm = MultiExp::new(&bases); + let mut ctx = MultiExpContext::default(); + for k in min_k..=max_k { + let n = 1 << k; + let coeffs = &coeffs[..n]; + + let start = start_measure("msm".to_string(), false); + msm.evaluate(&mut ctx, coeffs, false); + let duration = stop_measure(start); + + println!("{}: {}s", n, (duration as f32) / 1000000.0); + } +} + +#[test] +#[ignore] +fn test_multiexp_best_c() { + let max_k = 21; + let n = 1 << max_k; + + let (bases, coeffs) = get_random_data::(n); + + let msm = MultiExp::new(&bases); + let mut ctx = MultiExpContext::default(); + for k in 4..=max_k { + let n = 1 << k; + let coeffs = &coeffs[..n]; + let bases = &bases[..n]; + + let res_base = best_multiexp(coeffs, bases); + let res_base_affine: pairing::bn256::G1Affine = res_base.into(); + + let mut best_c = 0; + let mut best_duration = usize::MAX; + for c in 4..=21 { + // Allocate memory so it doesn't impact performance + ctx.allocate(n, c); + + let start = start_measure("measure performance".to_string(), false); + let res = msm.evaluate_with(&mut ctx, coeffs, false, c); + let duration = stop_measure(start); + + if duration < best_duration { + best_duration = duration; + best_c = c; + } + + let res_affine: pairing::bn256::G1Affine = res.into(); + assert_eq!(res_base_affine, res_affine); + } + println!("{}: {}", n, best_c); + } +} + +#[test] +fn test_endomorphism() { + let rng = &mut rand::thread_rng(); + + let scalar = Fp::rand(); + let point = bn256::G1Affine::random(rng); + + let expected = point * scalar; + let (part1, part2) = bn256::G1Affine::get_endomorphism_scalars(&scalar); + + let k1 = Fp::from_u128(part1); + let k2 = Fp::from_u128(part2); + + let t1 = point * k1; + let base = bn256::G1Affine::get_endomorphism_base(&point); + + let t2 = base * k2; + let result = t1 + t2; + + let res_affine: pairing::bn256::G1Affine = result.into(); + let exp_affine: pairing::bn256::G1Affine = expected.into(); + + assert_eq!(res_affine, exp_affine); +} diff --git a/halo2_proofs/src/lib.rs b/halo2_proofs/src/lib.rs index 2623a93833..2cedd2dd20 100644 --- a/halo2_proofs/src/lib.rs +++ b/halo2_proofs/src/lib.rs @@ -25,7 +25,14 @@ #![allow(unused_imports)] pub mod arithmetic; +pub mod arithmetic_msm; pub mod circuit; +use std::{ + env::var, + sync::atomic::{AtomicUsize, Ordering}, + time::Instant, +}; + pub use pairing; mod multicore; pub mod plonk; @@ -34,3 +41,69 @@ pub mod transcript; pub mod dev; mod helpers; + +/// Temp +#[allow(missing_debug_implementations)] +pub struct MeasurementInfo { + /// Show measurement + pub show: bool, + /// The start time + pub time: Instant, + /// What is being measured + pub message: String, + /// The indent + pub indent: usize, +} + +/// Global indent counter +pub static NUM_INDENT: AtomicUsize = AtomicUsize::new(0); + +/// Gets the time difference between the current time and the passed in time +pub fn get_duration(start: Instant) -> usize { + let final_time = Instant::now() - start; + let secs = final_time.as_secs() as usize; + let millis = final_time.subsec_millis() as usize; + let micros = (final_time.subsec_micros() % 1000) as usize; + secs * 1000000 + millis * 1000 + micros +} + +/// Prints a measurement on screen +pub fn log_measurement(indent: Option, msg: String, duration: usize) { + let indent = indent.unwrap_or(0); + println!( + "{}{} ........ {}s", + "*".repeat(indent), + msg, + (duration as f32) / 1000000.0 + ); +} + +/// Starts a measurement +pub fn start_measure(msg: String, always: bool) -> MeasurementInfo { + let measure = env_value("MEASURE", 0); + let indent = NUM_INDENT.fetch_add(1, Ordering::Relaxed); + MeasurementInfo { + show: always || measure == 1, + time: Instant::now(), + message: msg, + indent, + } +} + +/// Stops a measurement, returns the duration +pub fn stop_measure(info: MeasurementInfo) -> usize { + NUM_INDENT.fetch_sub(1, Ordering::Relaxed); + let duration = get_duration(info.time); + if info.show { + log_measurement(Some(info.indent), info.message, duration); + } + duration +} + +/// Gets the ENV variable if defined, otherwise returns the default value +pub fn env_value(key: &str, default: usize) -> usize { + match var(key) { + Ok(val) => val.parse().unwrap(), + Err(_) => default, + } +} diff --git a/halo2_proofs/src/plonk/keygen.rs b/halo2_proofs/src/plonk/keygen.rs index 1709dc5a31..0706e8b9d0 100644 --- a/halo2_proofs/src/plonk/keygen.rs +++ b/halo2_proofs/src/plonk/keygen.rs @@ -182,7 +182,7 @@ impl Assignment for Assembly { /// Generate a `VerifyingKey` from an instance of `Circuit`. pub fn keygen_vk( - params: &Params, + params: &mut Params, circuit: &ConcreteCircuit, ) -> Result, Error> where diff --git a/halo2_proofs/src/plonk/lookup/prover.rs b/halo2_proofs/src/plonk/lookup/prover.rs index d995c1ba33..947dbf7c4b 100644 --- a/halo2_proofs/src/plonk/lookup/prover.rs +++ b/halo2_proofs/src/plonk/lookup/prover.rs @@ -76,7 +76,7 @@ impl Argument { >( &self, pk: &ProvingKey, - params: &Params, + params: &mut Params, domain: &EvaluationDomain, value_evaluator: &poly::Evaluator, coset_evaluator: &mut poly::Evaluator, @@ -178,7 +178,7 @@ impl Argument { )?; // Closure to construct commitment to vector of values - let commit_values = |values: &Polynomial| { + let mut commit_values = |values: &Polynomial| { let poly = pk.vk.domain.lagrange_to_coeff(values.clone()); let commitment = params.commit_lagrange(values).to_affine(); (poly, commitment) @@ -231,7 +231,7 @@ impl Permuted { >( self, pk: &ProvingKey, - params: &Params, + params: &mut Params, theta: ChallengeTheta, beta: ChallengeBeta, gamma: ChallengeGamma, diff --git a/halo2_proofs/src/plonk/permutation/keygen.rs b/halo2_proofs/src/plonk/permutation/keygen.rs index 7c57b74fd8..e790bfdf50 100644 --- a/halo2_proofs/src/plonk/permutation/keygen.rs +++ b/halo2_proofs/src/plonk/permutation/keygen.rs @@ -99,7 +99,7 @@ impl Assembly { pub(crate) fn build_vk( self, - params: &Params, + params: &mut Params, domain: &EvaluationDomain, p: &Argument, ) -> VerifyingKey { diff --git a/halo2_proofs/src/plonk/permutation/prover.rs b/halo2_proofs/src/plonk/permutation/prover.rs index c2eb883559..09808a4817 100644 --- a/halo2_proofs/src/plonk/permutation/prover.rs +++ b/halo2_proofs/src/plonk/permutation/prover.rs @@ -49,7 +49,7 @@ impl Argument { T: TranscriptWrite, >( &self, - params: &Params, + params: &mut Params, pk: &plonk::ProvingKey, pkey: &ProvingKey, advice: &[Polynomial], diff --git a/halo2_proofs/src/plonk/prover.rs b/halo2_proofs/src/plonk/prover.rs index cc5c476dc9..b4a9218323 100644 --- a/halo2_proofs/src/plonk/prover.rs +++ b/halo2_proofs/src/plonk/prover.rs @@ -38,7 +38,7 @@ pub fn create_proof< T: TranscriptWrite, ConcreteCircuit: Circuit, >( - params: &Params, + params: &mut Params, pk: &ProvingKey, circuits: &[ConcreteCircuit], instances: &[&[&[C::Scalar]]], diff --git a/halo2_proofs/src/plonk/vanishing/prover.rs b/halo2_proofs/src/plonk/vanishing/prover.rs index a1d45cb432..cdf3ad438b 100644 --- a/halo2_proofs/src/plonk/vanishing/prover.rs +++ b/halo2_proofs/src/plonk/vanishing/prover.rs @@ -34,7 +34,7 @@ pub(in crate::plonk) struct Evaluated { impl Argument { pub(in crate::plonk) fn commit, R: RngCore, T: TranscriptWrite>( - params: &Params, + params: &mut Params, domain: &EvaluationDomain, mut rng: R, transcript: &mut T, @@ -60,7 +60,7 @@ impl Committed { T: TranscriptWrite, >( self, - params: &Params, + params: &mut Params, domain: &EvaluationDomain, evaluator: poly::Evaluator, expressions: impl Iterator>, diff --git a/halo2_proofs/src/poly/commitment.rs b/halo2_proofs/src/poly/commitment.rs index 50361c35e2..f82e2a39ad 100644 --- a/halo2_proofs/src/poly/commitment.rs +++ b/halo2_proofs/src/poly/commitment.rs @@ -7,6 +7,7 @@ use super::{Coeff, LagrangeCoeff, Polynomial, MSM}; use crate::arithmetic::{ best_fft, best_multiexp, parallelize, CurveAffine, CurveExt, Engine, FieldExt, Group, }; +use crate::arithmetic_msm::{MultiExp, MultiExpContext}; use crate::helpers::CurveRead; use ff::{Field, PrimeField}; @@ -18,13 +19,16 @@ use std::ops::{Add, AddAssign, Mul, MulAssign}; use std::io; /// These are the prover parameters for the polynomial commitment scheme. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Params { pub(crate) k: u32, pub(crate) n: u64, pub(crate) g: Vec, pub(crate) g_lagrange: Vec, pub(crate) additional_data: Vec, + pub(crate) msm_ctx: MultiExpContext, + pub(crate) msm_g: MultiExp, + pub(crate) msm_g_lagrange: MultiExp, } /// These are the verifier parameters for the polynomial commitment scheme. @@ -107,37 +111,36 @@ impl Params { let g2 = ::generator(); let s_g2 = g2 * s; let additional_data = Vec::from(s_g2.to_bytes().as_ref()); + + let msm_ctx = MultiExpContext::default(); + let msm_g = MultiExp::new(&g); + let msm_g_lagrange = MultiExp::new(&g_lagrange); + Params { k, n, g, g_lagrange, additional_data, + msm_ctx, + msm_g, + msm_g_lagrange, } } /// This computes a commitment to a polynomial described by the provided /// slice of coefficients. The commitment will be blinded by the blinding /// factor `r`. - pub fn commit(&self, poly: &Polynomial) -> C::Curve { - let mut scalars = Vec::with_capacity(poly.len()); - scalars.extend(poly.iter()); - let bases = &self.g; - let size = scalars.len(); - assert!(bases.len() >= size); - best_multiexp(&scalars, &bases[0..size]) + pub fn commit(&mut self, poly: &Polynomial) -> C::Curve { + self.msm_g.evaluate(&mut self.msm_ctx, &poly.values, false) } /// This commits to a polynomial using its evaluations over the $2^k$ size /// evaluation domain. The commitment will be blinded by the blinding factor /// `r`. - pub fn commit_lagrange(&self, poly: &Polynomial) -> C::Curve { - let mut scalars = Vec::with_capacity(poly.len()); - scalars.extend(poly.iter()); - let bases = &self.g_lagrange; - let size = scalars.len(); - assert!(bases.len() >= size); - best_multiexp(&scalars, &bases[0..size]) + pub fn commit_lagrange(&mut self, poly: &Polynomial) -> C::Curve { + self.msm_g_lagrange + .evaluate(&mut self.msm_ctx, &poly.values, false) } /// Generates an empty multiscalar multiplication struct using the @@ -187,12 +190,19 @@ impl Params { reader.read_exact(&mut additional_data[..])?; + let msm_ctx = MultiExpContext::default(); + let msm_g = MultiExp::new(&g); + let msm_g_lagrange = MultiExp::new(&g_lagrange); + Ok(Params { k, n, g, g_lagrange, additional_data, + msm_ctx, + msm_g, + msm_g_lagrange, }) } @@ -384,7 +394,7 @@ fn test_parameter_serialization() { fn test_commit_lagrange() { const K: u32 = 6; - let params: Params = Params::::unsafe_setup::(K); + let mut params: Params = Params::::unsafe_setup::(K); let domain = super::EvaluationDomain::new(1, K); let mut a = domain.empty_lagrange(); diff --git a/halo2_proofs/src/poly/multiopen.rs b/halo2_proofs/src/poly/multiopen.rs index 982937b5ee..7da6e03751 100644 --- a/halo2_proofs/src/poly/multiopen.rs +++ b/halo2_proofs/src/poly/multiopen.rs @@ -181,7 +181,7 @@ mod tests { const K: u32 = 4; - let params: Params = Params::::unsafe_setup::(K); + let mut params: Params = Params::::unsafe_setup::(K); let params_verifier: ParamsVerifier = params.verifier(0).unwrap(); let domain = EvaluationDomain::new(1, K); @@ -216,7 +216,7 @@ mod tests { let mut transcript = crate::transcript::Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof( - ¶ms, + &mut params, &mut transcript, std::iter::empty() .chain(Some(ProverQuery { @@ -282,7 +282,7 @@ mod tests { fn test_multiopen() { const K: u32 = 3; - let params = Params::::unsafe_setup::(K); + let mut params = Params::::unsafe_setup::(K); let params_verifier: ParamsVerifier = params.verifier(0).unwrap(); let rotation_sets_init = vec![ @@ -358,7 +358,7 @@ mod tests { // prover let proof = { let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); - create_proof(¶ms, &mut transcript, prover_queries).unwrap(); + create_proof(&mut params, &mut transcript, prover_queries).unwrap(); transcript.finalize() }; diff --git a/halo2_proofs/src/poly/multiopen/shplonk/prover.rs b/halo2_proofs/src/poly/multiopen/shplonk/prover.rs index 657790f38b..d4f6cc66d7 100644 --- a/halo2_proofs/src/poly/multiopen/shplonk/prover.rs +++ b/halo2_proofs/src/poly/multiopen/shplonk/prover.rs @@ -74,15 +74,16 @@ impl<'a, C: CurveAffine> RotationSet> { /// Create a multi-opening proof pub fn create_proof<'a, I, C: CurveAffine, E: EncodedChallenge, T: TranscriptWrite>( - params: &Params, + params: &mut Params, transcript: &mut T, queries: I, ) -> io::Result<()> where I: IntoIterator> + Clone, { + let n = params.n; let zero = || Polynomial:: { - values: vec![C::Scalar::zero(); params.n as usize], + values: vec![C::Scalar::zero(); n as usize], _marker: PhantomData, }; @@ -113,7 +114,7 @@ where // Q_i(X) = N_i(X) / Z_i(X) where // Z_i(X) = (x - r_i_0) * (x - r_i_1) * ... let mut poly = div_by_vanishing(n_x, points); - poly.resize(params.n as usize, C::Scalar::zero()); + poly.resize(n as usize, C::Scalar::zero()); Polynomial { values: poly, @@ -133,9 +134,7 @@ where let commitments: Vec> = rotation_set .commitments .iter() - .map(|commitment_data| { - commitment_data.extend(params.n, rotation_set.points.clone()) - }) + .map(|commitment_data| commitment_data.extend(n, rotation_set.points.clone())) .collect(); rotation_set.extend(commitments) }) diff --git a/halo2_proofs/tests/plonk_api.rs b/halo2_proofs/tests/plonk_api.rs index 534005c9fa..d702ec5b15 100644 --- a/halo2_proofs/tests/plonk_api.rs +++ b/halo2_proofs/tests/plonk_api.rs @@ -31,7 +31,7 @@ fn plonk_api() { pub struct Variable(Column, usize); // Initialize the polynomial commitment parameters - let params: Params = Params::::unsafe_setup::(K); + let mut params: Params = Params::::unsafe_setup::(K); let params_verifier: ParamsVerifier = params.verifier(public_inputs_size).unwrap(); #[derive(Clone)] @@ -406,9 +406,9 @@ fn plonk_api() { // Check that we get an error if we try to initialize the proving key with a value of // k that is too small for the minimum required number of rows. - let much_too_small_params: Params = Params::::unsafe_setup::(1); + let mut much_too_small_params: Params = Params::::unsafe_setup::(1); assert_matches!( - keygen_vk(&much_too_small_params, &empty_circuit), + keygen_vk(&mut much_too_small_params, &empty_circuit), Err(Error::NotEnoughRowsAvailable { current_k, }) if current_k == 1 @@ -416,17 +416,17 @@ fn plonk_api() { // Check that we get an error if we try to initialize the proving key with a value of // k that is too small for the number of rows the circuit uses. - let slightly_too_small_params: Params = + let mut slightly_too_small_params: Params = Params::::unsafe_setup::(K - 1); assert_matches!( - keygen_vk(&slightly_too_small_params, &empty_circuit), + keygen_vk(&mut slightly_too_small_params, &empty_circuit), Err(Error::NotEnoughRowsAvailable { current_k, }) if current_k == K - 1 ); // Initialize the proving key - let vk = keygen_vk(¶ms, &empty_circuit).expect("keygen_vk should not fail"); + let vk = keygen_vk(&mut params, &empty_circuit).expect("keygen_vk should not fail"); let pk = keygen_pk(¶ms, vk, &empty_circuit).expect("keygen_pk should not fail"); let pubinputs = vec![instance]; @@ -442,7 +442,7 @@ fn plonk_api() { let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); // Create a proof create_proof( - ¶ms, + &mut params, &pk, &[circuit.clone(), circuit.clone()], &[&[&[instance]], &[&[instance]]],