diff --git a/program-runtime/src/compute_budget.rs b/program-runtime/src/compute_budget.rs index 0adbc64db7628f..a7ba9c80336d45 100644 --- a/program-runtime/src/compute_budget.rs +++ b/program-runtime/src/compute_budget.rs @@ -112,14 +112,14 @@ impl ComputeBudget { secp256k1_recover_cost: 25_000, syscall_base_cost: 100, zk_token_elgamal_op_cost: 25_000, - curve25519_edwards_validate_point_cost: 25_000, // TODO: precisely determine curve25519 costs - curve25519_edwards_add_cost: 25_000, - curve25519_edwards_subtract_cost: 25_000, - curve25519_edwards_multiply_cost: 25_000, - curve25519_ristretto_validate_point_cost: 25_000, - curve25519_ristretto_add_cost: 25_000, - curve25519_ristretto_subtract_cost: 25_000, - curve25519_ristretto_multiply_cost: 25_000, + curve25519_edwards_validate_point_cost: 5_000, // TODO: precisely determine curve25519 costs + curve25519_edwards_add_cost: 5_000, + curve25519_edwards_subtract_cost: 5_000, + curve25519_edwards_multiply_cost: 10_000, + curve25519_ristretto_validate_point_cost: 5_000, + curve25519_ristretto_add_cost: 5_000, + curve25519_ristretto_subtract_cost: 5_000, + curve25519_ristretto_multiply_cost: 10_000, heap_size: None, heap_cost: 8, mem_op_base_cost: 10, diff --git a/zk-token-sdk/src/zk_token_elgamal/convert.rs b/zk-token-sdk/src/zk_token_elgamal/convert.rs index 2997d5de72888d..92e206359148b3 100644 --- a/zk-token-sdk/src/zk_token_elgamal/convert.rs +++ b/zk-token-sdk/src/zk_token_elgamal/convert.rs @@ -51,6 +51,7 @@ mod target_arch { use { super::pod, crate::{ + curve25519::scalar::PodScalar, encryption::{ auth_encryption::AeCiphertext, elgamal::{DecryptHandle, ElGamalCiphertext, ElGamalPubkey}, @@ -74,14 +75,14 @@ mod target_arch { std::convert::TryFrom, }; - impl From for pod::Scalar { + impl From for PodScalar { fn from(scalar: Scalar) -> Self { Self(scalar.to_bytes()) } } - impl From for Scalar { - fn from(pod: pod::Scalar) -> Self { + impl From for Scalar { + fn from(pod: PodScalar) -> Self { Scalar::from_bits(pod.0) } } diff --git a/zk-token-sdk/src/zk_token_elgamal/ops.rs b/zk-token-sdk/src/zk_token_elgamal/ops.rs index cf13ceca8e10a0..cb5c1afe96d54e 100644 --- a/zk-token-sdk/src/zk_token_elgamal/ops.rs +++ b/zk-token-sdk/src/zk_token_elgamal/ops.rs @@ -92,115 +92,130 @@ mod target_arch { #[cfg(target_os = "solana")] #[allow(unused_variables)] mod target_arch { - use {super::*, crate::zk_token_elgamal::pod, bytemuck::Zeroable}; + use crate::{ + curve25519::{ + ristretto::{add_ristretto, multiply_ristretto, subtract_ristretto, PodRistrettoPoint}, + scalar::PodScalar, + }, + zk_token_elgamal::pod, + }; - fn op( - op: u64, - ct_0: &pod::ElGamalCiphertext, - ct_1: &pod::ElGamalCiphertext, - ) -> Option { - let mut ct_result = pod::ElGamalCiphertext::zeroed(); - let result = unsafe { - solana_program::syscalls::sol_zk_token_elgamal_op( - op, - &ct_0.0 as *const u8, - &ct_1.0 as *const u8, - &mut ct_result.0 as *mut u8, - ) - }; - - if result == 0 { - Some(ct_result) - } else { - None - } - } + const SHIFT_BITS: usize = 16; - fn op_with_lo_hi( - op: u64, - ct_0: &pod::ElGamalCiphertext, - ct_1_lo: &pod::ElGamalCiphertext, - ct_1_hi: &pod::ElGamalCiphertext, - ) -> Option { - let mut ct_result = pod::ElGamalCiphertext::zeroed(); - let result = unsafe { - solana_program::syscalls::sol_zk_token_elgamal_op_with_lo_hi( - op, - &ct_0.0 as *const u8, - &ct_1_lo.0 as *const u8, - &ct_1_hi.0 as *const u8, - &mut ct_result.0 as *mut u8, - ) - }; - - if result == 0 { - Some(ct_result) - } else { - None - } - } - - fn op_with_scalar( - op: u64, - ct: &pod::ElGamalCiphertext, - scalar: u64, - ) -> Option { - let mut ct_result = pod::ElGamalCiphertext::zeroed(); - let result = unsafe { - solana_program::syscalls::sol_zk_token_elgamal_op_with_scalar( - op, - &ct.0 as *const u8, - scalar, - &mut ct_result.0 as *mut u8, - ) - }; - - if result == 0 { - Some(ct_result) - } else { - None - } - } + const G: PodRistrettoPoint = PodRistrettoPoint([ + 226, 242, 174, 10, 106, 188, 78, 113, 168, 132, 169, 97, 197, 0, 81, 95, 88, 227, 11, 106, + 165, 130, 221, 141, 182, 166, 89, 69, 224, 141, 45, 118, + ]); pub fn add( - ct_0: &pod::ElGamalCiphertext, - ct_1: &pod::ElGamalCiphertext, + left_ciphertext: &pod::ElGamalCiphertext, + right_ciphertext: &pod::ElGamalCiphertext, ) -> Option { - op(OP_ADD, ct_0, ct_1) + let (left_commitment, left_handle): (pod::PedersenCommitment, pod::DecryptHandle) = + (*left_ciphertext).into(); + let (right_commitment, right_handle): (pod::PedersenCommitment, pod::DecryptHandle) = + (*right_ciphertext).into(); + + let result_commitment: pod::PedersenCommitment = + add_ristretto(&left_commitment.into(), &right_commitment.into())?.into(); + let result_handle: pod::DecryptHandle = + add_ristretto(&left_handle.into(), &right_handle.into())?.into(); + + Some((result_commitment, result_handle).into()) } pub fn add_with_lo_hi( - ct_0: &pod::ElGamalCiphertext, - ct_1_lo: &pod::ElGamalCiphertext, - ct_1_hi: &pod::ElGamalCiphertext, + left_ciphertext: &pod::ElGamalCiphertext, + right_ciphertext_lo: &pod::ElGamalCiphertext, + right_ciphertext_hi: &pod::ElGamalCiphertext, ) -> Option { - op_with_lo_hi(OP_ADD, ct_0, ct_1_lo, ct_1_hi) + let shift_scalar = to_scalar(1_u64 << SHIFT_BITS); + let shifted_right_ciphertext_hi = scalar_ciphertext(&shift_scalar, &right_ciphertext_hi)?; + let combined_right_ciphertext = add(right_ciphertext_lo, &shifted_right_ciphertext_hi)?; + add(left_ciphertext, &combined_right_ciphertext) } pub fn subtract( - ct_0: &pod::ElGamalCiphertext, - ct_1: &pod::ElGamalCiphertext, + left_ciphertext: &pod::ElGamalCiphertext, + right_ciphertext: &pod::ElGamalCiphertext, ) -> Option { - op(OP_SUB, ct_0, ct_1) + let (left_commitment, left_handle): (pod::PedersenCommitment, pod::DecryptHandle) = + (*left_ciphertext).into(); + let (right_commitment, right_handle): (pod::PedersenCommitment, pod::DecryptHandle) = + (*right_ciphertext).into(); + + let result_commitment: pod::PedersenCommitment = + subtract_ristretto(&left_commitment.into(), &right_commitment.into())?.into(); + let result_handle: pod::DecryptHandle = + subtract_ristretto(&left_handle.into(), &right_handle.into())?.into(); + + Some((result_commitment, result_handle).into()) } pub fn subtract_with_lo_hi( - ct_0: &pod::ElGamalCiphertext, - ct_1_lo: &pod::ElGamalCiphertext, - ct_1_hi: &pod::ElGamalCiphertext, + left_ciphertext: &pod::ElGamalCiphertext, + right_ciphertext_lo: &pod::ElGamalCiphertext, + right_ciphertext_hi: &pod::ElGamalCiphertext, ) -> Option { - op_with_lo_hi(OP_SUB, ct_0, ct_1_lo, ct_1_hi) + let shift_scalar = to_scalar(1_u64 << SHIFT_BITS); + let shifted_right_ciphertext_hi = scalar_ciphertext(&shift_scalar, &right_ciphertext_hi)?; + let combined_right_ciphertext = add(right_ciphertext_lo, &shifted_right_ciphertext_hi)?; + subtract(left_ciphertext, &combined_right_ciphertext) } - pub fn add_to(ct: &pod::ElGamalCiphertext, amount: u64) -> Option { - op_with_scalar(OP_ADD, ct, amount) + pub fn add_to( + ciphertext: &pod::ElGamalCiphertext, + amount: u64, + ) -> Option { + let amount_scalar = to_scalar(amount); + let amount_point = multiply_ristretto(&amount_scalar, &G)?; + + let (commitment, handle): (pod::PedersenCommitment, pod::DecryptHandle) = + (*ciphertext).into(); + let commitment_point: PodRistrettoPoint = commitment.into(); + + let result_commitment: pod::PedersenCommitment = + add_ristretto(&commitment_point, &amount_point)?.into(); + Some((result_commitment, handle).into()) } pub fn subtract_from( - ct: &pod::ElGamalCiphertext, + ciphertext: &pod::ElGamalCiphertext, amount: u64, ) -> Option { - op_with_scalar(OP_SUB, ct, amount) + let amount_scalar = to_scalar(amount); + let amount_point = multiply_ristretto(&amount_scalar, &G)?; + + let (commitment, handle): (pod::PedersenCommitment, pod::DecryptHandle) = + (*ciphertext).into(); + let commitment_point: PodRistrettoPoint = commitment.into(); + + let result_commitment: pod::PedersenCommitment = + subtract_ristretto(&commitment_point, &amount_point)?.into(); + Some((result_commitment, handle).into()) + } + + fn to_scalar(amount: u64) -> PodScalar { + let mut bytes = [0u8; 32]; + bytes[..8].copy_from_slice(&amount.to_le_bytes()); + PodScalar(bytes) + } + + fn scalar_ciphertext( + scalar: &PodScalar, + ciphertext: &pod::ElGamalCiphertext, + ) -> Option { + let (commitment, handle): (pod::PedersenCommitment, pod::DecryptHandle) = + (*ciphertext).into(); + + let commitment_point: PodRistrettoPoint = commitment.into(); + let handle_point: PodRistrettoPoint = handle.into(); + + let result_commitment: pod::PedersenCommitment = + multiply_ristretto(scalar, &commitment_point)?.into(); + let result_handle: pod::DecryptHandle = multiply_ristretto(scalar, &handle_point)?.into(); + + Some((result_commitment, result_handle).into()) } } diff --git a/zk-token-sdk/src/zk_token_elgamal/pod.rs b/zk-token-sdk/src/zk_token_elgamal/pod.rs index 8e0784be8ec034..ab6e5d14e5b1de 100644 --- a/zk-token-sdk/src/zk_token_elgamal/pod.rs +++ b/zk-token-sdk/src/zk_token_elgamal/pod.rs @@ -29,10 +29,6 @@ impl From for u64 { } } -#[derive(Clone, Copy, Pod, Zeroable, PartialEq, Eq)] -#[repr(transparent)] -pub struct Scalar(pub [u8; 32]); - #[derive(Clone, Copy, Pod, Zeroable, PartialEq, Eq)] #[repr(transparent)] pub struct CompressedRistretto(pub [u8; 32]);