From 8a61dc748cd10999eab5aee50c54505fc5a041f5 Mon Sep 17 00:00:00 2001 From: "zhoujun.ma" Date: Tue, 17 Dec 2024 17:38:42 -0800 Subject: [PATCH] rounding v2 dryrun --- Cargo.lock | 1 + types/Cargo.toml | 1 + types/src/dkg/mod.rs | 6 +- types/src/dkg/real_dkg/mod.rs | 20 +++ types/src/dkg/real_dkg/rounding_v2.rs | 206 ++++++++++++++++++++++++++ 5 files changed, 233 insertions(+), 1 deletion(-) create mode 100644 types/src/dkg/real_dkg/rounding_v2.rs diff --git a/Cargo.lock b/Cargo.lock index 15f52e4b304bb..d4ae90fe19a99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4428,6 +4428,7 @@ dependencies = [ "move-vm-types", "num-bigint 0.3.3", "num-derive", + "num-integer", "num-traits", "once_cell", "p256", diff --git a/types/Cargo.toml b/types/Cargo.toml index 5296eb1af3eab..051f7d6ecbc30 100644 --- a/types/Cargo.toml +++ b/types/Cargo.toml @@ -44,6 +44,7 @@ move-table-extension = { workspace = true } move-vm-types = { workspace = true } num-bigint = { workspace = true } num-derive = { workspace = true } +num-integer = { workspace = true } num-traits = { workspace = true } once_cell = { workspace = true } passkey-types = { workspace = true } diff --git a/types/src/dkg/mod.rs b/types/src/dkg/mod.rs index 2e3bde94085ab..e0c2420d91f5e 100644 --- a/types/src/dkg/mod.rs +++ b/types/src/dkg/mod.rs @@ -3,7 +3,7 @@ use self::real_dkg::RealDKG; use crate::{ - dkg::real_dkg::rounding::DKGRoundingProfile, + dkg::real_dkg::{rounding::DKGRoundingProfile, rounding_v2::RoundedV2}, on_chain_config::{OnChainConfig, OnChainRandomnessConfig, RandomnessConfigMoveStruct}, validator_verifier::{ValidatorConsensusInfo, ValidatorConsensusInfoMoveStruct}, }; @@ -156,10 +156,14 @@ impl OnChainConfig for DKGState { #[derive(Clone, Debug, Default)] pub struct RoundingSummary { + pub stakes: Vec, pub method: String, pub output: DKGRoundingProfile, pub error: Option, pub exec_time: Duration, + pub rounded_v2: Option, + pub rounding_v2_err: Option, + pub rounding_v2_time: Duration, } pub trait MayHaveRoundingSummary { diff --git a/types/src/dkg/real_dkg/mod.rs b/types/src/dkg/real_dkg/mod.rs index d0e055bb284fd..84f9e357346a5 100644 --- a/types/src/dkg/real_dkg/mod.rs +++ b/types/src/dkg/real_dkg/mod.rs @@ -19,12 +19,14 @@ use aptos_dkg::{ }, }; use fixed::types::U64F64; +use num_bigint::BigUint; use num_traits::Zero; use rand::{CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; use std::{collections::BTreeSet, sync::Arc, time::Instant}; pub mod rounding; +pub mod rounding_v2; pub type WTrx = pvss::das::WeightedTranscript; pub type DkgPP = ::PublicParameters; @@ -107,6 +109,20 @@ pub fn build_dkg_pvss_config( maybe_fast_path_secrecy_threshold, ); let rounding_time = timer.elapsed(); + + let timer = Instant::now(); + let stakes = validator_stakes.iter().map(|s| BigUint::from(*s)).collect(); + let (rounded_v2, rounding_v2_err) = match rounding_v2::main( + stakes, + BigUint::from(secrecy_threshold.to_bits()), + BigUint::from(reconstruct_threshold.to_bits()), + maybe_fast_path_secrecy_threshold.map(|v| BigUint::from(v.to_bits())), + ) { + Ok(rounded) => (Some(rounded), None), + Err(e) => (None, Some(e.to_string())), + }; + let rounding_v2_time = timer.elapsed(); + let validator_consensus_keys: Vec = next_validators .iter() .map(|vi| vi.public_key.clone()) @@ -120,10 +136,14 @@ pub fn build_dkg_pvss_config( let pp = DkgPP::default_with_bls_base(); let rounding_summary = RoundingSummary { + stakes: validator_stakes, method: rounding_method, output: profile, exec_time: rounding_time, error: rounding_error, + rounded_v2, + rounding_v2_err, + rounding_v2_time, }; DKGPvssConfig::new( diff --git a/types/src/dkg/real_dkg/rounding_v2.rs b/types/src/dkg/real_dkg/rounding_v2.rs new file mode 100644 index 0000000000000..5e094dd082e42 --- /dev/null +++ b/types/src/dkg/real_dkg/rounding_v2.rs @@ -0,0 +1,206 @@ +// Copyright (c) Aptos Foundation +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::{anyhow, Result}; +use num_bigint::BigUint; +use num_integer::Integer; +use num_traits::{One, ToPrimitive, Zero}; +use std::cmp::{max, min}; + +#[derive(Clone, Debug, Default)] +pub struct RoundedV2 { + pub ideal_total_weight: u128, + pub weights: Vec, + pub reconstruct_threshold_default_path: u128, + pub reconstruct_threshold_fast_path: Option, +} + +#[derive(Debug)] +struct ReconstructThresholdInfo { + in_weights: BigUint, + in_stakes: BigUint, +} + +#[derive(Debug)] +struct Profile { + ideal_total_weight: BigUint, + validator_weights: Vec, + threshold_default_path: ReconstructThresholdInfo, + threshold_fast_path: Option, +} + +impl Profile { + fn naive(n: usize) -> Self { + Self { + ideal_total_weight: BigUint::from(n), + validator_weights: vec![BigUint::one(); n], + threshold_default_path: ReconstructThresholdInfo { + in_weights: BigUint::one(), + in_stakes: BigUint::one(), + }, + threshold_fast_path: None, + } + } +} + +pub fn main( + stakes: Vec, + mut secrecy_threshold_shl64: BigUint, + mut recon_threshold_shl64: BigUint, + fast_secrecy_thresh_shl64: Option, +) -> Result { + let n = stakes.len(); + // Ensure secrecy threshold is in [0,1). + secrecy_threshold_shl64 = min( + secrecy_threshold_shl64, + BigUint::from(0xFFFFFFFFFFFFFFFF_u64), + ); + // `recon_thresh > secrecy_thresh` should hold, otherwise it is invalid input. + recon_threshold_shl64 = max( + recon_threshold_shl64, + secrecy_threshold_shl64.clone() + BigUint::one(), + ); + recon_threshold_shl64 = min(recon_threshold_shl64, BigUint::from(1u128 << 64)); + let mut total_weight_max = BigUint::from(n) + BigUint::from(4u64); + total_weight_max <<= 64; + total_weight_max = total_weight_max.div_ceil( + &((recon_threshold_shl64.clone() - secrecy_threshold_shl64.clone()) * BigUint::from(2u64)), + ); + let stakes_total: BigUint = stakes.clone().into_iter().sum(); + let bar = (stakes_total.clone() * recon_threshold_shl64.clone()) >> 64; + let mut lo = 0; + let mut hi = total_weight_max + .to_u128() + .ok_or_else(|| anyhow!("total_weight_max is not a u128!"))? + * 2; + // This^ ensures the first `ideal_weight` to try is `total_weight_max`, + // which should always result in a valid weight assignment that satisfies `recon_threshold_shl64`. + + let mut profile = Profile::naive(n); + while lo + 1 < hi { + let ideal_weight = (lo + hi) / 2; + let mut weight_per_stake_shl64 = BigUint::from(ideal_weight); + weight_per_stake_shl64 <<= 64; + weight_per_stake_shl64 = weight_per_stake_shl64.div_ceil(&stakes_total); + let cur_profile = compute_profile( + secrecy_threshold_shl64.clone(), + fast_secrecy_thresh_shl64.clone(), + &stakes, + BigUint::from(ideal_weight), + weight_per_stake_shl64, + ); + if cur_profile.threshold_default_path.in_stakes <= bar { + hi = ideal_weight; + profile = cur_profile; + } else { + lo = ideal_weight; + } + } + + let Profile { + ideal_total_weight, + validator_weights, + threshold_default_path, + threshold_fast_path, + } = profile; + let mut weights = Vec::with_capacity(n); + for w in validator_weights { + let w = w.to_u64().ok_or_else(|| anyhow!("some w is not u64!"))?; + weights.push(w); + } + let reconstruct_threshold_fast_path = + if let Some(t) = threshold_fast_path { + Some(t.in_weights.to_u128().ok_or_else(|| { + anyhow!("reconstruct_threshold_fast_path.in_weights is not a u128!") + })?) + } else { + None + }; + + Ok(RoundedV2 { + ideal_total_weight: ideal_total_weight + .to_u128() + .ok_or_else(|| anyhow!("ideal_total_weight is not a u128"))?, + weights, + reconstruct_threshold_default_path: threshold_default_path + .in_weights + .to_u128() + .ok_or_else(|| anyhow!("threshold_default_path.in_weights is not a u128!"))?, + reconstruct_threshold_fast_path, + }) +} + +fn compute_threshold( + secrecy_threshold_shl64: BigUint, + weight_per_stake_shl64: BigUint, + stake_total: BigUint, + weight_total: BigUint, + weight_gain_shl64: BigUint, + weight_loss_shl64: BigUint, +) -> ReconstructThresholdInfo { + let mut final_thresh = (((weight_gain_shl64 << 64) + + stake_total * secrecy_threshold_shl64 * weight_per_stake_shl64.clone()) + >> 128) + + BigUint::one(); + final_thresh = min(final_thresh, weight_total); + let mut stakes_required = final_thresh.clone(); + stakes_required <<= 64; + stakes_required += weight_loss_shl64; + stakes_required = stakes_required.div_ceil(&weight_per_stake_shl64); + ReconstructThresholdInfo { + in_weights: final_thresh, + in_stakes: stakes_required, + } +} + +fn compute_profile( + secrecy_threshold_shl64: BigUint, + fast_path_secrecy_threshold_shl64: Option, + stakes: &[BigUint], + ideal_total_weight: BigUint, + weight_per_stake_shl64: BigUint, +) -> Profile { + let n = stakes.len(); + let mut validator_weights = Vec::with_capacity(n); + let mut weight_loss_shl64 = BigUint::zero(); + let mut weight_gain_shl64 = BigUint::zero(); + for stake in stakes { + let ideal_weight_shl64 = weight_per_stake_shl64.clone() * stake; + let mut rounded_weight = ideal_weight_shl64.clone() + BigUint::from(1u64 << 63); + rounded_weight >>= 64; + + validator_weights.push(rounded_weight.clone()); + let rounded_weight_shl64 = rounded_weight << 64; + if ideal_weight_shl64 > rounded_weight_shl64 { + weight_loss_shl64 += ideal_weight_shl64 - rounded_weight_shl64; + } else { + weight_gain_shl64 += rounded_weight_shl64 - ideal_weight_shl64; + } + } + let total_stake: BigUint = stakes.iter().cloned().sum(); + let total_weight: BigUint = validator_weights.clone().into_iter().sum(); + let threshold_default_path = compute_threshold( + secrecy_threshold_shl64, + weight_per_stake_shl64.clone(), + total_stake.clone(), + total_weight.clone(), + weight_gain_shl64.clone(), + weight_loss_shl64.clone(), + ); + let threshold_fast_path = fast_path_secrecy_threshold_shl64.map(|v| { + compute_threshold( + v, + weight_per_stake_shl64, + total_stake, + total_weight, + weight_gain_shl64, + weight_loss_shl64, + ) + }); + Profile { + ideal_total_weight, + validator_weights, + threshold_default_path, + threshold_fast_path, + } +}