From 723c9762fbb09bf3e82ca9fcb4ed0adc12d11b5a Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 21 Aug 2023 11:14:00 -0600 Subject: [PATCH 01/13] Add field conversion to/from `[u64;4]` (#80) * feat: add field conversion to/from `[u64;4]` * Added conversion tests * Added `montgomery_reduce_short` for no-asm * For bn256, uses assembly conversion when asm feature is on * fix: remove conflict for asm * chore: bump rust-toolchain to 1.67.0 --- rust-toolchain | 2 +- src/bn256/assembly.rs | 8 +++++++ src/bn256/fq.rs | 25 ++++++++++------------ src/bn256/fr.rs | 25 ++++++++++------------ src/derive/field.rs | 49 +++++++++++++++++++++++++++++++++++++++++++ src/secp256k1/fp.rs | 20 ++++++++++-------- src/secp256k1/fq.rs | 20 ++++++++++-------- src/secp256r1/fp.rs | 28 +++++++++++++------------ src/secp256r1/fq.rs | 20 ++++++++++-------- src/tests/field.rs | 16 ++++++++++++++ 10 files changed, 144 insertions(+), 69 deletions(-) diff --git a/rust-toolchain b/rust-toolchain index 7cc6ef41..77c582d8 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.63.0 \ No newline at end of file +1.67.0 \ No newline at end of file diff --git a/src/bn256/assembly.rs b/src/bn256/assembly.rs index a466005a..a8c1984d 100644 --- a/src/bn256/assembly.rs +++ b/src/bn256/assembly.rs @@ -608,6 +608,14 @@ macro_rules! field_arithmetic_asm { $field([r0, r1, r2, r3]) } } + + impl From<$field> for [u64; 4] { + fn from(elt: $field) -> [u64; 4] { + // Turn into canonical form by computing + // (a.R) / R = a + elt.montgomery_reduce_256().0 + } + } }; } diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index b8e1383a..0e2fc10f 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -3,7 +3,7 @@ use crate::bn256::assembly::field_arithmetic_asm; #[cfg(not(feature = "asm"))] use crate::{field_arithmetic, field_specific}; -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::bn256::LegendreSymbol; use crate::ff::{Field, FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ @@ -271,20 +271,12 @@ impl ff::PrimeField for Fq { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - - #[cfg(not(feature = "asm"))] - let tmp = - Self::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - #[cfg(feature = "asm")] - let tmp = self.montgomery_reduce_256(); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -384,6 +376,11 @@ mod test { crate::tests::field::random_field_tests::("fq".to_string()); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("fq".to_string()); + } + #[test] #[cfg(feature = "bits")] fn test_bits() { diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index 890c12e8..cd422d4b 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -18,7 +18,7 @@ pub use table::FR_TABLE; #[cfg(not(feature = "bn256-table"))] use crate::impl_from_u64; -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_bits, field_common, impl_add_binop_specify_output, impl_binops_additive, @@ -300,20 +300,12 @@ impl ff::PrimeField for Fr { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - - #[cfg(not(feature = "asm"))] - let tmp = - Self::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - #[cfg(feature = "asm")] - let tmp = self.montgomery_reduce_256(); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -406,6 +398,11 @@ mod test { ); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("fr".to_string()); + } + #[test] #[cfg(feature = "bits")] fn test_bits() { diff --git a/src/derive/field.rs b/src/derive/field.rs index 0a88556a..945ee981 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -267,6 +267,12 @@ macro_rules! field_common { } } + impl From<[u64; 4]> for $field { + fn from(digits: [u64; 4]) -> Self { + Self::from_raw(digits) + } + } + impl From<$field> for [u8; 32] { fn from(value: $field) -> [u8; 32] { value.to_repr() @@ -442,6 +448,49 @@ macro_rules! field_arithmetic { $field([d0 & mask, d1 & mask, d2 & mask, d3 & mask]) } + + /// Montgomery reduce where last 4 registers are 0 + #[inline(always)] + pub(crate) const fn montgomery_reduce_short(r: &[u64; 4]) -> $field { + // The Montgomery reduction here is based on Algorithm 14.32 in + // Handbook of Applied Cryptography + // . + + let k = r[0].wrapping_mul($inv); + let (_, r0) = macx(r[0], k, $modulus.0[0]); + let (r1, r0) = mac(r[1], k, $modulus.0[1], r0); + let (r2, r0) = mac(r[2], k, $modulus.0[2], r0); + let (r3, r0) = mac(r[3], k, $modulus.0[3], r0); + + let k = r1.wrapping_mul($inv); + let (_, r1) = macx(r1, k, $modulus.0[0]); + let (r2, r1) = mac(r2, k, $modulus.0[1], r1); + let (r3, r1) = mac(r3, k, $modulus.0[2], r1); + let (r0, r1) = mac(r0, k, $modulus.0[3], r1); + + let k = r2.wrapping_mul($inv); + let (_, r2) = macx(r2, k, $modulus.0[0]); + let (r3, r2) = mac(r3, k, $modulus.0[1], r2); + let (r0, r2) = mac(r0, k, $modulus.0[2], r2); + let (r1, r2) = mac(r1, k, $modulus.0[3], r2); + + let k = r3.wrapping_mul($inv); + let (_, r3) = macx(r3, k, $modulus.0[0]); + let (r0, r3) = mac(r0, k, $modulus.0[1], r3); + let (r1, r3) = mac(r1, k, $modulus.0[2], r3); + let (r2, r3) = mac(r2, k, $modulus.0[3], r3); + + // Result may be within MODULUS of the correct value + (&$field([r0, r1, r2, r3])).sub(&$modulus) + } + } + + impl From<$field> for [u64; 4] { + fn from(elt: $field) -> [u64; 4] { + // Turn into canonical form by computing + // (a.R) / R = a + $field::montgomery_reduce_short(&elt.0).0 + } } }; } diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index f332f0b6..01fecf84 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_arithmetic, field_bits, field_common, field_specific, impl_add_binop_specify_output, @@ -255,15 +255,12 @@ impl ff::PrimeField for Fp { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - let tmp = Fp::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -353,6 +350,11 @@ mod test { crate::tests::field::random_field_tests::("secp256k1 base".to_string()); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("secp256k1 base".to_string()); + } + #[test] #[cfg(feature = "bits")] fn test_bits() { diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index 9c5e7665..d38dc517 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_arithmetic, field_bits, field_common, field_specific, impl_add_binop_specify_output, @@ -266,15 +266,12 @@ impl ff::PrimeField for Fq { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - let tmp = Fq::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -360,6 +357,11 @@ mod test { crate::tests::field::random_field_tests::("secp256k1 scalar".to_string()); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("secp256k1 scalar".to_string()); + } + #[test] #[cfg(feature = "bits")] fn test_bits() { diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index d351b64d..e26f19fc 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_arithmetic, field_bits, field_common, field_specific, impl_add_binop_specify_output, @@ -273,15 +273,12 @@ impl ff::PrimeField for Fp { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - let tmp = Fp::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -368,19 +365,24 @@ mod test { #[test] fn test_field() { - crate::tests::field::random_field_tests::("secp256k1 base".to_string()); + crate::tests::field::random_field_tests::("secp256r1 base".to_string()); + } + + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("secp256r1 base".to_string()); } #[test] #[cfg(feature = "bits")] fn test_bits() { - crate::tests::field::random_bits_tests::("secp256k1 base".to_string()); + crate::tests::field::random_bits_tests::("secp256r1 base".to_string()); } #[test] fn test_serialization() { - crate::tests::field::random_serialization_test::("secp256k1 base".to_string()); + crate::tests::field::random_serialization_test::("secp256r1 base".to_string()); #[cfg(feature = "derive_serde")] - crate::tests::field::random_serde_test::("secp256k1 base".to_string()); + crate::tests::field::random_serde_test::("secp256r1 base".to_string()); } } diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index e28c3fe6..05fcf1fa 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -1,4 +1,4 @@ -use crate::arithmetic::{adc, mac, sbb}; +use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use core::convert::TryInto; use core::fmt; @@ -262,15 +262,12 @@ impl ff::PrimeField for Fq { } fn to_repr(&self) -> Self::Repr { - // Turn into canonical form by computing - // (a.R) / R = a - let tmp = Fq::montgomery_reduce(&[self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0]); - + let tmp: [u64; 4] = (*self).into(); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); - res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); - res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); - res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + res[0..8].copy_from_slice(&tmp[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp[3].to_le_bytes()); res } @@ -362,6 +359,11 @@ mod test { crate::tests::field::random_field_tests::("secp256r1 scalar".to_string()); } + #[test] + fn test_conversion() { + crate::tests::field::random_conversion_tests::("secp256r1 scalar".to_string()); + } + #[test] fn test_serialization() { crate::tests::field::random_serialization_test::("secp256r1 scalar".to_string()); diff --git a/src/tests/field.rs b/src/tests/field.rs index a85b3f0e..a064441e 100644 --- a/src/tests/field.rs +++ b/src/tests/field.rs @@ -212,6 +212,22 @@ fn random_expansion_tests(mut rng: R, type_name: String) { end_timer!(start); } +pub fn random_conversion_tests>(type_name: String) { + let mut rng = XorShiftRng::from_seed([ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, + 0xe5, + ]); + let _message = format!("conversion {type_name}"); + let start = start_timer!(|| _message); + for _ in 0..1000000 { + let a = F::random(&mut rng); + let bytes = a.to_repr(); + let b = F::from_repr(bytes).unwrap(); + assert_eq!(a, b); + } + end_timer!(start); +} + #[cfg(feature = "bits")] pub fn random_bits_tests(type_name: String) { let mut rng = XorShiftRng::from_seed([ From 1d71d343b3764f6a819513a0cd9a855cd3f6b698 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Tue, 22 Aug 2023 16:18:30 +0200 Subject: [PATCH 02/13] Compute Legendre symbol for `hash_to_curve` (#77) * Add `Legendre` trait and macro - Add Legendre macro with norm and legendre symbol computation - Add macro for automatic implementation in prime fields * Add legendre macro call for prime fields * Remove unused imports * Remove leftover * Add `is_quadratic_non_residue` for hash_to_curve * Add `legendre` function * Compute modulus separately * Substitute division for shift * Update modulus computation * Add quadratic residue check func * Add quadratic residue tests * Add hash_to_curve bench * Implement Legendre trait for all curves * Move misplaced comment * Add all curves to hash bench * fix: add suggestion for legendre_exp * fix: imports after rebase --- Cargo.toml | 4 +++ benches/hash_to_curve.rs | 59 ++++++++++++++++++++++++++++++++++++++++ src/bn256/fq.rs | 33 ++++++---------------- src/bn256/fq2.rs | 56 +++++++++++++++++++------------------- src/bn256/fr.rs | 10 +++++-- src/bn256/mod.rs | 7 ----- src/hash_to_curve.rs | 9 ++++-- src/legendre.rs | 50 ++++++++++++++++++++++++++++++++++ src/lib.rs | 2 ++ src/pasta/mod.rs | 9 ++++++ src/secp256k1/fp.rs | 7 +++++ src/secp256k1/fq.rs | 6 ++++ src/secp256r1/fp.rs | 7 +++++ src/secp256r1/fq.rs | 8 +++++- src/tests/curve.rs | 5 +++- src/tests/field.rs | 16 ++++++++++- 16 files changed, 220 insertions(+), 68 deletions(-) create mode 100644 benches/hash_to_curve.rs create mode 100644 src/legendre.rs diff --git a/Cargo.toml b/Cargo.toml index e0b9d8ba..f29c917e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,3 +63,7 @@ required-features = ["reexport"] [[bench]] name = "group" harness = false + +[[bench]] +name = "hash_to_curve" +harness = false diff --git a/benches/hash_to_curve.rs b/benches/hash_to_curve.rs new file mode 100644 index 00000000..76f6733a --- /dev/null +++ b/benches/hash_to_curve.rs @@ -0,0 +1,59 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use pasta_curves::arithmetic::CurveExt; +use rand_core::{OsRng, RngCore}; +use std::iter; + +fn hash_to_secp256k1(c: &mut Criterion) { + hash_to_curve::(c, "Secp256k1"); +} + +fn hash_to_secq256k1(c: &mut Criterion) { + hash_to_curve::(c, "Secq256k1"); +} + +fn hash_to_secp256r1(c: &mut Criterion) { + hash_to_curve::(c, "Secp256r1"); +} + +fn hash_to_pallas(c: &mut Criterion) { + hash_to_curve::(c, "Pallas"); +} + +fn hash_to_vesta(c: &mut Criterion) { + hash_to_curve::(c, "Vesta"); +} + +fn hash_to_bn256(c: &mut Criterion) { + hash_to_curve::(c, "Bn256"); +} + +fn hash_to_grumpkin(c: &mut Criterion) { + hash_to_curve::(c, "Grumpkin"); +} + +fn hash_to_curve(c: &mut Criterion, name: &'static str) { + { + let hasher = G::hash_to_curve("test"); + let mut rng = OsRng; + let message = iter::repeat_with(|| rng.next_u32().to_be_bytes()) + .take(1024) + .flatten() + .collect::>(); + + c.bench_function(&format!("Hash to {}", name), move |b| { + b.iter(|| hasher(black_box(&message))) + }); + } +} + +criterion_group!( + benches, + hash_to_secp256k1, + hash_to_secq256k1, + hash_to_secp256r1, + hash_to_pallas, + hash_to_vesta, + hash_to_bn256, + hash_to_grumpkin, +); +criterion_main!(benches); diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index 0e2fc10f..0024723a 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -1,11 +1,10 @@ #[cfg(feature = "asm")] use crate::bn256::assembly::field_arithmetic_asm; #[cfg(not(feature = "asm"))] -use crate::{field_arithmetic, field_specific}; +use crate::{arithmetic::macx, field_arithmetic, field_specific}; -use crate::arithmetic::{adc, mac, macx, sbb}; -use crate::bn256::LegendreSymbol; -use crate::ff::{Field, FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; +use crate::arithmetic::{adc, mac, sbb}; +use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_bits, field_common, impl_add_binop_specify_output, impl_binops_additive, impl_binops_additive_specify_output, impl_binops_multiplicative, @@ -160,27 +159,10 @@ impl Fq { pub const fn size() -> usize { 32 } - - pub fn legendre(&self) -> LegendreSymbol { - // s = self^((modulus - 1) // 2) - // 0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea3 - let s = &[ - 0x9e10460b6c3e7ea3u64, - 0xcbc0b548b438e546u64, - 0xdc2822db40c0ac2eu64, - 0x183227397098d014u64, - ]; - let s = self.pow(s); - if s == Self::zero() { - LegendreSymbol::Zero - } else if s == Self::one() { - LegendreSymbol::QuadraticResidue - } else { - LegendreSymbol::QuadraticNonResidue - } - } } +prime_field_legendre!(Fq); + impl ff::Field for Fq { const ZERO: Self = Self::zero(); const ONE: Self = Self::one(); @@ -310,6 +292,7 @@ impl WithSmallOrderMulGroup<3> for Fq { #[cfg(test)] mod test { use super::*; + use crate::legendre::Legendre; use ff::Field; use rand_core::OsRng; @@ -322,7 +305,7 @@ mod test { let a = Fq::random(OsRng); let mut b = a; b = b.square(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); let b = b.sqrt().unwrap(); let mut negb = b; @@ -335,7 +318,7 @@ mod test { for _ in 0..10000 { let mut b = c; b = b.square(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); b = b.sqrt().unwrap(); diff --git a/src/bn256/fq2.rs b/src/bn256/fq2.rs index e5a249ee..468b018c 100644 --- a/src/bn256/fq2.rs +++ b/src/bn256/fq2.rs @@ -1,6 +1,6 @@ use super::fq::{Fq, NEGATIVE_ONE}; -use super::LegendreSymbol; use crate::ff::{Field, FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; +use crate::legendre::Legendre; use core::convert::TryInto; use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; @@ -125,6 +125,30 @@ impl_binops_additive!(Fq2, Fq2); impl_binops_multiplicative!(Fq2, Fq2); impl_sum_prod!(Fq2); +impl Legendre for Fq2 { + type BasePrimeField = Fq; + + #[inline] + fn legendre_exp() -> &'static [u64] { + lazy_static::lazy_static! { + // (p-1) / 2 + static ref LEGENDRE_EXP: Vec = + (num_bigint::BigUint::from_bytes_le((-::ONE).to_repr().as_ref())/2usize).to_u64_digits(); + } + &*LEGENDRE_EXP + } + + /// Norm of Fq2 as extension field in i over Fq + #[inline] + fn norm(&self) -> Self::BasePrimeField { + let mut t0 = self.c0; + let mut t1 = self.c1; + t0 = t0.square(); + t1 = t1.square(); + t1 + t0 + } +} + impl Fq2 { #[inline] pub const fn zero() -> Fq2 { @@ -174,10 +198,6 @@ impl Fq2 { res } - pub fn legendre(&self) -> LegendreSymbol { - self.norm().legendre() - } - pub fn mul_assign(&mut self, other: &Self) { let mut t1 = self.c0 * other.c0; let mut t0 = self.c0 + self.c1; @@ -298,15 +318,6 @@ impl Fq2 { self.c1 += &t0; } - /// Norm of Fq2 as extension field in i over Fq - pub fn norm(&self) -> Fq { - let mut t0 = self.c0; - let mut t1 = self.c1; - t0 = t0.square(); - t1 = t1.square(); - t1 + t0 - } - pub fn invert(&self) -> CtOption { let mut t1 = self.c1; t1 = t1.square(); @@ -696,17 +707,6 @@ fn test_fq2_mul_nonresidue() { } } -#[test] -fn test_fq2_legendre() { - assert_eq!(LegendreSymbol::Zero, Fq2::ZERO.legendre()); - // i^2 = -1 - let mut m1 = Fq2::ONE; - m1 = m1.neg(); - assert_eq!(LegendreSymbol::QuadraticResidue, m1.legendre()); - m1.mul_by_nonresidue(); - assert_eq!(LegendreSymbol::QuadraticNonResidue, m1.legendre()); -} - #[test] pub fn test_sqrt() { let mut rng = XorShiftRng::from_seed([ @@ -716,7 +716,7 @@ pub fn test_sqrt() { for _ in 0..10000 { let a = Fq2::random(&mut rng); - if a.legendre() == LegendreSymbol::QuadraticNonResidue { + if a.legendre() == -Fq::ONE { assert!(bool::from(a.sqrt().is_none())); } } @@ -725,7 +725,7 @@ pub fn test_sqrt() { let a = Fq2::random(&mut rng); let mut b = a; b.square_assign(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); let b = b.sqrt().unwrap(); let mut negb = b; @@ -738,7 +738,7 @@ pub fn test_sqrt() { for _ in 0..10000 { let mut b = c; b.square_assign(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); b = b.sqrt().unwrap(); diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index cd422d4b..8a57ff9f 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -1,7 +1,7 @@ #[cfg(feature = "asm")] use crate::bn256::assembly::field_arithmetic_asm; #[cfg(not(feature = "asm"))] -use crate::{field_arithmetic, field_specific}; +use crate::{arithmetic::macx, field_arithmetic, field_specific}; #[cfg(feature = "bn256-table")] #[rustfmt::skip] @@ -18,7 +18,7 @@ pub use table::FR_TABLE; #[cfg(not(feature = "bn256-table"))] use crate::impl_from_u64; -use crate::arithmetic::{adc, mac, macx, sbb}; +use crate::arithmetic::{adc, mac, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_bits, field_common, impl_add_binop_specify_output, impl_binops_additive, @@ -166,6 +166,7 @@ field_common!( R3 ); impl_sum_prod!(Fr); +prime_field_legendre!(Fr); #[cfg(not(feature = "bn256-table"))] impl_from_u64!(Fr, R2); @@ -470,4 +471,9 @@ mod test { end_timer!(timer); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/bn256/mod.rs b/src/bn256/mod.rs index 9cd08946..3530b765 100644 --- a/src/bn256/mod.rs +++ b/src/bn256/mod.rs @@ -16,10 +16,3 @@ pub use fq12::*; pub use fq2::*; pub use fq6::*; pub use fr::*; - -#[derive(Debug, PartialEq, Eq)] -pub enum LegendreSymbol { - Zero = 0, - QuadraticResidue = 1, - QuadraticNonResidue = -1, -} diff --git a/src/hash_to_curve.rs b/src/hash_to_curve.rs index 4cef7095..5d51ec75 100644 --- a/src/hash_to_curve.rs +++ b/src/hash_to_curve.rs @@ -5,6 +5,8 @@ use pasta_curves::arithmetic::CurveExt; use static_assertions::const_assert; use subtle::{ConditionallySelectable, ConstantTimeEq}; +use crate::legendre::Legendre; + /// Hashes over a message and writes the output to all of `buf`. /// Modified from https://github.com/zcash/pasta_curves/blob/7e3fc6a4919f6462a32b79dd226cb2587b7961eb/src/hashtocurve.rs#L11. fn hash_to_field>( @@ -94,6 +96,7 @@ pub(crate) fn svdw_map_to_curve( ) -> C where C: CurveExt, + C::Base: Legendre, { let one = C::Base::ONE; let a = C::a(); @@ -128,7 +131,7 @@ where // 14. gx1 = gx1 + B let gx1 = gx1 + b; // 15. e1 = is_square(gx1) - let e1 = gx1.sqrt().is_some(); + let e1 = !gx1.ct_quadratic_non_residue(); // 16. x2 = c2 + tv4 let x2 = c2 + tv4; // 17. gx2 = x2^2 @@ -140,7 +143,7 @@ where // 20. gx2 = gx2 + B let gx2 = gx2 + b; // 21. e2 = is_square(gx2) AND NOT e1 # Avoid short-circuit logic ops - let e2 = gx2.sqrt().is_some() & (!e1); + let e2 = !gx2.ct_quadratic_non_residue() & (!e1); // 22. x3 = tv2^2 let x3 = tv2.square(); // 23. x3 = x3 * tv3 @@ -182,7 +185,7 @@ pub(crate) fn svdw_hash_to_curve<'a, C>( ) -> Box C + 'a> where C: CurveExt, - C::Base: FromUniformBytes<64>, + C::Base: FromUniformBytes<64> + Legendre, { let [c1, c2, c3, c4] = svdw_precomputed_constants::(z); diff --git a/src/legendre.rs b/src/legendre.rs new file mode 100644 index 00000000..6f6fda17 --- /dev/null +++ b/src/legendre.rs @@ -0,0 +1,50 @@ +use ff::{Field, PrimeField}; +use subtle::{Choice, ConstantTimeEq}; + +pub trait Legendre: Field { + type BasePrimeField: PrimeField; + + // This is (p-1)/2 where p is the modulus of the base prime field + fn legendre_exp() -> &'static [u64]; + + fn norm(&self) -> Self::BasePrimeField; + + #[inline] + fn legendre(&self) -> Self::BasePrimeField { + self.norm().pow(Self::legendre_exp()) + } + + #[inline] + fn ct_quadratic_residue(&self) -> Choice { + self.legendre().ct_eq(&Self::BasePrimeField::ONE) + } + + #[inline] + fn ct_quadratic_non_residue(&self) -> Choice { + self.legendre().ct_eq(&-Self::BasePrimeField::ONE) + } +} + +#[macro_export] +macro_rules! prime_field_legendre { + ($field:ident ) => { + impl crate::legendre::Legendre for $field { + type BasePrimeField = Self; + + #[inline] + fn legendre_exp() -> &'static [u64] { + lazy_static::lazy_static! { + // (p-1) / 2 + static ref LEGENDRE_EXP: Vec = + (num_bigint::BigUint::from_bytes_le((-<$field as ff::Field>::ONE).to_repr().as_ref())/2usize).to_u64_digits(); + } + &*LEGENDRE_EXP + } + + #[inline] + fn norm(&self) -> Self::BasePrimeField { + self.clone() + } + } + }; +} diff --git a/src/lib.rs b/src/lib.rs index b75d7143..3fa8e98f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ mod arithmetic; pub mod hash_to_curve; +#[macro_use] +pub mod legendre; pub mod serde; pub mod bn256; diff --git a/src/pasta/mod.rs b/src/pasta/mod.rs index 164697b5..078b663e 100644 --- a/src/pasta/mod.rs +++ b/src/pasta/mod.rs @@ -38,6 +38,9 @@ const ENDO_PARAMS_EP: EndoParameters = EndoParameters { endo!(Eq, Fp, ENDO_PARAMS_EQ); endo!(Ep, Fq, ENDO_PARAMS_EP); +prime_field_legendre!(Fp); +prime_field_legendre!(Fq); + #[test] fn test_endo() { use ff::Field; @@ -71,3 +74,9 @@ fn test_endo() { } } } + +#[test] +fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + crate::tests::field::random_quadratic_residue_test::(); +} diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index 01fecf84..f6a2a54b 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -295,6 +295,8 @@ impl WithSmallOrderMulGroup<3> for Fp { const ZETA: Self = ZETA; } +prime_field_legendre!(Fp); + #[cfg(test)] mod test { use super::*; @@ -367,4 +369,9 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256k1 base".to_string()); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index d38dc517..304f5f10 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -302,6 +302,8 @@ impl WithSmallOrderMulGroup<3> for Fq { const ZETA: Self = ZETA; } +prime_field_legendre!(Fq); + #[cfg(test)] mod test { use super::*; @@ -374,4 +376,8 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256k1 scalar".to_string()); } + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index e26f19fc..228e4a67 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -313,6 +313,8 @@ impl WithSmallOrderMulGroup<3> for Fp { const ZETA: Self = ZETA; } +prime_field_legendre!(Fp); + #[cfg(test)] mod test { use super::*; @@ -385,4 +387,9 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256r1 base".to_string()); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index 05fcf1fa..1b98761c 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -1,6 +1,5 @@ use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; -use core::convert::TryInto; use core::fmt; use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; @@ -298,6 +297,8 @@ impl WithSmallOrderMulGroup<3> for Fq { const ZETA: Self = ZETA; } +prime_field_legendre!(Fq); + #[cfg(test)] mod test { use super::*; @@ -370,4 +371,9 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256r1 scalar".to_string()); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/tests/curve.rs b/src/tests/curve.rs index 9bb0fe0b..54d23791 100644 --- a/src/tests/curve.rs +++ b/src/tests/curve.rs @@ -2,6 +2,7 @@ use crate::ff::Field; use crate::group::prime::PrimeCurveAffine; +use crate::legendre::Legendre; use crate::tests::fe_from_str; use crate::{group::GroupEncoding, serde::SerdeObject}; use crate::{hash_to_curve, CurveAffine, CurveExt}; @@ -343,7 +344,9 @@ pub fn svdw_map_to_curve_test( z: G::Base, precomputed_constants: [&'static str; 4], test_vector: impl IntoIterator, -) { +) where + ::Base: Legendre, +{ let [c1, c2, c3, c4] = hash_to_curve::svdw_precomputed_constants::(z); assert_eq!([c1, c2, c3, c4], precomputed_constants.map(fe_from_str)); for (u, (x, y)) in test_vector.into_iter() { diff --git a/src/tests/field.rs b/src/tests/field.rs index a064441e..b04f801e 100644 --- a/src/tests/field.rs +++ b/src/tests/field.rs @@ -1,6 +1,7 @@ -use crate::ff::Field; use crate::serde::SerdeObject; +use crate::{ff::Field, legendre::Legendre}; use ark_std::{end_timer, start_timer}; +use ff::PrimeField; use rand::{RngCore, SeedableRng}; use rand_xorshift::XorShiftRng; @@ -287,3 +288,16 @@ where } end_timer!(start); } + +pub fn random_quadratic_residue_test() { + let mut rng = XorShiftRng::from_seed([ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, + 0xe5, + ]); + for _ in 0..100000 { + let elem = F::random(&mut rng); + let is_quad_res_or_zero: bool = elem.sqrt().is_some().into(); + let is_quad_non_res: bool = elem.ct_quadratic_non_residue().into(); + assert_eq!(!is_quad_non_res, is_quad_res_or_zero) + } +} From 2bb4633111e77b7a962d2bbbf2de0c1641f16880 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Thu, 24 Aug 2023 15:36:42 +0200 Subject: [PATCH 03/13] Add simplified SWU method (#81) * Fix broken link * Add simple SWU algorithm * Add simplified SWU hash_to_curve for secp256r1 * add: sswu z reference * update MAP_ID identifier Co-authored-by: Han --------- Co-authored-by: Han --- src/hash_to_curve.rs | 129 ++++++++++++++++++++++++++++++++++++++++- src/secp256r1/curve.rs | 124 +++++++++++++++++++++++---------------- 2 files changed, 200 insertions(+), 53 deletions(-) diff --git a/src/hash_to_curve.rs b/src/hash_to_curve.rs index 5d51ec75..22251102 100644 --- a/src/hash_to_curve.rs +++ b/src/hash_to_curve.rs @@ -3,7 +3,7 @@ use ff::{Field, FromUniformBytes, PrimeField}; use pasta_curves::arithmetic::CurveExt; use static_assertions::const_assert; -use subtle::{ConditionallySelectable, ConstantTimeEq}; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use crate::legendre::Legendre; @@ -85,6 +85,94 @@ fn hash_to_field>( } } +// Implementation of +#[allow(clippy::too_many_arguments)] +pub(crate) fn simple_svdw_map_to_curve(u: C::Base, z: C::Base) -> C +where + C: CurveExt, +{ + let zero = C::Base::ZERO; + let one = C::Base::ONE; + let a = C::a(); + let b = C::b(); + + //1. tv1 = u^2 + let tv1 = u.square(); + //2. tv1 = Z * tv1 + let tv1 = z * tv1; + //3. tv2 = tv1^2 + let tv2 = tv1.square(); + //4. tv2 = tv2 + tv1 + let tv2 = tv2 + tv1; + //5. tv3 = tv2 + 1 + let tv3 = tv2 + one; + //6. tv3 = B * tv3 + let tv3 = b * tv3; + //7. tv4 = CMOV(Z, -tv2, tv2 != 0) # tv4 = z if tv2 is 0 else tv4 = -tv2 + let tv2_is_not_zero = !tv2.ct_eq(&zero); + let tv4 = C::Base::conditional_select(&z, &-tv2, tv2_is_not_zero); + //8. tv4 = A * tv4 + let tv4 = a * tv4; + //9. tv2 = tv3^2 + let tv2 = tv3.square(); + //10. tv6 = tv4^2 + let tv6 = tv4.square(); + //11. tv5 = A * tv6 + let tv5 = a * tv6; + //12. tv2 = tv2 + tv5 + let tv2 = tv2 + tv5; + //13. tv2 = tv2 * tv3 + let tv2 = tv2 * tv3; + //14. tv6 = tv6 * tv4 + let tv6 = tv6 * tv4; + //15. tv5 = B * tv6 + let tv5 = b * tv6; + //16. tv2 = tv2 + tv5 + let tv2 = tv2 + tv5; + //17. x = tv1 * tv3 + let x = tv1 * tv3; + //18. (is_gx1_square, y1) = sqrt_ratio(tv2, tv6) + let (is_gx1_square, y1) = sqrt_ratio(&tv2, &tv6, &z); + //19. y = tv1 * u + let y = tv1 * u; + //20. y = y * y1 + let y = y * y1; + //21. x = CMOV(x, tv3, is_gx1_square) + let x = C::Base::conditional_select(&x, &tv3, is_gx1_square); + //22. y = CMOV(y, y1, is_gx1_square) + let y = C::Base::conditional_select(&y, &y1, is_gx1_square); + //23. e1 = sgn0(u) == sgn0(y) + let e1 = u.is_odd().ct_eq(&y.is_odd()); + //24. y = CMOV(-y, y, e1) # Select correct sign of y + let y = C::Base::conditional_select(&-y, &y, e1); + //25. x = x / tv4 + let x = x * tv4.invert().unwrap(); + //26. return (x, y) + C::new_jacobian(x, y, one).unwrap() +} + +#[allow(clippy::type_complexity)] +pub(crate) fn simple_svdw_hash_to_curve<'a, C>( + curve_id: &'static str, + domain_prefix: &'a str, + z: C::Base, +) -> Box C + 'a> +where + C: CurveExt, + C::Base: FromUniformBytes<64>, +{ + Box::new(move |message| { + let mut us = [C::Base::ZERO; 2]; + hash_to_field("SSWU", curve_id, domain_prefix, message, &mut us); + + let [q0, q1]: [C; 2] = us.map(|u| simple_svdw_map_to_curve(u, z)); + + let r = q0 + &q1; + debug_assert!(bool::from(r.is_on_curve())); + r + }) +} + #[allow(clippy::too_many_arguments)] pub(crate) fn svdw_map_to_curve( u: C::Base, @@ -176,7 +264,44 @@ where C::new_jacobian(x, y, one).unwrap() } -/// Implementation of https://www.ietf.org/id/draft-irtf-cfrg-hash-to-curve-16.html#name-shallue-van-de-woestijne-met +// Implement https://datatracker.ietf.org/doc/html/rfc9380#name-sqrt_ratio-for-any-field +// Copied from ff sqrt_ratio_generic subsituting F::ROOT_OF_UNITY for input Z +fn sqrt_ratio(num: &F, div: &F, z: &F) -> (Choice, F) { + // General implementation: + // + // a = num * inv0(div) + // = { 0 if div is zero + // { num/div otherwise + // + // b = z * a + // = { 0 if div is zero + // { z*num/div otherwise + + // Since z is non-square, a and b are either both zero (and both square), or + // only one of them is square. We can therefore choose the square root to return + // based on whether a is square, but for the boolean output we need to handle the + // num != 0 && div == 0 case specifically. + + let a = div.invert().unwrap_or(F::ZERO) * num; + let b = a * z; + let sqrt_a = a.sqrt(); + let sqrt_b = b.sqrt(); + + let num_is_zero = num.is_zero(); + let div_is_zero = div.is_zero(); + let is_square = sqrt_a.is_some(); + let is_nonsquare = sqrt_b.is_some(); + assert!(bool::from( + num_is_zero | div_is_zero | (is_square ^ is_nonsquare) + )); + + ( + is_square & (num_is_zero | !div_is_zero), + CtOption::conditional_select(&sqrt_b, &sqrt_a, is_square).unwrap(), + ) +} + +/// Implementation of https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.html#section-6.6.1 #[allow(clippy::type_complexity)] pub(crate) fn svdw_hash_to_curve<'a, C>( curve_id: &'static str, diff --git a/src/secp256r1/curve.rs b/src/secp256r1/curve.rs index 5ce4522c..7e6e24ef 100644 --- a/src/secp256r1/curve.rs +++ b/src/secp256r1/curve.rs @@ -1,6 +1,7 @@ use crate::ff::WithSmallOrderMulGroup; use crate::ff::{Field, PrimeField}; use crate::group::{prime::PrimeCurveAffine, Curve, Group as _, GroupEncoding}; +use crate::hash_to_curve::simple_svdw_hash_to_curve; use crate::secp256r1::Fp; use crate::secp256r1::Fq; use crate::{Coordinates, CurveAffine, CurveExt}; @@ -75,77 +76,98 @@ new_curve_impl!( SECP_A, SECP_B, "secp256r1", - |_, _| unimplemented!(), + |curve_id, domain_prefix| simple_svdw_hash_to_curve(curve_id, domain_prefix, Secp256r1::SSVDW_Z), ); -#[test] -fn test_curve() { - crate::tests::curve::curve_tests::(); +impl Secp256r1 { + // Optimal Z with: + // 0xffffffff00000001000000000000000000000000fffffffffffffffffffffff5 + // Z = -10 (reference: ) + const SSVDW_Z: Fp = Fp::from_raw([ + 0xfffffffffffffff5, + 0x00000000ffffffff, + 0x0000000000000000, + 0xffffffff00000001, + ]); } -#[test] -fn test_serialization() { - crate::tests::curve::random_serialization_test::(); - #[cfg(feature = "derive_serde")] - crate::tests::curve::random_serde_test::(); -} - -#[test] -fn ecdsa_example() { +#[cfg(test)] +mod tests { + use super::*; use crate::group::Curve; - use crate::CurveAffine; + use crate::secp256r1::{Fp, Fq, Secp256r1}; use ff::FromUniformBytes; use rand_core::OsRng; - fn mod_n(x: Fp) -> Fq { - let mut x_repr = [0u8; 32]; - x_repr.copy_from_slice(x.to_repr().as_ref()); - let mut x_bytes = [0u8; 64]; - x_bytes[..32].copy_from_slice(&x_repr[..]); - Fq::from_uniform_bytes(&x_bytes) + #[test] + fn test_hash_to_curve() { + crate::tests::curve::hash_to_curve_test::(); + } + + #[test] + fn test_curve() { + crate::tests::curve::curve_tests::(); } - let g = Secp256r1::generator(); + #[test] + fn test_serialization() { + crate::tests::curve::random_serialization_test::(); + #[cfg(feature = "derive_serde")] + crate::tests::curve::random_serde_test::(); + } + + #[test] + fn ecdsa_example() { + fn mod_n(x: Fp) -> Fq { + let mut x_repr = [0u8; 32]; + x_repr.copy_from_slice(x.to_repr().as_ref()); + let mut x_bytes = [0u8; 64]; + x_bytes[..32].copy_from_slice(&x_repr[..]); + Fq::from_uniform_bytes(&x_bytes) + } + + let g = Secp256r1::generator(); - for _ in 0..1000 { - // Generate a key pair - let sk = Fq::random(OsRng); - let pk = (g * sk).to_affine(); + for _ in 0..1000 { + // Generate a key pair + let sk = Fq::random(OsRng); + let pk = (g * sk).to_affine(); - // Generate a valid signature - // Suppose `m_hash` is the message hash - let msg_hash = Fq::random(OsRng); + // Generate a valid signature + // Suppose `m_hash` is the message hash + let msg_hash = Fq::random(OsRng); - let (r, s) = { - // Draw arandomness - let k = Fq::random(OsRng); - let k_inv = k.invert().unwrap(); + let (r, s) = { + // Draw arandomness + let k = Fq::random(OsRng); + let k_inv = k.invert().unwrap(); - // Calculate `r` - let r_point = (g * k).to_affine().coordinates().unwrap(); - let x = r_point.x(); - let r = mod_n(*x); + // Calculate `r` + let r_point = (g * k).to_affine().coordinates().unwrap(); + let x = r_point.x(); + let r = mod_n(*x); - // Calculate `s` - let s = k_inv * (msg_hash + (r * sk)); + // Calculate `s` + let s = k_inv * (msg_hash + (r * sk)); - (r, s) - }; + (r, s) + }; - { - // Verify - let s_inv = s.invert().unwrap(); - let u_1 = msg_hash * s_inv; - let u_2 = r * s_inv; + { + // Verify + let s_inv = s.invert().unwrap(); + let u_1 = msg_hash * s_inv; + let u_2 = r * s_inv; - let v_1 = g * u_1; - let v_2 = pk * u_2; + let v_1 = g * u_1; + let v_2 = pk * u_2; - let r_point = (v_1 + v_2).to_affine().coordinates().unwrap(); - let x_candidate = r_point.x(); - let r_candidate = mod_n(*x_candidate); + let r_point = (v_1 + v_2).to_affine().coordinates().unwrap(); + let x_candidate = r_point.x(); + let r_candidate = mod_n(*x_candidate); - assert_eq!(r, r_candidate); + assert_eq!(r, r_candidate); + } } } } From 6e2ff3853c8fe91300650a733100640dacf313e6 Mon Sep 17 00:00:00 2001 From: Han Date: Mon, 4 Sep 2023 12:24:29 +0800 Subject: [PATCH 04/13] Bring back curve algorithms for `a = 0` (#82) * refactor: bring back curve algorithms for `a = 0` * fix: clippy warning --- benches/group.rs | 16 +- benches/hash_to_curve.rs | 2 +- src/bn256/fq2.rs | 2 +- src/derive/curve.rs | 441 +++++++++++++++++++++++++-------------- src/legendre.rs | 2 +- 5 files changed, 291 insertions(+), 172 deletions(-) diff --git a/benches/group.rs b/benches/group.rs index 68cfee53..b1936e68 100644 --- a/benches/group.rs +++ b/benches/group.rs @@ -18,28 +18,28 @@ fn criterion_benchmark(c: &mut Criterion) { let v = vec![G::generator(); N]; let mut q = vec![G::AffineExt::identity(); N]; - c.bench_function(&format!("{} check on curve", name), move |b| { + c.bench_function(&format!("{name} check on curve"), move |b| { b.iter(|| black_box(p1).is_on_curve()) }); - c.bench_function(&format!("{} check equality", name), move |b| { + c.bench_function(&format!("{name} check equality"), move |b| { b.iter(|| black_box(p1) == black_box(p1)) }); - c.bench_function(&format!("{} to affine", name), move |b| { + c.bench_function(&format!("{name} to affine"), move |b| { b.iter(|| G::AffineExt::from(black_box(p1))) }); - c.bench_function(&format!("{} doubling", name), move |b| { + c.bench_function(&format!("{name} doubling"), move |b| { b.iter(|| black_box(p1).double()) }); - c.bench_function(&format!("{} addition", name), move |b| { + c.bench_function(&format!("{name} addition"), move |b| { b.iter(|| black_box(p1).add(&p2)) }); - c.bench_function(&format!("{} mixed addition", name), move |b| { + c.bench_function(&format!("{name} mixed addition"), move |b| { b.iter(|| black_box(p2).add(&p1_affine)) }); - c.bench_function(&format!("{} scalar multiplication", name), move |b| { + c.bench_function(&format!("{name} scalar multiplication"), move |b| { b.iter(|| black_box(p1) * black_box(s)) }); - c.bench_function(&format!("{} batch to affine n={}", name, N), move |b| { + c.bench_function(&format!("{name} batch to affine n={N}"), move |b| { b.iter(|| { G::batch_normalize(black_box(&v), black_box(&mut q)); black_box(&q)[0] diff --git a/benches/hash_to_curve.rs b/benches/hash_to_curve.rs index 76f6733a..bda1c1d3 100644 --- a/benches/hash_to_curve.rs +++ b/benches/hash_to_curve.rs @@ -40,7 +40,7 @@ fn hash_to_curve(c: &mut Criterion, name: &'static str) { .flatten() .collect::>(); - c.bench_function(&format!("Hash to {}", name), move |b| { + c.bench_function(&format!("Hash to {name}"), move |b| { b.iter(|| hasher(black_box(&message))) }); } diff --git a/src/bn256/fq2.rs b/src/bn256/fq2.rs index 468b018c..66d2c6a7 100644 --- a/src/bn256/fq2.rs +++ b/src/bn256/fq2.rs @@ -135,7 +135,7 @@ impl Legendre for Fq2 { static ref LEGENDRE_EXP: Vec = (num_bigint::BigUint::from_bytes_le((-::ONE).to_repr().as_ref())/2usize).to_u64_digits(); } - &*LEGENDRE_EXP + &LEGENDRE_EXP } /// Norm of Fq2 as extension field in i over Fq diff --git a/src/derive/curve.rs b/src/derive/curve.rs index 1eeef572..098d1a2f 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -114,8 +114,7 @@ macro_rules! new_curve_impl { $base::from_bytes(&xbytes).and_then(|x| { CtOption::new(Self::identity(), x.is_zero() & (is_inf)).or_else(|| { - let x3 = x.square() * x; - (x3 + $name::curve_constant_a() * x + $name::curve_constant_b()).sqrt().and_then(|y| { + $name_affine::y2(x).sqrt().and_then(|y| { let sign = Choice::from(y.to_bytes()[0] & 1); let y = $base::conditional_select(&y, &-y, ysign ^ sign); @@ -321,18 +320,10 @@ macro_rules! new_curve_impl { } } - const fn curve_constant_a() -> $base { - $name_affine::curve_constant_a() - } - - const fn curve_constant_b() -> $base { - $name_affine::curve_constant_b() - } - #[inline] fn curve_constant_3b() -> $base { lazy_static::lazy_static! { - static ref CONST_3B: $base = $constant_b + $constant_b + $constant_b; + static ref CONST_3B: $base = $constant_b + $constant_b + $constant_b; } *CONST_3B } @@ -354,23 +345,24 @@ macro_rules! new_curve_impl { } } - const fn curve_constant_a() -> $base { - $constant_a - } - - const fn curve_constant_b() -> $base { - $constant_b + #[inline(always)] + fn y2(x: $base) -> $base { + if $constant_a == $base::ZERO { + let x3 = x.square() * x; + (x3 + $constant_b) + } else { + let x2 = x.square(); + ((x2 + $constant_a) * x + $constant_b) + } } - pub fn random(mut rng: impl RngCore) -> Self { loop { let x = $base::random(&mut rng); let ysign = (rng.next_u32() % 2) as u8; - let x3 = x.square() * x; - let y = (x3 + $name::curve_constant_a() * x + $name::curve_constant_b()).sqrt(); - if let Some(y) = Option::<$base>::from(y) { + let y2 = $name_affine::y2(x); + if let Some(y) = Option::<$base>::from(y2.sqrt()) { let sign = y.to_bytes()[0] & 1; let y = if ysign ^ sign == 0 { y } else { -y }; @@ -479,20 +471,30 @@ macro_rules! new_curve_impl { } fn is_on_curve(&self) -> Choice { - // Check (Y/Z)^2 = (X/Z)^3 + a(X/Z) + b - // <=> Z Y^2 - X^3 - a(X Z^2) = Z^3 b + if $constant_a == $base::ZERO { + // Check (Y/Z)^2 = (X/Z)^3 + b + // <=> Z Y^2 - X^3 = Z^3 b + + (self.z * self.y.square() - self.x.square() * self.x) + .ct_eq(&(self.z.square() * self.z * $constant_b)) + | self.z.is_zero() + } else { + // Check (Y/Z)^2 = (X/Z)^3 + a(X/Z) + b + // <=> Z Y^2 - X^3 - a(X Z^2) = Z^3 b - (self.z * self.y.square() - self.x.square() * self.x - $name::curve_constant_a() * self.x * self.z.square()) - .ct_eq(&(self.z.square() * self.z * $name::curve_constant_b())) - | self.z.is_zero() + let z2 = self.z.square(); + (self.z * self.y.square() - (self.x.square() + $constant_a * z2) * self.x) + .ct_eq(&(z2 * self.z * $constant_b)) + | self.z.is_zero() + } } fn b() -> Self::Base { - $name::curve_constant_b() + $constant_b } fn a() -> Self::Base { - $name::curve_constant_a() + $constant_a } fn new_jacobian(x: Self::Base, y: Self::Base, z: Self::Base) -> CtOption { @@ -566,46 +568,76 @@ macro_rules! new_curve_impl { } fn double(&self) -> Self { - // Algorithm 3, https://eprint.iacr.org/2015/1060.pdf - let t0 = self.x.square(); - let t1 = self.y.square(); - let t2 = self.z.square(); - let t3 = self.x * self.y; - let t3 = t3 + t3; - let z3 = self.x * self.z; - let z3 = z3 + z3; - let x3 = $name::curve_constant_a() * z3; - let y3 = $name::mul_by_3b(&t2); - let y3 = x3 + y3; - let x3 = t1 - y3; - let y3 = t1 + y3; - let y3 = x3 * y3; - let x3 = t3 * x3; - let z3 = $name::mul_by_3b(&z3); - let t2 = $name::curve_constant_a() * t2; - let t3 = t0 - t2; - let t3 = $name::curve_constant_a() * t3; - let t3 = t3 + z3; - let z3 = t0 + t0; - let t0 = z3 + t0; - let t0 = t0 + t2; - let t0 = t0 * t3; - let y3 = y3 + t0; - let t2 = self.y * self.z; - let t2 = t2 + t2; - let t0 = t2 * t3; - let x3 = x3 - t0; - let z3 = t2 * t1; - let z3 = z3 + z3; - let z3 = z3 + z3; - - let tmp = $name { - x: x3, - y: y3, - z: z3, - }; - - $name::conditional_select(&tmp, &$name::identity(), self.is_identity()) + if $constant_a == $base::ZERO { + // Algorithm 9, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.y.square(); + let z3 = t0 + t0; + let z3 = z3 + z3; + let z3 = z3 + z3; + let t1 = self.y * self.z; + let t2 = self.z.square(); + let t2 = $name::mul_by_3b(&t2); + let x3 = t2 * z3; + let y3 = t0 + t2; + let z3 = t1 * z3; + let t1 = t2 + t2; + let t2 = t1 + t2; + let t0 = t0 - t2; + let y3 = t0 * y3; + let y3 = x3 + y3; + let t1 = self.x * self.y; + let x3 = t0 * t1; + let x3 = x3 + x3; + + let tmp = $name { + x: x3, + y: y3, + z: z3, + }; + + $name::conditional_select(&tmp, &$name::identity(), self.is_identity()) + } else { + // Algorithm 3, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x.square(); + let t1 = self.y.square(); + let t2 = self.z.square(); + let t3 = self.x * self.y; + let t3 = t3 + t3; + let z3 = self.x * self.z; + let z3 = z3 + z3; + let x3 = $constant_a * z3; + let y3 = $name::mul_by_3b(&t2); + let y3 = x3 + y3; + let x3 = t1 - y3; + let y3 = t1 + y3; + let y3 = x3 * y3; + let x3 = t3 * x3; + let z3 = $name::mul_by_3b(&z3); + let t2 = $constant_a * t2; + let t3 = t0 - t2; + let t3 = $constant_a * t3; + let t3 = t3 + z3; + let z3 = t0 + t0; + let t0 = z3 + t0; + let t0 = t0 + t2; + let t0 = t0 * t3; + let y3 = y3 + t0; + let t2 = self.y * self.z; + let t2 = t2 + t2; + let t0 = t2 * t3; + let x3 = x3 - t0; + let z3 = t2 * t1; + let z3 = z3 + z3; + let z3 = z3 + z3; + + let tmp = $name { + x: x3, + y: y3, + z: z3, + }; + + $name::conditional_select(&tmp, &$name::identity(), self.is_identity()) + } } fn generator() -> Self { @@ -823,9 +855,15 @@ macro_rules! new_curve_impl { type CurveExt = $name; fn is_on_curve(&self) -> Choice { - // y^2 - x^3 - ax ?= b - (self.y.square() - self.x.square() * self.x - $name::curve_constant_a() * self.x).ct_eq(&$name::curve_constant_b()) - | self.is_identity() + if $constant_a == $base::ZERO { + // y^2 - x^3 ?= b + (self.y.square() - self.x.square() * self.x).ct_eq(&$constant_b) + | self.is_identity() + } else { + // y^2 - x^3 - ax ?= b + (self.y.square() - (self.x.square() + $constant_a) * self.x).ct_eq(&$constant_b) + | self.is_identity() + } } fn coordinates(&self) -> CtOption> { @@ -840,11 +878,11 @@ macro_rules! new_curve_impl { } fn a() -> Self::Base { - $name::curve_constant_a() + $constant_a } fn b() -> Self::Base { - $name::curve_constant_b() + $constant_b } } @@ -892,52 +930,95 @@ macro_rules! new_curve_impl { type Output = $name; fn add(self, rhs: &'a $name) -> $name { - // Algorithm 1, https://eprint.iacr.org/2015/1060.pdf - let t0 = self.x * rhs.x; - let t1 = self.y * rhs.y; - let t2 = self.z * rhs.z; - let t3 = self.x + self.y; - let t4 = rhs.x + rhs.y; - let t3 = t3 * t4; - let t4 = t0 + t1; - let t3 = t3 - t4; - let t4 = self.x + self.z; - let t5 = rhs.x + rhs.z; - let t4 = t4 * t5; - let t5 = t0 + t2; - let t4 = t4 - t5; - let t5 = self.y + self.z; - let x3 = rhs.y + rhs.z; - let t5 = t5 * x3; - let x3 = t1 + t2; - let t5 = t5 - x3; - let z3 = $name::curve_constant_a() * t4; - let x3 = $name::mul_by_3b(&t2); - let z3 = x3 + z3; - let x3 = t1 - z3; - let z3 = t1 + z3; - let y3 = x3 * z3; - let t1 = t0 + t0; - let t1 = t1 + t0; - let t2 = $name::curve_constant_a() * t2; - let t4 = $name::mul_by_3b(&t4); - let t1 = t1 + t2; - let t2 = t0 - t2; - let t2 = $name::curve_constant_a() * t2; - let t4 = t4 + t2; - let t0 = t1 * t4; - let y3 = y3 + t0; - let t0 = t5 * t4; - let x3 = t3 * x3; - let x3 = x3 - t0; - let t0 = t3 * t1; - let z3 = t5 * z3; - let z3 = z3 + t0; - - $name { - x: x3, - y: y3, - z: z3, + if $constant_a == $base::ZERO { + // Algorithm 7, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x * rhs.x; + let t1 = self.y * rhs.y; + let t2 = self.z * rhs.z; + let t3 = self.x + self.y; + let t4 = rhs.x + rhs.y; + let t3 = t3 * t4; + let t4 = t0 + t1; + let t3 = t3 - t4; + let t4 = self.y + self.z; + let x3 = rhs.y + rhs.z; + let t4 = t4 * x3; + let x3 = t1 + t2; + let t4 = t4 - x3; + let x3 = self.x + self.z; + let y3 = rhs.x + rhs.z; + let x3 = x3 * y3; + let y3 = t0 + t2; + let y3 = x3 - y3; + let x3 = t0 + t0; + let t0 = x3 + t0; + let t2 = $name::mul_by_3b(&t2); + let z3 = t1 + t2; + let t1 = t1 - t2; + let y3 = $name::mul_by_3b(&y3); + let x3 = t4 * y3; + let t2 = t3 * t1; + let x3 = t2 - x3; + let y3 = y3 * t0; + let t1 = t1 * z3; + let y3 = t1 + y3; + let t0 = t0 * t3; + let z3 = z3 * t4; + let z3 = z3 + t0; + + $name { + x: x3, + y: y3, + z: z3, + } + } else { + // Algorithm 1, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x * rhs.x; + let t1 = self.y * rhs.y; + let t2 = self.z * rhs.z; + let t3 = self.x + self.y; + let t4 = rhs.x + rhs.y; + let t3 = t3 * t4; + let t4 = t0 + t1; + let t3 = t3 - t4; + let t4 = self.x + self.z; + let t5 = rhs.x + rhs.z; + let t4 = t4 * t5; + let t5 = t0 + t2; + let t4 = t4 - t5; + let t5 = self.y + self.z; + let x3 = rhs.y + rhs.z; + let t5 = t5 * x3; + let x3 = t1 + t2; + let t5 = t5 - x3; + let z3 = $constant_a * t4; + let x3 = $name::mul_by_3b(&t2); + let z3 = x3 + z3; + let x3 = t1 - z3; + let z3 = t1 + z3; + let y3 = x3 * z3; + let t1 = t0 + t0; + let t1 = t1 + t0; + let t2 = $constant_a * t2; + let t4 = $name::mul_by_3b(&t4); + let t1 = t1 + t2; + let t2 = t0 - t2; + let t2 = $constant_a * t2; + let t4 = t4 + t2; + let t0 = t1 * t4; + let y3 = y3 + t0; + let t0 = t5 * t4; + let x3 = t3 * x3; + let x3 = x3 - t0; + let t0 = t3 * t1; + let z3 = t5 * z3; + let z3 = z3 + t0; + + $name { + x: x3, + y: y3, + z: z3, + } } } } @@ -947,48 +1028,86 @@ macro_rules! new_curve_impl { // Mixed addition fn add(self, rhs: &'a $name_affine) -> $name { - // Algorithm 2, https://eprint.iacr.org/2015/1060.pdf - let t0 = self.x * rhs.x; - let t1 = self.y * rhs.y; - let t3 = rhs.x + rhs.y; - let t4 = self.x + self.y; - let t3 = t3 * t4; - let t4 = t0 + t1; - let t3 = t3 - t4; - let t4 = rhs.x * self.z; - let t4 = t4 + self.x; - let t5 = rhs.y * self.z; - let t5 = t5 + self.y; - let z3 = $name::curve_constant_a() * t4; - let x3 = $name::mul_by_3b(&self.z); - let z3 = x3 + z3; - let x3 = t1 - z3; - let z3 = t1 + z3; - let y3 = x3 * z3; - let t1 = t0 + t0; - let t1 = t1 + t0; - let t2 = $name::curve_constant_a() * self.z; - let t4 = $name::mul_by_3b(&t4); - let t1 = t1 + t2; - let t2 = t0 - t2; - let t2 = $name::curve_constant_a() * t2; - let t4 = t4 + t2; - let t0 = t1 * t4; - let y3 = y3 + t0; - let t0 = t5 * t4; - let x3 = t3 * x3; - let x3 = x3 - t0; - let t0 = t3 * t1; - let z3 = t5 * z3; - let z3 = z3 + t0; - - let tmp = $name{ - x: x3, - y: y3, - z: z3, - }; - - $name::conditional_select(&tmp, self, rhs.is_identity()) + if $constant_a == $base::ZERO { + // Algorithm 8, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x * rhs.x; + let t1 = self.y * rhs.y; + let t3 = rhs.x + rhs.y; + let t4 = self.x + self.y; + let t3 = t3 * t4; + let t4 = t0 + t1; + let t3 = t3 - t4; + let t4 = rhs.y * self.z; + let t4 = t4 + self.y; + let y3 = rhs.x * self.z; + let y3 = y3 + self.x; + let x3 = t0 + t0; + let t0 = x3 + t0; + let t2 = $name::mul_by_3b(&self.z); + let z3 = t1 + t2; + let t1 = t1 - t2; + let y3 = $name::mul_by_3b(&y3); + let x3 = t4 * y3; + let t2 = t3 * t1; + let x3 = t2 - x3; + let y3 = y3 * t0; + let t1 = t1 * z3; + let y3 = t1 + y3; + let t0 = t0 * t3; + let z3 = z3 * t4; + let z3 = z3 + t0; + + let tmp = $name{ + x: x3, + y: y3, + z: z3, + }; + + $name::conditional_select(&tmp, self, rhs.is_identity()) + } else { + // Algorithm 2, https://eprint.iacr.org/2015/1060.pdf + let t0 = self.x * rhs.x; + let t1 = self.y * rhs.y; + let t3 = rhs.x + rhs.y; + let t4 = self.x + self.y; + let t3 = t3 * t4; + let t4 = t0 + t1; + let t3 = t3 - t4; + let t4 = rhs.x * self.z; + let t4 = t4 + self.x; + let t5 = rhs.y * self.z; + let t5 = t5 + self.y; + let z3 = $constant_a * t4; + let x3 = $name::mul_by_3b(&self.z); + let z3 = x3 + z3; + let x3 = t1 - z3; + let z3 = t1 + z3; + let y3 = x3 * z3; + let t1 = t0 + t0; + let t1 = t1 + t0; + let t2 = $constant_a * self.z; + let t4 = $name::mul_by_3b(&t4); + let t1 = t1 + t2; + let t2 = t0 - t2; + let t2 = $constant_a * t2; + let t4 = t4 + t2; + let t0 = t1 * t4; + let y3 = y3 + t0; + let t0 = t5 * t4; + let x3 = t3 * x3; + let x3 = x3 - t0; + let t0 = t3 * t1; + let z3 = t5 * z3; + let z3 = z3 + t0; + + let tmp = $name{ + x: x3, + y: y3, + z: z3, + }; + + $name::conditional_select(&tmp, self, rhs.is_identity()) + } } } diff --git a/src/legendre.rs b/src/legendre.rs index 6f6fda17..7e4b9971 100644 --- a/src/legendre.rs +++ b/src/legendre.rs @@ -28,7 +28,7 @@ pub trait Legendre: Field { #[macro_export] macro_rules! prime_field_legendre { ($field:ident ) => { - impl crate::legendre::Legendre for $field { + impl $crate::legendre::Legendre for $field { type BasePrimeField = Self; #[inline] From 8e3a33af78c941bb87ab8a5e81dc4cb3d09c0d69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Mon, 18 Sep 2023 11:43:26 +0200 Subject: [PATCH 05/13] fix: Improve serialization for prime fields (#85) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: Improve serialization for prime fields Summary: 256-bit field serialization is currently 4x u64, ie. the native format. This implements the standard of byte-serialization (corresponding to the PrimeField::{to,from}_repr), and an hex-encoded variant of that for (de)serializers that are human-readable (concretely, json). - Added a new macro `serialize_deserialize_32_byte_primefield!` for custom serialization and deserialization of 32-byte prime field in different struct (Fq, Fp, Fr) across the secp256r, bn256, and derive libraries. - Implemented the new macro for serialization and deserialization in various structs, replacing the previous `serde::{Deserialize, Serialize}` direct use. - Enhanced error checking in the custom serialization methods to ensure valid field elements. - Updated the test function in the tests/field.rs file to include JSON serialization and deserialization tests for object integrity checking. * fixup! fix: Improve serialization for prime fields --------- Co-authored-by: Carlos Pérez <37264926+CPerezz@users.noreply.github.com> --- Cargo.toml | 4 +++- src/bn256/fq.rs | 7 +++---- src/bn256/fr.rs | 7 +++---- src/derive/field.rs | 35 +++++++++++++++++++++++++++++++++++ src/secp256k1/fp.rs | 7 +++---- src/secp256k1/fq.rs | 7 +++---- src/secp256r1/fp.rs | 7 +++---- src/secp256r1/fq.rs | 7 +++---- src/tests/field.rs | 7 +++++++ 9 files changed, 63 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f29c917e..a843a97c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ criterion = { version = "0.3", features = ["html_reports"] } rand_xorshift = "0.3" ark-std = { version = "0.3" } bincode = "1.3.3" +serde_json = "1.0.105" [dependencies] subtle = "2.4" @@ -30,6 +31,7 @@ num-traits = "0.2" paste = "1.0.11" serde = { version = "1.0", default-features = false, optional = true } serde_arrays = { version = "0.1.0", optional = true } +hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] } blake2b_simd = "1" [features] @@ -37,7 +39,7 @@ default = ["reexport", "bits"] asm = [] bits = ["ff/bits"] bn256-table = [] -derive_serde = ["serde/derive", "serde_arrays"] +derive_serde = ["serde/derive", "serde_arrays", "hex"] prefetch = [] print-trace = ["ark-std/print-trace"] reexport = [] diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index 0024723a..fec8d863 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -16,9 +16,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_q$ where /// /// `p = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47` @@ -28,9 +25,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fq` values are always in // Montgomery form; i.e., Fq(a) = aR mod q, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fq(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fq); + /// Constant representing the modulus /// q = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 const MODULUS: Fq = Fq([ diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index 8a57ff9f..7e3b5ae8 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -31,9 +31,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_r$ where /// /// `r = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001` @@ -43,9 +40,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fr` values are always in // Montgomery form; i.e., Fr(a) = aR mod r, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fr(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fr); + /// Constant representing the modulus /// r = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 const MODULUS: Fr = Fr([ diff --git a/src/derive/field.rs b/src/derive/field.rs index 945ee981..da962457 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -686,3 +686,38 @@ macro_rules! field_bits { } }; } + +/// A macro to help define serialization and deserialization for prime field implementations +/// that use 32-byte representations. This assumes the concerned type implements PrimeField +/// (for from_repr, to_repr). +#[macro_export] +macro_rules! serialize_deserialize_32_byte_primefield { + ($type:ty) => { + impl ::serde::Serialize for $type { + fn serialize(&self, serializer: S) -> Result { + let bytes = &self.to_repr(); + if serializer.is_human_readable() { + hex::serde::serialize(bytes, serializer) + } else { + bytes.serialize(serializer) + } + } + } + + use ::serde::de::Error as _; + impl<'de> ::serde::Deserialize<'de> for $type { + fn deserialize>( + deserializer: D, + ) -> Result { + let bytes = if deserializer.is_human_readable() { + ::hex::serde::deserialize(deserializer)? + } else { + <[u8; 32]>::deserialize(deserializer)? + }; + Option::from(Self::from_repr(bytes)).ok_or_else(|| { + D::Error::custom("deserialized bytes don't encode a valid field element") + }) + } + } + }; +} diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index f6a2a54b..c346dc6c 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -11,9 +11,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_p$ where /// /// `p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f` @@ -23,9 +20,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fp` values are always in // Montgomery form; i.e., Fp(a) = aR mod p, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fp(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fp); + /// Constant representing the modulus /// p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f const MODULUS: Fp = Fp([ diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index 304f5f10..189daaba 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -11,9 +11,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_q$ where /// /// `q = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141` @@ -23,9 +20,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fq` values are always in // Montgomery form; i.e., Fq(a) = aR mod q, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fq(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fq); + /// Constant representing the modulus /// q = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 const MODULUS: Fq = Fq([ diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index 228e4a67..bf86e157 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -11,9 +11,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_p$ where /// /// `p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff @@ -23,9 +20,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fp` values are always in // Montgomery form; i.e., Fp(a) = aR mod p, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fp(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fp); + /// Constant representing the modulus /// p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff const MODULUS: Fp = Fp([ diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index 1b98761c..d1a7b809 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -5,9 +5,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_q$ where /// /// `q = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551` @@ -17,9 +14,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fq` values are always in // Montgomery form; i.e., Fq(a) = aR mod q, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fq(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fq); + /// Constant representing the modulus /// q = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551 const MODULUS: Fq = Fq([ diff --git a/src/tests/field.rs b/src/tests/field.rs index b04f801e..02f5509f 100644 --- a/src/tests/field.rs +++ b/src/tests/field.rs @@ -280,11 +280,18 @@ where let _message = format!("serialization with serde {type_name}"); let start = start_timer!(|| _message); for _ in 0..1000000 { + // byte serialization let a = F::random(&mut rng); let bytes = bincode::serialize(&a).unwrap(); let reader = std::io::Cursor::new(bytes); let b: F = bincode::deserialize_from(reader).unwrap(); assert_eq!(a, b); + + // json serialization + let json = serde_json::to_string(&a).unwrap(); + let reader = std::io::Cursor::new(json); + let b: F = serde_json::from_reader(reader).unwrap(); + assert_eq!(a, b); } end_timer!(start); } From 2f3e388eef9b788adf126bb4a8abb10877a0a04d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Mon, 18 Sep 2023 17:02:14 +0200 Subject: [PATCH 06/13] refactor: (De)Serialization of points using `GroupEncoding` (#88) * refactor: implement (De)Serialization of points using the `GroupEncoding` trait - Updated curve point (de)serialization logic from the internal representation to the representation offered by the implementation of the `GroupEncoding` trait. * fix: add explicit json serde tests --- src/derive/curve.rs | 70 +++++++++++++++++++++++++++++++++++++++++++-- src/tests/curve.rs | 12 ++++++++ 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/src/derive/curve.rs b/src/derive/curve.rs index 098d1a2f..660e85d8 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -288,8 +288,72 @@ macro_rules! new_curve_impl { } + /// A macro to help define point serialization using the [`group::GroupEncoding`] trait + /// This assumes both point types ($name, $nameaffine) implement [`group::GroupEncoding`]. + #[cfg(feature = "derive_serde")] + macro_rules! serialize_deserialize_to_from_bytes { + () => { + impl ::serde::Serialize for $name { + fn serialize(&self, serializer: S) -> Result { + let bytes = &self.to_bytes(); + if serializer.is_human_readable() { + ::hex::serde::serialize(&bytes.0, serializer) + } else { + ::serde_arrays::serialize(&bytes.0, serializer) + } + } + } + + paste::paste! { + use ::serde::de::Error as _; + impl<'de> ::serde::Deserialize<'de> for $name { + fn deserialize>( + deserializer: D, + ) -> Result { + let bytes = if deserializer.is_human_readable() { + ::hex::serde::deserialize(deserializer)? + } else { + ::serde_arrays::deserialize::<_, u8, [< $name _COMPRESSED_SIZE >]>(deserializer)? + }; + Option::from(Self::from_bytes(&[< $name Compressed >](bytes))).ok_or_else(|| { + D::Error::custom("deserialized bytes don't encode a valid field element") + }) + } + } + } + + impl ::serde::Serialize for $name_affine { + fn serialize(&self, serializer: S) -> Result { + let bytes = &self.to_bytes(); + if serializer.is_human_readable() { + ::hex::serde::serialize(&bytes.0, serializer) + } else { + ::serde_arrays::serialize(&bytes.0, serializer) + } + } + } + + paste::paste! { + use ::serde::de::Error as _; + impl<'de> ::serde::Deserialize<'de> for $name_affine { + fn deserialize>( + deserializer: D, + ) -> Result { + let bytes = if deserializer.is_human_readable() { + ::hex::serde::deserialize(deserializer)? + } else { + ::serde_arrays::deserialize::<_, u8, [< $name _COMPRESSED_SIZE >]>(deserializer)? + }; + Option::from(Self::from_bytes(&[< $name Compressed >](bytes))).ok_or_else(|| { + D::Error::custom("deserialized bytes don't encode a valid field element") + }) + } + } + } + }; + } + #[derive(Copy, Clone, Debug)] - #[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] $($privacy)* struct $name { pub x: $base, pub y: $base, @@ -297,13 +361,13 @@ macro_rules! new_curve_impl { } #[derive(Copy, Clone, PartialEq)] - #[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] $($privacy)* struct $name_affine { pub x: $base, pub y: $base, } - + #[cfg(feature = "derive_serde")] + serialize_deserialize_to_from_bytes!(); impl_compressed!(); impl_uncompressed!(); diff --git a/src/tests/curve.rs b/src/tests/curve.rs index 54d23791..2f93bbb4 100644 --- a/src/tests/curve.rs +++ b/src/tests/curve.rs @@ -74,12 +74,24 @@ where assert_eq!(projective_point.to_affine(), affine_point_rec); assert_eq!(affine_point, affine_point_rec); } + { + let affine_json = serde_json::to_string(&affine_point).unwrap(); + let reader = std::io::Cursor::new(affine_json); + let affine_point_rec: G::AffineExt = serde_json::from_reader(reader).unwrap(); + assert_eq!(affine_point, affine_point_rec); + } { let projective_bytes = bincode::serialize(&projective_point).unwrap(); let reader = std::io::Cursor::new(projective_bytes); let projective_point_rec: G = bincode::deserialize_from(reader).unwrap(); assert_eq!(projective_point, projective_point_rec); } + { + let projective_json = serde_json::to_string(&projective_point).unwrap(); + let reader = std::io::Cursor::new(projective_json); + let projective_point_rec: G = serde_json::from_reader(reader).unwrap(); + assert_eq!(projective_point, projective_point_rec); + } } } From ee7cb86ce7d733586e7ac48e4dc25930d7851d85 Mon Sep 17 00:00:00 2001 From: einar-taiko <126954546+einar-taiko@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:09:47 +0800 Subject: [PATCH 07/13] Insert MSM and FFT code and their benchmarks. (#86) * Insert MSM and FFT code and their benchmarks. Resolves taikoxyz/zkevm-circuits#150. * feedback * Add instructions * feeback * Implement feedback: Actually supply the correct arguments to `best_multiexp`. Split into `singlecore` and `multicore` benchmarks so Criterion's result caching and comparison over multiple runs makes sense. Rewrite point and scalar generation. * Use slicing and parallelism to to decrease running time. Laptop measurements: k=22: 109 sec k=16: 1 sec * Refactor msm * Refactor fft * Update module comments * Fix formatting * Implement suggestion for fixing CI --- Cargo.toml | 13 +++- benches/fft.rs | 57 ++++++++++++++++++ benches/msm.rs | 116 +++++++++++++++++++++++++++++++++++ src/fft.rs | 134 +++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 3 + src/msm.rs | 153 +++++++++++++++++++++++++++++++++++++++++++++++ src/multicore.rs | 16 +++++ 7 files changed, 491 insertions(+), 1 deletion(-) create mode 100644 benches/fft.rs create mode 100644 benches/msm.rs create mode 100644 src/fft.rs create mode 100644 src/msm.rs create mode 100644 src/multicore.rs diff --git a/Cargo.toml b/Cargo.toml index a843a97c..43fa7d03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,9 +33,11 @@ serde = { version = "1.0", default-features = false, optional = true } serde_arrays = { version = "0.1.0", optional = true } hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] } blake2b_simd = "1" +maybe-rayon = { version = "0.1.0", default-features = false } [features] -default = ["reexport", "bits"] +default = ["reexport", "bits", "multicore"] +multicore = ["maybe-rayon/threads"] asm = [] bits = ["ff/bits"] bn256-table = [] @@ -69,3 +71,12 @@ harness = false [[bench]] name = "hash_to_curve" harness = false + +[[bench]] +name = "fft" +harness = false + +[[bench]] +name = "msm" +harness = false +required-features = ["multicore"] diff --git a/benches/fft.rs b/benches/fft.rs new file mode 100644 index 00000000..a250308d --- /dev/null +++ b/benches/fft.rs @@ -0,0 +1,57 @@ +//! This benchmarks Fast-Fourier Transform (FFT). +//! Since it is over a finite field, it is actually the Number Theoretical +//! Transform (NNT). It uses the `Fr` scalar field from the BN256 curve. +//! +//! To run this benchmark: +//! +//! cargo bench -- fft +//! +//! Caveat: The multicore benchmark assumes: +//! 1. a multi-core system +//! 2. that the `multicore` feature is enabled. It is by default. + +#[macro_use] +extern crate criterion; + +use criterion::{BenchmarkId, Criterion}; +use group::ff::Field; +use halo2curves::bn256::Fr as Scalar; +use halo2curves::fft::best_fft; +use rand_core::OsRng; +use std::ops::Range; +use std::time::SystemTime; + +const RANGE: Range = 3..19; + +fn generate_data(k: u32) -> Vec { + let n = 1 << k; + let timer = SystemTime::now(); + println!("\n\nGenerating 2^{k} = {n} values..",); + let data: Vec = (0..n).map(|_| Scalar::random(OsRng)).collect(); + let end = timer.elapsed().unwrap(); + println!( + "Generating 2^{k} = {n} values took: {} sec.\n\n", + end.as_secs() + ); + data +} + +fn fft(c: &mut Criterion) { + let max_k = RANGE.max().unwrap_or(16); + let mut data = generate_data(max_k); + let omega = Scalar::random(OsRng); + let mut group = c.benchmark_group("fft"); + for k in RANGE { + group.bench_function(BenchmarkId::new("k", k), |b| { + let n = 1 << k; + assert!(n <= data.len()); + b.iter(|| { + best_fft(&mut data[..n], omega, k); + }); + }); + } + group.finish(); +} + +criterion_group!(benches, fft); +criterion_main!(benches); diff --git a/benches/msm.rs b/benches/msm.rs new file mode 100644 index 00000000..c78952b7 --- /dev/null +++ b/benches/msm.rs @@ -0,0 +1,116 @@ +//! This benchmarks Multi Scalar Multiplication (MSM). +//! It measures `G1` from the BN256 curve. +//! +//! To run this benchmark: +//! +//! cargo bench -- msm +//! +//! Caveat: The multicore benchmark assumes: +//! 1. a multi-core system +//! 2. that the `multicore` feature is enabled. It is by default. + +#[macro_use] +extern crate criterion; + +use criterion::{BenchmarkId, Criterion}; +use ff::Field; +use group::prime::PrimeCurveAffine; +use halo2curves::bn256::{Fr as Scalar, G1Affine as Point}; +use halo2curves::msm::{best_multiexp, multiexp_serial}; +use maybe_rayon::current_thread_index; +use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use rand_core::SeedableRng; +use rand_xorshift::XorShiftRng; +use std::time::SystemTime; + +const SAMPLE_SIZE: usize = 10; +const SINGLECORE_RANGE: [u8; 6] = [3, 8, 10, 12, 14, 16]; +const MULTICORE_RANGE: [u8; 9] = [3, 8, 10, 12, 14, 16, 18, 20, 22]; +const SEED: [u8; 16] = [ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5, +]; + +fn generate_coefficients_and_curvepoints(k: u8) -> (Vec, Vec) { + let n: u64 = { + assert!(k < 64); + 1 << k + }; + + println!("\n\nGenerating 2^{k} = {n} coefficients and curve points..",); + let timer = SystemTime::now(); + let coeffs = (0..n) + .into_par_iter() + .map_init( + || { + let mut thread_seed = SEED; + let uniq = current_thread_index().unwrap().to_ne_bytes(); + assert!(std::mem::size_of::() == 8); + for i in 0..uniq.len() { + thread_seed[i] += uniq[i]; + thread_seed[i + 8] += uniq[i]; + } + XorShiftRng::from_seed(thread_seed) + }, + |rng, _| Scalar::random(rng), + ) + .collect(); + let bases = (0..n) + .into_par_iter() + .map_init( + || { + let mut thread_seed = SEED; + let uniq = current_thread_index().unwrap().to_ne_bytes(); + assert!(std::mem::size_of::() == 8); + for i in 0..uniq.len() { + thread_seed[i] += uniq[i]; + thread_seed[i + 8] += uniq[i]; + } + XorShiftRng::from_seed(thread_seed) + }, + |rng, _| Point::random(rng), + ) + .collect(); + let end = timer.elapsed().unwrap(); + println!( + "Generating 2^{k} = {n} coefficients and curve points took: {} sec.\n\n", + end.as_secs() + ); + + (coeffs, bases) +} + +fn msm(c: &mut Criterion) { + let mut group = c.benchmark_group("msm"); + let max_k = *SINGLECORE_RANGE + .iter() + .chain(MULTICORE_RANGE.iter()) + .max() + .unwrap_or(&16); + let (coeffs, bases) = generate_coefficients_and_curvepoints(max_k); + + for k in SINGLECORE_RANGE { + group + .bench_function(BenchmarkId::new("singlecore", k), |b| { + assert!(k < 64); + let n: usize = 1 << k; + let mut acc = Point::identity().into(); + b.iter(|| multiexp_serial(&coeffs[..n], &bases[..n], &mut acc)); + }) + .sample_size(10); + } + for k in MULTICORE_RANGE { + group + .bench_function(BenchmarkId::new("multicore", k), |b| { + assert!(k < 64); + let n: usize = 1 << k; + b.iter(|| { + best_multiexp(&coeffs[..n], &bases[..n]); + }) + }) + .sample_size(SAMPLE_SIZE); + } + group.finish(); +} + +criterion_group!(benches, msm); +criterion_main!(benches); diff --git a/src/fft.rs b/src/fft.rs new file mode 100644 index 00000000..6eb3487e --- /dev/null +++ b/src/fft.rs @@ -0,0 +1,134 @@ +use crate::multicore; +pub use crate::{CurveAffine, CurveExt}; +use ff::Field; +use group::{GroupOpsOwned, ScalarMulOwned}; + +/// This represents an element of a group with basic operations that can be +/// performed. This allows an FFT implementation (for example) to operate +/// generically over either a field or elliptic curve group. +pub trait FftGroup: + Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned +{ +} + +impl FftGroup for T +where + Scalar: Field, + T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned, +{ +} + +/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size +/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative +/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when +/// interpreted as the coefficients of a polynomial of degree $n - 1$, is +/// transformed into the evaluations of this polynomial at each of the $n$ +/// distinct powers of $\omega$. This transformation is invertible by providing +/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element +/// by $n$. +/// +/// This will use multithreading if beneficial. +pub fn best_fft>(a: &mut [G], omega: Scalar, log_n: u32) { + fn bitreverse(mut n: usize, l: usize) -> usize { + let mut r = 0; + for _ in 0..l { + r = (r << 1) | (n & 1); + n >>= 1; + } + r + } + + let threads = multicore::current_num_threads(); + let log_threads = threads.ilog2(); + let n = a.len(); + assert_eq!(n, 1 << log_n); + + for k in 0..n { + let rk = bitreverse(k, log_n as usize); + if k < rk { + a.swap(rk, k); + } + } + + // precompute twiddle factors + let twiddles: Vec<_> = (0..(n / 2)) + .scan(Scalar::ONE, |w, _| { + let tw = *w; + *w *= ω + Some(tw) + }) + .collect(); + + if log_n <= log_threads { + let mut chunk = 2_usize; + let mut twiddle_chunk = n / 2; + for _ in 0..log_n { + a.chunks_mut(chunk).for_each(|coeffs| { + let (left, right) = coeffs.split_at_mut(chunk / 2); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + }); + chunk *= 2; + twiddle_chunk /= 2; + } + } else { + recursive_butterfly_arithmetic(a, n, 1, &twiddles) + } +} + +/// This perform recursive butterfly arithmetic +pub fn recursive_butterfly_arithmetic>( + a: &mut [G], + n: usize, + twiddle_chunk: usize, + twiddles: &[Scalar], +) { + if n == 2 { + let t = a[1]; + a[1] = a[0]; + a[0] += &t; + a[1] -= &t; + } else { + let (left, right) = a.split_at_mut(n / 2); + multicore::join( + || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles), + || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles), + ); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3fa8e98f..670a6448 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,8 @@ mod arithmetic; +pub mod fft; pub mod hash_to_curve; +pub mod msm; +pub mod multicore; #[macro_use] pub mod legendre; pub mod serde; diff --git a/src/msm.rs b/src/msm.rs new file mode 100644 index 00000000..de30be55 --- /dev/null +++ b/src/msm.rs @@ -0,0 +1,153 @@ +use ff::PrimeField; +use group::Group; +use pasta_curves::arithmetic::CurveAffine; + +use crate::multicore; + +pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + + let c = if bases.len() < 4 { + 1 + } else if bases.len() < 32 { + 3 + } else { + (f64::from(bases.len() as u32)).ln().ceil() as usize + }; + + fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { + let skip_bits = segment * c; + let skip_bytes = skip_bits / 8; + + if skip_bytes >= 32 { + return 0; + } + + let mut v = [0; 8]; + for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { + *v = *o; + } + + let mut tmp = u64::from_le_bytes(v); + tmp >>= skip_bits - (skip_bytes * 8); + tmp %= 1 << c; + + tmp as usize + } + + let segments = (256 / c) + 1; + + for current_segment in (0..segments).rev() { + for _ in 0..c { + *acc = acc.double(); + } + + #[derive(Clone, Copy)] + enum Bucket { + None, + Affine(C), + Projective(C::Curve), + } + + impl Bucket { + fn add_assign(&mut self, other: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*other), + Bucket::Affine(a) => Bucket::Projective(a + *other), + Bucket::Projective(mut a) => { + a += *other; + Bucket::Projective(a) + } + } + } + + fn add(self, mut other: C::Curve) -> C::Curve { + match self { + Bucket::None => other, + Bucket::Affine(a) => { + other += a; + other + } + Bucket::Projective(a) => other + a, + } + } + } + + let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; + + for (coeff, base) in coeffs.iter().zip(bases.iter()) { + let coeff = get_at::(current_segment, c, coeff); + if coeff != 0 { + buckets[coeff - 1].add_assign(base); + } + } + + // Summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for exp in buckets.into_iter().rev() { + running_sum = exp.add(running_sum); + *acc += &running_sum; + } + } +} + +/// Performs a small multi-exponentiation operation. +/// Uses the double-and-add algorithm with doublings shared across points. +pub fn small_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + let mut acc = C::Curve::identity(); + + // for byte idx + for byte_idx in (0..32).rev() { + // for bit idx + for bit_idx in (0..8).rev() { + acc = acc.double(); + // for each coeff + for coeff_idx in 0..coeffs.len() { + let byte = coeffs[coeff_idx].as_ref()[byte_idx]; + if ((byte >> bit_idx) & 1) != 0 { + acc += bases[coeff_idx]; + } + } + } + } + + acc +} + +/// Performs a multi-exponentiation operation. +/// +/// This function will panic if coeffs and bases have a different length. +/// +/// This will use multithreading if beneficial. +pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + assert_eq!(coeffs.len(), bases.len()); + + let num_threads = multicore::current_num_threads(); + if coeffs.len() > num_threads { + let chunk = coeffs.len() / num_threads; + let num_chunks = coeffs.chunks(chunk).len(); + let mut results = vec![C::Curve::identity(); num_chunks]; + multicore::scope(|scope| { + let chunk = coeffs.len() / num_threads; + + for ((coeffs, bases), acc) in coeffs + .chunks(chunk) + .zip(bases.chunks(chunk)) + .zip(results.iter_mut()) + { + scope.spawn(move |_| { + multiexp_serial(coeffs, bases, acc); + }); + } + }); + results.iter().fold(C::Curve::identity(), |a, b| a + b) + } else { + let mut acc = C::Curve::identity(); + multiexp_serial(coeffs, bases, &mut acc); + acc + } +} diff --git a/src/multicore.rs b/src/multicore.rs new file mode 100644 index 00000000..d8323553 --- /dev/null +++ b/src/multicore.rs @@ -0,0 +1,16 @@ +pub use maybe_rayon::{ + iter::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, + join, scope, Scope, +}; + +#[cfg(feature = "multicore")] +pub use maybe_rayon::{ + current_num_threads, + iter::{IndexedParallelIterator, IntoParallelRefIterator}, + slice::ParallelSliceMut, +}; + +#[cfg(not(feature = "multicore"))] +pub fn current_num_threads() -> usize { + 1 +} From 294c86f6c6a1b700cbe2f060ed38049d5b55f897 Mon Sep 17 00:00:00 2001 From: Han Date: Thu, 5 Oct 2023 16:20:29 +0800 Subject: [PATCH 08/13] Re-export also mod `pairing` and remove flag `reexport` to alwasy re-export (#93) fix: re-export also mod `pairing` and remove flag `reexport` to alwasy re-export --- Cargo.toml | 3 +-- src/lib.rs | 11 +++-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 43fa7d03..dac2d327 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ blake2b_simd = "1" maybe-rayon = { version = "0.1.0", default-features = false } [features] -default = ["reexport", "bits", "multicore"] +default = ["bits", "multicore"] multicore = ["maybe-rayon/threads"] asm = [] bits = ["ff/bits"] @@ -44,7 +44,6 @@ bn256-table = [] derive_serde = ["serde/derive", "serde_arrays", "hex"] prefetch = [] print-trace = ["ark-std/print-trace"] -reexport = [] [profile.bench] opt-level = 3 diff --git a/src/lib.rs b/src/lib.rs index 670a6448..5bd7d506 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,17 +16,12 @@ pub mod secq256k1; #[macro_use] mod derive; -pub use pasta_curves::arithmetic::{Coordinates, CurveAffine, CurveExt}; -// Re-export ff and group to simplify down stream dependencies -#[cfg(feature = "reexport")] +// Re-export to simplify down stream dependencies pub use ff; -#[cfg(not(feature = "reexport"))] -use ff; -#[cfg(feature = "reexport")] pub use group; -#[cfg(not(feature = "reexport"))] -use group; +pub use pairing; +pub use pasta_curves::arithmetic::{Coordinates, CurveAffine, CurveExt}; #[cfg(test)] pub mod tests; From 91d8dc1ca3a802933a0153330bc9d79ea8789bbb Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Mon, 16 Oct 2023 08:17:57 +0000 Subject: [PATCH 09/13] fix regression in #93 reexport field benches aren't run (#94) fix regression in https://github.com/privacy-scaling-explorations/halo2curves/pull/93, field benches aren't run --- Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index dac2d327..dc0d61dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,6 @@ harness = false [[bench]] name = "bn256_field" harness = false -required-features = ["reexport"] [[bench]] name = "group" From a83509660676f252d8534684fcc718603e72b1ee Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Thu, 19 Oct 2023 03:26:57 +0200 Subject: [PATCH 10/13] Fast modular inverse - 9.4x acceleration (#83) * Bernstein yang modular multiplicative inverter (#2) * rename similar to https://github.com/privacy-scaling-explorations/halo2curves/pull/95 --------- Co-authored-by: Aleksei Vambol <77882392+AlekseiVambol@users.noreply.github.com> --- benches/bn256_field.rs | 3 + src/bn256/fq.rs | 13 +- src/bn256/fr.rs | 13 +- src/derive/field.rs | 16 ++ src/ff_inverse.rs | 424 +++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/secp256k1/fp.rs | 13 +- src/secp256k1/fq.rs | 13 +- src/secp256r1/fp.rs | 13 +- src/secp256r1/fq.rs | 13 +- 10 files changed, 462 insertions(+), 60 deletions(-) create mode 100644 src/ff_inverse.rs diff --git a/benches/bn256_field.rs b/benches/bn256_field.rs index 8a17aef3..814d13b4 100644 --- a/benches/bn256_field.rs +++ b/benches/bn256_field.rs @@ -40,6 +40,9 @@ pub fn bench_bn256_field(c: &mut Criterion) { group.bench_function("bn256_fq_square", |bencher| { bencher.iter(|| black_box(&a).square()) }); + group.bench_function("bn256_fq_invert", |bencher| { + bencher.iter(|| black_box(&a).invert()) + }); } criterion_group!(benches, bench_bn256_field); diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index fec8d863..56be690f 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -198,17 +198,10 @@ impl ff::Field for Fq { ff::helpers::sqrt_ratio_generic(num, div) } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow([ - 0x3c208c16d87cfd45, - 0x97816a916871ca8d, - 0xb85045b68181585d, - 0x30644e72e131a029, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } } diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index 7e3b5ae8..9e02fcdc 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -231,17 +231,10 @@ impl ff::Field for Fr { self.square() } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow([ - 0x43e1f593efffffff, - 0x2833e84879b97091, - 0xb85045b68181585d, - 0x30644e72e131a029, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn sqrt(&self) -> CtOption { diff --git a/src/derive/field.rs b/src/derive/field.rs index da962457..a9f57c4a 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -24,6 +24,11 @@ macro_rules! field_common { $r2:ident, $r3:ident ) => { + /// Bernstein-Yang modular multiplicative inverter created for the modulus equal to + /// the characteristic of the field to invert positive integers in the Montgomery form. + const BYINVERTOR: $crate::ff_inverse::BYInverter<6> = + $crate::ff_inverse::BYInverter::<6>::new(&$modulus.0, &$r2.0); + impl $field { /// Returns zero, the additive identity. #[inline] @@ -37,6 +42,16 @@ macro_rules! field_common { $r } + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. + pub fn invert(&self) -> CtOption { + if let Some(inverse) = BYINVERTOR.invert(&self.0) { + CtOption::new(Self(inverse), Choice::from(1)) + } else { + CtOption::new(Self::zero(), Choice::from(0)) + } + } + fn from_u512(limbs: [u64; 8]) -> $field { // We reduce an arbitrary 512-bit number by decomposing it into two 256-bit digits // with the higher bits multiplied by 2^256. Thus, we perform two reductions @@ -345,6 +360,7 @@ macro_rules! field_common { macro_rules! field_arithmetic { ($field:ident, $modulus:ident, $inv:ident, $field_type:ident) => { field_specific!($field, $modulus, $inv, $field_type); + impl $field { /// Doubles this field element. #[inline] diff --git a/src/ff_inverse.rs b/src/ff_inverse.rs new file mode 100644 index 00000000..53285e6c --- /dev/null +++ b/src/ff_inverse.rs @@ -0,0 +1,424 @@ +use core::cmp::PartialEq; +use std::ops::{Add, Mul, Neg, Sub}; + +/// Big signed (B * L)-bit integer type, whose variables store +/// numbers in the two's complement code as arrays of B-bit chunks. +/// The ordering of the chunks in these arrays is little-endian. +/// The arithmetic operations for this type are wrapping ones. +#[derive(Clone)] +struct CInt(pub [u64; L]); + +impl CInt { + /// Mask, in which the B lowest bits are 1 and only they + pub const MASK: u64 = u64::MAX >> (64 - B); + + /// Representation of -1 + pub const MINUS_ONE: Self = Self([Self::MASK; L]); + + /// Representation of 0 + pub const ZERO: Self = Self([0; L]); + + /// Representation of 1 + pub const ONE: Self = { + let mut data = [0; L]; + data[0] = 1; + Self(data) + }; + + /// Returns the result of applying B-bit right + /// arithmetical shift to the current number + pub fn shift(&self) -> Self { + let mut data = [0; L]; + if self.is_negative() { + data[L - 1] = Self::MASK; + } + data[..L - 1].copy_from_slice(&self.0[1..]); + Self(data) + } + + /// Returns the lowest B bits of the current number + pub fn lowest(&self) -> u64 { + self.0[0] + } + + /// Returns "true" iff the current number is negative + pub fn is_negative(&self) -> bool { + self.0[L - 1] > (Self::MASK >> 1) + } +} + +impl PartialEq for CInt { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Add for &CInt { + type Output = CInt; + fn add(self, other: Self) -> Self::Output { + let (mut data, mut carry) = ([0; L], 0); + for i in 0..L { + let sum = self.0[i] + other.0[i] + carry; + data[i] = sum & CInt::::MASK; + carry = sum >> B; + } + Self::Output { 0: data } + } +} + +impl Add<&CInt> for CInt { + type Output = CInt; + fn add(self, other: &Self) -> Self::Output { + &self + other + } +} + +impl Add for CInt { + type Output = CInt; + fn add(self, other: Self) -> Self::Output { + &self + &other + } +} + +impl Sub for &CInt { + type Output = CInt; + fn sub(self, other: Self) -> Self::Output { + // For the two's complement code the additive negation is the result of + // adding 1 to the bitwise inverted argument's representation. Thus, for + // any encoded integers x and y we have x - y = x + !y + 1, where "!" is + // the bitwise inversion and addition is done according to the rules of + // the code. The algorithm below uses this formula and is the modified + // addition algorithm, where the carry flag is initialized with 1 and + // the chunks of the second argument are bitwise inverted + let (mut data, mut carry) = ([0; L], 1); + for i in 0..L { + let sum = self.0[i] + (other.0[i] ^ CInt::::MASK) + carry; + data[i] = sum & CInt::::MASK; + carry = sum >> B; + } + Self::Output { 0: data } + } +} + +impl Sub<&CInt> for CInt { + type Output = CInt; + fn sub(self, other: &Self) -> Self::Output { + &self - other + } +} + +impl Sub for CInt { + type Output = CInt; + fn sub(self, other: Self) -> Self::Output { + &self - &other + } +} + +impl Neg for &CInt { + type Output = CInt; + fn neg(self) -> Self::Output { + // For the two's complement code the additive negation is the result + // of adding 1 to the bitwise inverted argument's representation + let (mut data, mut carry) = ([0; L], 1); + for i in 0..L { + let sum = (self.0[i] ^ CInt::::MASK) + carry; + data[i] = sum & CInt::::MASK; + carry = sum >> B; + } + Self::Output { 0: data } + } +} + +impl Neg for CInt { + type Output = CInt; + fn neg(self) -> Self::Output { + -&self + } +} + +impl Mul for &CInt { + type Output = CInt; + fn mul(self, other: Self) -> Self::Output { + let mut data = [0; L]; + for i in 0..L { + let mut carry = 0; + for k in 0..(L - i) { + let sum = (data[i + k] as u128) + + (carry as u128) + + (self.0[i] as u128) * (other.0[k] as u128); + data[i + k] = sum as u64 & CInt::::MASK; + carry = (sum >> B) as u64; + } + } + Self::Output { 0: data } + } +} + +impl Mul<&CInt> for CInt { + type Output = CInt; + fn mul(self, other: &Self) -> Self::Output { + &self * other + } +} + +impl Mul for CInt { + type Output = CInt; + fn mul(self, other: Self) -> Self::Output { + &self * &other + } +} + +impl Mul for &CInt { + type Output = CInt; + fn mul(self, other: i64) -> Self::Output { + let mut data = [0; L]; + // If the short multiplicand is non-negative, the standard multiplication + // algorithm is performed. Otherwise, the product of the additively negated + // multiplicands is found as follows. Since for the two's complement code + // the additive negation is the result of adding 1 to the bitwise inverted + // argument's representation, for any encoded integers x and y we have + // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y), where "!" is + // the bitwise inversion and arithmetic operations are performed according + // to the rules of the code. If the short multiplicand is negative, the + // algorithm below uses this formula by substituting the short multiplicand + // for y and turns into the modified standard multiplication algorithm, + // where the carry flag is initialized with the additively negated short + // multiplicand and the chunks of the long multiplicand are bitwise inverted + let (other, mut carry, mask) = if other < 0 { + (-other, -other as u64, CInt::::MASK) + } else { + (other, 0, 0) + }; + for i in 0..L { + let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128); + data[i] = sum as u64 & CInt::::MASK; + carry = (sum >> B) as u64; + } + Self::Output { 0: data } + } +} + +impl Mul for CInt { + type Output = CInt; + fn mul(self, other: i64) -> Self::Output { + &self * other + } +} + +impl Mul<&CInt> for i64 { + type Output = CInt; + fn mul(self, other: &CInt) -> Self::Output { + other * self + } +} + +impl Mul> for i64 { + type Output = CInt; + fn mul(self, other: CInt) -> Self::Output { + other * self + } +} + +/// Type of the modular multiplicative inverter based on the Bernstein-Yang method. +/// The inverter can be created for a specified modulus M and adjusting parameter A +/// to compute the adjusted multiplicative inverses of positive integers, i.e. for +/// computing (1 / x) * A (mod M) for a positive integer x. +/// +/// The adjusting parameter allows computing the multiplicative inverses in the case of +/// using the Montgomery representation for the input or the expected output. If R is +/// the Montgomery factor, the multiplicative inverses in the appropriate representation +/// can be computed provided that the value of A is chosen as follows: +/// - A = 1, if both the input and the expected output are in the standard form +/// - A = R^2 mod M, if both the input and the expected output are in the Montgomery form +/// - A = R mod M, if either the input or the expected output is in the Montgomery form, +/// but not both of them +/// +/// The public methods of this type receive and return unsigned big integers as arrays of +/// 64-bit chunks, the ordering of which is little-endian. Both the modulus and the integer +/// to be inverted should not exceed 2 ^ (62 * L - 64) +/// +/// For better understanding the implementation, the following resources are recommended: +/// - D. Bernstein, B.-Y. Yang, "Fast constant-time gcd computation and modular inversion", +/// https://gcd.cr.yp.to/safegcd-20190413.pdf +/// - P. Wuille, "The safegcd implementation in libsecp256k1 explained", +/// https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md +pub struct BYInverter { + /// Modulus + modulus: CInt<62, L>, + + /// Adjusting parameter + adjuster: CInt<62, L>, + + /// Multiplicative inverse of the modulus modulo 2^62 + inverse: i64, +} + +/// Type of the Bernstein-Yang transition matrix multiplied by 2^62 +type Matrix = [[i64; 2]; 2]; + +impl BYInverter { + /// Returns the Bernstein-Yang transition matrix multiplied by 2^62 and the new value + /// of the delta variable for the 62 basic steps of the Bernstein-Yang method, which + /// are to be performed sequentially for specified initial values of f, g and delta + fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) { + let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128); + let mut t: Matrix = [[1, 0], [0, 1]]; + + loop { + let zeros = steps.min(g.trailing_zeros() as i64); + (steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros); + t[0] = [t[0][0] << zeros, t[0][1] << zeros]; + + if steps == 0 { + break; + } + if delta > 0 { + (delta, f, g) = (-delta, g as i64, -f as i128); + (t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]); + } + + // The formula (3 * x) xor 28 = -1 / x (mod 32) for an odd integer x + // in the two's complement code has been derived from the formula + // (3 * x) xor 2 = 1 / x (mod 32) attributed to Peter Montgomery + let mask = (1 << steps.min(1 - delta).min(5)) - 1; + let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask; + + t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]]; + g += w as i128 * f as i128; + } + + (delta, t) + } + + /// Returns the updated values of the variables f and g for specified initial ones and Bernstein-Yang transition + /// matrix multiplied by 2^62. The returned vector is "matrix * (f, g)' / 2^62", where "'" is the transpose operator + fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) { + ( + (t[0][0] * &f + t[0][1] * &g).shift(), + (t[1][0] * &f + t[1][1] * &g).shift(), + ) + } + + /// Returns the updated values of the variables d and e for specified initial ones and Bernstein-Yang transition + /// matrix multiplied by 2^62. The returned vector is congruent modulo M to "matrix * (d, e)' / 2^62 (mod M)", + /// where M is the modulus the inverter was created for and "'" stands for the transpose operator. Both the input + /// and output values lie in the interval (-2 * M, M) + fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) { + let mask = CInt::<62, L>::MASK as i64; + let mut md = t[0][0] * d.is_negative() as i64 + t[0][1] * e.is_negative() as i64; + let mut me = t[1][0] * d.is_negative() as i64 + t[1][1] * e.is_negative() as i64; + + let cd = t[0][0] + .wrapping_mul(d.lowest() as i64) + .wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64)) + & mask; + let ce = t[1][0] + .wrapping_mul(d.lowest() as i64) + .wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64)) + & mask; + + md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask; + me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask; + + let cd = t[0][0] * &d + t[0][1] * &e + md * &self.modulus; + let ce = t[1][0] * &d + t[1][1] * &e + me * &self.modulus; + + (cd.shift(), ce.shift()) + } + + /// Returns either "value (mod M)" or "-value (mod M)", where M is the modulus the + /// inverter was created for, depending on "negate", which determines the presence + /// of "-" in the used formula. The input integer lies in the interval (-2 * M, M) + fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> { + if value.is_negative() { + value = value + &self.modulus; + } + + if negate { + value = -value; + } + + if value.is_negative() { + value = value + &self.modulus; + } + + value + } + + /// Returns a big unsigned integer as an array of O-bit chunks, which is equal modulo + /// 2 ^ (O * S) to the input big unsigned integer stored as an array of I-bit chunks. + /// The ordering of the chunks in these arrays is little-endian + const fn convert(input: &[u64]) -> [u64; S] { + // This function is defined because the method "min" of the usize type is not constant + const fn min(a: usize, b: usize) -> usize { + if a > b { + b + } else { + a + } + } + + let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0); + + while bits < total { + let (i, o) = (bits % I, bits % O); + output[bits / O] |= (input[bits / I] >> i) << o; + bits += min(I - i, O - o); + } + + let mask = u64::MAX >> (64 - O); + let mut filled = total / O + if total % O > 0 { 1 } else { 0 }; + + while filled > 0 { + filled -= 1; + output[filled] &= mask; + } + + output + } + + /// Returns the multiplicative inverse of the argument modulo 2^62. The implementation is based + /// on the Hurchalla's method for computing the multiplicative inverse modulo a power of two. + /// For better understanding the implementation, the following paper is recommended: + /// J. Hurchalla, "An Improved Integer Multiplicative Inverse (modulo 2^w)", + /// https://arxiv.org/pdf/2204.04342.pdf + const fn inv(value: u64) -> i64 { + let x = value.wrapping_mul(3) ^ 2; + let y = 1u64.wrapping_sub(x.wrapping_mul(value)); + let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y)); + let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y)); + let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y)); + (x.wrapping_mul(y.wrapping_add(1)) & CInt::<62, L>::MASK) as i64 + } + + /// Creates the inverter for specified modulus and adjusting parameter + pub const fn new(modulus: &[u64], adjuster: &[u64]) -> Self { + Self { + modulus: CInt::<62, L>(Self::convert::<64, 62, L>(modulus)), + adjuster: CInt::<62, L>(Self::convert::<64, 62, L>(adjuster)), + inverse: Self::inv(modulus[0]), + } + } + + /// Returns either the adjusted modular multiplicative inverse for the argument or None + /// depending on invertibility of the argument, i.e. its coprimality with the modulus + pub fn invert(&self, value: &[u64]) -> Option<[u64; S]> { + let (mut d, mut e) = (CInt::ZERO, self.adjuster.clone()); + let mut g = CInt::<62, L>(Self::convert::<64, 62, L>(value)); + let (mut delta, mut f) = (1, self.modulus.clone()); + let mut matrix; + while g != CInt::ZERO { + (delta, matrix) = Self::jump(&f, &g, delta); + (f, g) = Self::fg(f, g, matrix); + (d, e) = self.de(d, e, matrix); + } + // At this point the absolute value of "f" equals the greatest common divisor + // of the integer to be inverted and the modulus the inverter was created for. + // Thus, if "f" is neither 1 nor -1, then the sought inverse does not exist + let antiunit = f == CInt::MINUS_ONE; + if (f != CInt::ONE) && !antiunit { + return None; + } + Some(Self::convert::<62, 64, S>(&self.norm(d, antiunit).0)) + } +} diff --git a/src/lib.rs b/src/lib.rs index 5bd7d506..afc95fdb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ mod arithmetic; +mod ff_inverse; pub mod fft; pub mod hash_to_curve; pub mod msm; diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index c346dc6c..db496559 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -178,17 +178,10 @@ impl ff::Field for Fp { CtOption::new(tmp, tmp.square().ct_eq(self)) } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow_vartime([ - 0xfffffffefffffc2d, - 0xffffffffffffffff, - 0xffffffffffffffff, - 0xffffffffffffffff, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn pow_vartime>(&self, exp: S) -> Self { diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index 189daaba..c6cb8e06 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -178,17 +178,10 @@ impl ff::Field for Fq { self.square() } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow_vartime([ - 0xbfd25e8cd036413f, - 0xbaaedce6af48a03b, - 0xfffffffffffffffe, - 0xffffffffffffffff, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn pow_vartime>(&self, exp: S) -> Self { diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index bf86e157..331545c3 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -196,17 +196,10 @@ impl ff::Field for Fp { CtOption::new(tmp, tmp.square().ct_eq(self)) } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow_vartime([ - 0xfffffffffffffffd, - 0x00000000ffffffff, - 0x0000000000000000, - 0xffffffff00000001, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn pow_vartime>(&self, exp: S) -> Self { diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index d1a7b809..077ec331 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -172,17 +172,10 @@ impl ff::Field for Fq { self.square() } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow_vartime([ - 0xf3b9cac2fc63254f, - 0xbce6faada7179e84, - 0xffffffffffffffff, - 0xffffffff00000000, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn pow_vartime>(&self, exp: S) -> Self { From 81a078254518a7a4b7c69fab120621deaace9389 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Tue, 31 Oct 2023 18:05:45 +0100 Subject: [PATCH 11/13] Fast isSquare / Legendre symbol / Jacobi symbol - 16.8x acceleration (#95) * Derivatives of the Pornin's method (https://github.com/taikoxyz/halo2curves/pull/3) * renaming file * make cargo fmt happy * clarifications from https://github.com/privacy-scaling-explorations/halo2curves/pull/95#issuecomment-1770484641 [skip ci] * Formatting and slightly changing a comment --------- Co-authored-by: Aleksei Vambol <77882392+AlekseiVambol@users.noreply.github.com> --- benches/bn256_field.rs | 9 +- src/derive/field.rs | 28 +++ src/ff_jacobi.rs | 415 +++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 4 files changed, 451 insertions(+), 2 deletions(-) create mode 100644 src/ff_jacobi.rs diff --git a/benches/bn256_field.rs b/benches/bn256_field.rs index 814d13b4..fa8db6a3 100644 --- a/benches/bn256_field.rs +++ b/benches/bn256_field.rs @@ -1,6 +1,5 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; -use halo2curves::bn256::*; -use halo2curves::ff::Field; +use halo2curves::{bn256::*, ff::Field, legendre::Legendre}; use rand::SeedableRng; use rand_xorshift::XorShiftRng; @@ -43,6 +42,12 @@ pub fn bench_bn256_field(c: &mut Criterion) { group.bench_function("bn256_fq_invert", |bencher| { bencher.iter(|| black_box(&a).invert()) }); + group.bench_function("bn256_fq_legendre", |bencher| { + bencher.iter(|| black_box(&a).legendre()) + }); + group.bench_function("bn256_fq_jacobi", |bencher| { + bencher.iter(|| black_box(&a).jacobi()) + }); } criterion_group!(benches, bench_bn256_field); diff --git a/src/derive/field.rs b/src/derive/field.rs index a9f57c4a..912a888e 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -52,6 +52,12 @@ macro_rules! field_common { } } + // Returns the Legendre symbol, where the numerator and denominator + // are the element and the characteristic of the field, respectively. + pub fn jacobi(&self) -> i64 { + $crate::ff_jacobi::jacobi::<5>(&self.0, &$modulus.0) + } + fn from_u512(limbs: [u64; 8]) -> $field { // We reduce an arbitrary 512-bit number by decomposing it into two 256-bit digits // with the higher bits multiplied by 2^256. Thus, we perform two reductions @@ -353,6 +359,28 @@ macro_rules! field_common { Ok(()) } } + + #[test] + fn test_jacobi() { + use rand::SeedableRng; + use $crate::ff::Field; + use $crate::legendre::Legendre; + let mut rng = rand_xorshift::XorShiftRng::from_seed([ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, + 0xbc, 0xe5, + ]); + for _ in 0..100000 { + let e = $field::random(&mut rng); + assert_eq!( + e.legendre(), + match e.jacobi() { + 1 => $field::ONE, + -1 => -$field::ONE, + _ => $field::ZERO, + } + ); + } + } }; } diff --git a/src/ff_jacobi.rs b/src/ff_jacobi.rs new file mode 100644 index 00000000..8c559352 --- /dev/null +++ b/src/ff_jacobi.rs @@ -0,0 +1,415 @@ +use core::cmp::PartialEq; +use std::ops::{Add, Mul, Neg, Shr, Sub}; + +/// Big signed (64 * L)-bit integer type, whose variables store +/// numbers in the two's complement code as arrays of 64-bit chunks. +/// The ordering of the chunks in these arrays is little-endian. +/// The arithmetic operations for this type are wrapping ones +#[derive(Clone)] +pub struct LInt([u64; L]); + +impl LInt { + /// Representation of -1 + pub const MINUS_ONE: Self = Self([u64::MAX; L]); + + /// Representation of 0 + pub const ZERO: Self = Self([0; L]); + + /// Representation of 1 + pub const ONE: Self = { + let mut data = [0; L]; + data[0] = 1; + Self(data) + }; + + /// Returns the number, which is stored as the specified + /// sequence padded with zeros to length L. If the input + /// sequence is longer than L, the method panics + pub fn new(data: &[u64]) -> Self { + let mut number = Self::ZERO; + number.0[..data.len()].copy_from_slice(data); + number + } + + /// Returns "true" iff the current number is negative + #[inline] + pub fn is_negative(&self) -> bool { + self.0[L - 1] > (u64::MAX >> 1) + } + + /// Returns a tuple representing the sum of the first two arguments and the bit + /// described by the third argument. The first element of the tuple is this sum + /// modulo 2^64, the second one indicates whether the sum is no less than 2^64 + #[inline] + fn sum(first: u64, second: u64, carry: bool) -> (u64, bool) { + // The implementation is inspired with the "carrying_add" function from this source: + // https://github.com/rust-lang/rust/blob/master/library/core/src/num/uint_macros.rs + let (second, carry) = second.overflowing_add(carry as u64); + let (first, high) = first.overflowing_add(second); + (first, carry || high) + } + + /// Returns "(low, high)", where "high * 2^64 + low = first * second + carry + summand" + #[inline] + fn prodsum(first: u64, second: u64, summand: u64, carry: u64) -> (u64, u64) { + let all = (first as u128) * (second as u128) + (carry as u128) + (summand as u128); + (all as u64, (all >> u64::BITS) as u64) + } +} + +impl PartialEq for LInt { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Shr for &LInt { + type Output = LInt; + /// Returns the result of applying the arithmetic right shift to the current number. + /// The specified bit quantity the number is shifted by must lie in {1, 2, ..., 63}. + /// For the quantities outside of the range, the behavior of the method is undefined + fn shr(self, bits: u32) -> Self::Output { + debug_assert!( + (bits > 0) && (bits < 64), + "Cannot shift by 0 or more than 63 bits!" + ); + let (mut data, right) = ([0; L], u64::BITS - bits); + for i in 0..(L - 1) { + data[i] = (self.0[i] >> bits) | (self.0[i + 1] << right); + } + data[L - 1] = self.0[L - 1] >> bits; + if self.is_negative() { + data[L - 1] |= u64::MAX << right; + } + Self::Output { 0: data } + } +} + +impl Shr for LInt { + type Output = LInt; + fn shr(self, bits: u32) -> Self::Output { + &self >> bits + } +} + +impl Add for &LInt { + type Output = LInt; + fn add(self, other: Self) -> Self::Output { + let (mut data, mut carry) = ([0; L], false); + for i in 0..L { + (data[i], carry) = Self::Output::sum(self.0[i], other.0[i], carry); + } + Self::Output { 0: data } + } +} + +impl Add<&LInt> for LInt { + type Output = LInt; + fn add(self, other: &Self) -> Self::Output { + &self + other + } +} + +impl Add for LInt { + type Output = LInt; + fn add(self, other: Self) -> Self::Output { + &self + &other + } +} + +impl Sub for &LInt { + type Output = LInt; + fn sub(self, other: Self) -> Self::Output { + // For the two's complement code the additive negation is the result of + // adding 1 to the bitwise inverted argument's representation. Thus, for + // any encoded integers x and y we have x - y = x + !y + 1, where "!" is + // the bitwise inversion and addition is done according to the rules of + // the code. The algorithm below uses this formula and is the modified + // addition algorithm, where the carry flag is initialized with "true" + // and the chunks of the second argument are bitwise inverted + let (mut data, mut carry) = ([0; L], true); + for i in 0..L { + (data[i], carry) = Self::Output::sum(self.0[i], !other.0[i], carry); + } + Self::Output { 0: data } + } +} + +impl Sub<&LInt> for LInt { + type Output = LInt; + fn sub(self, other: &Self) -> Self::Output { + &self - other + } +} + +impl Sub for LInt { + type Output = LInt; + fn sub(self, other: Self) -> Self::Output { + &self - &other + } +} + +impl Neg for &LInt { + type Output = LInt; + fn neg(self) -> Self::Output { + // For the two's complement code the additive negation is the result + // of adding 1 to the bitwise inverted argument's representation + let (mut data, mut carry) = ([0; L], true); + for i in 0..L { + (data[i], carry) = (!self.0[i]).overflowing_add(carry as u64); + } + Self::Output { 0: data } + } +} + +impl Neg for LInt { + type Output = LInt; + fn neg(self) -> Self::Output { + -&self + } +} + +impl Mul for &LInt { + type Output = LInt; + fn mul(self, other: Self) -> Self::Output { + let mut data = [0; L]; + for i in 0..L { + let mut carry = 0; + for k in 0..(L - i) { + (data[i + k], carry) = + Self::Output::prodsum(self.0[i], other.0[k], data[i + k], carry); + } + } + Self::Output { 0: data } + } +} + +impl Mul<&LInt> for LInt { + type Output = LInt; + fn mul(self, other: &Self) -> Self::Output { + &self * other + } +} + +impl Mul for LInt { + type Output = LInt; + fn mul(self, other: Self) -> Self::Output { + &self * &other + } +} + +impl Mul for &LInt { + type Output = LInt; + fn mul(self, other: i64) -> Self::Output { + let mut data = [0; L]; + // If the short multiplicand is non-negative, the standard multiplication + // algorithm is performed. Otherwise, the product of the additively negated + // multiplicands is found as follows. Since for the two's complement code + // the additive negation is the result of adding 1 to the bitwise inverted + // argument's representation, for any encoded integers x and y we have + // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y), where "!" is + // the bitwise inversion and arithmetic operations are performed according + // to the rules of the code. If the short multiplicand is negative, the + // algorithm below uses this formula by substituting the short multiplicand + // for y and becomes the modified standard multiplication algorithm, where + // the carry variable is being initialized with the additively negated short + // multiplicand and the chunks of the long multiplicand are bitwise inverted + let (other, mut carry, mask) = if other < 0 { + (-other as u64, -other as u64, u64::MAX) + } else { + (other as u64, 0, 0) + }; + for i in 0..L { + (data[i], carry) = Self::Output::prodsum(self.0[i] ^ mask, other, 0, carry); + } + Self::Output { 0: data } + } +} + +impl Mul for LInt { + type Output = LInt; + fn mul(self, other: i64) -> Self::Output { + &self * other + } +} + +impl Mul<&LInt> for i64 { + type Output = LInt; + fn mul(self, other: &LInt) -> Self::Output { + other * self + } +} + +impl Mul> for i64 { + type Output = LInt; + fn mul(self, other: LInt) -> Self::Output { + other * self + } +} + +/// Returns the "approximations" of the arguments and the flag indicating whether +/// both arguments are equal to their "approximations". Both the arguments must be +/// non-negative, and at least one of them must be non-zero. For an incorrect input, +/// the behavior of the function is undefined. These "approximations" are defined +/// in the following way. Let n be the bit length of the largest argument without +/// leading zeros. For n > 64 the "approximation" of the argument, which equals v, +/// is (v div 2 ^ (n - 32)) * 2 ^ 32 + (v mod 2 ^ 32), i.e. it retains the high and +/// low bits of the n-bit representation of v. If n does not exceed 64, an argument +/// and its "approximation" are equal. These "approximations" are defined slightly +/// differently from the ones in the Pornin's method for modular inversion: instead +/// of taking the 33 high and 31 low bits of the n-bit representation of an argument, +/// the 32 high and 32 low bits are taken +fn approximate(x: &LInt, y: &LInt) -> (u64, u64, bool) { + debug_assert!( + !(x.is_negative() || y.is_negative()), + "Both the arguments must be non-negative!" + ); + debug_assert!( + (*x != LInt::ZERO) || (*y != LInt::ZERO), + "At least one argument must be non-zero!" + ); + let mut i = L - 1; + while (x.0[i] == 0) && (y.0[i] == 0) { + i -= 1; + } + if i == 0 { + return (x.0[0], y.0[0], true); + } + let mut h = (x.0[i], y.0[i]); + let z = h.0.leading_zeros().min(h.1.leading_zeros()); + h = (h.0 << z, h.1 << z); + if z > 32 { + h.0 |= x.0[i - 1] >> z; + h.1 |= y.0[i - 1] >> z; + } + let h = (h.0 & u64::MAX << 32, h.1 & u64::MAX << 32); + let l = (x.0[0] & u64::MAX >> 32, y.0[0] & u64::MAX >> 32); + (h.0 | l.0, h.1 | l.1, false) +} + +/// Returns the Jacobi symbol ("n" / "d") multiplied by either 1 or -1. +/// The later multiplicand is -1 iff the second-lowest bit of "t" is 1. +/// The value of "d" must be odd in accordance with the Jacobi symbol +/// definition. For even values of "d", the behavior is not defined. +/// The implementation is based on the binary Euclidean algorithm +fn jacobinary(mut n: u64, mut d: u64, mut t: u64) -> i64 { + debug_assert!(d & 1 > 0, "The second argument must be odd!"); + while n != 0 { + if n & 1 > 0 { + if n < d { + (n, d) = (d, n); + t ^= n & d; + } + n = (n - d) >> 1; + t ^= d ^ d >> 1; + } else { + let z = n.trailing_zeros(); + t ^= (d ^ d >> 1) & (z << 1) as u64; + n >>= z; + } + } + (d == 1) as i64 * (1 - (t & 2) as i64) +} + +/// Returns the Jacobi symbol ("n" / "d") computed by means of the modification +/// of the the Pornin's method for modular inversion. The arguments are unsigned +/// big integers in the form of arrays of 64-bit chunks, the ordering of which +/// is little-endian. The value of "d" must be odd in accordance with the Jacobi +/// symbol definition. Both the arguments must be less than 2 ^ (64 * L - 31). +/// For an incorrect input, the behavior of the function is undefined. The method +/// differs from the Pornin's method for modular inversion in absence of the parts, +/// which are not necessary to compute the greatest common divisor of arguments, +/// presence of the parts used to compute the Jacobi symbol, which are based on +/// the properties of the modified Jacobi symbol (x / |y|) described by M. Hamburg, +/// and some original optimizations. Only these differences have been commented; +/// the aforesaid Pornin's method and the used ideas of M. Hamburg were given here: +/// - T. Pornin, "Optimized Binary GCD for Modular Inversion", +/// https://eprint.iacr.org/2020/972.pdf +/// - M. Hamburg, "Computing the Jacobi symbol using Bernstein-Yang", +/// https://eprint.iacr.org/2021/1271.pdf +pub fn jacobi(n: &[u64], d: &[u64]) -> i64 { + // Instead of the variable "j" taking the values from {-1, 1} and satysfying + // at the end of the outer loop iteration the equation J = "j" * ("n" / |"d"|) + // for the modified Jacobi symbol ("n" / |"d"|) and the sought Jacobi symbol J, + // we store the sign bit of "j" in the second-lowest bit of "t" for optimization + // purposes. This approach was influenced by the paper by M. Hamburg + let (mut n, mut d, mut t) = (LInt::::new(n), LInt::::new(d), 0u64); + debug_assert!(d.0[0] & 1 > 0, "The second argument must be odd!"); + debug_assert!( + n.0[L - 1].leading_zeros().min(d.0[L - 1].leading_zeros()) >= 31, + "Both the arguments must be less than 2 ^ (64 * L - 31)!" + ); + loop { + // The inner loop performs 30 iterations instead of 31 ones in the aforementioned + // Pornin's method, and the "approximations" of "n" and "d" retain 32 of the lowest + // bits instead of 31 in that method. These modifications allow the values of the + // "approximation" variables to be equal modulo 8 to the corresponding "precise" + // variables' values, which would have been computed, if the "precise" variables + // had been updated in the inner loop along with the "approximations". This equality + // modulo 8 is used to update the second-lowest bit of "t" in accordance with the + // properties of the modified Jacobi symbol (x / |y|). The admissibility of these + // modifications has been proven using the appropriately modified Pornin's theorems + let (mut u, mut v, mut i) = ((1i64, 0i64), (0i64, 1i64), 30); + let (mut a, mut b, precise) = approximate(&n, &d); + // When each "approximation" variable has the same value as the corresponding "precise" + // one, the computation is accomplished using the short-arithmetic method of the Jacobi + // symbol calculation by means of the binary Euclidean algorithm. This approach aims at + // avoiding the parts of the final computations, which are related to long arithmetics + if precise { + return jacobinary(a, b, t); + } + while i > 0 { + if a & 1 > 0 { + if a < b { + (a, b, u, v) = (b, a, v, u); + // In both the aforesaid Pornin's method and its modification "n" and "d" + // could not become negative simultaneously even if they were updated after + // each iteration of the inner loop. Also at this point they both have odd + // values. Therefore, the quadratic reciprocity law for the modified Jacobi + // symbol (x / |y|) can be used. According to it, if both x and y are odd + // numbers, among which there is a positive one, then for x = y = 3 (mod 4) + // we have (x / |y|) = -(y / |x|) and for either x or y equal 1 modulo 4 + // the symbols (x / |y|) and (y / |x|) are equal + t ^= a & b; + } + a = (a - b) >> 1; + u = (u.0 - v.0, u.1 - v.1); + v = (v.0 << 1, v.1 << 1); + // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5} + t ^= b ^ b >> 1; + i -= 1; + } else { + // Performing the batch of sequential iterations, which divide "a" by 2 + let z = i.min(a.trailing_zeros()); + // The modified Jacobi symbol (2 / |y|) is -1, iff y mod 8 is {3, 5}. However, + // we do not need its value for a batch with an even number of divisions by 2 + t ^= (b ^ b >> 1) & (z << 1) as u64; + v = (v.0 << z, v.1 << z); + a >>= z; + i -= z; + } + } + (n, d) = ((&n * u.0 + &d * u.1) >> 30, (&n * v.0 + &d * v.1) >> 30); + + // This fragment is present to guarantee the correct behavior of the function + // in the case of arguments, whose greatest common divisor is no less than 2^64 + if n == LInt::ZERO { + // In both the aforesaid Pornin's method and its modification the pair of the values + // of "n" and "d" after the divergence point contains a positive number and a negative + // one. Since the value of "n" is 0, the divergence point has not been reached by the + // inner loop this time, so there is no need to check whether "d" is equal to -1 + return (d == LInt::ONE) as i64 * (1 - (t & 2) as i64); + } + + if n.is_negative() { + // Since in both the aforesaid Pornin's method and its modification "d" is always odd + // and cannot become negative simultaneously with "n", the value of "d" is positive. + // The modified Jacobi symbol (-1 / |y|) for a positive y is -1, iff y mod 4 = 3 + t ^= d.0[0]; + n = -n; + } else if d.is_negative() { + // The modified Jacobi symbols (x / |y|) and (x / |-y|) are equal, so "t" is not updated + d = -d; + } + } +} diff --git a/src/lib.rs b/src/lib.rs index afc95fdb..a09fb05f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ mod arithmetic; mod ff_inverse; +mod ff_jacobi; pub mod fft; pub mod hash_to_curve; pub mod msm; From 69ff87c9ac79ec49accd50feb0117ec48927a504 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 12 Nov 2023 09:33:18 -0800 Subject: [PATCH 12/13] chore: delete bernsteinyang module (replaced by ff_inverse) --- src/bernsteinyang.rs | 424 ------------------------------------------- 1 file changed, 424 deletions(-) delete mode 100644 src/bernsteinyang.rs diff --git a/src/bernsteinyang.rs b/src/bernsteinyang.rs deleted file mode 100644 index 53285e6c..00000000 --- a/src/bernsteinyang.rs +++ /dev/null @@ -1,424 +0,0 @@ -use core::cmp::PartialEq; -use std::ops::{Add, Mul, Neg, Sub}; - -/// Big signed (B * L)-bit integer type, whose variables store -/// numbers in the two's complement code as arrays of B-bit chunks. -/// The ordering of the chunks in these arrays is little-endian. -/// The arithmetic operations for this type are wrapping ones. -#[derive(Clone)] -struct CInt(pub [u64; L]); - -impl CInt { - /// Mask, in which the B lowest bits are 1 and only they - pub const MASK: u64 = u64::MAX >> (64 - B); - - /// Representation of -1 - pub const MINUS_ONE: Self = Self([Self::MASK; L]); - - /// Representation of 0 - pub const ZERO: Self = Self([0; L]); - - /// Representation of 1 - pub const ONE: Self = { - let mut data = [0; L]; - data[0] = 1; - Self(data) - }; - - /// Returns the result of applying B-bit right - /// arithmetical shift to the current number - pub fn shift(&self) -> Self { - let mut data = [0; L]; - if self.is_negative() { - data[L - 1] = Self::MASK; - } - data[..L - 1].copy_from_slice(&self.0[1..]); - Self(data) - } - - /// Returns the lowest B bits of the current number - pub fn lowest(&self) -> u64 { - self.0[0] - } - - /// Returns "true" iff the current number is negative - pub fn is_negative(&self) -> bool { - self.0[L - 1] > (Self::MASK >> 1) - } -} - -impl PartialEq for CInt { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } -} - -impl Add for &CInt { - type Output = CInt; - fn add(self, other: Self) -> Self::Output { - let (mut data, mut carry) = ([0; L], 0); - for i in 0..L { - let sum = self.0[i] + other.0[i] + carry; - data[i] = sum & CInt::::MASK; - carry = sum >> B; - } - Self::Output { 0: data } - } -} - -impl Add<&CInt> for CInt { - type Output = CInt; - fn add(self, other: &Self) -> Self::Output { - &self + other - } -} - -impl Add for CInt { - type Output = CInt; - fn add(self, other: Self) -> Self::Output { - &self + &other - } -} - -impl Sub for &CInt { - type Output = CInt; - fn sub(self, other: Self) -> Self::Output { - // For the two's complement code the additive negation is the result of - // adding 1 to the bitwise inverted argument's representation. Thus, for - // any encoded integers x and y we have x - y = x + !y + 1, where "!" is - // the bitwise inversion and addition is done according to the rules of - // the code. The algorithm below uses this formula and is the modified - // addition algorithm, where the carry flag is initialized with 1 and - // the chunks of the second argument are bitwise inverted - let (mut data, mut carry) = ([0; L], 1); - for i in 0..L { - let sum = self.0[i] + (other.0[i] ^ CInt::::MASK) + carry; - data[i] = sum & CInt::::MASK; - carry = sum >> B; - } - Self::Output { 0: data } - } -} - -impl Sub<&CInt> for CInt { - type Output = CInt; - fn sub(self, other: &Self) -> Self::Output { - &self - other - } -} - -impl Sub for CInt { - type Output = CInt; - fn sub(self, other: Self) -> Self::Output { - &self - &other - } -} - -impl Neg for &CInt { - type Output = CInt; - fn neg(self) -> Self::Output { - // For the two's complement code the additive negation is the result - // of adding 1 to the bitwise inverted argument's representation - let (mut data, mut carry) = ([0; L], 1); - for i in 0..L { - let sum = (self.0[i] ^ CInt::::MASK) + carry; - data[i] = sum & CInt::::MASK; - carry = sum >> B; - } - Self::Output { 0: data } - } -} - -impl Neg for CInt { - type Output = CInt; - fn neg(self) -> Self::Output { - -&self - } -} - -impl Mul for &CInt { - type Output = CInt; - fn mul(self, other: Self) -> Self::Output { - let mut data = [0; L]; - for i in 0..L { - let mut carry = 0; - for k in 0..(L - i) { - let sum = (data[i + k] as u128) - + (carry as u128) - + (self.0[i] as u128) * (other.0[k] as u128); - data[i + k] = sum as u64 & CInt::::MASK; - carry = (sum >> B) as u64; - } - } - Self::Output { 0: data } - } -} - -impl Mul<&CInt> for CInt { - type Output = CInt; - fn mul(self, other: &Self) -> Self::Output { - &self * other - } -} - -impl Mul for CInt { - type Output = CInt; - fn mul(self, other: Self) -> Self::Output { - &self * &other - } -} - -impl Mul for &CInt { - type Output = CInt; - fn mul(self, other: i64) -> Self::Output { - let mut data = [0; L]; - // If the short multiplicand is non-negative, the standard multiplication - // algorithm is performed. Otherwise, the product of the additively negated - // multiplicands is found as follows. Since for the two's complement code - // the additive negation is the result of adding 1 to the bitwise inverted - // argument's representation, for any encoded integers x and y we have - // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y), where "!" is - // the bitwise inversion and arithmetic operations are performed according - // to the rules of the code. If the short multiplicand is negative, the - // algorithm below uses this formula by substituting the short multiplicand - // for y and turns into the modified standard multiplication algorithm, - // where the carry flag is initialized with the additively negated short - // multiplicand and the chunks of the long multiplicand are bitwise inverted - let (other, mut carry, mask) = if other < 0 { - (-other, -other as u64, CInt::::MASK) - } else { - (other, 0, 0) - }; - for i in 0..L { - let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128); - data[i] = sum as u64 & CInt::::MASK; - carry = (sum >> B) as u64; - } - Self::Output { 0: data } - } -} - -impl Mul for CInt { - type Output = CInt; - fn mul(self, other: i64) -> Self::Output { - &self * other - } -} - -impl Mul<&CInt> for i64 { - type Output = CInt; - fn mul(self, other: &CInt) -> Self::Output { - other * self - } -} - -impl Mul> for i64 { - type Output = CInt; - fn mul(self, other: CInt) -> Self::Output { - other * self - } -} - -/// Type of the modular multiplicative inverter based on the Bernstein-Yang method. -/// The inverter can be created for a specified modulus M and adjusting parameter A -/// to compute the adjusted multiplicative inverses of positive integers, i.e. for -/// computing (1 / x) * A (mod M) for a positive integer x. -/// -/// The adjusting parameter allows computing the multiplicative inverses in the case of -/// using the Montgomery representation for the input or the expected output. If R is -/// the Montgomery factor, the multiplicative inverses in the appropriate representation -/// can be computed provided that the value of A is chosen as follows: -/// - A = 1, if both the input and the expected output are in the standard form -/// - A = R^2 mod M, if both the input and the expected output are in the Montgomery form -/// - A = R mod M, if either the input or the expected output is in the Montgomery form, -/// but not both of them -/// -/// The public methods of this type receive and return unsigned big integers as arrays of -/// 64-bit chunks, the ordering of which is little-endian. Both the modulus and the integer -/// to be inverted should not exceed 2 ^ (62 * L - 64) -/// -/// For better understanding the implementation, the following resources are recommended: -/// - D. Bernstein, B.-Y. Yang, "Fast constant-time gcd computation and modular inversion", -/// https://gcd.cr.yp.to/safegcd-20190413.pdf -/// - P. Wuille, "The safegcd implementation in libsecp256k1 explained", -/// https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md -pub struct BYInverter { - /// Modulus - modulus: CInt<62, L>, - - /// Adjusting parameter - adjuster: CInt<62, L>, - - /// Multiplicative inverse of the modulus modulo 2^62 - inverse: i64, -} - -/// Type of the Bernstein-Yang transition matrix multiplied by 2^62 -type Matrix = [[i64; 2]; 2]; - -impl BYInverter { - /// Returns the Bernstein-Yang transition matrix multiplied by 2^62 and the new value - /// of the delta variable for the 62 basic steps of the Bernstein-Yang method, which - /// are to be performed sequentially for specified initial values of f, g and delta - fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) { - let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128); - let mut t: Matrix = [[1, 0], [0, 1]]; - - loop { - let zeros = steps.min(g.trailing_zeros() as i64); - (steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros); - t[0] = [t[0][0] << zeros, t[0][1] << zeros]; - - if steps == 0 { - break; - } - if delta > 0 { - (delta, f, g) = (-delta, g as i64, -f as i128); - (t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]); - } - - // The formula (3 * x) xor 28 = -1 / x (mod 32) for an odd integer x - // in the two's complement code has been derived from the formula - // (3 * x) xor 2 = 1 / x (mod 32) attributed to Peter Montgomery - let mask = (1 << steps.min(1 - delta).min(5)) - 1; - let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask; - - t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]]; - g += w as i128 * f as i128; - } - - (delta, t) - } - - /// Returns the updated values of the variables f and g for specified initial ones and Bernstein-Yang transition - /// matrix multiplied by 2^62. The returned vector is "matrix * (f, g)' / 2^62", where "'" is the transpose operator - fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) { - ( - (t[0][0] * &f + t[0][1] * &g).shift(), - (t[1][0] * &f + t[1][1] * &g).shift(), - ) - } - - /// Returns the updated values of the variables d and e for specified initial ones and Bernstein-Yang transition - /// matrix multiplied by 2^62. The returned vector is congruent modulo M to "matrix * (d, e)' / 2^62 (mod M)", - /// where M is the modulus the inverter was created for and "'" stands for the transpose operator. Both the input - /// and output values lie in the interval (-2 * M, M) - fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) { - let mask = CInt::<62, L>::MASK as i64; - let mut md = t[0][0] * d.is_negative() as i64 + t[0][1] * e.is_negative() as i64; - let mut me = t[1][0] * d.is_negative() as i64 + t[1][1] * e.is_negative() as i64; - - let cd = t[0][0] - .wrapping_mul(d.lowest() as i64) - .wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64)) - & mask; - let ce = t[1][0] - .wrapping_mul(d.lowest() as i64) - .wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64)) - & mask; - - md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask; - me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask; - - let cd = t[0][0] * &d + t[0][1] * &e + md * &self.modulus; - let ce = t[1][0] * &d + t[1][1] * &e + me * &self.modulus; - - (cd.shift(), ce.shift()) - } - - /// Returns either "value (mod M)" or "-value (mod M)", where M is the modulus the - /// inverter was created for, depending on "negate", which determines the presence - /// of "-" in the used formula. The input integer lies in the interval (-2 * M, M) - fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> { - if value.is_negative() { - value = value + &self.modulus; - } - - if negate { - value = -value; - } - - if value.is_negative() { - value = value + &self.modulus; - } - - value - } - - /// Returns a big unsigned integer as an array of O-bit chunks, which is equal modulo - /// 2 ^ (O * S) to the input big unsigned integer stored as an array of I-bit chunks. - /// The ordering of the chunks in these arrays is little-endian - const fn convert(input: &[u64]) -> [u64; S] { - // This function is defined because the method "min" of the usize type is not constant - const fn min(a: usize, b: usize) -> usize { - if a > b { - b - } else { - a - } - } - - let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0); - - while bits < total { - let (i, o) = (bits % I, bits % O); - output[bits / O] |= (input[bits / I] >> i) << o; - bits += min(I - i, O - o); - } - - let mask = u64::MAX >> (64 - O); - let mut filled = total / O + if total % O > 0 { 1 } else { 0 }; - - while filled > 0 { - filled -= 1; - output[filled] &= mask; - } - - output - } - - /// Returns the multiplicative inverse of the argument modulo 2^62. The implementation is based - /// on the Hurchalla's method for computing the multiplicative inverse modulo a power of two. - /// For better understanding the implementation, the following paper is recommended: - /// J. Hurchalla, "An Improved Integer Multiplicative Inverse (modulo 2^w)", - /// https://arxiv.org/pdf/2204.04342.pdf - const fn inv(value: u64) -> i64 { - let x = value.wrapping_mul(3) ^ 2; - let y = 1u64.wrapping_sub(x.wrapping_mul(value)); - let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y)); - let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y)); - let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y)); - (x.wrapping_mul(y.wrapping_add(1)) & CInt::<62, L>::MASK) as i64 - } - - /// Creates the inverter for specified modulus and adjusting parameter - pub const fn new(modulus: &[u64], adjuster: &[u64]) -> Self { - Self { - modulus: CInt::<62, L>(Self::convert::<64, 62, L>(modulus)), - adjuster: CInt::<62, L>(Self::convert::<64, 62, L>(adjuster)), - inverse: Self::inv(modulus[0]), - } - } - - /// Returns either the adjusted modular multiplicative inverse for the argument or None - /// depending on invertibility of the argument, i.e. its coprimality with the modulus - pub fn invert(&self, value: &[u64]) -> Option<[u64; S]> { - let (mut d, mut e) = (CInt::ZERO, self.adjuster.clone()); - let mut g = CInt::<62, L>(Self::convert::<64, 62, L>(value)); - let (mut delta, mut f) = (1, self.modulus.clone()); - let mut matrix; - while g != CInt::ZERO { - (delta, matrix) = Self::jump(&f, &g, delta); - (f, g) = Self::fg(f, g, matrix); - (d, e) = self.de(d, e, matrix); - } - // At this point the absolute value of "f" equals the greatest common divisor - // of the integer to be inverted and the modulus the inverter was created for. - // Thus, if "f" is neither 1 nor -1, then the sought inverse does not exist - let antiunit = f == CInt::MINUS_ONE; - if (f != CInt::ONE) && !antiunit { - return None; - } - Some(Self::convert::<62, 64, S>(&self.norm(d, antiunit).0)) - } -} From e4e620527d6b238e64750ad0d14105b4579dab70 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 13 Nov 2023 00:24:28 -0800 Subject: [PATCH 13/13] Bump version to 0.4.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 34a95132..c8539854 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2curves" -version = "0.4.0" +version = "0.4.1" authors = ["Privacy Scaling Explorations team"] license = "MIT/Apache-2.0" edition = "2021"