Skip to content

Commit

Permalink
ec: Use LeakyLimb for public values.
Browse files Browse the repository at this point in the history
Take a step toward making `Word` and `Limb` opaque types.

This adds some unnecessary copies but the overhead is
negligible as those copies are outside of loops.
  • Loading branch information
briansmith committed Dec 7, 2024
1 parent 5f3dbbf commit 6ce531b
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 58 deletions.
17 changes: 9 additions & 8 deletions mk/generate_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
p: limbs_from_hex("%(q)x"),
rr: limbs_from_hex(%(q_rr)s),
},
n: Elem::from_hex("%(n)x"),
n: PublicElem::from_hex("%(n)x"),
a: Elem::from_hex(%(a)s),
b: Elem::from_hex(%(b)s),
a: PublicElem::from_hex(%(a)s),
b: PublicElem::from_hex(%(b)s),
elem_mul_mont: p%(bits)s_elem_mul_mont,
elem_sqr_mont: p%(bits)s_elem_sqr_mont,
Expand All @@ -56,8 +56,8 @@
};
pub(super) static GENERATOR: (Elem<R>, Elem<R>) = (
Elem::from_hex(%(Gx)s),
Elem::from_hex(%(Gy)s),
PublicElem::from_hex(%(Gx)s),
PublicElem::from_hex(%(Gy)s),
);
pub static PRIVATE_KEY_OPS: PrivateKeyOps = PrivateKeyOps {
Expand Down Expand Up @@ -93,7 +93,8 @@
fn p%(bits)s_point_mul_base_impl(a: &Scalar) -> Point {
// XXX: Not efficient. TODO: Precompute multiples of the generator.
PRIVATE_KEY_OPS.point_mul(a, &GENERATOR)
let generator = (Elem::from(&GENERATOR.0), Elem::from(&GENERATOR.1));
PRIVATE_KEY_OPS.point_mul(a, &generator)
}
pub static PUBLIC_KEY_OPS: PublicKeyOps = PublicKeyOps {
Expand All @@ -112,7 +113,7 @@
twin_mul_inefficient(&PRIVATE_KEY_OPS, g_scalar, p_scalar, p_xy, cpu)
},
q_minus_n: Elem::from_hex("%(q_minus_n)x"),
q_minus_n: PublicElem::from_hex("%(q_minus_n)x"),
// TODO: Use an optimized variable-time implementation.
scalar_inv_to_mont_vartime: |s| PRIVATE_SCALAR_OPS.scalar_inv_to_mont(s),
Expand All @@ -121,7 +122,7 @@
pub static PRIVATE_SCALAR_OPS: PrivateScalarOps = PrivateScalarOps {
scalar_ops: &SCALAR_OPS,
oneRR_mod_n: Scalar::from_hex(%(oneRR_mod_n)s),
oneRR_mod_n: PublicScalar::from_hex(%(oneRR_mod_n)s),
scalar_inv_to_mont: p%(bits)s_scalar_inv_to_mont,
};
Expand Down
8 changes: 4 additions & 4 deletions src/arithmetic/constant.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::limb::Limb;
use crate::limb::LeakyLimb;
use core::mem::size_of;

const fn parse_digit(d: u8) -> u8 {
Expand All @@ -10,16 +10,16 @@ const fn parse_digit(d: u8) -> u8 {
}

// TODO: this would be nicer as a trait, but currently traits don't support const functions
pub const fn limbs_from_hex<const LIMBS: usize>(hex: &str) -> [Limb; LIMBS] {
pub const fn limbs_from_hex<const LIMBS: usize>(hex: &str) -> [LeakyLimb; LIMBS] {
let hex = hex.as_bytes();
let mut limbs = [0; LIMBS];
let limb_nibbles = size_of::<Limb>() * 2;
let limb_nibbles = size_of::<LeakyLimb>() * 2;
let mut i = 0;

while i < hex.len() {
let char = hex[hex.len() - 1 - i];
let val = parse_digit(char);
limbs[i / limb_nibbles] |= (val as Limb) << ((i % limb_nibbles) * 4);
limbs[i / limb_nibbles] |= (val as LeakyLimb) << ((i % limb_nibbles) * 4);
i += 1;
}

Expand Down
3 changes: 2 additions & 1 deletion src/ec/curve25519/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ impl Scalar {
pub fn from_bytes_checked(bytes: [u8; SCALAR_LEN]) -> Result<Self, error::Unspecified> {
const ORDER: [limb::Limb; SCALAR_LEN / limb::LIMB_BYTES] =
limbs_from_hex("1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed");
let order = ORDER.map(limb::Limb::from);

// `bytes` is in little-endian order.
let mut reversed = bytes;
Expand All @@ -34,7 +35,7 @@ impl Scalar {
limb::parse_big_endian_in_range_and_pad_consttime(
untrusted::Input::from(&reversed),
limb::AllowZero::Yes,
&ORDER,
&order,
&mut limbs,
)?;

Expand Down
11 changes: 8 additions & 3 deletions src/ec/suite_b.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ fn verify_affine_point_is_on_the_curve(
ops: &CommonOps,
(x, y): (&Elem<R>, &Elem<R>),
) -> Result<(), error::Unspecified> {
verify_affine_point_is_on_the_curve_scaled(ops, (x, y), &ops.a, &ops.b)
verify_affine_point_is_on_the_curve_scaled(
ops,
(x, y),
&Elem::from(&ops.a),
&Elem::from(&ops.b),
)
}

// Use `verify_affine_point_is_on_the_curve` instead of this function whenever
Expand Down Expand Up @@ -101,9 +106,9 @@ fn verify_jacobian_point_is_on_the_curve(
//
let z2 = ops.elem_squared(&z);
let z4 = ops.elem_squared(&z2);
let z4_a = ops.elem_product(&z4, &ops.a);
let z4_a = ops.elem_product(&z4, &Elem::from(&ops.a));
let z6 = ops.elem_product(&z4, &z2);
let z6_b = ops.elem_product(&z6, &ops.b);
let z6_b = ops.elem_product(&z6, &Elem::from(&ops.b));
verify_affine_point_is_on_the_curve_scaled(ops, (&x, &y), &z4_a, &z6_b)?;
Ok(z2)
}
Expand Down
3 changes: 2 additions & 1 deletion src/ec/suite_b/ecdsa/verification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ impl EcdsaVerificationAlgorithm {
return Ok(());
}
if self.ops.elem_less_than(&r, &self.ops.q_minus_n) {
self.ops.scalar_ops.common.elem_add(&mut r, self.ops.n());
let n = Elem::from(self.ops.n());
self.ops.scalar_ops.common.elem_add(&mut r, &n);
if sig_r_equals_x(self.ops, &r, &x, &z2) {
return Ok(());
}
Expand Down
56 changes: 35 additions & 21 deletions src/ec/suite_b/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub use self::elem::*;
/// A field element, i.e. an element of ℤ/qℤ for the curve's field modulus
/// *q*.
pub type Elem<E> = elem::Elem<Q, E>;
type PublicElem<E> = elem::PublicElem<Q, E>;

/// Represents the (prime) order *q* of the curve's prime field.
#[derive(Clone, Copy)]
Expand All @@ -31,6 +32,7 @@ pub enum Q {}
/// A scalar. Its value is in [0, n). Zero-valued scalars are forbidden in most
/// contexts.
pub type Scalar<E = Unencoded> = elem::Elem<N, E>;
type PublicScalar<E> = elem::PublicElem<N, E>;

/// Represents the prime order *n* of the curve's group.
#[derive(Clone, Copy)]
Expand All @@ -57,10 +59,10 @@ impl Point {
pub struct CommonOps {
num_limbs: usize,
q: Modulus,
n: Elem<Unencoded>,
n: PublicElem<Unencoded>,

pub a: Elem<R>, // Must be -3 mod q
pub b: Elem<R>,
pub a: PublicElem<R>, // Must be -3 mod q
pub b: PublicElem<R>,

// In all cases, `r`, `a`, and `b` may all alias each other.
elem_mul_mont: unsafe extern "C" fn(r: *mut Limb, a: *const Limb, b: *const Limb),
Expand Down Expand Up @@ -98,8 +100,7 @@ impl CommonOps {

#[inline]
pub fn elem_unencoded(&self, a: &Elem<R>) -> Elem<Unencoded> {
const ONE: Elem<Unencoded> = Elem::from_hex("1");
self.elem_product(a, &ONE)
self.elem_product(a, &Elem::one())
}

#[inline]
Expand Down Expand Up @@ -171,8 +172,8 @@ impl CommonOps {
}

struct Modulus {
p: [Limb; MAX_LIMBS],
rr: [Limb; MAX_LIMBS],
p: [LeakyLimb; MAX_LIMBS],
rr: [LeakyLimb; MAX_LIMBS],
}

/// Operations on private keys, for ECDH and ECDSA signing.
Expand Down Expand Up @@ -301,11 +302,11 @@ pub struct PublicScalarOps {
cpu: cpu::Features,
) -> Point,
scalar_inv_to_mont_vartime: fn(s: &Scalar<Unencoded>, cpu: cpu::Features) -> Scalar<R>,
pub(super) q_minus_n: Elem<Unencoded>,
pub(super) q_minus_n: PublicElem<Unencoded>,
}

impl PublicScalarOps {
pub fn n(&self) -> &Elem<Unencoded> {
pub fn n(&self) -> &PublicElem<Unencoded> {
&self.scalar_ops.common.n
}

Expand All @@ -323,7 +324,7 @@ impl PublicScalarOps {
== b.limbs[..self.public_key_ops.common.num_limbs]
}

pub fn elem_less_than(&self, a: &Elem<Unencoded>, b: &Elem<Unencoded>) -> bool {
pub fn elem_less_than(&self, a: &Elem<Unencoded>, b: &PublicElem<Unencoded>) -> bool {
let num_limbs = self.public_key_ops.common.num_limbs;
limbs_less_than_limbs_vartime(&a.limbs[..num_limbs], &b.limbs[..num_limbs])
}
Expand All @@ -341,13 +342,14 @@ impl PublicScalarOps {
pub struct PrivateScalarOps {
pub scalar_ops: &'static ScalarOps,

oneRR_mod_n: Scalar<RR>, // 1 * R**2 (mod n). TOOD: Use One<RR>.
oneRR_mod_n: PublicScalar<RR>, // 1 * R**2 (mod n). TOOD: Use One<RR>.
scalar_inv_to_mont: fn(a: Scalar<R>, cpu: cpu::Features) -> Scalar<R>,
}

impl PrivateScalarOps {
pub(super) fn to_mont(&self, s: &Scalar<Unencoded>, cpu: cpu::Features) -> Scalar<R> {
self.scalar_ops.scalar_product(s, &self.oneRR_mod_n, cpu)
self.scalar_ops
.scalar_product(s, &Scalar::from(&self.oneRR_mod_n), cpu)
}

/// Returns the modular inverse of `a` (mod `n`). Panics if `a` is zero.
Expand Down Expand Up @@ -417,15 +419,15 @@ pub fn elem_parse_big_endian_fixed_consttime(
ops: &CommonOps,
bytes: untrusted::Input,
) -> Result<Elem<Unencoded>, error::Unspecified> {
parse_big_endian_fixed_consttime(ops, bytes, AllowZero::Yes, &ops.q.p[..ops.num_limbs])
parse_big_endian_fixed_consttime(ops, bytes, AllowZero::Yes, &ops.q.p)
}

#[inline]
pub fn scalar_parse_big_endian_fixed_consttime(
ops: &CommonOps,
bytes: untrusted::Input,
) -> Result<Scalar, error::Unspecified> {
parse_big_endian_fixed_consttime(ops, bytes, AllowZero::No, &ops.n.limbs[..ops.num_limbs])
parse_big_endian_fixed_consttime(ops, bytes, AllowZero::No, &ops.n.limbs)
}

#[inline]
Expand All @@ -434,11 +436,12 @@ pub fn scalar_parse_big_endian_variable(
allow_zero: AllowZero,
bytes: untrusted::Input,
) -> Result<Scalar, error::Unspecified> {
let n = ops.n.limbs.map(Limb::from);
let mut r = Scalar::zero();
parse_big_endian_in_range_and_pad_consttime(
bytes,
allow_zero,
&ops.n.limbs[..ops.num_limbs],
&n[..ops.num_limbs],
&mut r.limbs[..ops.num_limbs],
)?;
Ok(r)
Expand All @@ -463,16 +466,18 @@ fn parse_big_endian_fixed_consttime<M>(
ops: &CommonOps,
bytes: untrusted::Input,
allow_zero: AllowZero,
max_exclusive: &[Limb],
max_exclusive: &[LeakyLimb; MAX_LIMBS],
) -> Result<elem::Elem<M, Unencoded>, error::Unspecified> {
let max_exclusive = max_exclusive.map(Limb::from);

if bytes.len() != ops.len() {
return Err(error::Unspecified);
}
let mut r = elem::Elem::zero();
parse_big_endian_in_range_and_pad_consttime(
bytes,
allow_zero,
max_exclusive,
&max_exclusive[..ops.num_limbs],
&mut r.limbs[..ops.num_limbs],
)?;
Ok(r)
Expand Down Expand Up @@ -509,8 +514,8 @@ mod tests {

fn q_minus_n_plus_n_equals_0_test(ops: &PublicScalarOps) {
let cops = ops.scalar_ops.common;
let mut x = ops.q_minus_n;
cops.elem_add(&mut x, &cops.n);
let mut x = Elem::from(&ops.q_minus_n);
cops.elem_add(&mut x, &Elem::from(&cops.n));
assert!(cops.is_zero(&x));
}

Expand Down Expand Up @@ -958,19 +963,28 @@ mod tests {
/// TODO: We should be testing `point_mul` with points other than the generator.
#[test]
fn p256_point_mul_test() {
let generator = (
Elem::from(&p256::GENERATOR.0),
Elem::from(&p256::GENERATOR.1),
);
point_mul_base_tests(
&p256::PRIVATE_KEY_OPS,
|s, cpu| p256::PRIVATE_KEY_OPS.point_mul(s, &p256::GENERATOR, cpu),
|s, cpu| p256::PRIVATE_KEY_OPS.point_mul(s, &generator, cpu),
test_file!("ops/p256_point_mul_base_tests.txt"),
);
}

/// TODO: We should be testing `point_mul` with points other than the generator.
#[test]
fn p384_point_mul_test() {
let generator = (
Elem::from(&p384::GENERATOR.0),
Elem::from(&p384::GENERATOR.1),
);

point_mul_base_tests(
&p384::PRIVATE_KEY_OPS,
|s, cpu| p384::PRIVATE_KEY_OPS.point_mul(s, &p384::GENERATOR, cpu),
|s, cpu| p384::PRIVATE_KEY_OPS.point_mul(s, &generator, cpu),
test_file!("ops/p384_point_mul_base_tests.txt"),
);
}
Expand Down
32 changes: 29 additions & 3 deletions src/ec/suite_b/ops/elem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
use crate::{
arithmetic::{
limbs_from_hex,
montgomery::{Encoding, ProductEncoding},
montgomery::{Encoding, ProductEncoding, Unencoded},
},
limb::{Limb, LIMB_BITS},
limb::{LeakyLimb, Limb, LIMB_BITS},
};
use core::marker::PhantomData;

Expand All @@ -36,6 +36,22 @@ pub struct Elem<M, E: Encoding> {
pub(super) encoding: PhantomData<E>,
}

pub struct PublicElem<M, E: Encoding> {
pub(super) limbs: [LeakyLimb; MAX_LIMBS],
pub(super) m: PhantomData<M>,
pub(super) encoding: PhantomData<E>,
}

impl<M, E: Encoding> From<&PublicElem<M, E>> for Elem<M, E> {
fn from(value: &PublicElem<M, E>) -> Self {
Self {
limbs: core::array::from_fn(|i| Limb::from(value.limbs[i])),
m: value.m,
encoding: value.encoding,
}
}
}

impl<M, E: Encoding> Elem<M, E> {
// There's no need to convert `value` to the Montgomery domain since
// 0 * R**2 (mod m) == 0, so neither the modulus nor the encoding are needed
Expand All @@ -47,9 +63,19 @@ impl<M, E: Encoding> Elem<M, E> {
encoding: PhantomData,
}
}
}

impl<M> Elem<M, Unencoded> {
pub fn one() -> Self {
let mut r = Self::zero();
r.limbs[0] = 1;
r
}
}

impl<M, E: Encoding> PublicElem<M, E> {
pub const fn from_hex(hex: &str) -> Self {
Elem {
Self {
limbs: limbs_from_hex(hex),
m: PhantomData,
encoding: PhantomData,
Expand Down
Loading

0 comments on commit 6ce531b

Please sign in to comment.