Skip to content

Commit

Permalink
Add support for RISC Zero acceleration for k256
Browse files Browse the repository at this point in the history
Squashed commit of the following:

commit 5fea17d
Author: Victor Graf <[email protected]>
Date:   Fri Sep 15 11:13:34 2023 -0700

    fix potential overflow error in FieldElement8x32R0::add (#2)

commit 44b1fc2
Author: Victor Graf <[email protected]>
Date:   Tue Jun 13 10:39:13 2023 -0700

    Use RISC Zero BigInt multiplier to accelerate k256 within the zkVM guest (#1)

    Building on risc0/risc0#466, this PR enables the use of the RISC Zero 256-bit modular multiplication accelerator within guest code for k256 arithmetic, including ECDSA.

    A key application, ECDSA verification is accelerated significantly from a little over 5M cycles without acceleration support to about 890k cycles.

    Based on [email protected]
  • Loading branch information
nategraf committed Nov 26, 2024
1 parent 5ac8f5d commit f913b0a
Show file tree
Hide file tree
Showing 13 changed files with 614 additions and 74 deletions.
20 changes: 10 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ members = [

[profile.dev]
opt-level = 2

[patch.crates-io.crypto-bigint]
git = "https://github.com/risc0/RustCrypto-crypto-bigint"
tag = "v0.5.5-risczero.0"

[patch.crates-io.sha2]
git = "https://github.com/risc0/RustCrypto-hashes"
tag = "sha2-v0.10.8-risczero.0"
10 changes: 8 additions & 2 deletions k256/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,21 @@ signature = { version = "2", optional = true }

[dev-dependencies]
blobby = "0.3"
criterion = "0.5"
ecdsa-core = { version = "0.16", package = "ecdsa", default-features = false, features = ["dev"] }
hex-literal = "0.4"
num-bigint = "0.4"
num-traits = "0.2"
proptest = "1.4"
rand_core = { version = "0.6", features = ["getrandom"] }
sha3 = { version = "0.10", default-features = false }

[target.'cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))'.dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
proptest = "1.4"

[target.'cfg(all(target_os = "zkvm", target_arch = "riscv32"))'.dev-dependencies]
proptest = { version = "1.4", default-features = false, features = ["alloc"] }
hex = "0.4"

[features]
default = ["arithmetic", "ecdsa", "pkcs8", "precomputed-tables", "schnorr", "std"]
alloc = ["ecdsa-core?/alloc", "elliptic-curve/alloc"]
Expand Down
35 changes: 32 additions & 3 deletions k256/src/arithmetic/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
use cfg_if::cfg_if;

cfg_if! {
if #[cfg(target_pointer_width = "32")] {
if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] {
mod field_8x32_risc0;
} else if #[cfg(target_pointer_width = "32")] {
mod field_10x26;
} else if #[cfg(target_pointer_width = "64")] {
mod field_5x52;
Expand All @@ -20,7 +22,9 @@ cfg_if! {
use field_impl::FieldElementImpl;
} else {
cfg_if! {
if #[cfg(target_pointer_width = "32")] {
if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] {
use field_8x32_risc0::FieldElement8x32R0 as FieldElementImpl;
} else if #[cfg(target_pointer_width = "32")] {
use field_10x26::FieldElement10x26 as FieldElementImpl;
} else if #[cfg(target_pointer_width = "64")] {
use field_5x52::FieldElement5x52 as FieldElementImpl;
Expand Down Expand Up @@ -104,6 +108,12 @@ impl FieldElement {
Self(FieldElementImpl::from_u64(w))
}

/// Convert a `i64` to a field element.
/// Returned value may be only weakly normalized.
pub const fn from_i64(w: i64) -> Self {
Self(FieldElementImpl::from_i64(w))
}

/// Returns the SEC1 encoding of this field element.
pub fn to_bytes(self) -> FieldBytes {
self.0.normalize().to_bytes()
Expand Down Expand Up @@ -141,7 +151,11 @@ impl FieldElement {
/// Returns 2*self.
/// Doubles the magnitude.
pub fn double(&self) -> Self {
Self(self.0.add(&(self.0)))
if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) {
self.mul_single(2)
} else {
Self(self.0.add(&(self.0)))
}
}

/// Returns self * rhs mod p
Expand Down Expand Up @@ -360,6 +374,12 @@ impl From<u64> for FieldElement {
}
}

impl From<i64> for FieldElement {
fn from(k: i64) -> Self {
Self(FieldElementImpl::from_i64(k))
}
}

impl PartialEq for FieldElement {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&(other.0)).into()
Expand Down Expand Up @@ -761,7 +781,16 @@ mod tests {
}
}

fn config() -> ProptestConfig {
if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) {
ProptestConfig::with_cases(1)
} else {
ProptestConfig::default()
}
}

proptest! {
#![proptest_config(config())]

#[test]
fn fuzzy_add(
Expand Down
14 changes: 11 additions & 3 deletions k256/src/arithmetic/field/field_10x26.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ impl FieldElement10x26 {
Self([w0, w1, w2, 0, 0, 0, 0, 0, 0, 0])
}

pub const fn from_i64(val: i64) -> Self {
// Compute val_abs = |val|
let val_mask = val >> 63;
let val_abs = ((val + val_mask) ^ val_mask) as u64;

Self::from_u64(val_abs).negate(1).normalize_weak()
}

/// Returns the SEC1 encoding of this field element.
pub fn to_bytes(self) -> FieldBytes {
let mut r = FieldBytes::default();
Expand Down Expand Up @@ -126,7 +134,7 @@ impl FieldElement10x26 {
}

/// Adds `x * (2^256 - modulus)`.
fn add_modulus_correction(&self, x: u32) -> Self {
const fn add_modulus_correction(&self, x: u32) -> Self {
// add (2^256 - modulus) * x to the first limb
let t0 = self.0[0] + x * 0x3D1u32;

Expand Down Expand Up @@ -164,7 +172,7 @@ impl FieldElement10x26 {

/// Subtracts the overflow in the last limb and return it with the new field element.
/// Equivalent to subtracting a multiple of 2^256.
fn subtract_modulus_approximation(&self) -> (Self, u32) {
const fn subtract_modulus_approximation(&self) -> (Self, u32) {
let x = self.0[9] >> 22;
let t9 = self.0[9] & 0x03FFFFFu32; // equivalent to self -= 2^256 * x
(
Expand All @@ -187,7 +195,7 @@ impl FieldElement10x26 {
}

/// Brings the field element's magnitude to 1, but does not necessarily normalize it.
pub fn normalize_weak(&self) -> Self {
pub const fn normalize_weak(&self) -> Self {
// Reduce t9 at the start so there will be at most a single carry from the first pass
let (t, x) = self.subtract_modulus_approximation();

Expand Down
14 changes: 11 additions & 3 deletions k256/src/arithmetic/field/field_5x52.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ impl FieldElement5x52 {
Self([w0, w1, 0, 0, 0])
}

pub const fn from_i64(val: i64) -> Self {
// Compute val_abs = |val|
let val_mask = val >> 63;
let val_abs = ((val + val_mask) ^ val_mask) as u64;

Self::from_u64(val_abs).negate(1).normalize_weak()
}

/// Returns the SEC1 encoding of this field element.
pub fn to_bytes(self) -> FieldBytes {
let mut ret = FieldBytes::default();
Expand Down Expand Up @@ -123,7 +131,7 @@ impl FieldElement5x52 {
}

/// Adds `x * (2^256 - modulus)`.
fn add_modulus_correction(&self, x: u64) -> Self {
const fn add_modulus_correction(&self, x: u64) -> Self {
// add (2^256 - modulus) * x to the first limb
let t0 = self.0[0] + x * 0x1000003D1u64;

Expand All @@ -145,7 +153,7 @@ impl FieldElement5x52 {

/// Subtracts the overflow in the last limb and return it with the new field element.
/// Equivalent to subtracting a multiple of 2^256.
fn subtract_modulus_approximation(&self) -> (Self, u64) {
const fn subtract_modulus_approximation(&self) -> (Self, u64) {
let x = self.0[4] >> 48;
let t4 = self.0[4] & 0x0FFFFFFFFFFFFu64; // equivalent to self -= 2^256 * x
(Self([self.0[0], self.0[1], self.0[2], self.0[3], t4]), x)
Expand All @@ -162,7 +170,7 @@ impl FieldElement5x52 {
}

/// Brings the field element's magnitude to 1, but does not necessarily normalize it.
pub fn normalize_weak(&self) -> Self {
pub const fn normalize_weak(&self) -> Self {
// Reduce t4 at the start so there will be at most a single carry from the first pass
let (t, x) = self.subtract_modulus_approximation();

Expand Down
Loading

0 comments on commit f913b0a

Please sign in to comment.