From cd0289f78bd526ab6dac18facf3e27cd17c2e3de Mon Sep 17 00:00:00 2001 From: "sinu.eth" <65924192+sinui0@users.noreply.github.com> Date: Thu, 26 Oct 2023 12:23:59 -0700 Subject: [PATCH] perf: pre-garble key exchange and PRF (#371) * prf pre-garble * fix * update mpz version to 7669232 * fix ValueId dependency * PR feedback * bump mpz to 1ac6779 --- components/aead/Cargo.toml | 4 +- components/aead/src/aes_gcm/mod.rs | 20 +- components/aead/src/lib.rs | 2 +- components/cipher/Cargo.toml | 4 +- components/cipher/block-cipher/src/cipher.rs | 12 +- components/cipher/block-cipher/src/lib.rs | 16 +- .../cipher/stream-cipher/benches/mock.rs | 23 +- components/cipher/stream-cipher/src/config.rs | 2 +- components/cipher/stream-cipher/src/lib.rs | 16 +- .../cipher/stream-cipher/src/stream_cipher.rs | 91 ++- components/integration-tests/Cargo.toml | 6 +- components/integration-tests/tests/test.rs | 36 +- components/key-exchange/Cargo.toml | 8 +- components/key-exchange/src/exchange.rs | 188 +++++- components/key-exchange/src/lib.rs | 11 +- components/point-addition/Cargo.toml | 6 +- components/prf/Cargo.toml | 4 +- components/prf/hmac-sha256/Cargo.toml | 2 + components/prf/hmac-sha256/benches/prf.rs | 81 ++- components/prf/hmac-sha256/src/config.rs | 24 + components/prf/hmac-sha256/src/error.rs | 45 ++ components/prf/hmac-sha256/src/lib.rs | 79 ++- components/prf/hmac-sha256/src/prf.rs | 608 ++++++++++-------- components/tls/tls-mpc/Cargo.toml | 8 +- components/tls/tls-mpc/src/follower.rs | 12 +- components/tls/tls-mpc/src/leader.rs | 11 +- components/tls/tls-mpc/src/setup.rs | 14 +- components/universal-hash/Cargo.toml | 6 +- tlsn/Cargo.toml | 12 +- tlsn/tlsn-core/src/fixtures/mod.rs | 11 +- 30 files changed, 888 insertions(+), 474 deletions(-) create mode 100644 components/prf/hmac-sha256/src/config.rs create mode 100644 components/prf/hmac-sha256/src/error.rs diff --git a/components/aead/Cargo.toml b/components/aead/Cargo.toml index 577a0d3a7c..7eb0c150f4 100644 --- a/components/aead/Cargo.toml +++ b/components/aead/Cargo.toml @@ -25,8 +25,8 @@ tracing = [ tlsn-block-cipher = { path = "../cipher/block-cipher" } tlsn-stream-cipher = { path = "../cipher/stream-cipher" } tlsn-universal-hash = { path = "../universal-hash" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } async-trait = "0.1" diff --git a/components/aead/src/aes_gcm/mod.rs b/components/aead/src/aes_gcm/mod.rs index acf3702205..77280be53e 100644 --- a/components/aead/src/aes_gcm/mod.rs +++ b/components/aead/src/aes_gcm/mod.rs @@ -17,7 +17,7 @@ use futures::{SinkExt, StreamExt, TryFutureExt}; use block_cipher::{Aes128, BlockCipher}; use mpz_core::commit::HashCommit; -use mpz_garble::ValueRef; +use mpz_garble::value::ValueRef; use tlsn_stream_cipher::{Aes128Ctr, StreamCipher}; use tlsn_universal_hash::UniversalHash; use utils_aio::expect_msg_or_err; @@ -365,15 +365,25 @@ mod tests { let leader_thread = leader_vm.new_thread("test_thread").await.unwrap(); let leader_key = leader_thread - .new_public_array_input("key", key.clone()) + .new_public_array_input::("key", key.len()) .unwrap(); let leader_iv = leader_thread - .new_public_array_input("iv", iv.clone()) + .new_public_array_input::("iv", iv.len()) .unwrap(); + leader_thread.assign(&leader_key, key.clone()).unwrap(); + leader_thread.assign(&leader_iv, iv.clone()).unwrap(); + let follower_thread = follower_vm.new_thread("test_thread").await.unwrap(); - let follower_key = follower_thread.new_public_array_input("key", key).unwrap(); - let follower_iv = follower_thread.new_public_array_input("iv", iv).unwrap(); + let follower_key = follower_thread + .new_public_array_input::("key", key.len()) + .unwrap(); + let follower_iv = follower_thread + .new_public_array_input::("iv", iv.len()) + .unwrap(); + + follower_thread.assign(&follower_key, key.clone()).unwrap(); + follower_thread.assign(&follower_iv, iv.clone()).unwrap(); let leader_config = AesGcmConfigBuilder::default() .id("test".to_string()) diff --git a/components/aead/src/lib.rs b/components/aead/src/lib.rs index 08c5c49732..fa52f65dea 100644 --- a/components/aead/src/lib.rs +++ b/components/aead/src/lib.rs @@ -19,7 +19,7 @@ pub use msg::AeadMessage; use async_trait::async_trait; -use mpz_garble::ValueRef; +use mpz_garble::value::ValueRef; use utils_aio::duplex::Duplex; /// A channel for sending and receiving AEAD messages. diff --git a/components/cipher/Cargo.toml b/components/cipher/Cargo.toml index 0f7855c9cd..75d504f93b 100644 --- a/components/cipher/Cargo.toml +++ b/components/cipher/Cargo.toml @@ -4,8 +4,8 @@ resolver = "2" [workspace.dependencies] # tlsn -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } # crypto diff --git a/components/cipher/block-cipher/src/cipher.rs b/components/cipher/block-cipher/src/cipher.rs index e6a2ad74d7..21d2713606 100644 --- a/components/cipher/block-cipher/src/cipher.rs +++ b/components/cipher/block-cipher/src/cipher.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use async_trait::async_trait; -use mpz_garble::{Decode, DecodePrivate, Execute, Memory, ValueRef}; +use mpz_garble::{value::ValueRef, Decode, DecodePrivate, Execute, Memory}; use utils::id::NestedId; use crate::{BlockCipher, BlockCipherCircuit, BlockCipherConfig, BlockCipherError}; @@ -80,11 +80,13 @@ where let msg = self .executor - .new_private_input::(&format!("{}/msg", &id), Some(block))?; + .new_private_input::(&format!("{}/msg", &id))?; let ciphertext = self .executor .new_output::(&format!("{}/ciphertext", &id))?; + self.executor.assign(&msg, block)?; + self.executor .execute(C::circuit(), &[key, msg], &[ciphertext.clone()]) .await?; @@ -115,7 +117,7 @@ where let msg = self .executor - .new_private_input::(&format!("{}/msg", &id), None)?; + .new_blind_input::(&format!("{}/msg", &id))?; let ciphertext = self .executor .new_output::(&format!("{}/ciphertext", &id))?; @@ -155,11 +157,13 @@ where let msg = self .executor - .new_public_input::(&format!("{}/msg", &id), block)?; + .new_public_input::(&format!("{}/msg", &id))?; let ciphertext = self .executor .new_output::(&format!("{}/ciphertext", &id))?; + self.executor.assign(&msg, block)?; + self.executor .execute(C::circuit(), &[key, msg], &[ciphertext.clone()]) .await?; diff --git a/components/cipher/block-cipher/src/lib.rs b/components/cipher/block-cipher/src/lib.rs index 7f2d20c4b3..37b84afaaa 100644 --- a/components/cipher/block-cipher/src/lib.rs +++ b/components/cipher/block-cipher/src/lib.rs @@ -12,7 +12,7 @@ mod config; use async_trait::async_trait; -use mpz_garble::ValueRef; +use mpz_garble::value::ValueRef; pub use crate::{ cipher::MpcBlockCipher, @@ -96,8 +96,11 @@ mod tests { let follower_thread = follower_vm.new_thread("test").await.unwrap(); // Key is public just for this test, typically it is private - let leader_key = leader_thread.new_public_input("key", key).unwrap(); - let follower_key = follower_thread.new_public_input("key", key).unwrap(); + let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap(); + + leader_thread.assign(&leader_key, key).unwrap(); + follower_thread.assign(&follower_key, key).unwrap(); let mut leader = MpcBlockCipher::::new(leader_config, leader_thread); leader.set_key(leader_key); @@ -131,8 +134,11 @@ mod tests { let follower_thread = follower_vm.new_thread("test").await.unwrap(); // Key is public just for this test, typically it is private - let leader_key = leader_thread.new_public_input("key", key).unwrap(); - let follower_key = follower_thread.new_public_input("key", key).unwrap(); + let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap(); + + leader_thread.assign(&leader_key, key).unwrap(); + follower_thread.assign(&follower_key, key).unwrap(); let mut leader = MpcBlockCipher::::new(leader_config, leader_thread); leader.set_key(leader_key); diff --git a/components/cipher/stream-cipher/benches/mock.rs b/components/cipher/stream-cipher/benches/mock.rs index 64be27b076..b2a63333e9 100644 --- a/components/cipher/stream-cipher/benches/mock.rs +++ b/components/cipher/stream-cipher/benches/mock.rs @@ -1,18 +1,24 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory, Vm}; use tlsn_stream_cipher::{Aes128Ctr, MpcStreamCipher, StreamCipher, StreamCipherConfigBuilder}; -async fn bench_stream_cipher_public_encrypt(thread_count: usize, len: usize) { +async fn bench_stream_cipher_encrypt(thread_count: usize, len: usize) { let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; let leader_thread = leader_vm.new_thread("key_config").await.unwrap(); - let leader_key = leader_thread.new_public_input("key", [0u8; 16]).unwrap(); - let leader_iv = leader_thread.new_public_input("iv", [0u8; 4]).unwrap(); + let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap(); + let leader_iv = leader_thread.new_public_input::<[u8; 4]>("iv").unwrap(); + + leader_thread.assign(&leader_key, [0u8; 16]).unwrap(); + leader_thread.assign(&leader_iv, [0u8; 4]).unwrap(); let follower_thread = follower_vm.new_thread("key_config").await.unwrap(); - let follower_key = follower_thread.new_public_input("key", [0u8; 16]).unwrap(); - let follower_iv = follower_thread.new_public_input("iv", [0u8; 4]).unwrap(); + let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_iv = follower_thread.new_public_input::<[u8; 4]>("iv").unwrap(); + + follower_thread.assign(&follower_key, [0u8; 16]).unwrap(); + follower_thread.assign(&follower_iv, [0u8; 4]).unwrap(); let leader_thread_pool = leader_vm .new_thread_pool("mock", thread_count) @@ -60,9 +66,8 @@ fn criterion_benchmark(c: &mut Criterion) { group.throughput(Throughput::Bytes(len as u64)); group.bench_function(format!("{}", len), |b| { - b.to_async(&rt).iter(|| async { - black_box(bench_stream_cipher_public_encrypt(thread_count, len).await) - }) + b.to_async(&rt) + .iter(|| async { bench_stream_cipher_encrypt(thread_count, len).await }) }); drop(group); diff --git a/components/cipher/stream-cipher/src/config.rs b/components/cipher/stream-cipher/src/config.rs index a2cd9a2f5b..73e375a7ed 100644 --- a/components/cipher/stream-cipher/src/config.rs +++ b/components/cipher/stream-cipher/src/config.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use derive_builder::Builder; -use mpz_garble::ValueRef; +use mpz_garble::value::ValueRef; use std::fmt::Debug; use crate::CtrCircuit; diff --git a/components/cipher/stream-cipher/src/lib.rs b/components/cipher/stream-cipher/src/lib.rs index eeba308d60..2ec96966d0 100644 --- a/components/cipher/stream-cipher/src/lib.rs +++ b/components/cipher/stream-cipher/src/lib.rs @@ -24,7 +24,7 @@ pub use config::{StreamCipherConfig, StreamCipherConfigBuilder, StreamCipherConf pub use stream_cipher::MpcStreamCipher; use async_trait::async_trait; -use mpz_garble::ValueRef; +use mpz_garble::value::ValueRef; /// Error that can occur when using a stream cipher #[derive(Debug, thiserror::Error)] @@ -208,12 +208,18 @@ mod tests { let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; let leader_thread = leader_vm.new_thread("key_config").await.unwrap(); - let leader_key = leader_thread.new_public_input("key", key).unwrap(); - let leader_iv = leader_thread.new_public_input("iv", iv).unwrap(); + let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap(); + let leader_iv = leader_thread.new_public_input::<[u8; 4]>("iv").unwrap(); + + leader_thread.assign(&leader_key, key).unwrap(); + leader_thread.assign(&leader_iv, iv).unwrap(); let follower_thread = follower_vm.new_thread("key_config").await.unwrap(); - let follower_key = follower_thread.new_public_input("key", key).unwrap(); - let follower_iv = follower_thread.new_public_input("iv", iv).unwrap(); + let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap(); + let follower_iv = follower_thread.new_public_input::<[u8; 4]>("iv").unwrap(); + + follower_thread.assign(&follower_key, key).unwrap(); + follower_thread.assign(&follower_iv, iv).unwrap(); let leader_thread_pool = leader_vm .new_thread_pool("mock", thread_count) diff --git a/components/cipher/stream-cipher/src/stream_cipher.rs b/components/cipher/stream-cipher/src/stream_cipher.rs index 970d4bd8b1..921826aebf 100644 --- a/components/cipher/stream-cipher/src/stream_cipher.rs +++ b/components/cipher/stream-cipher/src/stream_cipher.rs @@ -2,7 +2,8 @@ use async_trait::async_trait; use std::{collections::HashMap, fmt::Debug, marker::PhantomData}; use mpz_garble::{ - Decode, DecodePrivate, Execute, Memory, Prove, Thread, ThreadPool, ValueRef, Verify, + value::ValueRef, Decode, DecodePrivate, Execute, Memory, MemoryError, Prove, Thread, + ThreadPool, Verify, }; use utils::id::NestedId; @@ -506,26 +507,29 @@ async fn plaintext_proof text .into_iter() .zip(ids) - .map(|(byte, id)| thread.new_public_input::(&id, byte)) - .collect::, _>>()?, + .map(|(byte, id)| { + let byte_ref = thread.new_public_input::(&id)?; + thread.assign(&byte_ref, byte)?; + Ok(byte_ref) + }) + .collect::, MemoryError>>()?, InputText::Private { ids, text } => text .into_iter() .zip(ids) - .map(|(byte, id)| thread.new_private_input::(&id, Some(byte))) - .collect::, _>>()?, + .map(|(byte, id)| { + let byte_ref = thread.new_private_input::(&id)?; + thread.assign(&byte_ref, byte)?; + Ok(byte_ref) + }) + .collect::, MemoryError>>()?, InputText::Blind { ids } => ids .iter() - .map(|id| thread.new_private_input::(id, None)) - .collect::, _>>()?, + .map(|id| thread.new_blind_input::(id)) + .collect::, MemoryError>>()?, }; // Collect into a single array. - let plaintext = ValueRef::Array( - plaintext - .iter() - .flat_map(|value_ref| value_ref.iter().cloned()) - .collect(), - ); + let plaintext = thread.array_from_values(&plaintext)?; let keystream = match keystream_config { InputText::Blind { ids } => ids @@ -540,12 +544,7 @@ async fn plaintext_proof ( @@ -558,12 +557,7 @@ async fn plaintext_proof { @@ -641,40 +635,42 @@ async fn apply_keyblock::NONCE>( &block_id.append_string("explicit_nonce").to_string(), - explicit_nonce, - )?; - let ctr = thread.new_public_input( - &block_id.append_string("ctr").to_string(), - ctr.to_be_bytes(), )?; + let ctr_ref = thread.new_public_input::<[u8; 4]>(&block_id.append_string("ctr").to_string())?; + + thread.assign(&explicit_nonce_ref, explicit_nonce)?; + thread.assign(&ctr_ref, ctr.to_be_bytes())?; // Sets up the input text values. let input_values = match input_text_config { InputText::Public { ids, text } => text .into_iter() .zip(ids) - .map(|(byte, id)| thread.new_public_input::(&id, byte)) - .collect::, _>>()?, + .map(|(byte, id)| { + let byte_ref = thread.new_public_input::(&id)?; + thread.assign(&byte_ref, byte)?; + Ok(byte_ref) + }) + .collect::, MemoryError>>()?, InputText::Private { ids, text } => text .into_iter() .zip(ids) - .map(|(byte, id)| thread.new_private_input::(&id, Some(byte))) - .collect::, _>>()?, + .map(|(byte, id)| { + let byte_ref = thread.new_private_input::(&id)?; + thread.assign(&byte_ref, byte)?; + Ok(byte_ref) + }) + .collect::, MemoryError>>()?, InputText::Blind { ids } => ids .iter() - .map(|id| thread.new_private_input::(id, None)) - .collect::, _>>()?, + .map(|id| thread.new_blind_input::(id)) + .collect::, MemoryError>>()?, }; // Concatenate the values into a single block - let input_block = ValueRef::Array( - input_values - .iter() - .flat_map(|value_ref| value_ref.iter().cloned()) - .collect(), - ); + let input_block = thread.array_from_values(&input_values)?; // Set up the output text values. let output_values = match &output_text_config { @@ -697,18 +693,13 @@ async fn apply_keyblock { server_key: Option, /// The config used for the key exchange protocol config: KeyExchangeConfig, + /// The state of the protocol + state: State, } impl Debug for KeyExchangeCore @@ -92,10 +118,26 @@ where private_key: None, server_key: None, config, + state: State::Initialized, } } async fn compute_pms_shares(&mut self) -> Result<(P256, P256), KeyExchangeError> { + let state = std::mem::replace(&mut self.state, State::Error); + + let State::Setup { + share_a, + share_b, + share_c, + share_d, + pms_1, + pms_2, + eq, + } = state + else { + todo!() + }; + let server_key = match self.config.role() { Role::Leader => { // Send server public key to follower @@ -144,6 +186,16 @@ where .compute_x_coordinate_share(encoded_point) )?; + self.state = State::KeyExchange { + share_a, + share_b, + share_c, + share_d, + pms_1, + pms_2, + eq, + }; + match self.config.role() { Role::Leader => Ok((sender_share, receiver_share)), Role::Follower => Ok((receiver_share, sender_share)), @@ -155,6 +207,21 @@ where pms_share1: P256, pms_share2: P256, ) -> Result { + let state = std::mem::replace(&mut self.state, State::Error); + + let State::KeyExchange { + share_a, + share_b, + share_c, + share_d, + pms_1, + pms_2, + eq, + } = state + else { + todo!() + }; + let pms_share1: [u8; 32] = pms_share1 .to_be_bytes() .try_into() @@ -164,27 +231,16 @@ where .try_into() .expect("pms share is 32 bytes"); - let (share_a, share_b, share_c, share_d) = match self.config.role() { - Role::Leader => (Some(pms_share1), None, Some(pms_share2), None), - Role::Follower => (None, Some(pms_share1), None, Some(pms_share2)), - }; - - let share_a = self - .executor - .new_private_input::<[u8; 32]>("pms/share_a", share_a)?; - let share_b = self - .executor - .new_private_input::<[u8; 32]>("pms/share_b", share_b)?; - let share_c = self - .executor - .new_private_input::<[u8; 32]>("pms/share_c", share_c)?; - let share_d = self - .executor - .new_private_input::<[u8; 32]>("pms/share_d", share_d)?; - - let pms_1 = self.executor.new_output::<[u8; 32]>("pms/1")?; - let pms_2 = self.executor.new_output::<[u8; 32]>("pms/2")?; - let eq = self.executor.new_output::<[u8; 32]>("pms/eq")?; + match self.config.role() { + Role::Leader => { + self.executor.assign(&share_a, pms_share1)?; + self.executor.assign(&share_c, pms_share2)?; + } + Role::Follower => { + self.executor.assign(&share_b, pms_share1)?; + self.executor.assign(&share_d, pms_share2)?; + } + } self.executor .execute( @@ -206,6 +262,8 @@ where return Err(KeyExchangeError::CheckFailed); } + self.state = State::Complete; + // Both parties use pms_1 as the pre-master secret Ok(Pms::new(pms_1)) } @@ -216,7 +274,7 @@ impl KeyExchange for KeyExchangeCore where PS: PointAddition + Send + Debug, PR: PointAddition + Send + Debug, - E: Memory + Execute + Decode + Send, + E: Memory + Load + Execute + Decode + Send, { #[cfg_attr( feature = "tracing", @@ -232,6 +290,88 @@ where self.server_key = Some(server_key); } + async fn setup(&mut self) -> Result { + let state = std::mem::replace(&mut self.state, State::Error); + + let State::Initialized = state else { + return Err(KeyExchangeError::InvalidState( + "expected to be in Initialized state".to_string(), + )); + }; + + let (share_a, share_b, share_c, share_d) = match self.config.role() { + Role::Leader => { + let share_a = self + .executor + .new_private_input::<[u8; 32]>("pms/share_a") + .unwrap(); + let share_b = self + .executor + .new_blind_input::<[u8; 32]>("pms/share_b") + .unwrap(); + let share_c = self + .executor + .new_private_input::<[u8; 32]>("pms/share_c") + .unwrap(); + let share_d = self + .executor + .new_blind_input::<[u8; 32]>("pms/share_d") + .unwrap(); + + (share_a, share_b, share_c, share_d) + } + Role::Follower => { + let share_a = self + .executor + .new_blind_input::<[u8; 32]>("pms/share_a") + .unwrap(); + let share_b = self + .executor + .new_private_input::<[u8; 32]>("pms/share_b") + .unwrap(); + let share_c = self + .executor + .new_blind_input::<[u8; 32]>("pms/share_c") + .unwrap(); + let share_d = self + .executor + .new_private_input::<[u8; 32]>("pms/share_d") + .unwrap(); + + (share_a, share_b, share_c, share_d) + } + }; + + let pms_1 = self.executor.new_output::<[u8; 32]>("pms/1")?; + let pms_2 = self.executor.new_output::<[u8; 32]>("pms/2")?; + let eq = self.executor.new_output::<[u8; 32]>("pms/eq")?; + + self.executor + .load( + build_pms_circuit(), + &[ + share_a.clone(), + share_b.clone(), + share_c.clone(), + share_d.clone(), + ], + &[pms_1.clone(), pms_2.clone(), eq.clone()], + ) + .await?; + + self.state = State::Setup { + share_a, + share_b, + share_c, + share_d, + pms_1: pms_1.clone(), + pms_2, + eq, + }; + + Ok(Pms::new(pms_1)) + } + /// Compute the client's public key /// /// The client's public key in this context is the combined public key (EC point addition) of @@ -466,6 +606,8 @@ mod tests { follower_private_key: SecretKey, server_public_key: PublicKey, ) -> PublicKey { + tokio::try_join!(leader.setup(), follower.setup()).unwrap(); + let (client_public_key, _) = tokio::try_join!( leader.compute_client_key(leader_private_key), follower.compute_client_key(follower_private_key) diff --git a/components/key-exchange/src/lib.rs b/components/key-exchange/src/lib.rs index 094abac928..945cc1959e 100644 --- a/components/key-exchange/src/lib.rs +++ b/components/key-exchange/src/lib.rs @@ -30,7 +30,7 @@ pub use msg::KeyExchangeMessage; pub type KeyExchangeChannel = Box>; use async_trait::async_trait; -use mpz_garble::ValueRef; +use mpz_garble::value::ValueRef; use p256::{PublicKey, SecretKey}; use utils_aio::duplex::Duplex; @@ -59,6 +59,8 @@ pub enum KeyExchangeError { #[error(transparent)] MemoryError(#[from] mpz_garble::MemoryError), #[error(transparent)] + LoadError(#[from] mpz_garble::LoadError), + #[error(transparent)] ExecutionError(#[from] mpz_garble::ExecutionError), #[error(transparent)] DecodeError(#[from] mpz_garble::DecodeError), @@ -72,6 +74,8 @@ pub enum KeyExchangeError { NoServerKey, #[error("Private key not set")] NoPrivateKey, + #[error("invalid state: {0}")] + InvalidState(String), #[error("PMS equality check failed")] CheckFailed, } @@ -85,6 +89,11 @@ pub trait KeyExchange { /// Set the server's public key fn set_server_key(&mut self, server_key: PublicKey); + /// Performs any necessary one-time setup, returning a reference to the PMS. + /// + /// The PMS will not be assigned until `compute_pms` is called. + async fn setup(&mut self) -> Result; + /// Compute the client's public key /// /// The client's public key in this context is the combined public key (EC point addition) of diff --git a/components/point-addition/Cargo.toml b/components/point-addition/Cargo.toml index 80fa9b2c2f..63ea5977ad 100644 --- a/components/point-addition/Cargo.toml +++ b/components/point-addition/Cargo.toml @@ -17,9 +17,9 @@ mock = ["dep:mpz-core"] tracing = ["dep:tracing"] [dependencies] -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922", optional = true } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779", optional = true } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } p256 = { version = "0.13", features = ["arithmetic"] } tracing = { version = "0.1", optional = true } async-trait = "0.1" diff --git a/components/prf/Cargo.toml b/components/prf/Cargo.toml index 5d1898370f..2086d21ca5 100644 --- a/components/prf/Cargo.toml +++ b/components/prf/Cargo.toml @@ -4,8 +4,8 @@ resolver = "2" [workspace.dependencies] # tlsn -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } # async async-trait = "0.1" diff --git a/components/prf/hmac-sha256/Cargo.toml b/components/prf/hmac-sha256/Cargo.toml index 45642ec3f7..77318b225f 100644 --- a/components/prf/hmac-sha256/Cargo.toml +++ b/components/prf/hmac-sha256/Cargo.toml @@ -26,6 +26,8 @@ async-trait.workspace = true futures.workspace = true thiserror.workspace = true tracing = { workspace = true, optional = true } +derive_builder = "0.12" +enum-try-as-inner = "0.1" [dev-dependencies] criterion = { workspace = true, features = ["async_tokio"] } diff --git a/components/prf/hmac-sha256/benches/prf.rs b/components/prf/hmac-sha256/benches/prf.rs index 2d068d68d0..25cafcfaa9 100644 --- a/components/prf/hmac-sha256/benches/prf.rs +++ b/components/prf/hmac-sha256/benches/prf.rs @@ -1,15 +1,57 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use criterion::{criterion_group, criterion_main, Criterion}; -use hmac_sha256::{MpcPrf, Prf}; +use hmac_sha256::{MpcPrf, Prf, PrfConfig, Role}; use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory, Vm}; -async fn bench_prf() { +#[allow(clippy::unit_arg)] +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("prf"); + group.sample_size(10); + let rt = tokio::runtime::Runtime::new().unwrap(); + + group.bench_function("prf_setup", |b| b.to_async(&rt).iter(setup)); + group.bench_function("prf", |b| b.to_async(&rt).iter(prf)); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + +async fn setup() { let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("bench").await; - let mut leader = MpcPrf::new(leader_vm.new_thread("bench_thread").await.unwrap()); - let mut follower = MpcPrf::new(follower_vm.new_thread("bench_thread").await.unwrap()); + let mut leader = MpcPrf::new( + PrfConfig::builder().role(Role::Leader).build().unwrap(), + leader_vm.new_thread("prf/0").await.unwrap(), + leader_vm.new_thread("prf/1").await.unwrap(), + ); + let mut follower = MpcPrf::new( + PrfConfig::builder().role(Role::Follower).build().unwrap(), + follower_vm.new_thread("prf/0").await.unwrap(), + follower_vm.new_thread("prf/1").await.unwrap(), + ); + + let leader_thread = leader_vm.new_thread("setup").await.unwrap(); + let follower_thread = follower_vm.new_thread("setup").await.unwrap(); - futures::try_join!(leader.setup(), follower.setup()).unwrap(); + let leader_pms = leader_thread.new_public_input::<[u8; 32]>("pms").unwrap(); + let follower_pms = follower_thread.new_public_input::<[u8; 32]>("pms").unwrap(); + + futures::try_join!(leader.setup(leader_pms), follower.setup(follower_pms)).unwrap(); +} + +async fn prf() { + let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("bench").await; + + let mut leader = MpcPrf::new( + PrfConfig::builder().role(Role::Leader).build().unwrap(), + leader_vm.new_thread("prf/0").await.unwrap(), + leader_vm.new_thread("prf/1").await.unwrap(), + ); + let mut follower = MpcPrf::new( + PrfConfig::builder().role(Role::Follower).build().unwrap(), + follower_vm.new_thread("prf/0").await.unwrap(), + follower_vm.new_thread("prf/1").await.unwrap(), + ); let pms = [42u8; 32]; @@ -21,12 +63,17 @@ async fn bench_prf() { let leader_thread = leader_vm.new_thread("setup").await.unwrap(); let follower_thread = follower_vm.new_thread("setup").await.unwrap(); - let leader_pms = leader_thread.new_public_input("pms", pms).unwrap(); - let follower_pms = follower_thread.new_public_input("pms", pms).unwrap(); + let leader_pms = leader_thread.new_public_input::<[u8; 32]>("pms").unwrap(); + let follower_pms = follower_thread.new_public_input::<[u8; 32]>("pms").unwrap(); + + leader_thread.assign(&leader_pms, pms).unwrap(); + follower_thread.assign(&follower_pms, pms).unwrap(); + + futures::try_join!(leader.setup(leader_pms), follower.setup(follower_pms)).unwrap(); let (_leader_keys, _follower_keys) = futures::try_join!( - leader.compute_session_keys_private(client_random, server_random, leader_pms), - follower.compute_session_keys_blind(follower_pms) + leader.compute_session_keys_private(client_random, server_random), + follower.compute_session_keys_blind() ) .unwrap(); @@ -44,17 +91,3 @@ async fn bench_prf() { futures::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap(); } - -#[allow(clippy::unit_arg)] -fn criterion_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("prf"); - group.sample_size(10); - let rt = tokio::runtime::Runtime::new().unwrap(); - group.bench_function("prf", |b| { - b.to_async(&rt) - .iter(|| async { black_box(bench_prf().await) }) - }); -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/components/prf/hmac-sha256/src/config.rs b/components/prf/hmac-sha256/src/config.rs new file mode 100644 index 0000000000..c9e96c9cd4 --- /dev/null +++ b/components/prf/hmac-sha256/src/config.rs @@ -0,0 +1,24 @@ +use derive_builder::Builder; + +/// Role of this party in the PRF. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Role { + /// The leader provides the private inputs to the PRF. + Leader, + /// The follower is blind to the inputs to the PRF. + Follower, +} + +/// Configuration for the PRF. +#[derive(Debug, Builder)] +pub struct PrfConfig { + /// The role of this party in the PRF. + pub(crate) role: Role, +} + +impl PrfConfig { + /// Creates a new builder. + pub fn builder() -> PrfConfigBuilder { + PrfConfigBuilder::default() + } +} diff --git a/components/prf/hmac-sha256/src/error.rs b/components/prf/hmac-sha256/src/error.rs new file mode 100644 index 0000000000..1c30ff3b86 --- /dev/null +++ b/components/prf/hmac-sha256/src/error.rs @@ -0,0 +1,45 @@ +use std::error::Error; + +use crate::prf::state::StateError; + +/// Errors that can occur during PRF computation. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum PrfError { + #[error("MPC backend error: {0:?}")] + Mpc(Box), + #[error("role error: {0:?}")] + RoleError(String), + #[error("Invalid state: {0}")] + InvalidState(String), +} + +impl From for PrfError { + fn from(err: StateError) -> Self { + PrfError::InvalidState(err.to_string()) + } +} + +impl From for PrfError { + fn from(err: mpz_garble::MemoryError) -> Self { + PrfError::Mpc(Box::new(err)) + } +} + +impl From for PrfError { + fn from(err: mpz_garble::LoadError) -> Self { + PrfError::Mpc(Box::new(err)) + } +} + +impl From for PrfError { + fn from(err: mpz_garble::ExecutionError) -> Self { + PrfError::Mpc(Box::new(err)) + } +} + +impl From for PrfError { + fn from(err: mpz_garble::DecodeError) -> Self { + PrfError::Mpc(Box::new(err)) + } +} diff --git a/components/prf/hmac-sha256/src/lib.rs b/components/prf/hmac-sha256/src/lib.rs index a34db25807..57e450f89f 100644 --- a/components/prf/hmac-sha256/src/lib.rs +++ b/components/prf/hmac-sha256/src/lib.rs @@ -4,16 +4,23 @@ #![deny(clippy::all)] #![forbid(unsafe_code)] +mod config; +mod error; mod prf; +pub use config::{PrfConfig, PrfConfigBuilder, PrfConfigBuilderError, Role}; +pub use error::PrfError; pub use prf::MpcPrf; use async_trait::async_trait; -use mpz_garble::ValueRef; -use prf::State; +use mpz_garble::value::ValueRef; + +pub(crate) static CF_LABEL: &[u8] = b"client finished"; +pub(crate) static SF_LABEL: &[u8] = b"server finished"; /// Session keys computed by the PRF. +#[derive(Debug)] pub struct SessionKeys { /// Client write key. pub client_write_key: ValueRef, @@ -25,34 +32,21 @@ pub struct SessionKeys { pub server_iv: ValueRef, } -/// Errors that can occur during PRF computation. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum PrfError { - #[error(transparent)] - MemoryError(#[from] mpz_garble::MemoryError), - #[error(transparent)] - ExecutionError(#[from] mpz_garble::ExecutionError), - #[error(transparent)] - DecodeError(#[from] mpz_garble::DecodeError), - #[error("role error: {0:?}")] - RoleError(String), - #[error("Invalid state: {0:?}")] - InvalidState(State), -} - /// PRF trait for computing TLS PRF. #[async_trait] pub trait Prf { /// Performs any necessary one-time setup. - async fn setup(&mut self) -> Result<(), PrfError>; + /// + /// # Arguments + /// + /// * `pms` - The pre-master secret. + async fn setup(&mut self, pms: ValueRef) -> Result<(), PrfError>; /// Computes the session keys using the provided client random, server random and PMS. async fn compute_session_keys_private( &mut self, client_random: [u8; 32], server_random: [u8; 32], - pms: ValueRef, ) -> Result; /// Computes the client finished verify data using the provided handshake hash. @@ -68,7 +62,7 @@ pub trait Prf { ) -> Result<[u8; 12], PrfError>; /// Computes the session keys using randoms provided by the other party. - async fn compute_session_keys_blind(&mut self, pms: ValueRef) -> Result; + async fn compute_session_keys_blind(&mut self) -> Result; /// Computes the client finished verify data using the handshake hash provided by the other party. async fn compute_client_finished_vd_blind(&mut self) -> Result<(), PrfError>; @@ -105,28 +99,43 @@ mod tests { #[ignore = "expensive"] #[tokio::test] async fn test_prf() { - let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; - - let mut leader_test_thread = leader_vm.new_thread("test").await.unwrap(); - let mut follower_test_thread = follower_vm.new_thread("test").await.unwrap(); - - let mut leader = MpcPrf::new(leader_vm.new_thread("prf").await.unwrap()); - let mut follower = MpcPrf::new(follower_vm.new_thread("prf").await.unwrap()); - - futures::try_join!(leader.setup(), follower.setup()).unwrap(); - let pms = [42u8; 32]; let client_random = [69u8; 32]; let server_random: [u8; 32] = [96u8; 32]; let ms = compute_ms(pms, client_random, server_random); + let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await; + + let mut leader_test_thread = leader_vm.new_thread("test").await.unwrap(); + let mut follower_test_thread = follower_vm.new_thread("test").await.unwrap(); + // Setup public PMS for testing - let leader_pms = leader_test_thread.new_public_input("pms", pms).unwrap(); - let follower_pms = follower_test_thread.new_public_input("pms", pms).unwrap(); + let leader_pms = leader_test_thread + .new_public_input::<[u8; 32]>("pms") + .unwrap(); + let follower_pms = follower_test_thread + .new_public_input::<[u8; 32]>("pms") + .unwrap(); + + leader_test_thread.assign(&leader_pms, pms).unwrap(); + follower_test_thread.assign(&follower_pms, pms).unwrap(); + + let mut leader = MpcPrf::new( + PrfConfig::builder().role(Role::Leader).build().unwrap(), + leader_vm.new_thread("prf/0").await.unwrap(), + leader_vm.new_thread("prf/1").await.unwrap(), + ); + let mut follower = MpcPrf::new( + PrfConfig::builder().role(Role::Follower).build().unwrap(), + follower_vm.new_thread("prf/0").await.unwrap(), + follower_vm.new_thread("prf/1").await.unwrap(), + ); + + futures::try_join!(leader.setup(leader_pms), follower.setup(follower_pms)).unwrap(); let (leader_session_keys, follower_session_keys) = futures::try_join!( - leader.compute_session_keys_private(client_random, server_random, leader_pms), - follower.compute_session_keys_blind(follower_pms) + leader.compute_session_keys_private(client_random, server_random), + follower.compute_session_keys_blind() ) .unwrap(); diff --git a/components/prf/hmac-sha256/src/prf.rs b/components/prf/hmac-sha256/src/prf.rs index 7161a0e763..d501dc19c4 100644 --- a/components/prf/hmac-sha256/src/prf.rs +++ b/components/prf/hmac-sha256/src/prf.rs @@ -7,10 +7,15 @@ use async_trait::async_trait; use hmac_sha256_circuits::{build_session_keys, build_verify_data}; use mpz_circuits::Circuit; -use mpz_garble::{Decode, DecodePrivate, Execute, Memory, ValueRef}; +use mpz_garble::{ + config::Visibility, value::ValueRef, Decode, DecodePrivate, Execute, Load, Memory, +}; use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; -use crate::{Prf, PrfError, SessionKeys}; +use crate::{Prf, PrfConfig, PrfError, Role, SessionKeys, CF_LABEL, SF_LABEL}; + +#[cfg(feature = "tracing")] +use tracing::instrument; /// Circuit for computing TLS session keys. static SESSION_KEYS_CIRC: OnceLock> = OnceLock::new(); @@ -19,357 +24,452 @@ static CLIENT_VD_CIRC: OnceLock> = OnceLock::new(); /// Circuit for computing TLS server verify data. static SERVER_VD_CIRC: OnceLock> = OnceLock::new(); +enum Msg { + Cf, + Sf, +} + +#[derive(Debug)] +pub(crate) struct Randoms { + pub(crate) client_random: ValueRef, + pub(crate) server_random: ValueRef, +} + +#[derive(Debug, Clone)] +pub(crate) struct HashState { + pub(crate) ms_outer_hash_state: ValueRef, + pub(crate) ms_inner_hash_state: ValueRef, +} + +#[derive(Debug)] +pub(crate) struct VerifyData { + pub(crate) handshake_hash: ValueRef, + pub(crate) vd: ValueRef, +} + /// MPC PRF for computing TLS HMAC-SHA256 PRF. -pub struct MpcPrf -where - E: Memory + Execute + DecodePrivate, -{ - state: State, - executor: E, +pub struct MpcPrf { + config: PrfConfig, + state: state::State, + thread_0: E, + thread_1: E, } -impl Debug for MpcPrf { +impl Debug for MpcPrf { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MpcPrf") + .field("config", &self.config) .field("state", &self.state) - .field("executor", &"{{ ... }}") .finish() } } -/// Internal state of [MpcPrf]. -#[derive(Debug, Clone)] -#[allow(missing_docs)] -pub enum State { - Initialized, - SessionKeys, - ClientFinished { - ms_outer_hash_state: ValueRef, - ms_inner_hash_state: ValueRef, - }, - ServerFinished { - ms_outer_hash_state: ValueRef, - ms_inner_hash_state: ValueRef, - }, - Complete, - Error, -} - impl MpcPrf where - E: Memory + Execute + Decode + DecodePrivate, + E: Load + Memory + Execute + DecodePrivate + Send, { /// Creates a new instance of the PRF. - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(executor), ret) - )] - pub fn new(executor: E) -> MpcPrf { + pub fn new(config: PrfConfig, thread_0: E, thread_1: E) -> MpcPrf { MpcPrf { - state: State::Initialized, - executor, + config, + state: state::State::Initialized, + thread_0, + thread_1, } } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] - async fn internal_compute_session_keys( + /// Executes a circuit which computes TLS session keys. + async fn execute_session_keys( &mut self, - client_random: Option<[u8; 32]>, - server_random: Option<[u8; 32]>, - pms: ValueRef, + randoms: Option<([u8; 32], [u8; 32])>, ) -> Result { - let state = std::mem::replace(&mut self.state, State::Error); - - let State::SessionKeys = state else { - return Err(PrfError::InvalidState(state)); - }; - - let client_random = self - .executor - .new_private_input("client_random", client_random)?; - let server_random = self - .executor - .new_private_input("server_random", server_random)?; - - let client_write_key = self.executor.new_output::<[u8; 16]>("client_write_key")?; - let server_write_key = self.executor.new_output::<[u8; 16]>("server_write_key")?; - let client_iv = self.executor.new_output::<[u8; 4]>("client_write_iv")?; - let server_iv = self.executor.new_output::<[u8; 4]>("server_write_iv")?; - let ms_outer_hash_state = self - .executor - .new_output::<[u32; 8]>("ms_outer_hash_state")?; - let ms_inner_hash_state = self - .executor - .new_output::<[u32; 8]>("ms_inner_hash_state")?; + let state::SessionKeys { + pms, + randoms: randoms_refs, + hash_state, + keys, + cf_vd, + sf_vd, + } = std::mem::replace(&mut self.state, state::State::Error).try_into_session_keys()?; let circ = SESSION_KEYS_CIRC .get() .expect("session keys circuit is set"); - self.executor + if let Some((client_random, server_random)) = randoms { + self.thread_0 + .assign(&randoms_refs.client_random, client_random)?; + self.thread_0 + .assign(&randoms_refs.server_random, server_random)?; + } + + self.thread_0 .execute( circ.clone(), - &[pms, client_random, server_random], + &[pms, randoms_refs.client_random, randoms_refs.server_random], &[ - client_write_key.clone(), - server_write_key.clone(), - client_iv.clone(), - server_iv.clone(), - ms_outer_hash_state.clone(), - ms_inner_hash_state.clone(), + keys.client_write_key.clone(), + keys.server_write_key.clone(), + keys.client_iv.clone(), + keys.server_iv.clone(), + hash_state.ms_outer_hash_state.clone(), + hash_state.ms_inner_hash_state.clone(), ], ) .await?; - self.state = State::ClientFinished { - ms_outer_hash_state, - ms_inner_hash_state, - }; + self.state = state::State::ClientFinished(state::ClientFinished { + hash_state, + cf_vd, + sf_vd, + }); - Ok(SessionKeys { - client_write_key, - server_write_key, - client_iv, - server_iv, - }) + Ok(keys) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self, label), err) - )] - async fn internal_compute_vd( + async fn execute_cf_vd( &mut self, - label: &str, handshake_hash: Option<[u8; 32]>, - outer_state: ValueRef, - inner_state: ValueRef, ) -> Result, PrfError> { - let handshake_hash_value = self - .executor - .new_private_input(&format!("prf_label/{}/hash", label), handshake_hash)?; - let vd = self - .executor - .new_output::<[u8; 12]>(&format!("prf_label/{}/vd", label))?; - - let circ = match label { - "client finished" => CLIENT_VD_CIRC.get().expect("client vd circuit is set"), - "server finished" => SERVER_VD_CIRC.get().expect("server vd circuit is set"), - _ => unreachable!("invalid label"), - }; + let state::ClientFinished { + hash_state, + cf_vd, + sf_vd, + } = std::mem::replace(&mut self.state, state::State::Error).try_into_client_finished()?; + + let circ = CLIENT_VD_CIRC.get().expect("client vd circuit is set"); - self.executor + if let Some(handshake_hash) = handshake_hash { + self.thread_0 + .assign(&cf_vd.handshake_hash, handshake_hash)?; + } + + self.thread_0 .execute( circ.clone(), - &[outer_state, inner_state, handshake_hash_value], - &[vd.clone()], + &[ + hash_state.ms_outer_hash_state.clone(), + hash_state.ms_inner_hash_state.clone(), + cf_vd.handshake_hash, + ], + &[cf_vd.vd.clone()], ) .await?; - if handshake_hash.is_some() { - let mut outputs = self.executor.decode_private(&[vd]).await?; - + let vd = if handshake_hash.is_some() { + let mut outputs = self.thread_0.decode_private(&[cf_vd.vd]).await?; let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes"); - Ok(Some(vd)) + Some(vd) } else { - self.executor.decode_blind(&[vd]).await?; + self.thread_0.decode_blind(&[cf_vd.vd]).await?; + + None + }; - Ok(None) + self.state = state::State::ServerFinished(state::ServerFinished { hash_state, sf_vd }); + + Ok(vd) + } + + async fn execute_sf_vd( + &mut self, + handshake_hash: Option<[u8; 32]>, + ) -> Result, PrfError> { + let state::ServerFinished { hash_state, sf_vd } = + std::mem::replace(&mut self.state, state::State::Error).try_into_server_finished()?; + + let circ = SERVER_VD_CIRC.get().expect("server vd circuit is set"); + + if let Some(handshake_hash) = handshake_hash { + self.thread_1 + .assign(&sf_vd.handshake_hash, handshake_hash)?; } + + self.thread_1 + .execute( + circ.clone(), + &[ + hash_state.ms_outer_hash_state, + hash_state.ms_inner_hash_state, + sf_vd.handshake_hash, + ], + &[sf_vd.vd.clone()], + ) + .await?; + + let vd = if handshake_hash.is_some() { + let mut outputs = self.thread_1.decode_private(&[sf_vd.vd]).await?; + let vd: [u8; 12] = outputs.remove(0).try_into().expect("vd is 12 bytes"); + + Some(vd) + } else { + self.thread_1.decode_blind(&[sf_vd.vd]).await?; + + None + }; + + self.state = state::State::Complete; + + Ok(vd) } } #[async_trait] impl Prf for MpcPrf where - E: Memory + Execute + Decode + DecodePrivate + Send, + E: Memory + Load + Execute + Decode + DecodePrivate + Send, { - async fn setup(&mut self) -> Result<(), PrfError> { - let state = std::mem::replace(&mut self.state, State::Error); + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + async fn setup(&mut self, pms: ValueRef) -> Result<(), PrfError> { + std::mem::replace(&mut self.state, state::State::Error).try_into_initialized()?; - let State::Initialized = state else { - return Err(PrfError::InvalidState(state)); + let visibility = match self.config.role { + Role::Leader => Visibility::Private, + Role::Follower => Visibility::Blind, }; - // Pre-build all circuits - futures::join!( - async { - if SESSION_KEYS_CIRC.get().is_none() { - _ = SESSION_KEYS_CIRC.set(Backend::spawn(build_session_keys).await); - } - }, - async { - if CLIENT_VD_CIRC.get().is_none() { - _ = CLIENT_VD_CIRC - .set(Backend::spawn(|| build_verify_data(b"client finished")).await); - } - }, - async { - if SERVER_VD_CIRC.get().is_none() { - _ = SERVER_VD_CIRC - .set(Backend::spawn(|| build_verify_data(b"server finished")).await); - } - }, - ); - - self.state = State::SessionKeys; + // Perform pre-computation for all circuits. + let (randoms, hash_state, keys) = + setup_session_keys(&mut self.thread_0, pms.clone(), visibility).await?; + + let (cf_vd, sf_vd) = futures::try_join!( + setup_finished_msg(&mut self.thread_0, Msg::Cf, hash_state.clone(), visibility), + setup_finished_msg(&mut self.thread_1, Msg::Sf, hash_state.clone(), visibility), + )?; + + self.state = state::State::SessionKeys(state::SessionKeys { + pms, + randoms, + hash_state, + keys, + cf_vd, + sf_vd, + }); Ok(()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] async fn compute_session_keys_private( &mut self, client_random: [u8; 32], server_random: [u8; 32], - pms: ValueRef, ) -> Result { - self.internal_compute_session_keys(Some(client_random), Some(server_random), pms) + if self.config.role != Role::Leader { + return Err(PrfError::RoleError( + "only leader can provide inputs".to_string(), + )); + } + + self.execute_session_keys(Some((client_random, server_random))) .await } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] async fn compute_client_finished_vd_private( &mut self, handshake_hash: [u8; 32], ) -> Result<[u8; 12], PrfError> { - let state = std::mem::replace(&mut self.state, State::Error); - - let State::ClientFinished { - ms_outer_hash_state, - ms_inner_hash_state, - } = state - else { - return Err(PrfError::InvalidState(state)); - }; - - let vd = self - .internal_compute_vd( - "client finished", - Some(handshake_hash), - ms_outer_hash_state.clone(), - ms_inner_hash_state.clone(), - ) - .await? - .unwrap(); - - self.state = State::ServerFinished { - ms_outer_hash_state, - ms_inner_hash_state, - }; + if self.config.role != Role::Leader { + return Err(PrfError::RoleError( + "only leader can provide inputs".to_string(), + )); + } - Ok(vd) + self.execute_cf_vd(Some(handshake_hash)) + .await + .map(|hash| hash.expect("vd is decoded")) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] async fn compute_server_finished_vd_private( &mut self, handshake_hash: [u8; 32], ) -> Result<[u8; 12], PrfError> { - let state = std::mem::replace(&mut self.state, State::Error); - - let State::ServerFinished { - ms_outer_hash_state, - ms_inner_hash_state, - } = state - else { - return Err(PrfError::InvalidState(state)); - }; + if self.config.role != Role::Leader { + return Err(PrfError::RoleError( + "only leader can provide inputs".to_string(), + )); + } - let vd = self - .internal_compute_vd( - "server finished", - Some(handshake_hash), - ms_outer_hash_state.clone(), - ms_inner_hash_state.clone(), - ) - .await? - .unwrap(); + self.execute_sf_vd(Some(handshake_hash)) + .await + .map(|hash| hash.expect("vd is decoded")) + } - self.state = State::Complete; + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + async fn compute_session_keys_blind(&mut self) -> Result { + if self.config.role != Role::Follower { + return Err(PrfError::RoleError( + "leader must provide inputs".to_string(), + )); + } - Ok(vd) + self.execute_session_keys(None).await } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] - async fn compute_session_keys_blind(&mut self, pms: ValueRef) -> Result { - self.internal_compute_session_keys(None, None, pms).await + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + async fn compute_client_finished_vd_blind(&mut self) -> Result<(), PrfError> { + if self.config.role != Role::Follower { + return Err(PrfError::RoleError( + "leader must provide inputs".to_string(), + )); + } + + self.execute_cf_vd(None).await.map(|_| ()) } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] - async fn compute_client_finished_vd_blind(&mut self) -> Result<(), PrfError> { - let state = std::mem::replace(&mut self.state, State::Error); + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self), err))] + async fn compute_server_finished_vd_blind(&mut self) -> Result<(), PrfError> { + if self.config.role != Role::Follower { + return Err(PrfError::RoleError( + "leader must provide inputs".to_string(), + )); + } - let State::ClientFinished { - ms_outer_hash_state, - ms_inner_hash_state, - } = state - else { - return Err(PrfError::InvalidState(state)); - }; + self.execute_sf_vd(None).await.map(|_| ()) + } +} - _ = self - .internal_compute_vd( - "client finished", - None, - ms_outer_hash_state.clone(), - ms_inner_hash_state.clone(), - ) - .await?; +pub(crate) mod state { + use super::*; + use enum_try_as_inner::EnumTryAsInner; + + #[derive(Debug, EnumTryAsInner)] + #[derive_err(Debug)] + pub(crate) enum State { + Initialized, + SessionKeys(SessionKeys), + ClientFinished(ClientFinished), + ServerFinished(ServerFinished), + Complete, + Error, + } - self.state = State::ServerFinished { - ms_outer_hash_state, - ms_inner_hash_state, - }; + #[derive(Debug)] + pub(crate) struct SessionKeys { + pub(crate) pms: ValueRef, + pub(crate) randoms: Randoms, + pub(crate) hash_state: HashState, + pub(crate) keys: crate::SessionKeys, + pub(crate) cf_vd: VerifyData, + pub(crate) sf_vd: VerifyData, + } - Ok(()) + #[derive(Debug)] + pub(crate) struct ClientFinished { + pub(crate) hash_state: HashState, + pub(crate) cf_vd: VerifyData, + pub(crate) sf_vd: VerifyData, } - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "debug", skip(self), err) - )] - async fn compute_server_finished_vd_blind(&mut self) -> Result<(), PrfError> { - let state = std::mem::replace(&mut self.state, State::Error); + #[derive(Debug)] + pub(crate) struct ServerFinished { + pub(crate) hash_state: HashState, + pub(crate) sf_vd: VerifyData, + } +} - let State::ServerFinished { - ms_outer_hash_state, - ms_inner_hash_state, - } = state - else { - return Err(PrfError::InvalidState(state)); - }; +async fn setup_session_keys( + thread: &mut T, + pms: ValueRef, + visibility: Visibility, +) -> Result<(Randoms, HashState, SessionKeys), PrfError> { + let client_random = thread.new_input::<[u8; 32]>("client_finished", visibility)?; + let server_random = thread.new_input::<[u8; 32]>("server_finished", visibility)?; + + let client_write_key = thread.new_output::<[u8; 16]>("client_write_key")?; + let server_write_key = thread.new_output::<[u8; 16]>("server_write_key")?; + let client_iv = thread.new_output::<[u8; 4]>("client_write_iv")?; + let server_iv = thread.new_output::<[u8; 4]>("server_write_iv")?; + + let ms_outer_hash_state = thread.new_output::<[u32; 8]>("ms_outer_hash_state")?; + let ms_inner_hash_state = thread.new_output::<[u32; 8]>("ms_inner_hash_state")?; + + if SESSION_KEYS_CIRC.get().is_none() { + _ = SESSION_KEYS_CIRC.set(Backend::spawn(build_session_keys).await); + } - _ = self - .internal_compute_vd( - "server finished", - None, + let circ = SESSION_KEYS_CIRC + .get() + .expect("session keys circuit is set"); + + thread + .load( + circ.clone(), + &[pms, client_random.clone(), server_random.clone()], + &[ + client_write_key.clone(), + server_write_key.clone(), + client_iv.clone(), + server_iv.clone(), ms_outer_hash_state.clone(), ms_inner_hash_state.clone(), - ) - .await?; - - self.state = State::Complete; + ], + ) + .await?; + + Ok(( + Randoms { + client_random, + server_random, + }, + HashState { + ms_outer_hash_state, + ms_inner_hash_state, + }, + SessionKeys { + client_write_key, + server_write_key, + client_iv, + server_iv, + }, + )) +} - Ok(()) +async fn setup_finished_msg( + thread: &mut T, + msg: Msg, + hash_state: HashState, + visibility: Visibility, +) -> Result { + let name = match msg { + Msg::Cf => String::from("client_finished"), + Msg::Sf => String::from("server_finished"), + }; + + let handshake_hash = + thread.new_input::<[u8; 32]>(&format!("{name}/handshake_hash"), visibility)?; + let vd = thread.new_output::<[u8; 12]>(&format!("{name}/vd"))?; + + let circ = match msg { + Msg::Cf => &CLIENT_VD_CIRC, + Msg::Sf => &SERVER_VD_CIRC, + }; + + let label = match msg { + Msg::Cf => CF_LABEL, + Msg::Sf => SF_LABEL, + }; + + if circ.get().is_none() { + _ = circ.set(Backend::spawn(move || build_verify_data(label)).await); } + + let circ = circ.get().expect("session keys circuit is set"); + + thread + .load( + circ.clone(), + &[ + hash_state.ms_outer_hash_state, + hash_state.ms_inner_hash_state, + handshake_hash.clone(), + ], + &[vd.clone()], + ) + .await?; + + Ok(VerifyData { handshake_hash, vd }) } diff --git a/components/tls/tls-mpc/Cargo.toml b/components/tls/tls-mpc/Cargo.toml index 3b1b44daa0..1010c04f74 100644 --- a/components/tls/tls-mpc/Cargo.toml +++ b/components/tls/tls-mpc/Cargo.toml @@ -29,9 +29,9 @@ tracing = [ tlsn-tls-core = { path = "../tls-core", features = ["serde"] } tlsn-tls-backend = { path = "../tls-backend" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } tlsn-block-cipher = { path = "../../cipher/block-cipher" } tlsn-stream-cipher = { path = "../../cipher/stream-cipher" } @@ -56,7 +56,7 @@ tracing = { workspace = true, optional = true } tlsn-tls-client = { path = "../tls-client" } tlsn-tls-client-async = { path = "../tls-client-async" } tls-server-fixture = { path = "../tls-server-fixture" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } uid-mux = { path = "../../uid-mux" } tracing-subscriber.workspace = true diff --git a/components/tls/tls-mpc/src/follower.rs b/components/tls/tls-mpc/src/follower.rs index 3eedfae5a2..0a1b1cdc80 100644 --- a/components/tls/tls-mpc/src/follower.rs +++ b/components/tls/tls-mpc/src/follower.rs @@ -3,7 +3,7 @@ use futures::StreamExt; use hmac_sha256 as prf; use key_exchange as ke; use mpz_core::hash::Hash; -use mpz_garble::ValueRef; +use mpz_garble::value::ValueRef; use p256::elliptic_curve::sec1::ToEncodedPoint; use prf::SessionKeys; @@ -86,7 +86,8 @@ impl MpcTlsFollower { /// Performs any one-time setup operations. pub async fn setup(&mut self) -> Result<(), MpcTlsError> { - self.prf.setup().await?; + let pms = self.ke.setup().await?; + self.prf.setup(pms.into_value()).await?; Ok(()) } @@ -134,7 +135,7 @@ impl MpcTlsFollower { self.handshake_commitment = Some(handshake_commitment); } - let pms = self.ke.compute_pms().await?; + self.ke.compute_pms().await?; // PRF let SessionKeys { @@ -142,10 +143,7 @@ impl MpcTlsFollower { server_write_key, client_iv, server_iv, - } = self - .prf - .compute_session_keys_blind(pms.into_value()) - .await?; + } = self.prf.compute_session_keys_blind().await?; self.encrypter.set_key(client_write_key, client_iv).await?; self.decrypter.set_key(server_write_key, server_iv).await?; diff --git a/components/tls/tls-mpc/src/leader.rs b/components/tls/tls-mpc/src/leader.rs index fc67aa2cc2..dc5d4c983c 100644 --- a/components/tls/tls-mpc/src/leader.rs +++ b/components/tls/tls-mpc/src/leader.rs @@ -85,10 +85,6 @@ impl Default for ConnectionState { impl MpcTlsLeader { /// Create a new leader instance - #[cfg_attr( - feature = "tracing", - tracing::instrument(level = "info", skip(channel, ke, prf, encrypter, decrypter)) - )] pub fn new( config: MpcTlsLeaderConfig, channel: MpcTlsChannel, @@ -110,7 +106,8 @@ impl MpcTlsLeader { /// Performs any one-time setup operations. pub async fn setup(&mut self) -> Result<(), MpcTlsError> { - self.prf.setup().await?; + let pms = self.ke.setup().await?; + self.prf.setup(pms.into_value()).await?; Ok(()) } @@ -249,7 +246,7 @@ impl MpcTlsLeader { self.ke.set_server_key(server_key); - let pms = self.ke.compute_pms().await?; + self.ke.compute_pms().await?; let SessionKeys { client_write_key, @@ -258,7 +255,7 @@ impl MpcTlsLeader { server_iv, } = self .prf - .compute_session_keys_private(client_random.0, server_random.0, pms.into_value()) + .compute_session_keys_private(client_random.0, server_random.0) .await?; self.encrypter.set_key(client_write_key, client_iv).await?; diff --git a/components/tls/tls-mpc/src/setup.rs b/components/tls/tls-mpc/src/setup.rs index 6b6d8f2912..d66748052d 100644 --- a/components/tls/tls-mpc/src/setup.rs +++ b/components/tls/tls-mpc/src/setup.rs @@ -1,6 +1,6 @@ use hmac_sha256 as prf; use key_exchange as ke; -use mpz_garble::{Decode, DecodePrivate, Execute, Prove, Verify, Vm}; +use mpz_garble::{Decode, DecodePrivate, Execute, Load, Prove, Verify, Vm}; use mpz_share_conversion as ff; use point_addition as pa; use tlsn_stream_cipher as stream_cipher; @@ -44,7 +44,7 @@ pub async fn setup_components< MpcTlsError, > where - ::Thread: Execute + Decode + DecodePrivate + Prove + Verify + Send + Sync, + ::Thread: Execute + Load + Decode + DecodePrivate + Prove + Verify + Send + Sync, { // Set up channels let (mut mux_0, mut mux_1) = (mux.clone(), mux.clone()); @@ -81,7 +81,15 @@ where ); // PRF - let prf = prf::MpcPrf::new(vm.new_thread("prf").await?); + let prf_role = match role { + TlsRole::Leader => prf::Role::Leader, + TlsRole::Follower => prf::Role::Follower, + }; + let prf = prf::MpcPrf::new( + prf::PrfConfig::builder().role(prf_role).build().unwrap(), + vm.new_thread("prf/0").await?, + vm.new_thread("prf/1").await?, + ); // Encrypter let block_cipher = block_cipher::MpcBlockCipher::::new( diff --git a/components/universal-hash/Cargo.toml b/components/universal-hash/Cargo.toml index c450d0790a..b3c48fba13 100644 --- a/components/universal-hash/Cargo.toml +++ b/components/universal-hash/Cargo.toml @@ -16,9 +16,9 @@ mock = [] [dependencies] # tlsn -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } # async async-trait = "0.1" diff --git a/tlsn/Cargo.toml b/tlsn/Cargo.toml index 710847896b..666175fbd9 100644 --- a/tlsn/Cargo.toml +++ b/tlsn/Cargo.toml @@ -28,12 +28,12 @@ uid-mux = { path = "../components/uid-mux" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1f2c922" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" } futures = "0.3" diff --git a/tlsn/tlsn-core/src/fixtures/mod.rs b/tlsn/tlsn-core/src/fixtures/mod.rs index 59ce6e26cf..90aa7aa1bf 100644 --- a/tlsn/tlsn-core/src/fixtures/mod.rs +++ b/tlsn/tlsn-core/src/fixtures/mod.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use hex::FromHex; use mpz_circuits::types::ValueType; -use mpz_core::{commit::HashCommit, hash::Hash, value::ValueId}; +use mpz_core::{commit::HashCommit, hash::Hash, utils::blake3}; use mpz_garble_core::{ChaChaEncoder, Encoder}; use tls_core::{ cert::ServerCertDetails, @@ -29,6 +29,11 @@ use crate::{ EncodingProvider, }; +fn value_id(id: &str) -> u64 { + let hash = blake3(id.as_bytes()); + u64::from_be_bytes(hash[..8].try_into().unwrap()) +} + /// Returns a session header fixture using the given transcript lengths and merkle root. /// /// # Arguments @@ -52,12 +57,12 @@ pub fn encoding_provider(transcript_tx: &[u8], transcript_rx: &[u8]) -> Encoding let mut active_encodings = HashMap::new(); for (idx, byte) in transcript_tx.iter().enumerate() { let id = format!("tx/{idx}"); - let enc = encoder.encode_by_type(ValueId::new(&id).to_u64(), &ValueType::U8); + let enc = encoder.encode_by_type(value_id(&id), &ValueType::U8); active_encodings.insert(id, enc.select(*byte).unwrap()); } for (idx, byte) in transcript_rx.iter().enumerate() { let id = format!("rx/{idx}"); - let enc = encoder.encode_by_type(ValueId::new(&id).to_u64(), &ValueType::U8); + let enc = encoder.encode_by_type(value_id(&id), &ValueType::U8); active_encodings.insert(id, enc.select(*byte).unwrap()); }