Skip to content

Commit

Permalink
rounding v2 dryrun
Browse files Browse the repository at this point in the history
  • Loading branch information
zjma committed Dec 18, 2024
1 parent dd0dcd2 commit 8a61dc7
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ move-table-extension = { workspace = true }
move-vm-types = { workspace = true }
num-bigint = { workspace = true }
num-derive = { workspace = true }
num-integer = { workspace = true }
num-traits = { workspace = true }
once_cell = { workspace = true }
passkey-types = { workspace = true }
Expand Down
6 changes: 5 additions & 1 deletion types/src/dkg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use self::real_dkg::RealDKG;
use crate::{
dkg::real_dkg::rounding::DKGRoundingProfile,
dkg::real_dkg::{rounding::DKGRoundingProfile, rounding_v2::RoundedV2},
on_chain_config::{OnChainConfig, OnChainRandomnessConfig, RandomnessConfigMoveStruct},
validator_verifier::{ValidatorConsensusInfo, ValidatorConsensusInfoMoveStruct},
};
Expand Down Expand Up @@ -156,10 +156,14 @@ impl OnChainConfig for DKGState {

#[derive(Clone, Debug, Default)]
pub struct RoundingSummary {
pub stakes: Vec<u64>,
pub method: String,
pub output: DKGRoundingProfile,
pub error: Option<String>,
pub exec_time: Duration,
pub rounded_v2: Option<RoundedV2>,
pub rounding_v2_err: Option<String>,
pub rounding_v2_time: Duration,
}

pub trait MayHaveRoundingSummary {
Expand Down
20 changes: 20 additions & 0 deletions types/src/dkg/real_dkg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ use aptos_dkg::{
},
};
use fixed::types::U64F64;
use num_bigint::BigUint;
use num_traits::Zero;
use rand::{CryptoRng, RngCore};
use serde::{Deserialize, Serialize};
use std::{collections::BTreeSet, sync::Arc, time::Instant};

pub mod rounding;
pub mod rounding_v2;

pub type WTrx = pvss::das::WeightedTranscript;
pub type DkgPP = <WTrx as Transcript>::PublicParameters;
Expand Down Expand Up @@ -107,6 +109,20 @@ pub fn build_dkg_pvss_config(
maybe_fast_path_secrecy_threshold,
);
let rounding_time = timer.elapsed();

let timer = Instant::now();
let stakes = validator_stakes.iter().map(|s| BigUint::from(*s)).collect();
let (rounded_v2, rounding_v2_err) = match rounding_v2::main(
stakes,
BigUint::from(secrecy_threshold.to_bits()),
BigUint::from(reconstruct_threshold.to_bits()),
maybe_fast_path_secrecy_threshold.map(|v| BigUint::from(v.to_bits())),
) {
Ok(rounded) => (Some(rounded), None),
Err(e) => (None, Some(e.to_string())),
};
let rounding_v2_time = timer.elapsed();

let validator_consensus_keys: Vec<bls12381::PublicKey> = next_validators
.iter()
.map(|vi| vi.public_key.clone())
Expand All @@ -120,10 +136,14 @@ pub fn build_dkg_pvss_config(
let pp = DkgPP::default_with_bls_base();

let rounding_summary = RoundingSummary {
stakes: validator_stakes,
method: rounding_method,
output: profile,
exec_time: rounding_time,
error: rounding_error,
rounded_v2,
rounding_v2_err,
rounding_v2_time,
};

DKGPvssConfig::new(
Expand Down
206 changes: 206 additions & 0 deletions types/src/dkg/real_dkg/rounding_v2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
// Copyright (c) Aptos Foundation
// SPDX-License-Identifier: Apache-2.0

use anyhow::{anyhow, Result};
use num_bigint::BigUint;
use num_integer::Integer;
use num_traits::{One, ToPrimitive, Zero};
use std::cmp::{max, min};

#[derive(Clone, Debug, Default)]
pub struct RoundedV2 {
pub ideal_total_weight: u128,
pub weights: Vec<u64>,
pub reconstruct_threshold_default_path: u128,
pub reconstruct_threshold_fast_path: Option<u128>,
}

#[derive(Debug)]
struct ReconstructThresholdInfo {
in_weights: BigUint,
in_stakes: BigUint,
}

#[derive(Debug)]
struct Profile {
ideal_total_weight: BigUint,
validator_weights: Vec<BigUint>,
threshold_default_path: ReconstructThresholdInfo,
threshold_fast_path: Option<ReconstructThresholdInfo>,
}

impl Profile {
fn naive(n: usize) -> Self {
Self {
ideal_total_weight: BigUint::from(n),
validator_weights: vec![BigUint::one(); n],
threshold_default_path: ReconstructThresholdInfo {
in_weights: BigUint::one(),
in_stakes: BigUint::one(),
},
threshold_fast_path: None,
}
}
}

pub fn main(
stakes: Vec<BigUint>,
mut secrecy_threshold_shl64: BigUint,
mut recon_threshold_shl64: BigUint,
fast_secrecy_thresh_shl64: Option<BigUint>,
) -> Result<RoundedV2> {
let n = stakes.len();
// Ensure secrecy threshold is in [0,1).
secrecy_threshold_shl64 = min(
secrecy_threshold_shl64,
BigUint::from(0xFFFFFFFFFFFFFFFF_u64),
);
// `recon_thresh > secrecy_thresh` should hold, otherwise it is invalid input.
recon_threshold_shl64 = max(
recon_threshold_shl64,
secrecy_threshold_shl64.clone() + BigUint::one(),
);
recon_threshold_shl64 = min(recon_threshold_shl64, BigUint::from(1u128 << 64));
let mut total_weight_max = BigUint::from(n) + BigUint::from(4u64);
total_weight_max <<= 64;
total_weight_max = total_weight_max.div_ceil(
&((recon_threshold_shl64.clone() - secrecy_threshold_shl64.clone()) * BigUint::from(2u64)),
);
let stakes_total: BigUint = stakes.clone().into_iter().sum();
let bar = (stakes_total.clone() * recon_threshold_shl64.clone()) >> 64;
let mut lo = 0;
let mut hi = total_weight_max
.to_u128()
.ok_or_else(|| anyhow!("total_weight_max is not a u128!"))?
* 2;
// This^ ensures the first `ideal_weight` to try is `total_weight_max`,
// which should always result in a valid weight assignment that satisfies `recon_threshold_shl64`.

let mut profile = Profile::naive(n);
while lo + 1 < hi {
let ideal_weight = (lo + hi) / 2;
let mut weight_per_stake_shl64 = BigUint::from(ideal_weight);
weight_per_stake_shl64 <<= 64;
weight_per_stake_shl64 = weight_per_stake_shl64.div_ceil(&stakes_total);
let cur_profile = compute_profile(
secrecy_threshold_shl64.clone(),
fast_secrecy_thresh_shl64.clone(),
&stakes,
BigUint::from(ideal_weight),
weight_per_stake_shl64,
);
if cur_profile.threshold_default_path.in_stakes <= bar {
hi = ideal_weight;
profile = cur_profile;
} else {
lo = ideal_weight;
}
}

let Profile {
ideal_total_weight,
validator_weights,
threshold_default_path,
threshold_fast_path,
} = profile;
let mut weights = Vec::with_capacity(n);
for w in validator_weights {
let w = w.to_u64().ok_or_else(|| anyhow!("some w is not u64!"))?;
weights.push(w);
}
let reconstruct_threshold_fast_path =
if let Some(t) = threshold_fast_path {
Some(t.in_weights.to_u128().ok_or_else(|| {
anyhow!("reconstruct_threshold_fast_path.in_weights is not a u128!")
})?)
} else {
None
};

Ok(RoundedV2 {
ideal_total_weight: ideal_total_weight
.to_u128()
.ok_or_else(|| anyhow!("ideal_total_weight is not a u128"))?,
weights,
reconstruct_threshold_default_path: threshold_default_path
.in_weights
.to_u128()
.ok_or_else(|| anyhow!("threshold_default_path.in_weights is not a u128!"))?,
reconstruct_threshold_fast_path,
})
}

fn compute_threshold(
secrecy_threshold_shl64: BigUint,
weight_per_stake_shl64: BigUint,
stake_total: BigUint,
weight_total: BigUint,
weight_gain_shl64: BigUint,
weight_loss_shl64: BigUint,
) -> ReconstructThresholdInfo {
let mut final_thresh = (((weight_gain_shl64 << 64)
+ stake_total * secrecy_threshold_shl64 * weight_per_stake_shl64.clone())
>> 128)
+ BigUint::one();
final_thresh = min(final_thresh, weight_total);
let mut stakes_required = final_thresh.clone();
stakes_required <<= 64;
stakes_required += weight_loss_shl64;
stakes_required = stakes_required.div_ceil(&weight_per_stake_shl64);
ReconstructThresholdInfo {
in_weights: final_thresh,
in_stakes: stakes_required,
}
}

fn compute_profile(
secrecy_threshold_shl64: BigUint,
fast_path_secrecy_threshold_shl64: Option<BigUint>,
stakes: &[BigUint],
ideal_total_weight: BigUint,
weight_per_stake_shl64: BigUint,
) -> Profile {
let n = stakes.len();
let mut validator_weights = Vec::with_capacity(n);
let mut weight_loss_shl64 = BigUint::zero();
let mut weight_gain_shl64 = BigUint::zero();
for stake in stakes {
let ideal_weight_shl64 = weight_per_stake_shl64.clone() * stake;
let mut rounded_weight = ideal_weight_shl64.clone() + BigUint::from(1u64 << 63);
rounded_weight >>= 64;

validator_weights.push(rounded_weight.clone());
let rounded_weight_shl64 = rounded_weight << 64;
if ideal_weight_shl64 > rounded_weight_shl64 {
weight_loss_shl64 += ideal_weight_shl64 - rounded_weight_shl64;
} else {
weight_gain_shl64 += rounded_weight_shl64 - ideal_weight_shl64;
}
}
let total_stake: BigUint = stakes.iter().cloned().sum();
let total_weight: BigUint = validator_weights.clone().into_iter().sum();
let threshold_default_path = compute_threshold(
secrecy_threshold_shl64,
weight_per_stake_shl64.clone(),
total_stake.clone(),
total_weight.clone(),
weight_gain_shl64.clone(),
weight_loss_shl64.clone(),
);
let threshold_fast_path = fast_path_secrecy_threshold_shl64.map(|v| {
compute_threshold(
v,
weight_per_stake_shl64,
total_stake,
total_weight,
weight_gain_shl64,
weight_loss_shl64,
)
});
Profile {
ideal_total_weight,
validator_weights,
threshold_default_path,
threshold_fast_path,
}
}

0 comments on commit 8a61dc7

Please sign in to comment.