From 28c574f173e5e88f6893bfb732a4c164b98eb5f1 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 22 Nov 2024 10:15:08 -0800 Subject: [PATCH] Make finalizer protocol generic for MPC --- ipa-core/src/protocol/basics/shard_fin.rs | 400 +++++++++++++++--- .../src/protocol/context/dzkp_malicious.rs | 32 +- ipa-core/src/protocol/hybrid/step.rs | 2 + 3 files changed, 366 insertions(+), 68 deletions(-) diff --git a/ipa-core/src/protocol/basics/shard_fin.rs b/ipa-core/src/protocol/basics/shard_fin.rs index 4c44d5d4c..45aae6303 100644 --- a/ipa-core/src/protocol/basics/shard_fin.rs +++ b/ipa-core/src/protocol/basics/shard_fin.rs @@ -1,20 +1,139 @@ -use std::ops::Add; +use std::{future::Future, marker::PhantomData, ops::Add}; -use futures::{future, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use generic_array::ArrayLength; +use ipa_step::Step; use crate::{ - ff::{boolean_array::BooleanArray, Serializable}, + error::{Error, LengthError}, + ff::{boolean::Boolean, boolean_array::BooleanArray, Serializable}, helpers::{Message, TotalRecords}, - protocol::{context::ShardedContext, RecordId}, - secret_sharing::replicated::semi_honest::AdditiveShare, - seq_join::seq_join, + protocol::{ + boolean::step::EightBitStep, + context::{ + dzkp_validator::DZKPValidator, DZKPContext, DZKPUpgradedMaliciousContext, + DZKPUpgradedSemiHonestContext, MaliciousProtocolSteps, ShardedContext, + ShardedMaliciousContext, ShardedSemiHonestContext, UpgradableContext, + }, + ipa_prf::boolean_ops::addition_sequential::integer_sat_add, + BooleanProtocols, RecordId, + }, + secret_sharing::{ + replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd, TransposeFrom, + }, + seq_join::{assert_send, seq_join}, + sharding::Sharded, }; +/// This is just a step trait with thread safety bounds. Those make sense +/// on a generic [`Step`] trait so maybe we can change it later. +trait FinalizerStep: Step + Sync + 'static {} +impl FinalizerStep for S {} + +/// Context to finalize sharded MPC executions. The finalization protocol +/// is very simple - all shards just send data to the leader that performs +/// some sort of aggregation. The aggregation logic can be simple: shuffle +/// just requires bundling all rows together. Or it can be complicated and +/// needing an MPC circuit. Histogram aggregation will need an addition in +/// MPC to properly assemble the final result. +/// +/// This trait provides a generic way to write protocols that require +/// shard aggregation step. It only supports ZKP. +trait FinalizerContext: ShardedContext + UpgradableContext { + type FinalizingContext: ShardedContext + DZKPContext; + type Step; + + fn finalize>( + self, + step: Self::Step, + inputs: R, + ) -> impl Future> + Send; +} + +/// Trait for results obtained by running sharded MPC protocols. Many shards run MPC +/// executions in parallel and at the end they need to synchronize and agree on final results. +/// The details of that agreement is up to each individual protocol, however the interactions +/// between them is the same: +/// - Once shard completes its computation, it sends its results to the leader shard +/// - When leader completes the MPC part of computation, it blocks awaiting results from all +/// other shards that participate. When all results are received, the leader merges them +/// together to obtain the final result that is later shared with the report collector. +/// +/// Based on that interaction, shard final results need to be mergeable and communicated +/// over shard channels as well as they need to have a default value. +trait ShardAssembledResult: Send + Sized { + /// Type of messages used to communicate the entire result over the network. Often, shards + /// hold a collection of shares, so this type will indicate the share type. + type SingleMessage: Message; + + /// Return empty value that will be used by all shards except the leader, + /// to set their result of execution. + fn empty() -> Self; + + /// Merges two assembled results together. + fn merge<'a>( + &'a mut self, + ctx: C, + record_id: RecordId, + other: Self, + ) -> impl Future> + Send + 'a + where + C: 'a; + + /// Converts this into a form suitable to be sent over the wire + fn into_messages(self) -> impl ExactSizeIterator + Send; + + /// Reverse conversion from the stream of messages back to this type. + fn from_message_stream(stream: S) -> impl Future> + Send + where + S: Stream> + Send; +} + +impl<'a> FinalizerContext for ShardedMaliciousContext<'a> { + type FinalizingContext = DZKPUpgradedMaliciousContext<'a, Sharded>; + type Step = MaliciousProtocolSteps<'a, S>; + + #[allow(clippy::manual_async_fn)] // good luck with `Send` is not general enough, clippy + fn finalize>( + self, + step: Self::Step, + inputs: R, + ) -> impl Future> + Send { + async move { + // We use a single batch here because the whole assumption of this protocol to be + // small and simple. If it is not the case, it requires adjustments. + let validator = self.dzkp_validator(step, usize::MAX); + let ctx = validator.context(); + let r = semi_honest(ctx, inputs).await?; + validator.validate().await?; + + Ok(r) + } + } +} + +impl<'a> FinalizerContext for ShardedSemiHonestContext<'a> { + type FinalizingContext = DZKPUpgradedSemiHonestContext<'a, Sharded>; + type Step = MaliciousProtocolSteps<'a, S>; + + fn finalize>( + self, + step: Self::Step, + inputs: R, + ) -> impl Future> + Send { + let v = self.dzkp_validator(step, usize::MAX); + semi_honest(v.context(), inputs) + } +} + /// This finalizes the MPC execution in sharded context by forwarding the computation results /// to the leader shard from all follower shards. Leader shard aggregates them and returns it, -/// followers set their result to be empty. -async fn finalize( +/// followers set their result to be empty. This implementation only supports semi-honest +/// security and shouldn't be used directly. Instead [`FinalizerContext`] provides a means +/// to finalize the execution. +/// This is a generic implementation that works for both malicious and semi-honest. For the +/// former, it requires validation phase to be performed after. +async fn semi_honest>( ctx: C, inputs: R, ) -> Result { @@ -25,9 +144,18 @@ async fn finalize( let stream = ctx.shard_recv_channel::(shard); R::from_message_stream(stream) }) - .try_fold(inputs, |mut acc, r| { - acc.merge(r); - future::ok(acc) + .enumerate() + .map(|(i, va)| va.map(|v| (v, i))) + .try_fold(inputs, |mut acc, (r, record_id)| { + // we merge elements into a single accumulator one by one, thus + // record count is indeterminate. A better strategy would be to do + // tree-based merge + println!("we are in {:?}", ctx.gate()); + let ctx = ctx.set_total_records(TotalRecords::Indeterminate); + async move { + assert_send(acc.merge(ctx, RecordId::from(record_id), r)).await?; + Ok(acc) + } }) .await?; @@ -55,98 +183,204 @@ async fn finalize( } } -/// Trait for results obtained by running sharded MPC protocols. Many shards run MPC -/// executions in parallel and at the end they need to synchronize and agree on final results. -/// The details of that agreement is up to each individual protocol, however the interactions -/// between them is the same: -/// - Once shard completes its computation, it sends its results to the leader shard -/// - When leader completes the MPC part of computation, it blocks awaiting results from all -/// other shards that participate. When all results are received, the leader merges them -/// together to obtain the final result that is later shared with the report collector. -/// -/// Based on that interaction, shard final results need to be mergeable and communicated -/// over shard channels as well as they need to have a default value. -#[async_trait::async_trait] -trait ShardAssembledResult: Sized { - /// Type of messages used to communicate the entire result over the network. Often, shards - /// hold a collection of shares, so this type will indicate the share type. - type SingleMessage: Message; - - /// Return empty value that will be used by all shards except the leader, - /// to set their result of execution. - fn empty() -> Self; +/// This type exists to bind [`HV`] and [`B`] together and allow +/// conversions from [`AdditiveShare`] to [`BitDecomposed>`] +/// and vice versa. Decomposed view is used to perform additions, +/// share is used to send data to other shards. +#[derive(Debug, Default)] +struct Histogram +where + Boolean: FieldSimd, +{ + values: BitDecomposed>, + _marker: PhantomData, +} - /// Merges two assembled results together. - fn merge(&mut self, other: Self); +impl Histogram +where + BitDecomposed>: + for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, + Vec>: + for<'a> TransposeFrom<&'a BitDecomposed>, Error = LengthError>, + Boolean: FieldSimd, +{ + pub fn new(input: &Vec>) -> Result { + Ok(Self { + values: BitDecomposed::transposed_from(input)?, + _marker: PhantomData, + }) + } - /// Converts this into a form suitable to be sent over the wire - fn into_messages(self) -> impl ExactSizeIterator + Send; + pub fn compose(&self) -> Vec> { + if self.values.is_empty() { + Vec::new() + } else { + // unwrap here is safe because we converted values from a vector during + // initialization, so it must have the value we need + Vec::transposed_from(&self.values).unwrap() + } + } +} - /// Reverse conversion from the stream of messages back to this type. - async fn from_message_stream< - S: Stream> + Send, - >( - stream: S, - ) -> Result; +#[cfg(test)] +impl crate::test_fixture::Reconstruct> + for [Histogram; 3] +where + Boolean: FieldSimd, + BitDecomposed>: + for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, + Vec>: + for<'a> TransposeFrom<&'a BitDecomposed>, Error = LengthError>, +{ + fn reconstruct(&self) -> Vec { + let shares = self.each_ref().map(Histogram::compose); + shares.reconstruct() + } } -#[async_trait::async_trait] -impl ShardAssembledResult for Vec> +impl ShardAssembledResult + for Histogram where + AdditiveShare: BooleanProtocols, + Boolean: FieldSimd, + // I mean... there must be a less-verbose way to write these bounds + Vec>: + for<'a> TransposeFrom<&'a BitDecomposed>, Error = LengthError>, + BitDecomposed>: + for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, ::Size: Add, { type SingleMessage = AdditiveShare; fn empty() -> Self { - Vec::new() + Self::default() } - fn merge(&mut self, other: Self) { - // this merges two histograms together by adding them up - for (a, b) in self.iter_mut().zip(other) { - *a += b; + #[allow(clippy::manual_async_fn)] + fn merge<'a>( + &'a mut self, + ctx: C, + record_id: RecordId, + other: Self, + ) -> impl Future> + Send + 'a + where + C: 'a, + { + async move { + // todo: EightBit only works for 256 breakdowns. EightBitStep will panic if we try + // to add larger values + self.values = + integer_sat_add::<_, EightBitStep, B>(ctx, record_id, &self.values, &other.values) + .await?; + + Ok(()) } } fn into_messages(self) -> impl ExactSizeIterator + Send { - self.into_iter() + self.compose().into_iter() } - async fn from_message_stream< - S: Stream> + Send, - >( - stream: S, - ) -> Result { - stream.try_collect::>().await + #[allow(clippy::manual_async_fn)] + fn from_message_stream(stream: S) -> impl Future> + Send + where + S: Stream> + Send, + { + async move { Ok(Self::new(&stream.try_collect::>().await?)?) } } } #[cfg(all(test, unit_test))] mod tests { + use std::iter::repeat; + use crate::{ - ff::boolean_array::BA64, - protocol::basics::shard_fin::finalize, - secret_sharing::SharedValue, + ff::{boolean_array::BA8, U128Conversions}, + helpers::{in_memory_config::MaliciousHelper, Role}, + protocol::{ + basics::shard_fin::{FinalizerContext, Histogram}, + context::TEST_DZKP_STEPS, + }, + sharding::ShardIndex, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards}, }; + /// generate some data to validate the integer addition finalizer + fn gen(values: [BA8; SHARDS]) -> impl Iterator + Clone { + let mut cnt = 0; + // each shard receive the same value + std::iter::from_fn(move || { + cnt += 1; + Some(values[(cnt - 1) % SHARDS]) + }) + } + + #[test] + fn semi_honest() { + run(|| async { + const SHARDS: usize = 3; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + + let input = gen::([ + BA8::truncate_from(10_u128), + BA8::truncate_from(21_u128), + BA8::truncate_from(3_u128), + ]) + .take(16 * SHARDS); + let results = world + .semi_honest(input.clone(), |ctx, input| async move { + let input = Histogram::::new(&input).unwrap(); + ctx.finalize(TEST_DZKP_STEPS, input).await.unwrap() + }) + .await; + + // leader aggregates everything + let leader_shares = results[0].reconstruct(); + assert_eq!( + repeat(BA8::truncate_from(34_u128)) + .take(16) + .collect::>(), + leader_shares + ); + + // followers have nothing + let f1 = results[1].reconstruct(); + let f2 = results[2].reconstruct(); + assert_eq!(f1, f2); + assert_eq!(0, f1.len()); + }); + } + #[test] - fn shards_set_result() { + fn malicious() { run(|| async { - let world: TestWorld> = + const SHARDS: usize = 3; + let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); - let input = vec![BA64::ZERO, BA64::ZERO, BA64::ZERO]; + + let input = gen::([ + BA8::truncate_from(1_u128), + BA8::truncate_from(3_u128), + BA8::truncate_from(5_u128), + ]) + .take(16 * SHARDS); let results = world - .semi_honest(input.into_iter(), |ctx, input| async move { - assert_eq!(1, input.len()); - finalize(ctx, input).await.unwrap() + .malicious(input.clone(), |ctx, input| async move { + let input = Histogram::::new(&input).unwrap(); + ctx.finalize(TEST_DZKP_STEPS, input).await.unwrap() }) .await; // leader aggregates everything let leader_shares = results[0].reconstruct(); - assert_eq!(vec![BA64::ZERO], leader_shares); + assert_eq!( + repeat(BA8::truncate_from(9_u128)) + .take(16) + .collect::>(), + leader_shares + ); // followers have nothing let f1 = results[1].reconstruct(); @@ -155,4 +389,40 @@ mod tests { assert_eq!(0, f1.len()); }); } + + #[test] + #[should_panic(expected = "DZKPValidationFailed")] + fn malicious_attack_resistant() { + run(|| async { + const SHARDS: usize = 3; + let mut config = TestWorldConfig::default(); + config.stream_interceptor = + MaliciousHelper::new(Role::H2, config.role_assignment(), move |ctx, data| { + if ctx + .gate + .as_ref() + .contains(TEST_DZKP_STEPS.protocol.as_ref()) + && ctx.dest == Role::H1 + && ctx.shard == Some(ShardIndex::FIRST) + { + data[0] ^= 1u8; + } + }); + let world: TestWorld> = TestWorld::with_shards(config); + + let input = gen::([ + BA8::truncate_from(1_u128), + BA8::truncate_from(3_u128), + BA8::truncate_from(5_u128), + ]) + .take(16 * SHARDS); + world + .malicious(input, |ctx, input| async move { + ctx.finalize(TEST_DZKP_STEPS, Histogram::::new(&input).unwrap()) + .await + .unwrap() + }) + .await; + }); + } } diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 951b5b4c4..773b30e00 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -8,18 +8,20 @@ use ipa_step::{Step, StepNarrow}; use crate::{ error::Error, - helpers::{MpcMessage, MpcReceivingEnd, Role, SendingEnd, TotalRecords}, + helpers::{ + Message, MpcMessage, MpcReceivingEnd, Role, SendingEnd, ShardReceivingEnd, TotalRecords, + }, protocol::{ context::{ dzkp_validator::{Batch, MaliciousDZKPValidatorInner, Segment}, prss::InstrumentedIndexedSharedRandomness, Context as ContextTrait, DZKPContext, InstrumentedSequentialSharedRandomness, - MaliciousContext, + MaliciousContext, ShardedContext, }, Gate, RecordId, }, seq_join::SeqJoin, - sharding::ShardBinding, + sharding::{ShardBinding, ShardConfiguration, ShardIndex, Sharded}, sync::{Arc, Weak}, }; @@ -168,3 +170,27 @@ impl Debug for DZKPUpgraded<'_, B> { write!(f, "DZKPMaliciousContext") } } + +impl ShardConfiguration for DZKPUpgraded<'_, Sharded> { + fn shard_id(&self) -> ShardIndex { + self.base_ctx.shard_id() + } + + fn shard_count(&self) -> ShardIndex { + self.base_ctx.shard_count() + } +} + +impl ShardedContext for DZKPUpgraded<'_, Sharded> { + fn shard_send_channel(&self, dest_shard: ShardIndex) -> SendingEnd { + self.base_ctx.shard_send_channel(dest_shard) + } + + fn shard_recv_channel(&self, origin: ShardIndex) -> ShardReceivingEnd { + self.base_ctx.shard_recv_channel(origin) + } + + fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> { + self.base_ctx.cross_shard_prss() + } +} diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs index 93dcb0aee..b2b516416 100644 --- a/ipa-core/src/protocol/hybrid/step.rs +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -15,4 +15,6 @@ pub(crate) enum HybridStep { #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] EvalPrf, ReshardByPrf, + Finalize, + FinalizeValidate, }