From df14ce91553759c9d7c9acfa014dd920304b1d6e Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Fri, 6 Dec 2024 14:30:48 -0800 Subject: [PATCH] arithmetic: Remove `PartialEq` & `Debug` for `LimbMask`. Add `LimbMask::leak()` and change all callers to use it. This proactively prevents accidental leakage of the `LimbMask` value and makes it easier to audit the code for places where we intentionally leak the value of a `LimbMask`. Within the tests, use a `#[cfg(test)]-only wrapper `leak_in_test` to make it easier to see that those leaks are uninteresting. --- src/arithmetic/bigint.rs | 8 ++--- src/arithmetic/bigint/boxed_limbs.rs | 4 +-- src/arithmetic/bigint/modulusvalue.rs | 8 ++--- src/arithmetic/bigint/private_exponent.rs | 4 +-- src/ec/suite_b.rs | 4 +-- src/ec/suite_b/ops.rs | 2 +- src/limb.rs | 39 +++++++++++++++-------- 7 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index 0c7f2d4e3..cbb93b57a 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -46,7 +46,7 @@ use crate::{ arithmetic::montgomery::*, bits::BitLength, c, error, - limb::{self, Limb, LimbMask, LIMB_BITS}, + limb::{self, Limb, LIMB_BITS}, }; use alloc::vec; use core::{marker::PhantomData, num::NonZeroU64}; @@ -85,7 +85,7 @@ impl Clone for Elem { impl Elem { #[inline] pub fn is_zero(&self) -> bool { - limb::limbs_are_zero_constant_time(&self.limbs) == LimbMask::True + limb::limbs_are_zero_constant_time(&self.limbs).leak() } } @@ -132,7 +132,7 @@ impl Elem { } fn is_one(&self) -> bool { - limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True + limb::limbs_equal_limb_constant_time(&self.limbs, 1).leak() } } @@ -696,7 +696,7 @@ pub fn elem_verify_equal_consttime( a: &Elem, b: &Elem, ) -> Result<(), error::Unspecified> { - if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True { + if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs).leak() { Ok(()) } else { Err(error::Unspecified) diff --git a/src/arithmetic/bigint/boxed_limbs.rs b/src/arithmetic/bigint/boxed_limbs.rs index 361d8689d..c49ace220 100644 --- a/src/arithmetic/bigint/boxed_limbs.rs +++ b/src/arithmetic/bigint/boxed_limbs.rs @@ -15,7 +15,7 @@ use super::Modulus; use crate::{ error, - limb::{self, Limb, LimbMask, LIMB_BYTES}, + limb::{self, Limb, LIMB_BYTES}, }; use alloc::{boxed::Box, vec}; use core::{ @@ -88,7 +88,7 @@ impl BoxedLimbs { ) -> Result { let mut r = Self::zero(m.limbs().len()); limb::parse_big_endian_and_pad_consttime(input, &mut r)?; - if limb::limbs_less_than_limbs_consttime(&r, m.limbs()) != LimbMask::True { + if !limb::limbs_less_than_limbs_consttime(&r, m.limbs()).leak() { return Err(error::Unspecified); } Ok(r) diff --git a/src/arithmetic/bigint/modulusvalue.rs b/src/arithmetic/bigint/modulusvalue.rs index 075873e86..3bf18868f 100644 --- a/src/arithmetic/bigint/modulusvalue.rs +++ b/src/arithmetic/bigint/modulusvalue.rs @@ -19,7 +19,7 @@ use super::{ use crate::{ bits::BitLength, error, - limb::{self, Limb, LimbMask}, + limb::{self, Limb}, }; /// `OwnedModulus`, without the overhead of Montgomery multiplication support. @@ -47,10 +47,10 @@ impl OwnedModulusValue { if n.len() < MODULUS_MIN_LIMBS { return Err(error::KeyRejected::unexpected_error()); } - if limb::limbs_are_even_constant_time(&n) != LimbMask::False { + if limb::limbs_are_even_constant_time(&n).leak() { return Err(error::KeyRejected::invalid_component()); } - if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False { + if limb::limbs_less_than_limb_constant_time(&n, 3).leak() { return Err(error::KeyRejected::unexpected_error()); } @@ -62,7 +62,7 @@ impl OwnedModulusValue { pub fn verify_less_than(&self, l: &Modulus) -> Result<(), error::Unspecified> { if self.len_bits() > l.len_bits() || (self.limbs.len() == l.limbs().len() - && limb::limbs_less_than_limbs_consttime(&self.limbs, l.limbs()) != LimbMask::True) + && !limb::limbs_less_than_limbs_consttime(&self.limbs, l.limbs()).leak()) { return Err(error::Unspecified); } diff --git a/src/arithmetic/bigint/private_exponent.rs b/src/arithmetic/bigint/private_exponent.rs index b416cdcd9..725049a76 100644 --- a/src/arithmetic/bigint/private_exponent.rs +++ b/src/arithmetic/bigint/private_exponent.rs @@ -12,7 +12,7 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use super::{limb, BoxedLimbs, Limb, LimbMask, Modulus}; +use super::{limb, BoxedLimbs, Limb, Modulus}; use crate::error; use alloc::boxed::Box; @@ -36,7 +36,7 @@ impl PrivateExponent { // `p - 1` and so we know `dP < p - 1`. // // Further we know `dP != 0` because `dP` is not even. - if limb::limbs_are_even_constant_time(&dP) != LimbMask::False { + if limb::limbs_are_even_constant_time(&dP).leak() { return Err(error::Unspecified); } diff --git a/src/ec/suite_b.rs b/src/ec/suite_b.rs index 9c322165f..7753052e9 100644 --- a/src/ec/suite_b.rs +++ b/src/ec/suite_b.rs @@ -15,7 +15,7 @@ //! Elliptic curve operations on P-256 & P-384. use self::ops::*; -use crate::{arithmetic::montgomery::*, cpu, ec, error, io::der, limb::LimbMask, pkcs8}; +use crate::{arithmetic::montgomery::*, cpu, ec, error, io::der, pkcs8}; // NIST SP 800-56A Step 3: "If q is an odd prime p, verify that // yQ**2 = xQ**3 + axQ + b in GF(p), where the arithmetic is performed modulo @@ -146,7 +146,7 @@ fn verify_affine_point_is_on_the_curve_scaled( ops.elem_mul(&mut rhs, x); ops.elem_add(&mut rhs, b_scaled); - if ops.elems_are_equal(&lhs, &rhs) != LimbMask::True { + if !ops.elems_are_equal(&lhs, &rhs).leak() { return Err(error::Unspecified); } diff --git a/src/ec/suite_b/ops.rs b/src/ec/suite_b/ops.rs index 85d50f7ff..fda19031b 100644 --- a/src/ec/suite_b/ops.rs +++ b/src/ec/suite_b/ops.rs @@ -128,7 +128,7 @@ impl CommonOps { #[inline] pub fn is_zero(&self, a: &elem::Elem) -> bool { - limbs_are_zero_constant_time(&a.limbs[..self.num_limbs]) == LimbMask::True + limbs_are_zero_constant_time(&a.limbs[..self.num_limbs]).leak() } pub fn elem_verify_is_not_zero(&self, a: &Elem) -> Result<(), error::Unspecified> { diff --git a/src/limb.rs b/src/limb.rs index 1a37147c5..0927fc12f 100644 --- a/src/limb.rs +++ b/src/limb.rs @@ -32,12 +32,19 @@ pub const LIMB_BITS: usize = usize_from_u32(Limb::BITS); #[cfg_attr(target_pointer_width = "64", repr(u64))] #[cfg_attr(target_pointer_width = "32", repr(u32))] -#[derive(Debug, PartialEq)] pub enum LimbMask { + #[cfg_attr(not(test), allow(dead_code))] // Only constructed by non-Rust & test code. True = Limb::MAX, + #[cfg_attr(not(test), allow(dead_code))] // Only constructed by non-Rust & test code. False = 0, } +impl LimbMask { + pub fn leak(self) -> bool { + !matches!(self, LimbMask::False) + } +} + pub const LIMB_BYTES: usize = (LIMB_BITS + 7) / 8; #[inline] @@ -58,7 +65,7 @@ pub fn limbs_less_than_limbs_consttime(a: &[Limb], b: &[Limb]) -> LimbMask { #[inline] pub fn limbs_less_than_limbs_vartime(a: &[Limb], b: &[Limb]) -> bool { - limbs_less_than_limbs_consttime(a, b) == LimbMask::True + limbs_less_than_limbs_consttime(a, b).leak() } #[inline] @@ -142,11 +149,11 @@ pub fn parse_big_endian_in_range_and_pad_consttime( result: &mut [Limb], ) -> Result<(), error::Unspecified> { parse_big_endian_and_pad_consttime(input, result)?; - if limbs_less_than_limbs_consttime(result, max_exclusive) != LimbMask::True { + if !limbs_less_than_limbs_consttime(result, max_exclusive).leak() { return Err(error::Unspecified); } if allow_zero != AllowZero::Yes { - if limbs_are_zero_constant_time(result) != LimbMask::False { + if limbs_are_zero_constant_time(result).leak() { return Err(error::Unspecified); } } @@ -362,6 +369,10 @@ mod tests { const MAX: Limb = Limb::MAX; + fn leak_in_test(a: LimbMask) -> bool { + a.leak() + } + #[test] fn test_limbs_are_even() { static EVENS: &[&[Limb]] = &[ @@ -376,7 +387,7 @@ mod tests { &[0, 0, 0, 0, MAX], ]; for even in EVENS { - assert_eq!(limbs_are_even_constant_time(even), LimbMask::True); + assert!(leak_in_test(limbs_are_even_constant_time(even))); } static ODDS: &[&[Limb]] = &[ &[1], @@ -389,7 +400,7 @@ mod tests { &[1, 0, 0, 0, MAX], ]; for odd in ODDS { - assert_eq!(limbs_are_even_constant_time(odd), LimbMask::False); + assert!(!leak_in_test(limbs_are_even_constant_time(odd))); } } @@ -418,20 +429,20 @@ mod tests { #[test] fn test_limbs_are_zero() { for zero in ZEROES { - assert_eq!(limbs_are_zero_constant_time(zero), LimbMask::True); + assert!(leak_in_test(limbs_are_zero_constant_time(zero))); } for nonzero in NONZEROES { - assert_eq!(limbs_are_zero_constant_time(nonzero), LimbMask::False); + assert!(!leak_in_test(limbs_are_zero_constant_time(nonzero))); } } #[test] fn test_limbs_equal_limb() { for zero in ZEROES { - assert_eq!(limbs_equal_limb_constant_time(zero, 0), LimbMask::True); + assert!(leak_in_test(limbs_equal_limb_constant_time(zero, 0))); } for nonzero in NONZEROES { - assert_eq!(limbs_equal_limb_constant_time(nonzero, 0), LimbMask::False); + assert!(!leak_in_test(limbs_equal_limb_constant_time(nonzero, 0))); } static EQUAL: &[(&[Limb], Limb)] = &[ (&[1], 1), @@ -442,7 +453,7 @@ mod tests { (&[0b100, 0], 0b100), ]; for &(a, b) in EQUAL { - assert_eq!(limbs_equal_limb_constant_time(a, b), LimbMask::True); + assert!(leak_in_test(limbs_equal_limb_constant_time(a, b))); } static UNEQUAL: &[(&[Limb], Limb)] = &[ (&[0], 1), @@ -456,7 +467,7 @@ mod tests { (&[MAX, 1], MAX), ]; for &(a, b) in UNEQUAL { - assert_eq!(limbs_equal_limb_constant_time(a, b), LimbMask::False); + assert!(!leak_in_test(limbs_equal_limb_constant_time(a, b))); } } @@ -473,7 +484,7 @@ mod tests { (&[MAX - 1, 0], MAX), ]; for &(a, b) in LESSER { - assert_eq!(limbs_less_than_limb_constant_time(a, b), LimbMask::True); + assert!(leak_in_test(limbs_less_than_limb_constant_time(a, b))); } static EQUAL: &[(&[Limb], Limb)] = &[ (&[0], 0), @@ -492,7 +503,7 @@ mod tests { (&[MAX], MAX - 1), ]; for &(a, b) in EQUAL.iter().chain(GREATER.iter()) { - assert_eq!(limbs_less_than_limb_constant_time(a, b), LimbMask::False); + assert!(!leak_in_test(limbs_less_than_limb_constant_time(a, b))); } }