diff --git a/program-runtime/src/compute_budget.rs b/program-runtime/src/compute_budget.rs index 733632ad3e756d..42cef21966609d 100644 --- a/program-runtime/src/compute_budget.rs +++ b/program-runtime/src/compute_budget.rs @@ -56,7 +56,11 @@ pub struct ComputeBudget { /// Number of compute units consumed to do a syscall without any work pub syscall_base_cost: u64, /// Number of compute units consumed to call zktoken_crypto_op - pub zk_token_elgamal_op_cost: u64, + pub zk_token_elgamal_op_cost: u64, // to be replaced by curve25519 operations + /// Number of compute units consumed to add/sub two edwards points + pub curve25519_edwards_validate_point_cost: u64, + /// Number of compute units consumed to add/sub two ristretto points + pub curve25519_ristretto_validate_point_cost: u64, /// Optional program heap region size, if `None` then loader default pub heap_size: Option, /// Number of compute units per additional 32k heap above the default (~.5 @@ -92,6 +96,8 @@ impl ComputeBudget { secp256k1_recover_cost: 25_000, syscall_base_cost: 100, zk_token_elgamal_op_cost: 25_000, + curve25519_edwards_validate_point_cost: 25_000, // TODO: precisely determine cost + curve25519_ristretto_validate_point_cost: 25_000, heap_size: None, heap_cost: 8, mem_op_base_cost: 10, diff --git a/programs/bpf_loader/src/syscalls.rs b/programs/bpf_loader/src/syscalls.rs index 1a184203da238e..9cfc53732ba2ce 100644 --- a/programs/bpf_loader/src/syscalls.rs +++ b/programs/bpf_loader/src/syscalls.rs @@ -22,12 +22,13 @@ use { entrypoint::{BPF_ALIGN_OF_U128, MAX_PERMITTED_DATA_INCREASE, SUCCESS}, feature_set::{ add_get_processed_sibling_instruction_syscall, blake3_syscall_enabled, - check_physical_overlapping, check_slice_translation_size, disable_fees_sysvar, - do_support_realloc, executables_incur_cpi_data_cost, fixed_memcpy_nonoverlapping_check, - libsecp256k1_0_5_upgrade_enabled, limit_secp256k1_recovery_id, - prevent_calling_precompiles_as_programs, return_data_syscall_enabled, - secp256k1_recover_syscall_enabled, sol_log_data_syscall_enabled, - syscall_saturated_math, update_syscall_base_costs, zk_token_sdk_enabled, + check_physical_overlapping, check_slice_translation_size, curve25519_syscall_enabled, + disable_fees_sysvar, do_support_realloc, executables_incur_cpi_data_cost, + fixed_memcpy_nonoverlapping_check, libsecp256k1_0_5_upgrade_enabled, + limit_secp256k1_recovery_id, prevent_calling_precompiles_as_programs, + return_data_syscall_enabled, secp256k1_recover_syscall_enabled, + sol_log_data_syscall_enabled, syscall_saturated_math, update_syscall_base_costs, + zk_token_sdk_enabled, }, hash::{Hasher, HASH_BYTES}, instruction::{ @@ -137,6 +138,9 @@ pub fn register_syscalls( let zk_token_sdk_enabled = invoke_context .feature_set .is_active(&zk_token_sdk_enabled::id()); + let curve25519_syscall_enabled = invoke_context + .feature_set + .is_active(&curve25519_syscall_enabled::id()); let disable_fees_sysvar = invoke_context .feature_set .is_active(&disable_fees_sysvar::id()); @@ -247,6 +251,17 @@ pub fn register_syscalls( SyscallZkTokenElgamalOpWithScalar::call, )?; + // Elliptic Curve Point Validation + // + // TODO: add group operations and multiscalar multiplications + register_feature_gated_syscall!( + syscall_registry, + curve25519_syscall_enabled, + b"sol_curve25519_point_validation", + SyscallCurvePointValidation::init, + SyscallCurvePointValidation::call, + )?; + // Sysvars syscall_registry.register_syscall_by_name( b"sol_get_clock_sysvar", @@ -1890,6 +1905,80 @@ declare_syscall!( } ); +declare_syscall!( + // Elliptic Curve Point Validation + // + // Currently, only curve25519 Edwards and Ristretto representations are supported + SyscallCurvePointValidation, + fn call( + &mut self, + curve_id: u64, + point_addr: u64, + _arg3: u64, + _arg4: u64, + _arg5: u64, + memory_mapping: &MemoryMapping, + result: &mut Result>, + ) { + use solana_zk_token_sdk::curve25519::{curve_syscall_traits::*, edwards, ristretto}; + + let invoke_context = question_mark!( + self.invoke_context + .try_borrow() + .map_err(|_| SyscallError::InvokeContextBorrowFailed), + result + ); + + match curve_id { + CURVE25519_EDWARDS => { + let cost = invoke_context + .get_compute_budget() + .curve25519_edwards_validate_point_cost; + question_mark!(invoke_context.get_compute_meter().consume(cost), result); + + let point = question_mark!( + translate_type::( + memory_mapping, + point_addr, + invoke_context.get_check_aligned() + ), + result + ); + + if edwards::validate_edwards(point) { + *result = Ok(0); + } else { + *result = Ok(1); + } + } + CURVE25519_RISTRETTO => { + let cost = invoke_context + .get_compute_budget() + .curve25519_ristretto_validate_point_cost; + question_mark!(invoke_context.get_compute_meter().consume(cost), result); + + let point = question_mark!( + translate_type::( + memory_mapping, + point_addr, + invoke_context.get_check_aligned() + ), + result + ); + + if ristretto::validate_ristretto(point) { + *result = Ok(0); + } else { + *result = Ok(1); + } + } + _ => { + *result = Ok(1); + } + }; + } +); + declare_syscall!( // Blake3 SyscallBlake3, diff --git a/sdk/src/feature_set.rs b/sdk/src/feature_set.rs index b1ab0041c2e67e..20af010c103fc0 100644 --- a/sdk/src/feature_set.rs +++ b/sdk/src/feature_set.rs @@ -149,6 +149,11 @@ pub mod zk_token_sdk_enabled { solana_sdk::declare_id!("zk1snxsc6Fh3wsGNbbHAJNHiJoYgF29mMnTSusGx5EJ"); } +// TODO: temporary address for now +pub mod curve25519_syscall_enabled { + solana_sdk::declare_id!("curve25519111111111111111111111111111111111"); +} + pub mod versioned_tx_message_enabled { solana_sdk::declare_id!("3KZZ6Ks1885aGBQ45fwRcPXVBCtzUvxhUTkwKMR41Tca"); } diff --git a/zk-token-sdk/src/curve25519/curve_syscall_traits.rs b/zk-token-sdk/src/curve25519/curve_syscall_traits.rs new file mode 100644 index 00000000000000..05d1eb229a6f03 --- /dev/null +++ b/zk-token-sdk/src/curve25519/curve_syscall_traits.rs @@ -0,0 +1,94 @@ +//! The traits representing the basic elliptic curve operations. +//! +//! These traits are instantiatable by all the commonly used elliptic curves and should help in +//! organizing syscall support for other curves in the future. more complicated or curve-specific +//! functions that are needed in cryptographic applications should be representable by combining +//! the associated functions of these traits. +//! +//! NOTE: This module temporarily lives in zk_token_sdk/curve25519, but it is independent of +//! zk-token-sdk or curve25519. It should be moved to a more general location in the future. +//! + +pub trait PointValidation { + type Point; + + /// Verifies if a byte representation of a curve point lies in the curve. + fn validate_point(&self) -> bool; +} + +pub trait GroupOperations { + type Point; + type Scalar; + + /// Adds two curve points: P_0 + P_1. + fn add(left_point: &Self::Point, right_point: &Self::Point) -> Option; + + /// Subtracts two curve points: P_0 - P_1. + /// + /// NOTE: Altneratively, one can consider replacing this with a `negate` function that maps a + /// curve point P -> -P. Then subtraction can be computed by combining `negate` and `add` + /// syscalls. However, `subtract` is a much more widely used function than `negate`. + fn subtract(left_point: &Self::Point, right_point: &Self::Point) -> Option; + + /// Multiplies a scalar S with a curve point P: S*P + fn multiply(scalar: &Self::Scalar, point: &Self::Point) -> Option; +} + +pub trait MultiScalarMultiplication { + type Scalar; + type Point; + + /// Given a vector of scalsrs S_1, ..., S_N, and curve points P_1, ..., P_N, computes the + /// "inner product": S_1*P_1 + ... + S_N*P_N. + /// + /// NOTE: This operation can be represented by combining `add` and `multiply` functions in + /// `GroupOperations`, but computing it in a single batch is significantly cheaper. Given how + /// commonly used the multiscalar multiplication (MSM) is, it seems to make sense to have a + /// designated trait for MSM support. + /// + /// NOTE: The inputs to the function is a non-fixed size vector and hence, there are some + /// complications in computing the cost for the syscall. The computational costs should only + /// depend on the length of the vectors (and the curve), so it would be ideal to support + /// variable length inputs and compute the syscall cost as is done in eip-197: + /// https://github.com/ethereum/EIPs/blob/master/EIPS/eip-197.md#gas-costs. If not, then we can + /// consider bounding the length of the input and assigning worst-case cost. + fn multiscalar_multiply( + scalars: &[Self::Scalar], + points: &[Self::Point], + ) -> Option; +} + +pub trait Pairing { + type G1Point; + type G2Point; + type GTPoint; + + /// Applies the bilinear pairing operation to two curve points P1, P2 -> e(P1, P2). This trait + /// is only relevant for "pairing-friendly" curves such as BN254 and BLS12-381. + fn pairing_map( + left_point: &Self::G1Point, + right_point: &Self::G2Point, + ) -> Option; +} + +pub const CURVE25519_EDWARDS: u64 = 0; +pub const CURVE25519_RISTRETTO: u64 = 1; + +pub const ADD: u64 = 0; +pub const SUB: u64 = 1; +pub const MUL: u64 = 2; + +// Functions are organized by the curve traits, which can be instantiated by multiple curve +// representations. The functions take in a `curve_id` (e.g. `CURVE25519_EDWARDS`) and should run +// the associated functions in the appropriate trait instantiation. The `curve_op` function +// additionally takes in an `op_id` (e.g. `ADD`) that controls which associated functions to run in +// `GroupOperations`. +extern "C" { + pub fn sol_curve_validate_point(curve_id: u64, point: *const u8, result: *mut u8) -> u64; + + pub fn sol_curve_op(curve_id: u64, op_id: u64, point: *const u8, result: *mut u8) -> u64; + + pub fn sol_curve_multiscalar_mul(curve_id: u64, point: *const u8, result: *mut u8) -> u64; + + pub fn sol_curve_pairing_map(curve_id: u64, point: *const u8, result: *mut u8) -> u64; +} diff --git a/zk-token-sdk/src/curve25519/edwards.rs b/zk-token-sdk/src/curve25519/edwards.rs new file mode 100644 index 00000000000000..a37761ae17970c --- /dev/null +++ b/zk-token-sdk/src/curve25519/edwards.rs @@ -0,0 +1,286 @@ +use bytemuck::{Pod, Zeroable}; +pub use target_arch::*; + +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodEdwardsPoint(pub [u8; 32]); + +#[cfg(not(target_arch = "bpf"))] +mod target_arch { + use { + super::*, + crate::curve25519::{ + curve_syscall_traits::{GroupOperations, MultiScalarMultiplication, PointValidation}, + errors::Curve25519Error, + scalar::PodScalar, + }, + curve25519_dalek::{ + edwards::{CompressedEdwardsY, EdwardsPoint}, + scalar::Scalar, + traits::VartimeMultiscalarMul, + }, + }; + + pub fn validate_edwards(point: &PodEdwardsPoint) -> bool { + point.validate_point() + } + + pub fn add_edwards( + left_point: &PodEdwardsPoint, + right_point: &PodEdwardsPoint, + ) -> Option { + PodEdwardsPoint::add(left_point, right_point) + } + + pub fn subtract_edwards( + left_point: &PodEdwardsPoint, + right_point: &PodEdwardsPoint, + ) -> Option { + PodEdwardsPoint::subtract(left_point, right_point) + } + + pub fn multiply_edwards( + scalar: &PodScalar, + point: &PodEdwardsPoint, + ) -> Option { + PodEdwardsPoint::multiply(scalar, point) + } + + pub fn multiscalar_multiply_edwards( + scalars: &[PodScalar], + points: &[PodEdwardsPoint], + ) -> Option { + PodEdwardsPoint::multiscalar_multiply(scalars, points) + } + + impl From<&EdwardsPoint> for PodEdwardsPoint { + fn from(point: &EdwardsPoint) -> Self { + Self(point.compress().to_bytes()) + } + } + + impl TryFrom<&PodEdwardsPoint> for EdwardsPoint { + type Error = Curve25519Error; + + fn try_from(pod: &PodEdwardsPoint) -> Result { + CompressedEdwardsY::from_slice(&pod.0) + .decompress() + .ok_or(Curve25519Error::PodConversion) + } + } + + impl PointValidation for PodEdwardsPoint { + type Point = Self; + + fn validate_point(&self) -> bool { + CompressedEdwardsY::from_slice(&self.0) + .decompress() + .is_some() + } + } + + impl GroupOperations for PodEdwardsPoint { + type Scalar = PodScalar; + type Point = Self; + + fn add(left_point: &Self, right_point: &Self) -> Option { + let left_point: EdwardsPoint = left_point.try_into().ok()?; + let right_point: EdwardsPoint = right_point.try_into().ok()?; + + let result = &left_point + &right_point; + Some((&result).into()) + } + + fn subtract(left_point: &Self, right_point: &Self) -> Option { + let left_point: EdwardsPoint = left_point.try_into().ok()?; + let right_point: EdwardsPoint = right_point.try_into().ok()?; + + let result = &left_point - &right_point; + Some((&result).into()) + } + + #[cfg(not(target_arch = "bpf"))] + fn multiply(scalar: &PodScalar, point: &Self) -> Option { + let scalar: Scalar = scalar.into(); + let point: EdwardsPoint = point.try_into().ok()?; + + let result = &scalar * &point; + Some((&result).into()) + } + } + + impl MultiScalarMultiplication for PodEdwardsPoint { + type Scalar = PodScalar; + type Point = Self; + + fn multiscalar_multiply(scalars: &[PodScalar], points: &[Self]) -> Option { + EdwardsPoint::optional_multiscalar_mul( + scalars.iter().map(Scalar::from), + points + .iter() + .map(|point| EdwardsPoint::try_from(point).ok()), + ) + .map(|result| PodEdwardsPoint::from(&result)) + } + } +} + +#[cfg(target_arch = "bpf")] +mod target_arch { + use { + super::*, + crate::curve25519::curve_syscall_traits::{sol_curve_validate_point, CURVE25519_EDWARDS}, + }; + + pub fn validate_edwards(point: &PodEdwardsPoint) -> bool { + let mut validate_result = 0u8; + let result = unsafe { + sol_curve_validate_point( + CURVE25519_EDWARDS, + &point.0 as *const u8, + &mut validate_result, + ) + }; + + result == 0 + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::curve25519::scalar::PodScalar, + curve25519_dalek::{ + constants::ED25519_BASEPOINT_POINT as G, edwards::EdwardsPoint, traits::Identity, + }, + }; + + #[test] + fn test_validate_edwards() { + let pod = PodEdwardsPoint(G.compress().to_bytes()); + assert!(validate_edwards(&pod)); + + let invalid_bytes = [ + 120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84, + 60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79, + ]; + + assert!(!validate_edwards(&PodEdwardsPoint(invalid_bytes))); + } + + #[test] + fn test_edwards_add_subtract() { + // identity + let identity = PodEdwardsPoint(EdwardsPoint::identity().compress().to_bytes()); + let point = PodEdwardsPoint([ + 201, 179, 241, 122, 180, 185, 239, 50, 183, 52, 221, 0, 153, 195, 43, 18, 22, 38, 187, + 206, 179, 192, 210, 58, 53, 45, 150, 98, 89, 17, 158, 11, + ]); + + assert_eq!(add_edwards(&point, &identity).unwrap(), point); + assert_eq!(subtract_edwards(&point, &identity).unwrap(), point); + + // associativity + let point_a = PodEdwardsPoint([ + 33, 124, 71, 170, 117, 69, 151, 247, 59, 12, 95, 125, 133, 166, 64, 5, 2, 27, 90, 27, + 200, 167, 59, 164, 52, 54, 52, 200, 29, 13, 34, 213, + ]); + let point_b = PodEdwardsPoint([ + 70, 222, 137, 221, 253, 204, 71, 51, 78, 8, 124, 1, 67, 200, 102, 225, 122, 228, 111, + 183, 129, 14, 131, 210, 212, 95, 109, 246, 55, 10, 159, 91, + ]); + let point_c = PodEdwardsPoint([ + 72, 60, 66, 143, 59, 197, 111, 36, 181, 137, 25, 97, 157, 201, 247, 215, 123, 83, 220, + 250, 154, 150, 180, 192, 196, 28, 215, 137, 34, 247, 39, 129, + ]); + + assert_eq!( + add_edwards(&add_edwards(&point_a, &point_b).unwrap(), &point_c), + add_edwards(&point_a, &add_edwards(&point_b, &point_c).unwrap()), + ); + + assert_eq!( + subtract_edwards(&subtract_edwards(&point_a, &point_b).unwrap(), &point_c), + subtract_edwards(&point_a, &add_edwards(&point_b, &point_c).unwrap()), + ); + + // commutativity + assert_eq!( + add_edwards(&point_a, &point_b).unwrap(), + add_edwards(&point_b, &point_a).unwrap(), + ); + + // subtraction + let point = PodEdwardsPoint(G.compress().to_bytes()); + let point_negated = PodEdwardsPoint((-G).compress().to_bytes()); + + assert_eq!(point_negated, subtract_edwards(&identity, &point).unwrap(),) + } + + #[test] + fn test_edwards_mul() { + let scalar_a = PodScalar([ + 72, 191, 131, 55, 85, 86, 54, 60, 116, 10, 39, 130, 180, 3, 90, 227, 47, 228, 252, 99, + 151, 71, 118, 29, 34, 102, 117, 114, 120, 50, 57, 8, + ]); + let point_x = PodEdwardsPoint([ + 176, 121, 6, 191, 108, 161, 206, 141, 73, 14, 235, 97, 49, 68, 48, 112, 98, 215, 145, + 208, 44, 188, 70, 10, 180, 124, 230, 15, 98, 165, 104, 85, + ]); + let point_y = PodEdwardsPoint([ + 174, 86, 89, 208, 236, 123, 223, 128, 75, 54, 228, 232, 220, 100, 205, 108, 237, 97, + 105, 79, 74, 192, 67, 224, 185, 23, 157, 116, 216, 151, 223, 81, + ]); + + let ax = multiply_edwards(&scalar_a, &point_x).unwrap(); + let bx = multiply_edwards(&scalar_a, &point_y).unwrap(); + + assert_eq!( + add_edwards(&ax, &bx), + multiply_edwards(&scalar_a, &add_edwards(&point_x, &point_y).unwrap()), + ); + } + + #[test] + fn test_multiscalar_multiplication_edwards() { + let scalar = PodScalar([ + 205, 73, 127, 173, 83, 80, 190, 66, 202, 3, 237, 77, 52, 223, 238, 70, 80, 242, 24, 87, + 111, 84, 49, 63, 194, 76, 202, 108, 62, 240, 83, 15, + ]); + let point = PodEdwardsPoint([ + 222, 174, 184, 139, 143, 122, 253, 96, 0, 207, 120, 157, 112, 38, 54, 189, 91, 144, 78, + 111, 111, 122, 140, 183, 65, 250, 191, 133, 6, 42, 212, 93, + ]); + + let basic_product = multiply_edwards(&scalar, &point).unwrap(); + let msm_product = multiscalar_multiply_edwards(&[scalar], &[point]).unwrap(); + + assert_eq!(basic_product, msm_product); + + let scalar_a = PodScalar([ + 246, 154, 34, 110, 31, 185, 50, 1, 252, 194, 163, 56, 211, 18, 101, 192, 57, 225, 207, + 69, 19, 84, 231, 118, 137, 175, 148, 218, 106, 212, 69, 9, + ]); + let scalar_b = PodScalar([ + 27, 58, 126, 136, 253, 178, 176, 245, 246, 55, 15, 202, 35, 183, 66, 199, 134, 187, + 169, 154, 66, 120, 169, 193, 75, 4, 33, 241, 126, 227, 59, 3, + ]); + let point_x = PodEdwardsPoint([ + 252, 31, 230, 46, 173, 95, 144, 148, 158, 157, 63, 10, 8, 68, 58, 176, 142, 192, 168, + 53, 61, 105, 194, 166, 43, 56, 246, 236, 28, 146, 114, 133, + ]); + let point_y = PodEdwardsPoint([ + 10, 111, 8, 236, 97, 189, 124, 69, 89, 176, 222, 39, 199, 253, 111, 11, 248, 186, 128, + 90, 120, 128, 248, 210, 232, 183, 93, 104, 111, 150, 7, 241, + ]); + + let ax = multiply_edwards(&scalar_a, &point_x).unwrap(); + let by = multiply_edwards(&scalar_b, &point_y).unwrap(); + let basic_product = add_edwards(&ax, &by).unwrap(); + let msm_product = + multiscalar_multiply_edwards(&[scalar_a, scalar_b], &[point_x, point_y]).unwrap(); + + assert_eq!(basic_product, msm_product); + } +} diff --git a/zk-token-sdk/src/curve25519/errors.rs b/zk-token-sdk/src/curve25519/errors.rs new file mode 100644 index 00000000000000..2aabc732a39006 --- /dev/null +++ b/zk-token-sdk/src/curve25519/errors.rs @@ -0,0 +1,7 @@ +use thiserror::Error; + +#[derive(Error, Clone, Debug, Eq, PartialEq)] +pub enum Curve25519Error { + #[error("pod conversion failed")] + PodConversion, +} diff --git a/zk-token-sdk/src/curve25519/mod.rs b/zk-token-sdk/src/curve25519/mod.rs new file mode 100644 index 00000000000000..0f1ab2c949e827 --- /dev/null +++ b/zk-token-sdk/src/curve25519/mod.rs @@ -0,0 +1,11 @@ +//! Syscall operations for curve25519 +//! +//! This module lives inside the zk-token-sdk for now, but should move to a general location since +//! it is independent of zk-tokens. + +pub mod curve_syscall_traits; +pub mod edwards; +#[cfg(not(target_arch = "bpf"))] +pub mod errors; +pub mod ristretto; +pub mod scalar; diff --git a/zk-token-sdk/src/curve25519/ristretto.rs b/zk-token-sdk/src/curve25519/ristretto.rs new file mode 100644 index 00000000000000..5542a756b74696 --- /dev/null +++ b/zk-token-sdk/src/curve25519/ristretto.rs @@ -0,0 +1,290 @@ +use bytemuck::{Pod, Zeroable}; +pub use target_arch::*; + +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodRistrettoPoint(pub [u8; 32]); + +#[cfg(not(target_arch = "bpf"))] +mod target_arch { + use { + super::*, + crate::curve25519::{ + curve_syscall_traits::{GroupOperations, MultiScalarMultiplication, PointValidation}, + errors::Curve25519Error, + scalar::PodScalar, + }, + curve25519_dalek::{ + ristretto::{CompressedRistretto, RistrettoPoint}, + scalar::Scalar, + traits::VartimeMultiscalarMul, + }, + }; + + pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool { + point.validate_point() + } + + pub fn add_ristretto( + left_point: &PodRistrettoPoint, + right_point: &PodRistrettoPoint, + ) -> Option { + PodRistrettoPoint::add(left_point, right_point) + } + + pub fn subtract_ristretto( + left_point: &PodRistrettoPoint, + right_point: &PodRistrettoPoint, + ) -> Option { + PodRistrettoPoint::subtract(left_point, right_point) + } + + pub fn multiply_ristretto( + scalar: &PodScalar, + point: &PodRistrettoPoint, + ) -> Option { + PodRistrettoPoint::multiply(scalar, point) + } + + pub fn multiscalar_multiply_ristretto( + scalars: &[PodScalar], + points: &[PodRistrettoPoint], + ) -> Option { + PodRistrettoPoint::multiscalar_multiply(scalars, points) + } + + impl From<&RistrettoPoint> for PodRistrettoPoint { + fn from(point: &RistrettoPoint) -> Self { + Self(point.compress().to_bytes()) + } + } + + impl TryFrom<&PodRistrettoPoint> for RistrettoPoint { + type Error = Curve25519Error; + + fn try_from(pod: &PodRistrettoPoint) -> Result { + CompressedRistretto::from_slice(&pod.0) + .decompress() + .ok_or(Curve25519Error::PodConversion) + } + } + + impl PointValidation for PodRistrettoPoint { + type Point = Self; + + fn validate_point(&self) -> bool { + CompressedRistretto::from_slice(&self.0) + .decompress() + .is_some() + } + } + + impl GroupOperations for PodRistrettoPoint { + type Scalar = PodScalar; + type Point = Self; + + fn add(left_point: &Self, right_point: &Self) -> Option { + let left_point: RistrettoPoint = left_point.try_into().ok()?; + let right_point: RistrettoPoint = right_point.try_into().ok()?; + + let result = &left_point + &right_point; + Some((&result).into()) + } + + fn subtract(left_point: &Self, right_point: &Self) -> Option { + let left_point: RistrettoPoint = left_point.try_into().ok()?; + let right_point: RistrettoPoint = right_point.try_into().ok()?; + + let result = &left_point - &right_point; + Some((&result).into()) + } + + #[cfg(not(target_arch = "bpf"))] + fn multiply(scalar: &PodScalar, point: &Self) -> Option { + let scalar: Scalar = scalar.into(); + let point: RistrettoPoint = point.try_into().ok()?; + + let result = &scalar * &point; + Some((&result).into()) + } + } + + impl MultiScalarMultiplication for PodRistrettoPoint { + type Scalar = PodScalar; + type Point = Self; + + fn multiscalar_multiply(scalars: &[PodScalar], points: &[Self]) -> Option { + RistrettoPoint::optional_multiscalar_mul( + scalars.iter().map(Scalar::from), + points + .iter() + .map(|point| RistrettoPoint::try_from(point).ok()), + ) + .map(|result| PodRistrettoPoint::from(&result)) + } + } +} + +#[cfg(target_arch = "bpf")] +#[allow(unused_variables)] +mod target_arch { + use { + super::*, + crate::curve25519::curve_syscall_traits::{sol_curve_validate_point, CURVE25519_RISTRETTO}, + }; + + pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool { + let mut validate_result = 0u8; + let result = unsafe { + sol_curve_validate_point( + CURVE25519_RISTRETTO, + &point.0 as *const u8, + &mut validate_result, + ) + }; + + result == 0 + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::curve25519::scalar::PodScalar, + curve25519_dalek::{ + constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, traits::Identity, + }, + }; + + #[test] + fn test_validate_ristretto() { + let pod = PodRistrettoPoint(G.compress().to_bytes()); + assert!(validate_ristretto(&pod)); + + let invalid_bytes = [ + 120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84, + 60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79, + ]; + + assert!(!validate_ristretto(&PodRistrettoPoint(invalid_bytes))); + } + + #[test] + fn test_add_subtract_ristretto() { + // identity + let identity = PodRistrettoPoint(RistrettoPoint::identity().compress().to_bytes()); + let point = PodRistrettoPoint([ + 210, 174, 124, 127, 67, 77, 11, 114, 71, 63, 168, 136, 113, 20, 141, 228, 195, 254, + 232, 229, 220, 249, 213, 232, 61, 238, 152, 249, 83, 225, 206, 16, + ]); + + assert_eq!(add_ristretto(&point, &identity).unwrap(), point); + assert_eq!(subtract_ristretto(&point, &identity).unwrap(), point); + + // associativity + let point_a = PodRistrettoPoint([ + 208, 165, 125, 204, 2, 100, 218, 17, 170, 194, 23, 9, 102, 156, 134, 136, 217, 190, 98, + 34, 183, 194, 228, 153, 92, 11, 108, 103, 28, 57, 88, 15, + ]); + let point_b = PodRistrettoPoint([ + 208, 241, 72, 163, 73, 53, 32, 174, 54, 194, 71, 8, 70, 181, 244, 199, 93, 147, 99, + 231, 162, 127, 25, 40, 39, 19, 140, 132, 112, 212, 145, 108, + ]); + let point_c = PodRistrettoPoint([ + 250, 61, 200, 25, 195, 15, 144, 179, 24, 17, 252, 167, 247, 44, 47, 41, 104, 237, 49, + 137, 231, 173, 86, 106, 121, 249, 245, 247, 70, 188, 31, 49, + ]); + + assert_eq!( + add_ristretto(&add_ristretto(&point_a, &point_b).unwrap(), &point_c), + add_ristretto(&point_a, &add_ristretto(&point_b, &point_c).unwrap()), + ); + + assert_eq!( + subtract_ristretto(&subtract_ristretto(&point_a, &point_b).unwrap(), &point_c), + subtract_ristretto(&point_a, &add_ristretto(&point_b, &point_c).unwrap()), + ); + + // commutativity + assert_eq!( + add_ristretto(&point_a, &point_b).unwrap(), + add_ristretto(&point_b, &point_a).unwrap(), + ); + + // subtraction + let point = PodRistrettoPoint(G.compress().to_bytes()); + let point_negated = PodRistrettoPoint((-G).compress().to_bytes()); + + assert_eq!( + point_negated, + subtract_ristretto(&identity, &point).unwrap(), + ) + } + + #[test] + fn test_multiply_ristretto() { + let scalar_x = PodScalar([ + 254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250, + 78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6, + ]); + let point_a = PodRistrettoPoint([ + 68, 80, 232, 181, 241, 77, 60, 81, 154, 51, 173, 35, 98, 234, 149, 37, 1, 39, 191, 201, + 193, 48, 88, 189, 97, 126, 63, 35, 144, 145, 203, 31, + ]); + let point_b = PodRistrettoPoint([ + 200, 236, 1, 12, 244, 130, 226, 214, 28, 125, 43, 163, 222, 234, 81, 213, 201, 156, 31, + 4, 167, 132, 240, 76, 164, 18, 45, 20, 48, 85, 206, 121, + ]); + + let ax = multiply_ristretto(&scalar_x, &point_a).unwrap(); + let bx = multiply_ristretto(&scalar_x, &point_b).unwrap(); + + assert_eq!( + add_ristretto(&ax, &bx), + multiply_ristretto(&scalar_x, &add_ristretto(&point_a, &point_b).unwrap()), + ); + } + + #[test] + fn test_multiscalar_multiplication_ristretto() { + let scalar = PodScalar([ + 123, 108, 109, 66, 154, 185, 88, 122, 178, 43, 17, 154, 201, 223, 31, 238, 59, 215, 71, + 154, 215, 143, 177, 158, 9, 136, 32, 223, 139, 13, 133, 5, + ]); + let point = PodRistrettoPoint([ + 158, 2, 130, 90, 148, 36, 172, 155, 86, 196, 74, 139, 30, 98, 44, 225, 155, 207, 135, + 111, 238, 167, 235, 67, 234, 125, 0, 227, 146, 31, 24, 113, + ]); + + let basic_product = multiply_ristretto(&scalar, &point).unwrap(); + let msm_product = multiscalar_multiply_ristretto(&[scalar], &[point]).unwrap(); + + assert_eq!(basic_product, msm_product); + + let scalar_a = PodScalar([ + 8, 161, 219, 155, 192, 137, 153, 26, 27, 40, 30, 17, 124, 194, 26, 41, 32, 7, 161, 45, + 212, 198, 212, 81, 133, 185, 164, 85, 95, 232, 106, 10, + ]); + let scalar_b = PodScalar([ + 135, 207, 106, 208, 107, 127, 46, 82, 66, 22, 136, 125, 105, 62, 69, 34, 213, 210, 17, + 196, 120, 114, 238, 237, 149, 170, 5, 243, 54, 77, 172, 12, + ]); + let point_x = PodRistrettoPoint([ + 130, 35, 97, 25, 18, 199, 33, 239, 85, 143, 119, 111, 49, 51, 224, 40, 167, 185, 240, + 179, 25, 194, 213, 41, 14, 155, 104, 18, 181, 197, 15, 112, + ]); + let point_y = PodRistrettoPoint([ + 152, 156, 155, 197, 152, 232, 92, 206, 219, 159, 193, 134, 121, 128, 139, 36, 56, 191, + 51, 143, 72, 204, 87, 76, 110, 124, 101, 96, 238, 158, 42, 108, + ]); + + let ax = multiply_ristretto(&scalar_a, &point_x).unwrap(); + let by = multiply_ristretto(&scalar_b, &point_y).unwrap(); + let basic_product = add_ristretto(&ax, &by).unwrap(); + let msm_product = + multiscalar_multiply_ristretto(&[scalar_a, scalar_b], &[point_x, point_y]).unwrap(); + + assert_eq!(basic_product, msm_product); + } +} diff --git a/zk-token-sdk/src/curve25519/scalar.rs b/zk-token-sdk/src/curve25519/scalar.rs new file mode 100644 index 00000000000000..c9a08b05b2ea10 --- /dev/null +++ b/zk-token-sdk/src/curve25519/scalar.rs @@ -0,0 +1,22 @@ +pub use bytemuck::{Pod, Zeroable}; + +#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodScalar(pub [u8; 32]); + +#[cfg(not(target_arch = "bpf"))] +mod target_arch { + use {super::*, curve25519_dalek::scalar::Scalar}; + + impl From<&Scalar> for PodScalar { + fn from(scalar: &Scalar) -> Self { + Self(scalar.to_bytes()) + } + } + + impl From<&PodScalar> for Scalar { + fn from(pod: &PodScalar) -> Self { + Scalar::from_bits(pod.0) + } + } +} diff --git a/zk-token-sdk/src/lib.rs b/zk-token-sdk/src/lib.rs index 6751705bdcfc3d..e81186caa0aeea 100644 --- a/zk-token-sdk/src/lib.rs +++ b/zk-token-sdk/src/lib.rs @@ -32,6 +32,7 @@ mod sigma_proofs; mod transcript; // TODO: re-organize visibility +pub mod curve25519; pub mod instruction; pub mod zk_token_elgamal; pub mod zk_token_proof_instruction;