diff --git a/Cargo.lock b/Cargo.lock index 8043df46..b8af87f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -319,10 +319,10 @@ dependencies = [ [[package]] name = "crypto-bigint" version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +source = "git+https://github.com/risc0/RustCrypto-crypto-bigint?tag=v0.5.5-risczero.0#3ab63a6f1048833f7047d5a50532e4a4cc789384" dependencies = [ "generic-array", + "getrandom", "rand_core", "subtle", "zeroize", @@ -474,9 +474,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", @@ -609,6 +609,7 @@ dependencies = [ "criterion", "ecdsa", "elliptic-curve", + "hex", "hex-literal", "num-bigint", "num-traits", @@ -878,9 +879,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b" dependencies = [ "unicode-ident", ] @@ -913,9 +914,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.23" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" dependencies = [ "proc-macro2", ] @@ -1147,8 +1148,7 @@ dependencies = [ [[package]] name = "sha2" version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +source = "git+https://github.com/risc0/RustCrypto-hashes?tag=sha2-v0.10.8-risczero.0#244dc3b08788f7a4ccce14c66896ae3b4f24c166" dependencies = [ "cfg-if", "cpufeatures", diff --git a/Cargo.toml b/Cargo.toml index 15523c0f..7a47be9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,11 @@ members = [ [profile.dev] opt-level = 2 + +[patch.crates-io.crypto-bigint] +git = "https://github.com/risc0/RustCrypto-crypto-bigint" +tag = "v0.5.5-risczero.0" + +[patch.crates-io.sha2] +git = "https://github.com/risc0/RustCrypto-hashes" +tag = "sha2-v0.10.8-risczero.0" diff --git a/k256/Cargo.toml b/k256/Cargo.toml index a29ac2b7..21069f3a 100644 --- a/k256/Cargo.toml +++ b/k256/Cargo.toml @@ -31,15 +31,21 @@ signature = { version = "2", optional = true } [dev-dependencies] blobby = "0.3" -criterion = "0.5" ecdsa-core = { version = "0.16", package = "ecdsa", default-features = false, features = ["dev"] } hex-literal = "0.4" num-bigint = "0.4" num-traits = "0.2" -proptest = "1.4" rand_core = { version = "0.6", features = ["getrandom"] } sha3 = { version = "0.10", default-features = false } +[target.'cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))'.dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } +proptest = "1.4" + +[target.'cfg(all(target_os = "zkvm", target_arch = "riscv32"))'.dev-dependencies] +proptest = { version = "1.4", default-features = false, features = ["alloc"] } +hex = "0.4" + [features] default = ["arithmetic", "ecdsa", "pkcs8", "precomputed-tables", "schnorr", "std"] alloc = ["ecdsa-core?/alloc", "elliptic-curve/alloc"] diff --git a/k256/src/arithmetic/field.rs b/k256/src/arithmetic/field.rs index db2f556c..d8e6351a 100644 --- a/k256/src/arithmetic/field.rs +++ b/k256/src/arithmetic/field.rs @@ -5,7 +5,9 @@ use cfg_if::cfg_if; cfg_if! { - if #[cfg(target_pointer_width = "32")] { + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + mod field_8x32_risc0; + } else if #[cfg(target_pointer_width = "32")] { mod field_10x26; } else if #[cfg(target_pointer_width = "64")] { mod field_5x52; @@ -20,7 +22,9 @@ cfg_if! { use field_impl::FieldElementImpl; } else { cfg_if! { - if #[cfg(target_pointer_width = "32")] { + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + use field_8x32_risc0::FieldElement8x32R0 as FieldElementImpl; + } else if #[cfg(target_pointer_width = "32")] { use field_10x26::FieldElement10x26 as FieldElementImpl; } else if #[cfg(target_pointer_width = "64")] { use field_5x52::FieldElement5x52 as FieldElementImpl; @@ -104,6 +108,12 @@ impl FieldElement { Self(FieldElementImpl::from_u64(w)) } + /// Convert a `i64` to a field element. + /// Returned value may be only weakly normalized. + pub const fn from_i64(w: i64) -> Self { + Self(FieldElementImpl::from_i64(w)) + } + /// Returns the SEC1 encoding of this field element. pub fn to_bytes(self) -> FieldBytes { self.0.normalize().to_bytes() @@ -141,7 +151,11 @@ impl FieldElement { /// Returns 2*self. /// Doubles the magnitude. pub fn double(&self) -> Self { - Self(self.0.add(&(self.0))) + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + self.mul_single(2) + } else { + Self(self.0.add(&(self.0))) + } } /// Returns self * rhs mod p @@ -360,6 +374,12 @@ impl From for FieldElement { } } +impl From for FieldElement { + fn from(k: i64) -> Self { + Self(FieldElementImpl::from_i64(k)) + } +} + impl PartialEq for FieldElement { fn eq(&self, other: &Self) -> bool { self.0.ct_eq(&(other.0)).into() @@ -761,7 +781,16 @@ mod tests { } } + fn config() -> ProptestConfig { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + ProptestConfig::with_cases(1) + } else { + ProptestConfig::default() + } + } + proptest! { + #![proptest_config(config())] #[test] fn fuzzy_add( diff --git a/k256/src/arithmetic/field/field_10x26.rs b/k256/src/arithmetic/field/field_10x26.rs index 6ea525a0..cdcb0286 100644 --- a/k256/src/arithmetic/field/field_10x26.rs +++ b/k256/src/arithmetic/field/field_10x26.rs @@ -87,6 +87,14 @@ impl FieldElement10x26 { Self([w0, w1, w2, 0, 0, 0, 0, 0, 0, 0]) } + pub const fn from_i64(val: i64) -> Self { + // Compute val_abs = |val| + let val_mask = val >> 63; + let val_abs = ((val + val_mask) ^ val_mask) as u64; + + Self::from_u64(val_abs).negate(1).normalize_weak() + } + /// Returns the SEC1 encoding of this field element. pub fn to_bytes(self) -> FieldBytes { let mut r = FieldBytes::default(); @@ -126,7 +134,7 @@ impl FieldElement10x26 { } /// Adds `x * (2^256 - modulus)`. - fn add_modulus_correction(&self, x: u32) -> Self { + const fn add_modulus_correction(&self, x: u32) -> Self { // add (2^256 - modulus) * x to the first limb let t0 = self.0[0] + x * 0x3D1u32; @@ -164,7 +172,7 @@ impl FieldElement10x26 { /// Subtracts the overflow in the last limb and return it with the new field element. /// Equivalent to subtracting a multiple of 2^256. - fn subtract_modulus_approximation(&self) -> (Self, u32) { + const fn subtract_modulus_approximation(&self) -> (Self, u32) { let x = self.0[9] >> 22; let t9 = self.0[9] & 0x03FFFFFu32; // equivalent to self -= 2^256 * x ( @@ -187,7 +195,7 @@ impl FieldElement10x26 { } /// Brings the field element's magnitude to 1, but does not necessarily normalize it. - pub fn normalize_weak(&self) -> Self { + pub const fn normalize_weak(&self) -> Self { // Reduce t9 at the start so there will be at most a single carry from the first pass let (t, x) = self.subtract_modulus_approximation(); diff --git a/k256/src/arithmetic/field/field_5x52.rs b/k256/src/arithmetic/field/field_5x52.rs index c14de07a..881b5cdf 100644 --- a/k256/src/arithmetic/field/field_5x52.rs +++ b/k256/src/arithmetic/field/field_5x52.rs @@ -84,6 +84,14 @@ impl FieldElement5x52 { Self([w0, w1, 0, 0, 0]) } + pub const fn from_i64(val: i64) -> Self { + // Compute val_abs = |val| + let val_mask = val >> 63; + let val_abs = ((val + val_mask) ^ val_mask) as u64; + + Self::from_u64(val_abs).negate(1).normalize_weak() + } + /// Returns the SEC1 encoding of this field element. pub fn to_bytes(self) -> FieldBytes { let mut ret = FieldBytes::default(); @@ -123,7 +131,7 @@ impl FieldElement5x52 { } /// Adds `x * (2^256 - modulus)`. - fn add_modulus_correction(&self, x: u64) -> Self { + const fn add_modulus_correction(&self, x: u64) -> Self { // add (2^256 - modulus) * x to the first limb let t0 = self.0[0] + x * 0x1000003D1u64; @@ -145,7 +153,7 @@ impl FieldElement5x52 { /// Subtracts the overflow in the last limb and return it with the new field element. /// Equivalent to subtracting a multiple of 2^256. - fn subtract_modulus_approximation(&self) -> (Self, u64) { + const fn subtract_modulus_approximation(&self) -> (Self, u64) { let x = self.0[4] >> 48; let t4 = self.0[4] & 0x0FFFFFFFFFFFFu64; // equivalent to self -= 2^256 * x (Self([self.0[0], self.0[1], self.0[2], self.0[3], t4]), x) @@ -162,7 +170,7 @@ impl FieldElement5x52 { } /// Brings the field element's magnitude to 1, but does not necessarily normalize it. - pub fn normalize_weak(&self) -> Self { + pub const fn normalize_weak(&self) -> Self { // Reduce t4 at the start so there will be at most a single carry from the first pass let (t, x) = self.subtract_modulus_approximation(); diff --git a/k256/src/arithmetic/field/field_8x32_risc0.rs b/k256/src/arithmetic/field/field_8x32_risc0.rs new file mode 100644 index 00000000..b76d4035 --- /dev/null +++ b/k256/src/arithmetic/field/field_8x32_risc0.rs @@ -0,0 +1,304 @@ +//! Field element modulo the curve internal modulus using 32-bit limbs. +#![allow(unsafe_code)] + +use crate::FieldBytes; +use elliptic_curve::{ + bigint::{risc0, ArrayEncoding, Integer, Limb, Zero, U256}, + subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}, + zeroize::Zeroize, +}; + +/// Base field characteristic for secp256k1 as an 8x32 big integer, least to most significant. +const MODULUS: U256 = + U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F"); + +/// Low two words of 2^256 - MODULUS, used for correcting the value after addition mod 2^256. +const MODULUS_CORRECTION: U256 = U256::ZERO.wrapping_sub(&MODULUS); + +/// Scalars modulo SECP256k1 modulus (2^256 - 2^32 - 2^9 - 2^8 - 2^7 - 2^6 - 2^4 - 1). +/// Uses 8 32-bit limbs (little-endian) and acceleration support from the RISC Zero rv32im impl. +/// Unlike the 10x26 and 8x52 implementations, the values in this implementation are always +/// fully reduced and normalized as there is no extra room in the representation. +/// +/// NOTE: This implementation will only run inside the RISC Zero guest. As a result, the +/// requirements for constant-timeness are different than on a physical platform. +#[derive(Clone, Copy, Debug)] +pub struct FieldElement8x32R0(pub(crate) U256); + +impl FieldElement8x32R0 { + /// Zero element. + pub const ZERO: Self = Self(U256::ZERO); + + /// Multiplicative identity. + pub const ONE: Self = Self(U256::ONE); + + /// Attempts to parse the given byte array as an SEC1-encoded field element. + /// Does not check the result for being in the correct range. + pub(crate) const fn from_bytes_unchecked(bytes: &[u8; 32]) -> Self { + Self(U256::from_be_slice(&bytes.as_slice())) + } + + /// Attempts to parse the given byte array as an SEC1-encoded field element. + /// + /// Returns None if the byte array does not contain a big-endian integer in the range + /// [0, p). + pub fn from_bytes(bytes: &FieldBytes) -> CtOption { + let res = Self::from_bytes_unchecked(bytes.as_ref()); + let overflow = res.get_overflow(); + + CtOption::new(res, !overflow) + } + + pub const fn from_u64(val: u64) -> Self { + let w0 = val as u32; + let w1 = (val >> 32) as u32; + Self(U256::from_words([w0, w1, 0, 0, 0, 0, 0, 0])) + } + + pub const fn from_i64(val: i64) -> Self { + // Compute val_abs = |val| + let val_mask = val >> 63; + let val_abs = ((val + val_mask) ^ val_mask) as u64; + + Self::from_u64(val_abs).negate_const() + } + + /// Returns the SEC1 encoding of this field element. + pub fn to_bytes(self) -> FieldBytes { + self.0.to_be_byte_array() + } + + /// Checks if the field element is greater or equal to the modulus. + fn get_overflow(&self) -> Choice { + let (_, carry) = self.0.adc(&MODULUS_CORRECTION, Limb(0)); + Choice::from(carry.0 as u8) + } + + /// Brings the field element's magnitude to 1, but does not necessarily normalize it. + /// + /// NOTE: In RISC Zero, this is a no-op since weak normalization is not an operation that + /// needs to be performed between calls to arithmetic routines. + #[inline(always)] + pub const fn normalize_weak(&self) -> Self { + Self(self.0) + } + + /// Returns the fully normalized and canonical representation of the value. + #[inline(always)] + pub fn normalize(&self) -> Self { + // When the prover is cooperative, the value is always normalized. + assert!(!bool::from(self.get_overflow())); + self.clone() + } + + /// Checks if the field element becomes zero if normalized. + pub fn normalizes_to_zero(&self) -> Choice { + self.0.ct_eq(&U256::ZERO) | self.0.ct_eq(&MODULUS) + } + + /// Determine if this `FieldElement8x32R0` is zero. + /// + /// # Returns + /// + /// If zero, return `Choice(1)`. Otherwise, return `Choice(0)`. + pub fn is_zero(&self) -> Choice { + self.0.is_zero() + } + + /// Determine if this `FieldElement8x32R0` is odd in the SEC1 sense: `self mod 2 == 1`. + /// + /// Value must be normalized before calling is_odd. + /// + /// # Returns + /// + /// If odd, return `Choice(1)`. Otherwise, return `Choice(0)`. + pub fn is_odd(&self) -> Choice { + self.0.is_odd() + } + + #[cfg(debug_assertions)] + pub const fn max_magnitude() -> u32 { + // Results as always reduced, so this implementation does not need to track magnitude. + u32::MAX + } + + /// Returns -self. + const fn negate_const(&self) -> Self { + let (s, borrow) = MODULUS.sbb(&self.0, Limb(0)); + assert!(borrow.0 == 0); + Self(s) + } + + /// Returns -self. + pub fn negate(&self, _magnitude: u32) -> Self { + self.mul(&Self::ONE.negate_const()) + } + + /// Returns self + rhs mod p. + /// Sums the magnitudes. + pub fn add(&self, rhs: &Self) -> Self { + let self_limbs = self.0.as_limbs(); + let rhs_limbs = rhs.0.as_limbs(); + + // Carrying addition of self and rhs, with the overflow correction added in. + let (a0, carry0) = self_limbs[0].adc(rhs_limbs[0], MODULUS_CORRECTION.as_limbs()[0]); + let (a1, carry1) = self_limbs[1].adc( + rhs_limbs[1], + carry0.wrapping_add(MODULUS_CORRECTION.as_limbs()[1]), + ); + let (a2, carry2) = self_limbs[2].adc(rhs_limbs[2], carry1); + let (a3, carry3) = self_limbs[3].adc(rhs_limbs[3], carry2); + let (a4, carry4) = self_limbs[4].adc(rhs_limbs[4], carry3); + let (a5, carry5) = self_limbs[5].adc(rhs_limbs[5], carry4); + let (a6, carry6) = self_limbs[6].adc(rhs_limbs[6], carry5); + let (a7, carry7) = self_limbs[7].adc(rhs_limbs[7], carry6); + let a = U256::from([a0, a1, a2, a3, a4, a5, a6, a7]); + + // If the inputs are not in the range [0, p), then then carry7 may be greater than 1, + // indicating more than one overflow occurred. In this case, the code below will not + // correct the value. If the host is cooperative, this should never happen. + assert!(carry7.0 <= 1); + + // If a carry occured, then the correction was already added and the result is correct. + // If a carry did not occur, the correction needs to be removed. Result will be in [0, p). + // Wrap and unwrap to prevent the compiler interpreting this as a boolean, potentially + // introducing non-constant time code. + let mask = 1 - Choice::from(carry7.0 as u8).unwrap_u8(); + let c0 = MODULUS_CORRECTION.as_words()[0] * (mask as u32); + let c1 = MODULUS_CORRECTION.as_words()[1] * (mask as u32); + let correction = U256::from_words([c0, c1, 0, 0, 0, 0, 0, 0]); + + // The correction value was either already added to a, or is 0, so this sub will not + // underflow. + Self(a.wrapping_sub(&correction)) + } + + /// Returns self * rhs mod p + pub fn mul(&self, rhs: &Self) -> Self { + Self(risc0::modmul_u256_denormalized(&self.0, &rhs.0, &MODULUS)) + } + + /// Multiplies by a single-limb integer. + pub fn mul_single(&self, rhs: u32) -> Self { + Self(risc0::modmul_u256_denormalized( + &self.0, + &U256::from_words([rhs, 0, 0, 0, 0, 0, 0, 0]), + &MODULUS, + )) + } + + /// Returns self * self + pub fn square(&self) -> Self { + Self(risc0::modmul_u256_denormalized(&self.0, &self.0, &MODULUS)) + } +} + +impl Default for FieldElement8x32R0 { + fn default() -> Self { + Self::ZERO + } +} + +impl ConditionallySelectable for FieldElement8x32R0 { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Self(U256::conditional_select(&a.0, &b.0, choice)) + } +} + +impl ConstantTimeEq for FieldElement8x32R0 { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +impl Zeroize for FieldElement8x32R0 { + fn zeroize(&mut self) { + self.0.zeroize(); + } +} + +#[cfg(test)] +mod tests { + use super::FieldElement8x32R0 as F; + use hex_literal::hex; + + const VAL_A: F = F::from_bytes_unchecked(&hex!( + "EC08EAC2CBCEFE58E61038DCA45BA2B4A56BDF05A3595EBEE1BCFC488889C1CF" + )); + const VAL_B: F = F::from_bytes_unchecked(&hex!( + "9FC3E90D2FAD03C8669F437A26374FA694CA76A7913C5E016322EBAA5C7616C5" + )); + + extern crate alloc; + + fn as_hex(&elem: &F) -> alloc::string::String { + // Call normalize here simply to assert that the value is normalized. + ::hex::encode_upper(elem.normalize().to_bytes()) + } + + #[test] + fn add() { + let expected: F = F::from_bytes_unchecked(&hex!( + "8BCCD3CFFB7C02214CAF7C56CA92F25B3A3655AD3495BCC044DFE7F3E4FFDC65" + )); + assert_eq!(as_hex(&VAL_A.add(&VAL_B)), as_hex(&expected)); + } + + // Tests the other "code path" returning the reduced or non-reduced result. + #[test] + fn add_negated() { + let expected: F = F::from_bytes_unchecked(&hex!( + "74332C300483FDDEB35083A9356D0DA4C5C9AA52CB6A433FBB20180B1B001FCA" + )); + assert_eq!( + as_hex(&VAL_A.negate(0).add(&VAL_B.negate(0))), + as_hex(&expected) + ); + } + + #[test] + fn negate() { + let expected: F = F::from_bytes_unchecked(&hex!( + "13F7153D343101A719EFC7235BA45D4B5A9420FA5CA6A1411E4303B677763A60" + )); + assert_eq!(as_hex(&VAL_A.negate(0)), as_hex(&expected)); + assert_eq!(as_hex(&VAL_A.add(&VAL_A.negate(0))), as_hex(&F::ZERO)); + } + + #[test] + fn mul() { + let expected: F = F::from_bytes_unchecked(&hex!( + "26B936E25A89EBAF821A46DC6BD8A0B1F0ED329412FA75FADF9A494D6F0EB4DB" + )); + assert_eq!(as_hex(&VAL_A.mul(&VAL_B)), as_hex(&expected)); + } + + #[test] + fn mul_zero() { + assert_eq!(as_hex(&VAL_A.mul(&F::ZERO)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&VAL_B.mul(&F::ZERO)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&F::ZERO.mul(&F::ZERO)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&F::ONE.mul(&F::ZERO)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&F::ONE.negate(0).mul(&F::ZERO)), as_hex(&F::ZERO)); + } + + #[test] + fn mul_one() { + assert_eq!(as_hex(&VAL_A.mul(&F::ONE)), as_hex(&VAL_A)); + assert_eq!(as_hex(&VAL_B.mul(&F::ONE)), as_hex(&VAL_B)); + assert_eq!(as_hex(&F::ZERO.mul(&F::ONE)), as_hex(&F::ZERO)); + assert_eq!(as_hex(&F::ONE.mul(&F::ONE)), as_hex(&F::ONE)); + assert_eq!( + as_hex(&F::ONE.negate(0).mul(&F::ONE)), + as_hex(&F::ONE.negate(0)) + ); + } + + #[test] + fn square() { + let expected: F = F::from_bytes_unchecked(&hex!( + "111671376746955B968F48A94AFBACD243EA840AAE13EF85BC39AAE9552D8EDA" + )); + assert_eq!(as_hex(&VAL_A.square()), as_hex(&expected)); + } +} diff --git a/k256/src/arithmetic/field/field_impl.rs b/k256/src/arithmetic/field/field_impl.rs index 6c7820b1..8d46ea4e 100644 --- a/k256/src/arithmetic/field/field_impl.rs +++ b/k256/src/arithmetic/field/field_impl.rs @@ -8,11 +8,17 @@ use elliptic_curve::{ zeroize::Zeroize, }; -#[cfg(target_pointer_width = "32")] -use super::field_10x26::FieldElement10x26 as FieldElementUnsafeImpl; - -#[cfg(target_pointer_width = "64")] -use super::field_5x52::FieldElement5x52 as FieldElementUnsafeImpl; +cfg_if::cfg_if! { + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + use super::field_8x32_risc0::FieldElement8x32R0 as FieldElementUnsafeImpl; + } else if #[cfg(target_pointer_width = "32")] { + use super::field_10x26::FieldElement10x26 as FieldElementUnsafeImpl; + } else if #[cfg(target_pointer_width = "64")] { + use super::field_5x52::FieldElement5x52 as FieldElementUnsafeImpl; + } else { + compile_error!("unsupported target word size (i.e. target_pointer_width)"); + } +} #[derive(Clone, Copy, Debug)] pub struct FieldElementImpl { @@ -54,10 +60,19 @@ impl FieldElementImpl { fn new(value: &FieldElementUnsafeImpl, magnitude: u32) -> Self { debug_assert!(magnitude <= FieldElementUnsafeImpl::max_magnitude()); - Self { - value: *value, - magnitude, - normalized: false, + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + // In the RISC Zero field impl, magnitude is always 1. + Self { + value: *value, + magnitude: 1, + normalized: false, + } + } else { + Self { + value: *value, + magnitude, + normalized: false, + } } } @@ -70,6 +85,12 @@ impl FieldElementImpl { Self::new_normalized(&FieldElementUnsafeImpl::from_u64(val)) } + /// Convert a `i64` to a field element. + /// Returned value may be only weakly normalized. + pub(crate) const fn from_i64(w: i64) -> Self { + Self::new_weak_normalized(&FieldElementUnsafeImpl::from_i64(w)) + } + pub fn from_bytes(bytes: &FieldBytes) -> CtOption { let value = FieldElementUnsafeImpl::from_bytes(bytes); CtOption::map(value, |x| Self::new_normalized(&x)) diff --git a/k256/src/arithmetic/hash2curve.rs b/k256/src/arithmetic/hash2curve.rs index 5ce39321..264ec167 100644 --- a/k256/src/arithmetic/hash2curve.rs +++ b/k256/src/arithmetic/hash2curve.rs @@ -415,7 +415,15 @@ mod tests { Scalar(reduced_scalar) }; - proptest!(ProptestConfig::with_cases(1000), |(b0 in ANY, b1 in ANY, b2 in ANY, b3 in ANY, b4 in ANY, b5 in ANY)| { + fn config() -> ProptestConfig { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + ProptestConfig::with_cases(1) + } else { + ProptestConfig::with_cases(1000) + } + } + + proptest!(config(), |(b0 in ANY, b1 in ANY, b2 in ANY, b3 in ANY, b4 in ANY, b5 in ANY)| { let mut data = GenericArray::default(); data[..8].copy_from_slice(&b0.to_be_bytes()); data[8..16].copy_from_slice(&b1.to_be_bytes()); diff --git a/k256/src/arithmetic/mul.rs b/k256/src/arithmetic/mul.rs index 6354a7d2..306752de 100644 --- a/k256/src/arithmetic/mul.rs +++ b/k256/src/arithmetic/mul.rs @@ -58,8 +58,10 @@ use once_cell::sync::Lazy; /// Lookup table containing precomputed values `[p, 2p, 3p, ..., 8p]` #[derive(Copy, Clone, Default)] +#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] struct LookupTable([ProjectivePoint; 8]); +#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] impl From<&ProjectivePoint> for LookupTable { fn from(p: &ProjectivePoint) -> Self { let mut points = [*p; 8]; @@ -69,6 +71,23 @@ impl From<&ProjectivePoint> for LookupTable { LookupTable(points) } } +/// Lookup table containing precomputed values `[0, p, 2p, 3p, ..., 8p]` +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +#[repr(align(1024))] +#[derive(Copy, Clone, Default)] +struct LookupTable([ProjectivePoint; 9]); + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +impl From<&ProjectivePoint> for LookupTable { + fn from(p: &ProjectivePoint) -> Self { + let mut points = [*p; 9]; + points[0] = ProjectivePoint::IDENTITY; + for j in 1..8 { + points[j + 1] = p + &points[j]; + } + LookupTable(points) + } +} impl LookupTable { /// Given -8 <= x <= 8, returns x * p in constant time. @@ -80,6 +99,17 @@ impl LookupTable { let xmask = x >> 7; let xabs = (x + xmask) ^ xmask; + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + // All paged-in memory is constant time to access in RISC Zero. + // LookupTable fits in 864 bytes, which is less than the page size of 1024. Adding the + // repr(align(1024)) attribute above ensure the struct is placed on a page boundary and + // so all accesses within the table will result in the same paging behavior. + let value = self.0[xabs as usize]; + + let neg_mask = Choice::from((xmask & 1) as u8); + return ProjectivePoint::conditional_select(&value, &-value, neg_mask); + } + // Get an array element in constant time let mut t = ProjectivePoint::IDENTITY; for j in 1..9 { diff --git a/k256/src/arithmetic/projective.rs b/k256/src/arithmetic/projective.rs index a7f7bbbe..8fe0cf70 100644 --- a/k256/src/arithmetic/projective.rs +++ b/k256/src/arithmetic/projective.rs @@ -108,6 +108,30 @@ impl ProjectivePoint { let yz_pairs = ((self.y + &self.z) * &(other.y + &other.z)) + &n_yy_zz; let xz_pairs = ((self.x + &self.z) * &(other.x + &other.z)) + &n_xx_zz; + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + // Same as below, but using mul_single instead of repeated addition to get small + // multiplications and normalize_weak is removed. + let bzz3 = zz.mul_single(CURVE_EQUATION_B_SINGLE * 3); + + let yy_m_bzz3 = yy + &bzz3.negate(1); + let yy_p_bzz3 = yy + &bzz3; + + let byz3 = &yz_pairs.mul_single(CURVE_EQUATION_B_SINGLE * 3); + + let xx3 = xx.mul_single(3); + let bxx9 = xx3.mul_single(CURVE_EQUATION_B_SINGLE * 3); + + let new_x = (xy_pairs * &yy_m_bzz3) + &(byz3 * &xz_pairs).negate(1); // m1 + let new_y = (yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs); + let new_z = (yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs); + + return ProjectivePoint { + x: new_x, + y: new_y, + z: new_z, + }; + } + let bzz = zz.mul_single(CURVE_EQUATION_B_SINGLE); let bzz3 = (bzz.double() + &bzz).normalize_weak(); @@ -147,6 +171,29 @@ impl ProjectivePoint { let yz_pairs = (other.y * &self.z) + &self.y; let xz_pairs = (other.x * &self.z) + &self.x; + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + // Same as below, but using mul_single instead of repeated addition to get small + // multiplications and normalize_weak is removed. + let bzz3 = self.z.mul_single(CURVE_EQUATION_B_SINGLE * 3); + + let yy_m_bzz3 = yy + &bzz3.negate(1); + let yy_p_bzz3 = yy + &bzz3; + + let n_byz3 = + &yz_pairs.mul(&FieldElement::from_i64(CURVE_EQUATION_B_SINGLE as i64 * -3)); + + let xx3 = xx.mul_single(3); + let bxx9 = xx3.mul_single(CURVE_EQUATION_B_SINGLE * 3); + + let mut ret = ProjectivePoint { + x: (xy_pairs * &yy_m_bzz3) + &(n_byz3 * &xz_pairs), + y: (yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs), + z: (yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs), + }; + ret.conditional_assign(self, other.is_identity()); + return ret; + } + let bzz = &self.z.mul_single(CURVE_EQUATION_B_SINGLE); let bzz3 = (bzz.double() + bzz).normalize_weak(); @@ -183,6 +230,25 @@ impl ProjectivePoint { let zz = self.z.square(); let xy2 = (self.x * &self.y).double(); + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + // Same as below, but using mul_single instead of repeated addition to get small + // multiplications and normalize_weak is removed. + let bzz3 = zz.mul_single(CURVE_EQUATION_B_SINGLE * 3); + let n_bzz9 = zz.mul(&FieldElement::from_i64(CURVE_EQUATION_B_SINGLE as i64 * -9)); + + let yy_m_bzz9 = yy + &n_bzz9; + let yy_p_bzz3 = yy + &bzz3; + + let yy_zz = yy * &zz; + let t = yy_zz.mul_single(CURVE_EQUATION_B_SINGLE * 24); + + return ProjectivePoint { + x: xy2 * &yy_m_bzz9, + y: ((yy_m_bzz9 * &yy_p_bzz3) + &t), + z: ((yy * &self.y) * &self.z).mul_single(8), + }; + } + let bzz = &zz.mul_single(CURVE_EQUATION_B_SINGLE); let bzz3 = (bzz.double() + bzz).normalize_weak(); let bzz9 = (bzz3.double() + &bzz3).normalize_weak(); diff --git a/k256/src/arithmetic/scalar.rs b/k256/src/arithmetic/scalar.rs index e0f98595..397bbd4d 100644 --- a/k256/src/arithmetic/scalar.rs +++ b/k256/src/arithmetic/scalar.rs @@ -6,6 +6,9 @@ mod wide; pub(crate) use self::wide::WideScalar; +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use elliptic_curve::bigint::risc0; + use crate::{FieldBytes, Secp256k1, WideBytes, ORDER, ORDER_HEX}; use core::{ iter::{Product, Sum}, @@ -109,7 +112,25 @@ impl Scalar { /// Modulo multiplies two scalars. pub fn mul(&self, rhs: &Scalar) -> Scalar { - WideScalar::mul_wide(self, rhs).reduce() + cfg_if::cfg_if! { + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + let result = Self(risc0::modmul_u256_denormalized(&self.0, &rhs.0, &ORDER)); + assert!(bool::from(result.0.ct_lt(&ORDER))); + result + } else { + WideScalar::mul_wide(self, rhs).reduce() + } + } + } + + fn mul_denormalized(&self, rhs: &Scalar) -> Scalar { + cfg_if::cfg_if! { + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + Self(risc0::modmul_u256_denormalized(&self.0, &rhs.0, &ORDER)) + } else { + WideScalar::mul_wide(self, rhs).reduce() + } + } } /// Modulo squares the scalar. @@ -117,6 +138,18 @@ impl Scalar { self.mul(self) } + fn square_denormalized(&self) -> Self { + self.mul_denormalized(self) + } + + #[inline(always)] + fn normalize(&self) -> Self { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + assert!(bool::from(self.0.ct_lt(&ORDER))); + } + self.clone() + } + /// Right shifts the scalar. /// /// Note: not constant-time with respect to the `shift` parameter. @@ -130,49 +163,49 @@ impl Scalar { // https://briansmith.org/ecc-inversion-addition-chains-01#secp256k1_scalar_inversion let x_1 = *self; let x_10 = self.pow2k(1); - let x_11 = x_10.mul(&x_1); - let x_101 = x_10.mul(&x_11); - let x_111 = x_10.mul(&x_101); - let x_1001 = x_10.mul(&x_111); - let x_1011 = x_10.mul(&x_1001); - let x_1101 = x_10.mul(&x_1011); - - let x6 = x_1101.pow2k(2).mul(&x_1011); - let x8 = x6.pow2k(2).mul(&x_11); - let x14 = x8.pow2k(6).mul(&x6); - let x28 = x14.pow2k(14).mul(&x14); - let x56 = x28.pow2k(28).mul(&x28); + let x_11 = x_10.mul_denormalized(&x_1); + let x_101 = x_10.mul_denormalized(&x_11); + let x_111 = x_10.mul_denormalized(&x_101); + let x_1001 = x_10.mul_denormalized(&x_111); + let x_1011 = x_10.mul_denormalized(&x_1001); + let x_1101 = x_10.mul_denormalized(&x_1011); + + let x6 = x_1101.pow2k(2).mul_denormalized(&x_1011); + let x8 = x6.pow2k(2).mul_denormalized(&x_11); + let x14 = x8.pow2k(6).mul_denormalized(&x6); + let x28 = x14.pow2k(14).mul_denormalized(&x14); + let x56 = x28.pow2k(28).mul_denormalized(&x28); #[rustfmt::skip] let res = x56 - .pow2k(56).mul(&x56) - .pow2k(14).mul(&x14) - .pow2k(3).mul(&x_101) - .pow2k(4).mul(&x_111) - .pow2k(4).mul(&x_101) - .pow2k(5).mul(&x_1011) - .pow2k(4).mul(&x_1011) - .pow2k(4).mul(&x_111) - .pow2k(5).mul(&x_111) - .pow2k(6).mul(&x_1101) - .pow2k(4).mul(&x_101) - .pow2k(3).mul(&x_111) - .pow2k(5).mul(&x_1001) - .pow2k(6).mul(&x_101) - .pow2k(10).mul(&x_111) - .pow2k(4).mul(&x_111) - .pow2k(9).mul(&x8) - .pow2k(5).mul(&x_1001) - .pow2k(6).mul(&x_1011) - .pow2k(4).mul(&x_1101) - .pow2k(5).mul(&x_11) - .pow2k(6).mul(&x_1101) - .pow2k(10).mul(&x_1101) - .pow2k(4).mul(&x_1001) - .pow2k(6).mul(&x_1) - .pow2k(8).mul(&x6); - - CtOption::new(res, !self.is_zero()) + .pow2k(56).mul_denormalized(&x56) + .pow2k(14).mul_denormalized(&x14) + .pow2k(3).mul_denormalized(&x_101) + .pow2k(4).mul_denormalized(&x_111) + .pow2k(4).mul_denormalized(&x_101) + .pow2k(5).mul_denormalized(&x_1011) + .pow2k(4).mul_denormalized(&x_1011) + .pow2k(4).mul_denormalized(&x_111) + .pow2k(5).mul_denormalized(&x_111) + .pow2k(6).mul_denormalized(&x_1101) + .pow2k(4).mul_denormalized(&x_101) + .pow2k(3).mul_denormalized(&x_111) + .pow2k(5).mul_denormalized(&x_1001) + .pow2k(6).mul_denormalized(&x_101) + .pow2k(10).mul_denormalized(&x_111) + .pow2k(4).mul_denormalized(&x_111) + .pow2k(9).mul_denormalized(&x8) + .pow2k(5).mul_denormalized(&x_1001) + .pow2k(6).mul_denormalized(&x_1011) + .pow2k(4).mul_denormalized(&x_1101) + .pow2k(5).mul_denormalized(&x_11) + .pow2k(6).mul_denormalized(&x_1101) + .pow2k(10).mul_denormalized(&x_1101) + .pow2k(4).mul_denormalized(&x_1001) + .pow2k(6).mul_denormalized(&x_1) + .pow2k(8).mul_denormalized(&x6); + + CtOption::new(res.normalize(), !self.is_zero()) } /// Returns the scalar modulus as a `BigUint` object. @@ -214,7 +247,7 @@ impl Scalar { fn pow2k(&self, k: usize) -> Self { let mut x = *self; for _j in 0..k { - x = x.square(); + x = x.square_denormalized(); } x } @@ -432,6 +465,11 @@ impl Invert for Scalar { /// sidechannels. #[allow(non_snake_case)] fn invert_vartime(&self) -> CtOption { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + // Constant time algorithm is faster in the RISC Zero zkVM. + return self.invert(); + } + let mut u = *self; let mut v = Self::from_uint_unchecked(Secp256k1::ORDER); let mut A = Self::ONE; @@ -1127,7 +1165,17 @@ mod tests { } } + fn config() -> ProptestConfig { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + ProptestConfig::with_cases(1) + } else { + ProptestConfig::default() + } + } + proptest! { + #![proptest_config(config())] + #[test] fn fuzzy_roundtrip_to_bytes(a in scalar()) { let a_back = Scalar::from_repr(a.to_bytes()).unwrap(); diff --git a/k256/src/lib.rs b/k256/src/lib.rs index f47f4882..23df00a7 100644 --- a/k256/src/lib.rs +++ b/k256/src/lib.rs @@ -6,7 +6,11 @@ html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg" )] #![allow(clippy::needless_range_loop)] -#![forbid(unsafe_code)] +#![cfg_attr(all(target_os = "zkvm", target_arch = "riscv32"), deny(unsafe_code))] +#![cfg_attr( + not(all(target_os = "zkvm", target_arch = "riscv32")), + forbid(unsafe_code) +)] #![warn( clippy::mod_module_files, clippy::unwrap_used,