From a3bc406b55f238f5e284cb06baa84b60ae81e641 Mon Sep 17 00:00:00 2001 From: samkim-crypto Date: Tue, 26 Mar 2024 19:54:06 +0900 Subject: [PATCH] [zk-token-sdk] Remove `std::thread` from wasm target (#379) --- zk-token-sdk/src/encryption/discrete_log.rs | 76 +++++++++++++-------- zk-token-sdk/src/encryption/elgamal.rs | 1 + 2 files changed, 48 insertions(+), 29 deletions(-) diff --git a/zk-token-sdk/src/encryption/discrete_log.rs b/zk-token-sdk/src/encryption/discrete_log.rs index b3e02a74625b61..5ffc1c206a6f68 100644 --- a/zk-token-sdk/src/encryption/discrete_log.rs +++ b/zk-token-sdk/src/encryption/discrete_log.rs @@ -16,6 +16,8 @@ #![cfg(not(target_os = "solana"))] +#[cfg(not(target_arch = "wasm32"))] +use std::thread; use { crate::RISTRETTO_POINT_LEN, curve25519_dalek::{ @@ -26,7 +28,7 @@ use { }, itertools::Itertools, serde::{Deserialize, Serialize}, - std::{collections::HashMap, thread}, + std::collections::HashMap, thiserror::Error, }; @@ -34,6 +36,7 @@ const TWO16: u64 = 65536; // 2^16 const TWO17: u64 = 131072; // 2^17 /// Maximum number of threads permitted for discrete log computation +#[cfg(not(target_arch = "wasm32"))] const MAX_THREAD: usize = 65536; #[derive(Error, Clone, Debug, Eq, PartialEq)] @@ -112,6 +115,7 @@ impl DiscreteLog { } /// Adjusts number of threads in a discrete log instance. + #[cfg(not(target_arch = "wasm32"))] pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> { // number of threads must be a positive power-of-two integer if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > MAX_THREAD { @@ -141,35 +145,48 @@ impl DiscreteLog { /// Solves the discrete log problem under the assumption that the solution /// is a positive 32-bit number. pub fn decode_u32(self) -> Option { - let mut starting_point = self.target; - let handles = (0..self.num_threads) - .map(|i| { - let ristretto_iterator = RistrettoIterator::new( - (starting_point, i as u64), - (-(&self.step_point), self.num_threads as u64), - ); - - let handle = thread::spawn(move || { - Self::decode_range( - ristretto_iterator, - self.range_bound, - self.compression_batch_size, - ) - }); - - starting_point -= G; - handle - }) - .collect::>(); - - let mut solution = None; - for handle in handles { - let discrete_log = handle.join().unwrap(); - if discrete_log.is_some() { - solution = discrete_log; - } + #[cfg(not(target_arch = "wasm32"))] + { + let mut starting_point = self.target; + let handles = (0..self.num_threads) + .map(|i| { + let ristretto_iterator = RistrettoIterator::new( + (starting_point, i as u64), + (-(&self.step_point), self.num_threads as u64), + ); + + let handle = thread::spawn(move || { + Self::decode_range( + ristretto_iterator, + self.range_bound, + self.compression_batch_size, + ) + }); + + starting_point -= G; + handle + }) + .collect::>(); + + handles + .into_iter() + .map_while(|h| h.join().ok()) + .find(|x| x.is_some()) + .flatten() + } + #[cfg(target_arch = "wasm32")] + { + let ristretto_iterator = RistrettoIterator::new( + (self.target, 0_u64), + (-(&self.step_point), self.num_threads as u64), + ); + + Self::decode_range( + ristretto_iterator, + self.range_bound, + self.compression_batch_size, + ) } - solution } fn decode_range( @@ -274,6 +291,7 @@ mod tests { println!("single thread discrete log computation secs: {computation_secs:?} sec"); } + #[cfg(not(target_arch = "wasm32"))] #[test] fn test_decode_correctness_threaded() { // general case diff --git a/zk-token-sdk/src/encryption/elgamal.rs b/zk-token-sdk/src/encryption/elgamal.rs index 5b4e2dba872530..e499106e1e58b2 100644 --- a/zk-token-sdk/src/encryption/elgamal.rs +++ b/zk-token-sdk/src/encryption/elgamal.rs @@ -791,6 +791,7 @@ mod tests { assert_eq!(57_u64, secret.decrypt_u32(&ciphertext).unwrap()); } + #[cfg(not(target_arch = "wasm32"))] #[test] fn test_encrypt_decrypt_correctness_multithreaded() { let ElGamalKeypair { public, secret } = ElGamalKeypair::new_rand();