Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change: Refactor & optimize the NAF #63

Merged
merged 7 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ rand = "0.8"
rand_chacha = "0.3"
serde_json = "1.0"
frost-rerandomized = { version = "0.2", features=["test-impl"] }
num-bigint = "0.4.3"
num-traits = "0.2.15"

# `alloc` is only used in test code
[dev-dependencies.pasta_curves]
Expand Down
68 changes: 10 additions & 58 deletions src/orchard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,63 +87,14 @@ impl private::Sealed<Binding> for Binding {

#[cfg(feature = "alloc")]
impl NonAdjacentForm for pallas::Scalar {
/// Compute a width-\\(w\\) "Non-Adjacent Form" of this scalar.
///
/// Thanks to curve25519-dalek
fn non_adjacent_form(&self, w: usize) -> [i8; 256] {
// required by the NAF definition
debug_assert!(w >= 2);
// required so that the NAF digits fit in i8
debug_assert!(w <= 8);

use byteorder::{ByteOrder, LittleEndian};

let mut naf = [0i8; 256];

let mut x_u64 = [0u64; 5];
LittleEndian::read_u64_into(self.to_repr().as_ref(), &mut x_u64[0..4]);

let width = 1 << w;
let window_mask = width - 1;

let mut pos = 0;
let mut carry = 0;
while pos < 256 {
// Construct a buffer of bits of the scalar, starting at bit `pos`
let u64_idx = pos / 64;
let bit_idx = pos % 64;
let bit_buf = if bit_idx < 64 - w {
// This window's bits are contained in a single u64
x_u64[u64_idx] >> bit_idx
} else {
// Combine the current u64's bits with the bits from the next u64
(x_u64[u64_idx] >> bit_idx) | (x_u64[1 + u64_idx] << (64 - bit_idx))
};

// Add the carry into the current window
let window = carry + (bit_buf & window_mask);

if window & 1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0, then the next carry should be 0
// If carry == 1 and window & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1
pos += 1;
continue;
}

if window < width / 2 {
carry = 0;
naf[pos] = window as i8;
} else {
carry = 1;
naf[pos] = (window as i8).wrapping_sub(width as i8);
}

pos += w;
}
fn inner_to_bytes(&self) -> [u8; 32] {
self.to_repr()
}

naf
/// The NAF length for Pallas is 255 since Pallas' order is about 2<sup>254</sup> +
/// 2<sup>125.1</sup>.
dconnolly marked this conversation as resolved.
Show resolved Hide resolved
fn naf_length() -> usize {
255
}
}

Expand Down Expand Up @@ -184,14 +135,15 @@ impl VartimeMultiscalarMul for pallas::Point {
.collect::<Option<Vec<_>>>()?;

let mut r = pallas::Point::identity();
let naf_size = Self::Scalar::naf_length();

for i in (0..256).rev() {
for i in (0..naf_size).rev() {
let mut t = r.double();

for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {
#[allow(clippy::comparison_chain)]
if naf[i] > 0 {
t += lookup_table.select(naf[i] as usize)
t += lookup_table.select(naf[i] as usize);
} else if naf[i] < 0 {
t -= lookup_table.select(-naf[i] as usize);
}
Expand Down
16 changes: 13 additions & 3 deletions src/orchard/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::scalar_mul::VartimeMultiscalarMul;
use crate::scalar_mul::{self, VartimeMultiscalarMul};
use alloc::vec::Vec;
use group::ff::Field;
use group::{ff::PrimeField, GroupEncoding};
use rand::thread_rng;

use pasta_curves::arithmetic::CurveExt;
use pasta_curves::pallas;
Expand All @@ -27,8 +29,7 @@ fn orchard_binding_basepoint() {
// #[test]
#[allow(dead_code)]
fn gen_pallas_test_vectors() {
use group::{ff::Field, Group};
use rand::thread_rng;
use group::Group;
use std::println;

let rng = thread_rng();
Expand Down Expand Up @@ -105,3 +106,12 @@ fn test_pallas_vartime_multiscalar_mul() {
let product = pallas::Point::vartime_multiscalar_mul(scalars, points);
assert_eq!(expected_product, product);
}

/// Tests the non-adjacent form for a Pallas scalar.
#[test]
fn test_non_adjacent_form() {
let rng = thread_rng();

let scalar = pallas::Scalar::random(rng);
scalar_mul::tests::test_non_adjacent_form_for_scalar(5, scalar);
}
54 changes: 42 additions & 12 deletions src/scalar_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@ use core::{borrow::Borrow, fmt::Debug};

use jubjub::{ExtendedNielsPoint, ExtendedPoint};

pub trait NonAdjacentForm {
fn non_adjacent_form(&self, w: usize) -> [i8; 256];
}

#[cfg(test)]
mod tests;
pub(crate) mod tests;

/// A trait for variable-time multiscalar multiplication without precomputation.
pub trait VartimeMultiscalarMul {
Expand Down Expand Up @@ -67,31 +63,53 @@ pub trait VartimeMultiscalarMul {
}
}

impl NonAdjacentForm for jubjub::Scalar {
/// Compute a width-\\(w\\) "Non-Adjacent Form" of this scalar.
/// Produces the non-adjacent form (NAF) of a 32-byte scalar.
pub trait NonAdjacentForm {
/// Returns the scalar represented as a little-endian byte array.
fn inner_to_bytes(&self) -> [u8; 32];
conradoplg marked this conversation as resolved.
Show resolved Hide resolved

/// Returns the number of coefficients in the NAF.
///
/// Claim: The length of the NAF requires at most one more coefficient than the length of the
/// binary representation of the scalar. [^1]
///
/// This trait works with scalars of at most 256 binary bits, so the default implementation
/// returns 257. However, some (sub)groups' orders don't reach 256 bits and their scalars don't
/// need the full 256 bits. Setting the corresponding NAF length for a particular curve will
/// speed up the multiscalar multiplication since the number of loop iterations required for the
/// multiplication is equal to the length of the NAF.
///
/// [^1]: The proof is left as an exercise to the reader.
fn naf_length() -> usize {
257
}

/// Computes the width-`w` non-adjacent form (width-`w` NAF) of the scalar.
///
/// Thanks to [`curve25519-dalek`].
///
/// [`curve25519-dalek`]: https://github.com/dalek-cryptography/curve25519-dalek/blob/3e189820da03cc034f5fa143fc7b2ccb21fffa5e/src/scalar.rs#L907
fn non_adjacent_form(&self, w: usize) -> [i8; 256] {
fn non_adjacent_form(&self, w: usize) -> Vec<i8> {
// required by the NAF definition
debug_assert!(w >= 2);
// required so that the NAF digits fit in i8
debug_assert!(w <= 8);

use byteorder::{ByteOrder, LittleEndian};

let mut naf = [0i8; 256];
let naf_length = Self::naf_length();
let mut naf = vec![0; naf_length];

let mut x_u64 = [0u64; 5];
LittleEndian::read_u64_into(&self.to_bytes(), &mut x_u64[0..4]);
LittleEndian::read_u64_into(&self.inner_to_bytes(), &mut x_u64[0..4]);

let width = 1 << w;
let window_mask = width - 1;

let mut pos = 0;
let mut carry = 0;
while pos < 256 {

while pos < naf_length {
// Construct a buffer of bits of the scalar, starting at bit `pos`
let u64_idx = pos / 64;
let bit_idx = pos % 64;
Expand Down Expand Up @@ -130,6 +148,17 @@ impl NonAdjacentForm for jubjub::Scalar {
}
}

impl NonAdjacentForm for jubjub::Scalar {
fn inner_to_bytes(&self) -> [u8; 32] {
self.to_bytes()
}

/// The NAF length for Jubjub is 253 since Jubjub's order is about 2<sup>251.85</sup>.
dconnolly marked this conversation as resolved.
Show resolved Hide resolved
fn naf_length() -> usize {
253
}
}

/// Holds odd multiples 1A, 3A, ..., 15A of a point A.
#[derive(Copy, Clone)]
pub(crate) struct LookupTable5<T>(pub(crate) [T; 8]);
Expand Down Expand Up @@ -195,8 +224,9 @@ impl VartimeMultiscalarMul for ExtendedPoint {
.collect::<Option<Vec<_>>>()?;

let mut r = ExtendedPoint::identity();
let naf_size = Self::Scalar::naf_length();

for i in (0..256).rev() {
for i in (0..naf_size).rev() {
let mut t = r.double();

for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {
Expand Down
107 changes: 93 additions & 14 deletions src/scalar_mul/tests.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,41 @@
use alloc::vec::Vec;
use group::GroupEncoding;
use jubjub::{ExtendedPoint, Scalar};
use group::{ff::Field, GroupEncoding};
use num_bigint::BigInt;
use num_traits::Zero;
use rand::thread_rng;

use crate::scalar_mul::VartimeMultiscalarMul;

use super::NonAdjacentForm;

/// Generates test vectors for [`test_jubjub_vartime_multiscalar_mul`].
// #[test]
#[allow(dead_code)]
fn gen_jubjub_test_vectors() {
use group::{ff::Field, Group};
use rand::thread_rng;
use group::Group;
use std::println;

let rng = thread_rng();

let scalars = [Scalar::random(rng.clone()), Scalar::random(rng.clone())];
let scalars = [
jubjub::Scalar::random(rng.clone()),
jubjub::Scalar::random(rng.clone()),
];
println!("Scalars:");
for scalar in scalars {
println!("{:?}", scalar.to_bytes());
}

let points = [
ExtendedPoint::random(rng.clone()),
ExtendedPoint::random(rng),
jubjub::ExtendedPoint::random(rng.clone()),
jubjub::ExtendedPoint::random(rng),
];
println!("Points:");
for point in points {
println!("{:?}", point.to_bytes());
}

let res = ExtendedPoint::vartime_multiscalar_mul(scalars, points);
let res = jubjub::ExtendedPoint::vartime_multiscalar_mul(scalars, points);
println!("Result:");
println!("{:?}", res.to_bytes());
}
Expand Down Expand Up @@ -65,21 +71,94 @@ fn test_jubjub_vartime_multiscalar_mul() {
131, 180, 48, 148, 72, 212, 148, 212, 240, 77, 244, 91, 213,
];

let scalars: Vec<Scalar> = scalars
let scalars: Vec<jubjub::Scalar> = scalars
.into_iter()
.map(|s| Scalar::from_bytes(&s).expect("Could not deserialize a `jubjub::Scalar`."))
.map(|s| jubjub::Scalar::from_bytes(&s).expect("Could not deserialize a `jubjub::Scalar`."))
.collect();

let points: Vec<ExtendedPoint> = points
let points: Vec<jubjub::ExtendedPoint> = points
.into_iter()
.map(|p| {
ExtendedPoint::from_bytes(&p).expect("Could not deserialize a `jubjub::ExtendedPoint`.")
jubjub::ExtendedPoint::from_bytes(&p)
.expect("Could not deserialize a `jubjub::ExtendedPoint`.")
})
.collect();

let expected_product = ExtendedPoint::from_bytes(&expected_product)
let expected_product = jubjub::ExtendedPoint::from_bytes(&expected_product)
.expect("Could not deserialize a `jubjub::ExtendedPoint`.");

let product = ExtendedPoint::vartime_multiscalar_mul(scalars, points);
let product = jubjub::ExtendedPoint::vartime_multiscalar_mul(scalars, points);
assert_eq!(expected_product, product);
}

/// Tests the non-adjacent form for a Jubjub scalar.
#[test]
fn test_non_adjacent_form() {
let rng = thread_rng();

let scalar = jubjub::Scalar::random(rng);
test_non_adjacent_form_for_scalar(5, scalar);
}

/// Tests the non-adjacent form for a particular scalar.
pub(crate) fn test_non_adjacent_form_for_scalar<Scalar: NonAdjacentForm>(w: usize, scalar: Scalar) {
let naf = scalar.non_adjacent_form(w);
let naf_length = Scalar::naf_length();

// Check that the computed w-NAF has the intended length.
assert_eq!(naf.len(), naf_length);

let w = u32::try_from(w).expect("The window `w` did not fit into `u32`.");

// `bound` <- 2^(w-1)
let bound = 2_i32.pow(w - 1);

// `valid_coeffs` <- a range of odd integers from -2^(w-1) to 2^(w-1)
let valid_coeffs: Vec<i32> = (-bound..bound).filter(|x| x.rem_euclid(2) == 1).collect();

let mut reconstructed_scalar: BigInt = Zero::zero();

// Reconstruct the original scalar, and check two general invariants for any w-NAF along the
// way.
let mut i = 0;
while i < naf_length {
if naf[i] != 0 {
// In a w-NAF, every nonzero coefficient `naf[i]` is an odd signed integer with
// -2^(w-1) < `naf[i]` < 2^(w-1).
assert!(valid_coeffs.contains(&i32::from(naf[i])));

// Incrementally keep reconstructing the original scalar.
reconstructed_scalar += naf[i] * BigInt::from(2).pow(i.try_into().unwrap());

// In a w-NAF, at most one of any `w` consecutive coefficients is nonzero.
for _ in 1..w {
i += 1;
if i >= naf_length {
break;
}
assert_eq!(naf[i], 0)
}
}

i += 1;
}

// Check that the reconstructed scalar is not negative, and convert it to little-endian bytes.
let reconstructed_scalar = reconstructed_scalar
.to_biguint()
.expect("The reconstructed scalar is negative.")
.to_bytes_le();

// Check that the reconstructed scalar is not too big.
assert!(reconstructed_scalar.len() <= 32);

// Convert the reconstructed scalar to a fixed byte array so we can compare it with the orginal
// scalar.
let mut reconstructed_scalar_bytes: [u8; 32] = [0; 32];
for (i, byte) in reconstructed_scalar.iter().enumerate() {
reconstructed_scalar_bytes[i] = *byte;
}
upbqdn marked this conversation as resolved.
Show resolved Hide resolved

// Check that the reconstructed scalar matches the original one.
assert_eq!(reconstructed_scalar_bytes, scalar.inner_to_bytes());
}