Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use halo2curves cycloneMSM #36

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions halo2_proofs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "halo2-axiom"
version = "0.4.4"
version = "0.5.0-rc.1"
authors = [
"Sean Bowe <[email protected]>",
"Ying Tong Lai <[email protected]>",
Expand Down Expand Up @@ -32,10 +32,6 @@ autoexamples = false
all-features = true
rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"]

[[bench]]
name = "arithmetic"
harness = false

[[bench]]
name = "commit_zk"
harness = false
Expand Down Expand Up @@ -63,7 +59,11 @@ crossbeam = "0.8"
ff = "0.13"
group = "0.13"
pairing = "0.23"
halo2curves = { package = "halo2curves-axiom", version = "0.5.0", default-features = false, features = ["bits", "bn256-table", "derive_serde"] }
halo2curves = { package = "halo2curves-axiom", version = "0.7.0", default-features = false, features = [
"bits",
"bn256-table",
"derive_serde",
] }
rand = "0.8"
rand_core = { version = "0.6", default-features = false }
tracing = "0.1"
Expand Down Expand Up @@ -96,7 +96,12 @@ getrandom = { version = "0.2", features = ["js"] }
default = ["batch", "multicore", "circuit-params"]
multicore = ["maybe-rayon/threads"]
dev-graph = ["plotters", "tabbycat"]
test-dev-graph = ["dev-graph", "plotters/bitmap_backend", "plotters/bitmap_encoder", "plotters/ttf"]
test-dev-graph = [
"dev-graph",
"plotters/bitmap_backend",
"plotters/bitmap_encoder",
"plotters/ttf",
]
gadget-traces = ["backtrace"]
# thread-safe-region = []
sanity-checks = []
Expand Down
39 changes: 0 additions & 39 deletions halo2_proofs/benches/arithmetic.rs

This file was deleted.

185 changes: 4 additions & 181 deletions halo2_proofs/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
//! This module provides common utilities, traits and structures for group,
//! field and polynomial arithmetic.

use std::cmp;

use super::multicore;
pub use ff::Field;
use group::{
ff::{BatchInvert, PrimeField},
prime::PrimeCurveAffine,
Curve, Group, GroupOpsOwned, ScalarMulOwned,
Curve, GroupOpsOwned, ScalarMulOwned,
};

use halo2curves::msm::msm_best;
pub use halo2curves::{CurveAffine, CurveExt};

/// This represents an element of a group with basic operations that can be
Expand All @@ -28,190 +27,14 @@ where
{
}

// ASSUMES C::Scalar::Repr is little endian
fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

let c = if bases.len() < 4 {
1
} else if bases.len() < 32 {
3
} else {
(f64::from(bases.len() as u32)).ln().ceil() as usize
};

// Group `bytes` into bits and take the `segment`th chunk of `c` bits
fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
let skip_bits = segment * c;
let skip_bytes = skip_bits / 8;

if skip_bytes >= 32 {
return 0;
}

let mut v = [0; 8];
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
*v = *o;
}

let mut tmp = u64::from_le_bytes(v);
tmp >>= skip_bits - (skip_bytes * 8);
tmp %= 1 << c;

tmp as usize
}

let segments = (C::Scalar::NUM_BITS as usize + c - 1) / c;

// this can be optimized
let mut coeffs_in_segments = Vec::with_capacity(segments);
// track what is the last segment where we actually have nonzero bits, so we completely skip buckets where the scalar bits for all coeffs are 0
let mut max_nonzero_segment = None;
for current_segment in 0..segments {
let coeff_segments: Vec<_> = coeffs
.iter()
.map(|coeff| {
let c_bits = get_at::<C::Scalar>(current_segment, c, coeff);
if c_bits != 0 {
max_nonzero_segment = Some(current_segment);
}
c_bits
})
.collect();
coeffs_in_segments.push(coeff_segments);
}

if max_nonzero_segment.is_none() {
return;
}
for coeffs_seg in coeffs_in_segments
.into_iter()
.take(max_nonzero_segment.unwrap() + 1)
.rev()
{
for _ in 0..c {
*acc = acc.double();
}

#[derive(Clone, Copy)]
enum Bucket<C: CurveAffine> {
None,
Affine(C),
Projective(C::Curve),
}

impl<C: CurveAffine> Bucket<C> {
fn add_assign(&mut self, other: &C) {
*self = match *self {
Bucket::None => Bucket::Affine(*other),
Bucket::Affine(a) => Bucket::Projective(a + *other),
Bucket::Projective(mut a) => {
a += *other;
Bucket::Projective(a)
}
}
}

fn add(self, mut other: C::Curve) -> C::Curve {
match self {
Bucket::None => other,
Bucket::Affine(a) => {
other += a;
other
}
Bucket::Projective(a) => other + &a,
}
}
}

let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];

let mut max_bits = 0;
for (coeff, base) in coeffs_seg.into_iter().zip(bases.iter()) {
if coeff != 0 {
max_bits = cmp::max(max_bits, coeff);
buckets[coeff - 1].add_assign(base);
}
}

// Summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut running_sum = C::Curve::identity();
for exp in buckets.into_iter().take(max_bits).rev() {
running_sum = exp.add(running_sum);
*acc += &running_sum;
}
}
}

/// Performs a small multi-exponentiation operation.
/// Uses the double-and-add algorithm with doublings shared across points.
pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
let mut acc = C::Curve::identity();

// for byte idx
for byte_idx in (0..32).rev() {
// for bit idx
for bit_idx in (0..8).rev() {
acc = acc.double();
// for each coeff
for coeff_idx in 0..coeffs.len() {
let byte = coeffs[coeff_idx].as_ref()[byte_idx];
if ((byte >> bit_idx) & 1) != 0 {
acc += bases[coeff_idx];
}
}
}
}

acc
}

// [JPW] Keep this adapter to halo2curves to minimize code changes.
/// Performs a multi-exponentiation operation.
///
/// This function will panic if coeffs and bases have a different length.
///
/// This will use multithreading if beneficial.
pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

//println!("msm: {}", coeffs.len());

// let start = get_time();
let num_threads = multicore::current_num_threads();
let res = if coeffs.len() > num_threads {
let chunk = coeffs.len() / num_threads;
let num_chunks = coeffs.chunks(chunk).len();
let mut results = vec![C::Curve::identity(); num_chunks];
multicore::scope(|scope| {
let chunk = coeffs.len() / num_threads;

for ((coeffs, bases), acc) in coeffs
.chunks(chunk)
.zip(bases.chunks(chunk))
.zip(results.iter_mut())
{
scope.spawn(move |_| {
multiexp_serial(coeffs, bases, acc);
});
}
});
results.iter().fold(C::Curve::identity(), |a, b| a + b)
} else {
let mut acc = C::Curve::identity();
multiexp_serial(coeffs, bases, &mut acc);
acc
};

// let duration = get_duration(start);
#[allow(unsafe_code)]
// unsafe {
// MULTIEXP_TOTAL_TIME += duration;
// }
res
msm_best(coeffs, bases)
}

/// Dispatcher
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
nightly-2023-08-11
nightly-2024-07-25
Loading