Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prover): Optimize setup keys loading #2847

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions core/lib/basic_types/src/basic_fri_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,29 @@ impl AggregationRound {
AggregationRound::Scheduler => None,
}
}

/// Returns all the circuit IDs that correspond to a particular
/// aggregation round.
///
/// For example, in aggregation round 0, the circuit ids should be 1 to 15 + 255 (EIP4844).
/// In aggregation round 1, the circuit ids should be 3 to 18.
/// In aggregation round 2, the circuit ids should be 2.
/// In aggregation round 3, the circuit ids should be 255.
/// In aggregation round 4, the circuit ids should be 1.
pub fn circuit_ids(self) -> Vec<CircuitIdRoundTuple> {
match self {
AggregationRound::BasicCircuits => (1..=15)
.chain(once(255))
.map(|circuit_id| CircuitIdRoundTuple::new(circuit_id, self as u8))
.collect(),
AggregationRound::LeafAggregation => (3..=18)
.map(|circuit_id| CircuitIdRoundTuple::new(circuit_id, self as u8))
.collect(),
AggregationRound::NodeAggregation => vec![CircuitIdRoundTuple::new(2, self as u8)],
AggregationRound::RecursionTip => vec![CircuitIdRoundTuple::new(255, self as u8)],
AggregationRound::Scheduler => vec![CircuitIdRoundTuple::new(1, self as u8)],
}
}
}

impl std::fmt::Display for AggregationRound {
Expand Down Expand Up @@ -265,33 +288,17 @@ impl CircuitProverStats {

impl Default for CircuitProverStats {
fn default() -> Self {
let mut stats = HashMap::new();
for circuit in (1..=15).chain(once(255)) {
stats.insert(
CircuitIdRoundTuple::new(circuit, 0),
JobCountStatistics::default(),
);
}
for circuit in 3..=18 {
stats.insert(
CircuitIdRoundTuple::new(circuit, 1),
JobCountStatistics::default(),
);
}
stats.insert(
CircuitIdRoundTuple::new(2, 2),
JobCountStatistics::default(),
);
stats.insert(
CircuitIdRoundTuple::new(255, 3),
JobCountStatistics::default(),
);
stats.insert(
CircuitIdRoundTuple::new(1, 4),
JobCountStatistics::default(),
);
let circuits_prover_stats = AggregationRound::ALL_ROUNDS
EmilLuta marked this conversation as resolved.
Show resolved Hide resolved
.into_iter()
.flat_map(|round| {
let circuit_ids = round.circuit_ids();
circuit_ids.into_iter().map(|circuit_id_round_tuple| {
(circuit_id_round_tuple, JobCountStatistics::default())
})
})
.collect();
Self {
circuits_prover_stats: stats,
circuits_prover_stats,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/lib/config/src/configs/fri_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::Deserialize;

use crate::ObjectStoreConfig;

#[derive(Debug, Deserialize, Clone, PartialEq)]
#[derive(Debug, Deserialize, Clone, Copy, PartialEq)]
pub enum SetupLoadMode {
FromDisk,
FromMemory,
Expand Down
137 changes: 36 additions & 101 deletions core/lib/config/src/configs/fri_prover_group.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashSet;

use serde::Deserialize;
use zksync_basic_types::basic_fri_types::CircuitIdRoundTuple;
use zksync_basic_types::basic_fri_types::{AggregationRound, CircuitIdRoundTuple};

/// Configuration for the grouping of specialized provers.
#[derive(Debug, Deserialize, Clone, PartialEq)]
Expand Down Expand Up @@ -81,6 +81,7 @@ impl FriProverGroupConfig {
.flatten()
.collect()
}

/// check all_circuit ids present exactly once
/// and For each aggregation round, check that the circuit ids are in the correct range.
/// For example, in aggregation round 0, the circuit ids should be 1 to 15 + 255 (EIP4844).
Expand All @@ -89,7 +90,6 @@ impl FriProverGroupConfig {
/// In aggregation round 3, the circuit ids should be 255.
/// In aggregation round 4, the circuit ids should be 1.
pub fn validate(&self) -> anyhow::Result<()> {
let mut rounds: Vec<Vec<CircuitIdRoundTuple>> = vec![Vec::new(); 5];
let groups = [
&self.group_0,
&self.group_1,
Expand All @@ -107,110 +107,45 @@ impl FriProverGroupConfig {
&self.group_13,
&self.group_14,
];
for group in groups {
for circuit_round in group {
let round = match rounds.get_mut(circuit_round.aggregation_round as usize) {
Some(round) => round,
None => anyhow::bail!(
"Invalid aggregation round {}.",
circuit_round.aggregation_round
),
};
round.push(circuit_round.clone());
}
}

for (round, round_data) in rounds.iter().enumerate() {
let circuit_ids: Vec<u8> = round_data.iter().map(|x| x.circuit_id).collect();
let unique_circuit_ids: HashSet<u8> = circuit_ids.iter().copied().collect();
let duplicates: HashSet<u8> = circuit_ids
.iter()
.filter(|id| circuit_ids.iter().filter(|x| x == id).count() > 1)
.copied()
.collect();
let mut expected_circuit_ids: HashSet<_> = AggregationRound::ALL_ROUNDS
.into_iter()
.flat_map(|r| r.circuit_ids())
.collect();

let (missing_ids, not_in_range, expected_circuits_description) = match round {
0 => {
let mut expected_range: Vec<_> = (1..=15).collect();
expected_range.push(255);
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();

let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit IDs 1 to 15 and 255")
}
1 => {
let expected_range: Vec<_> = (3..=18).collect();
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();
let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit IDs 3 to 18")
}
2 => {
let expected_range: Vec<_> = vec![2];
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();
let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit ID 2")
let mut provided_circuit_ids = HashSet::new();
EmilLuta marked this conversation as resolved.
Show resolved Hide resolved
for (group_id, group) in groups.iter().enumerate() {
for circuit_id_round in group.iter() {
// Make sure that it's a known circuit.
if !expected_circuit_ids.contains(circuit_id_round) {
anyhow::bail!(
"Group {} contains unexpected circuit id: {:?}",
group_id,
circuit_id_round
);
}
3 => {
let expected_range: Vec<_> = vec![255];
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();
let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit ID 255")
}
4 => {
let expected_range: Vec<_> = vec![1];
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();
let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit ID 1")
}
_ => {
anyhow::bail!("Unknown round {}", round);
// Remove this circuit from the expected set: later we will check that all circuits
// are present.
expected_circuit_ids.remove(circuit_id_round);

// Make sure that the circuit is not duplicated.
if provided_circuit_ids.contains(circuit_id_round) {
anyhow::bail!(
"Group {} contains duplicate circuit id: {:?}",
group_id,
circuit_id_round
);
}
};
if !missing_ids.is_empty() {
anyhow::bail!("Circuit IDs for round {round} are missing: {missing_ids:?}");
}
if circuit_ids.len() != unique_circuit_ids.len() {
anyhow::bail!("Circuit IDs: {duplicates:?} should be unique for round {round}.",);
}
if !not_in_range.is_empty() {
anyhow::bail!("Aggregation round {round} should only contain {expected_circuits_description}. Ids out of range: {not_in_range:?}");
provided_circuit_ids.insert(circuit_id_round.clone());
}
}
// All the circuit IDs should have been removed from the expected set.
if !expected_circuit_ids.is_empty() {
anyhow::bail!(
"Some circuit ids are missing from the groups: {:?}",
expected_circuit_ids
);
}

Ok(())
}
}
1 change: 1 addition & 0 deletions prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ tempfile = "3"
tokio = "1"
toml_edit = "0.14.4"
tracing = "0.1"
tracing-subscriber = { version = "0.3" }
tracing-subscriber = "0.3"
vise = "0.2.0"

# Proving dependencies
Expand Down
3 changes: 3 additions & 0 deletions prover/crates/bin/prover_fri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ reqwest = { workspace = true, features = ["blocking"] }
regex.workspace = true
clap = { workspace = true, features = ["derive"] }

[dev-dependencies]
tracing-subscriber.workspace = true

[features]
default = []
gpu = ["shivini", "zksync_prover_keystore/gpu"]
90 changes: 68 additions & 22 deletions prover/crates/bin/prover_fri/src/gpu_prover_job_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ pub mod gpu_prover {
ProverContextConfig,
};
use tokio::task::JoinHandle;
use zksync_config::configs::{fri_prover_group::FriProverGroupConfig, FriProverConfig};
use zksync_env_config::FromEnv;
use zksync_config::configs::fri_prover::SetupLoadMode as SetupLoadModeConfig;
use zksync_config::configs::FriProverConfig;
use zksync_object_store::ObjectStore;
use zksync_prover_dal::{ConnectionPool, ProverDal};
use zksync_prover_fri_types::{
Expand Down Expand Up @@ -341,38 +341,84 @@ pub mod gpu_prover {
}
}

pub fn load_setup_data_cache(
#[tracing::instrument(skip_all, fields(setup_load_mode = ?setup_load_mode, specialized_group_id = %specialized_group_id))]
pub async fn load_setup_data_cache(
keystore: &Keystore,
config: &FriProverConfig,
setup_load_mode: SetupLoadModeConfig,
specialized_group_id: u8,
circuit_ids: &[CircuitIdRoundTuple],
) -> anyhow::Result<SetupLoadMode> {
Ok(match config.setup_load_mode {
zksync_config::configs::fri_prover::SetupLoadMode::FromDisk => SetupLoadMode::FromDisk,
zksync_config::configs::fri_prover::SetupLoadMode::FromMemory => {
Ok(match setup_load_mode {
SetupLoadModeConfig::FromDisk => SetupLoadMode::FromDisk,
SetupLoadModeConfig::FromMemory => {
anyhow::ensure!(
!circuit_ids.is_empty(),
"Circuit IDs must be provided when using FromMemory mode"
);
let mut cache = HashMap::new();
tracing::info!(
"Loading setup data cache for group {}",
&config.specialized_group_id
&specialized_group_id
);
let prover_setup_metadata_list = FriProverGroupConfig::from_env()
.context("FriProverGroupConfig::from_env()")?
.get_circuit_ids_for_group_id(config.specialized_group_id)
.context(
"At least one circuit should be configured for group when running in FromMemory mode",
)?;
tracing::info!(
"for group {} configured setup metadata are {:?}",
&config.specialized_group_id,
prover_setup_metadata_list
&specialized_group_id,
circuit_ids
);
for prover_setup_metadata in prover_setup_metadata_list {
let key = setup_metadata_to_setup_data_key(&prover_setup_metadata);
let setup_data = keystore
.load_gpu_setup_data_for_circuit_type(key.clone())
.context("load_gpu_setup_data_for_circuit_type()")?;
cache.insert(key, Arc::new(setup_data));
// Load each file in parallel. Note that FS access is not necessarily parallel, but
// deserialization is (and it's not insignificant, as setup keys are large).
// Note: `collect` is important, because iterators are lazy and otherwise we won't actually
// spawn threads.
let handles: Vec<_> = circuit_ids
.into_iter()
.map(|prover_setup_metadata| {
let keystore = keystore.clone();
let prover_setup_metadata = prover_setup_metadata.clone();
tokio::task::spawn_blocking(move || {
let key = setup_metadata_to_setup_data_key(&prover_setup_metadata);
let setup_data = keystore
.load_gpu_setup_data_for_circuit_type(key.clone())
.context("load_gpu_setup_data_for_circuit_type()")?;
anyhow::Ok((key, Arc::new(setup_data)))
})
})
.collect();
for handle in futures::future::join_all(handles).await {
let (key, setup_data) = handle.context("Key loading future panicked")??;
cache.insert(key, setup_data);
}
SetupLoadMode::FromMemory(cache)
}
})
}

#[cfg(test)]
mod tests {
EmilLuta marked this conversation as resolved.
Show resolved Hide resolved
use zksync_types::basic_fri_types::AggregationRound;

use super::*;

#[tokio::test]
async fn test_load_setup_data_cache() {
tracing_subscriber::fmt::try_init().ok();

let keystore = Keystore::locate();
let mode = SetupLoadModeConfig::FromMemory;
let specialized_group_id = 0;
let ids: Vec<_> = AggregationRound::ALL_ROUNDS
.into_iter()
.flat_map(|r| r.circuit_ids())
.collect();
if !keystore.is_setup_data_present(&setup_metadata_to_setup_data_key(&ids[0])) {
// We don't want this test to fail on envs where setup keys are not present.
return;
}

let start = Instant::now();
let _cache = load_setup_data_cache(&keystore, mode, specialized_group_id, &ids)
.await
.expect("Unable to load keys");
tracing::info!("Cache load time: {:?}", start.elapsed());
}
}
}
Loading
Loading