From 19887ef21a8bbd26977353f8ee277b711850dfd2 Mon Sep 17 00:00:00 2001 From: Igor Aleksanov Date: Thu, 12 Sep 2024 16:25:21 +0400 Subject: [PATCH] feat(prover): Optimize setup keys loading (#2847) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What ❔ - Loads setup keys to memory in parallel (for GPU prover only). - Refactors a bunch of related code for simplicity. ## Why ❔ - Locally I've observed load time going from ~30s to ~12s, so ~60% improvement for prover start time. - Readability & maintainability. ## Checklist - [ ] PR title corresponds to the body of PR (we generate changelog entries from PRs). - [ ] Tests for the changes have been added / updated. - [ ] Documentation comments have been added / updated. - [ ] Code has been formatted via `zk fmt` and `zk lint`. --- core/lib/basic_types/src/basic_fri_types.rs | 59 ++++---- core/lib/config/src/configs/fri_prover.rs | 2 +- .../config/src/configs/fri_prover_group.rs | 137 +++++------------- prover/Cargo.lock | 1 + prover/Cargo.toml | 2 +- prover/crates/bin/prover_fri/Cargo.toml | 3 + .../src/gpu_prover_job_processor.rs | 90 +++++++++--- prover/crates/bin/prover_fri/src/main.rs | 10 +- 8 files changed, 151 insertions(+), 153 deletions(-) diff --git a/core/lib/basic_types/src/basic_fri_types.rs b/core/lib/basic_types/src/basic_fri_types.rs index 5969cca6b8c0..9de9920e86f6 100644 --- a/core/lib/basic_types/src/basic_fri_types.rs +++ b/core/lib/basic_types/src/basic_fri_types.rs @@ -152,6 +152,29 @@ impl AggregationRound { AggregationRound::Scheduler => None, } } + + /// Returns all the circuit IDs that correspond to a particular + /// aggregation round. + /// + /// For example, in aggregation round 0, the circuit ids should be 1 to 15 + 255 (EIP4844). + /// In aggregation round 1, the circuit ids should be 3 to 18. + /// In aggregation round 2, the circuit ids should be 2. + /// In aggregation round 3, the circuit ids should be 255. + /// In aggregation round 4, the circuit ids should be 1. + pub fn circuit_ids(self) -> Vec { + match self { + AggregationRound::BasicCircuits => (1..=15) + .chain(once(255)) + .map(|circuit_id| CircuitIdRoundTuple::new(circuit_id, self as u8)) + .collect(), + AggregationRound::LeafAggregation => (3..=18) + .map(|circuit_id| CircuitIdRoundTuple::new(circuit_id, self as u8)) + .collect(), + AggregationRound::NodeAggregation => vec![CircuitIdRoundTuple::new(2, self as u8)], + AggregationRound::RecursionTip => vec![CircuitIdRoundTuple::new(255, self as u8)], + AggregationRound::Scheduler => vec![CircuitIdRoundTuple::new(1, self as u8)], + } + } } impl std::fmt::Display for AggregationRound { @@ -265,33 +288,17 @@ impl CircuitProverStats { impl Default for CircuitProverStats { fn default() -> Self { - let mut stats = HashMap::new(); - for circuit in (1..=15).chain(once(255)) { - stats.insert( - CircuitIdRoundTuple::new(circuit, 0), - JobCountStatistics::default(), - ); - } - for circuit in 3..=18 { - stats.insert( - CircuitIdRoundTuple::new(circuit, 1), - JobCountStatistics::default(), - ); - } - stats.insert( - CircuitIdRoundTuple::new(2, 2), - JobCountStatistics::default(), - ); - stats.insert( - CircuitIdRoundTuple::new(255, 3), - JobCountStatistics::default(), - ); - stats.insert( - CircuitIdRoundTuple::new(1, 4), - JobCountStatistics::default(), - ); + let circuits_prover_stats = AggregationRound::ALL_ROUNDS + .into_iter() + .flat_map(|round| { + let circuit_ids = round.circuit_ids(); + circuit_ids.into_iter().map(|circuit_id_round_tuple| { + (circuit_id_round_tuple, JobCountStatistics::default()) + }) + }) + .collect(); Self { - circuits_prover_stats: stats, + circuits_prover_stats, } } } diff --git a/core/lib/config/src/configs/fri_prover.rs b/core/lib/config/src/configs/fri_prover.rs index f6a21beaa6dc..32558dd2244b 100644 --- a/core/lib/config/src/configs/fri_prover.rs +++ b/core/lib/config/src/configs/fri_prover.rs @@ -4,7 +4,7 @@ use serde::Deserialize; use crate::ObjectStoreConfig; -#[derive(Debug, Deserialize, Clone, PartialEq)] +#[derive(Debug, Deserialize, Clone, Copy, PartialEq)] pub enum SetupLoadMode { FromDisk, FromMemory, diff --git a/core/lib/config/src/configs/fri_prover_group.rs b/core/lib/config/src/configs/fri_prover_group.rs index 0fd752b5c286..294d4d1bbd44 100644 --- a/core/lib/config/src/configs/fri_prover_group.rs +++ b/core/lib/config/src/configs/fri_prover_group.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use serde::Deserialize; -use zksync_basic_types::basic_fri_types::CircuitIdRoundTuple; +use zksync_basic_types::basic_fri_types::{AggregationRound, CircuitIdRoundTuple}; /// Configuration for the grouping of specialized provers. #[derive(Debug, Deserialize, Clone, PartialEq)] @@ -81,6 +81,7 @@ impl FriProverGroupConfig { .flatten() .collect() } + /// check all_circuit ids present exactly once /// and For each aggregation round, check that the circuit ids are in the correct range. /// For example, in aggregation round 0, the circuit ids should be 1 to 15 + 255 (EIP4844). @@ -89,7 +90,6 @@ impl FriProverGroupConfig { /// In aggregation round 3, the circuit ids should be 255. /// In aggregation round 4, the circuit ids should be 1. pub fn validate(&self) -> anyhow::Result<()> { - let mut rounds: Vec> = vec![Vec::new(); 5]; let groups = [ &self.group_0, &self.group_1, @@ -107,110 +107,45 @@ impl FriProverGroupConfig { &self.group_13, &self.group_14, ]; - for group in groups { - for circuit_round in group { - let round = match rounds.get_mut(circuit_round.aggregation_round as usize) { - Some(round) => round, - None => anyhow::bail!( - "Invalid aggregation round {}.", - circuit_round.aggregation_round - ), - }; - round.push(circuit_round.clone()); - } - } - - for (round, round_data) in rounds.iter().enumerate() { - let circuit_ids: Vec = round_data.iter().map(|x| x.circuit_id).collect(); - let unique_circuit_ids: HashSet = circuit_ids.iter().copied().collect(); - let duplicates: HashSet = circuit_ids - .iter() - .filter(|id| circuit_ids.iter().filter(|x| x == id).count() > 1) - .copied() - .collect(); + let mut expected_circuit_ids: HashSet<_> = AggregationRound::ALL_ROUNDS + .into_iter() + .flat_map(|r| r.circuit_ids()) + .collect(); - let (missing_ids, not_in_range, expected_circuits_description) = match round { - 0 => { - let mut expected_range: Vec<_> = (1..=15).collect(); - expected_range.push(255); - let missing_ids: Vec<_> = expected_range - .iter() - .copied() - .filter(|id| !circuit_ids.contains(id)) - .collect(); - - let not_in_range: Vec<_> = circuit_ids - .iter() - .filter(|&id| !expected_range.contains(id)) - .collect(); - (missing_ids, not_in_range, "circuit IDs 1 to 15 and 255") - } - 1 => { - let expected_range: Vec<_> = (3..=18).collect(); - let missing_ids: Vec<_> = expected_range - .iter() - .copied() - .filter(|id| !circuit_ids.contains(id)) - .collect(); - let not_in_range: Vec<_> = circuit_ids - .iter() - .filter(|&id| !expected_range.contains(id)) - .collect(); - (missing_ids, not_in_range, "circuit IDs 3 to 18") - } - 2 => { - let expected_range: Vec<_> = vec![2]; - let missing_ids: Vec<_> = expected_range - .iter() - .copied() - .filter(|id| !circuit_ids.contains(id)) - .collect(); - let not_in_range: Vec<_> = circuit_ids - .iter() - .filter(|&id| !expected_range.contains(id)) - .collect(); - (missing_ids, not_in_range, "circuit ID 2") + let mut provided_circuit_ids = HashSet::new(); + for (group_id, group) in groups.iter().enumerate() { + for circuit_id_round in group.iter() { + // Make sure that it's a known circuit. + if !expected_circuit_ids.contains(circuit_id_round) { + anyhow::bail!( + "Group {} contains unexpected circuit id: {:?}", + group_id, + circuit_id_round + ); } - 3 => { - let expected_range: Vec<_> = vec![255]; - let missing_ids: Vec<_> = expected_range - .iter() - .copied() - .filter(|id| !circuit_ids.contains(id)) - .collect(); - let not_in_range: Vec<_> = circuit_ids - .iter() - .filter(|&id| !expected_range.contains(id)) - .collect(); - (missing_ids, not_in_range, "circuit ID 255") - } - 4 => { - let expected_range: Vec<_> = vec![1]; - let missing_ids: Vec<_> = expected_range - .iter() - .copied() - .filter(|id| !circuit_ids.contains(id)) - .collect(); - let not_in_range: Vec<_> = circuit_ids - .iter() - .filter(|&id| !expected_range.contains(id)) - .collect(); - (missing_ids, not_in_range, "circuit ID 1") - } - _ => { - anyhow::bail!("Unknown round {}", round); + // Remove this circuit from the expected set: later we will check that all circuits + // are present. + expected_circuit_ids.remove(circuit_id_round); + + // Make sure that the circuit is not duplicated. + if provided_circuit_ids.contains(circuit_id_round) { + anyhow::bail!( + "Group {} contains duplicate circuit id: {:?}", + group_id, + circuit_id_round + ); } - }; - if !missing_ids.is_empty() { - anyhow::bail!("Circuit IDs for round {round} are missing: {missing_ids:?}"); - } - if circuit_ids.len() != unique_circuit_ids.len() { - anyhow::bail!("Circuit IDs: {duplicates:?} should be unique for round {round}.",); - } - if !not_in_range.is_empty() { - anyhow::bail!("Aggregation round {round} should only contain {expected_circuits_description}. Ids out of range: {not_in_range:?}"); + provided_circuit_ids.insert(circuit_id_round.clone()); } } + // All the circuit IDs should have been removed from the expected set. + if !expected_circuit_ids.is_empty() { + anyhow::bail!( + "Some circuit ids are missing from the groups: {:?}", + expected_circuit_ids + ); + } + Ok(()) } } diff --git a/prover/Cargo.lock b/prover/Cargo.lock index 21e2ea8b21de..cea147deccf8 100644 --- a/prover/Cargo.lock +++ b/prover/Cargo.lock @@ -7974,6 +7974,7 @@ dependencies = [ "shivini", "tokio", "tracing", + "tracing-subscriber", "vise", "zkevm_test_harness", "zksync_config", diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 403314cc13ca..251b3b0fb082 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -52,7 +52,7 @@ tempfile = "3" tokio = "1" toml_edit = "0.14.4" tracing = "0.1" -tracing-subscriber = { version = "0.3" } +tracing-subscriber = "0.3" vise = "0.2.0" # Proving dependencies diff --git a/prover/crates/bin/prover_fri/Cargo.toml b/prover/crates/bin/prover_fri/Cargo.toml index ae7853427e96..e41244cecbf7 100644 --- a/prover/crates/bin/prover_fri/Cargo.toml +++ b/prover/crates/bin/prover_fri/Cargo.toml @@ -43,6 +43,9 @@ reqwest = { workspace = true, features = ["blocking"] } regex.workspace = true clap = { workspace = true, features = ["derive"] } +[dev-dependencies] +tracing-subscriber.workspace = true + [features] default = [] gpu = ["shivini", "zksync_prover_keystore/gpu"] diff --git a/prover/crates/bin/prover_fri/src/gpu_prover_job_processor.rs b/prover/crates/bin/prover_fri/src/gpu_prover_job_processor.rs index 240251df15bf..be28f2bd97ee 100644 --- a/prover/crates/bin/prover_fri/src/gpu_prover_job_processor.rs +++ b/prover/crates/bin/prover_fri/src/gpu_prover_job_processor.rs @@ -8,8 +8,8 @@ pub mod gpu_prover { ProverContextConfig, }; use tokio::task::JoinHandle; - use zksync_config::configs::{fri_prover_group::FriProverGroupConfig, FriProverConfig}; - use zksync_env_config::FromEnv; + use zksync_config::configs::fri_prover::SetupLoadMode as SetupLoadModeConfig; + use zksync_config::configs::FriProverConfig; use zksync_object_store::ObjectStore; use zksync_prover_dal::{ConnectionPool, ProverDal}; use zksync_prover_fri_types::{ @@ -341,38 +341,84 @@ pub mod gpu_prover { } } - pub fn load_setup_data_cache( + #[tracing::instrument(skip_all, fields(setup_load_mode = ?setup_load_mode, specialized_group_id = %specialized_group_id))] + pub async fn load_setup_data_cache( keystore: &Keystore, - config: &FriProverConfig, + setup_load_mode: SetupLoadModeConfig, + specialized_group_id: u8, + circuit_ids: &[CircuitIdRoundTuple], ) -> anyhow::Result { - Ok(match config.setup_load_mode { - zksync_config::configs::fri_prover::SetupLoadMode::FromDisk => SetupLoadMode::FromDisk, - zksync_config::configs::fri_prover::SetupLoadMode::FromMemory => { + Ok(match setup_load_mode { + SetupLoadModeConfig::FromDisk => SetupLoadMode::FromDisk, + SetupLoadModeConfig::FromMemory => { + anyhow::ensure!( + !circuit_ids.is_empty(), + "Circuit IDs must be provided when using FromMemory mode" + ); let mut cache = HashMap::new(); tracing::info!( "Loading setup data cache for group {}", - &config.specialized_group_id + &specialized_group_id ); - let prover_setup_metadata_list = FriProverGroupConfig::from_env() - .context("FriProverGroupConfig::from_env()")? - .get_circuit_ids_for_group_id(config.specialized_group_id) - .context( - "At least one circuit should be configured for group when running in FromMemory mode", - )?; tracing::info!( "for group {} configured setup metadata are {:?}", - &config.specialized_group_id, - prover_setup_metadata_list + &specialized_group_id, + circuit_ids ); - for prover_setup_metadata in prover_setup_metadata_list { - let key = setup_metadata_to_setup_data_key(&prover_setup_metadata); - let setup_data = keystore - .load_gpu_setup_data_for_circuit_type(key.clone()) - .context("load_gpu_setup_data_for_circuit_type()")?; - cache.insert(key, Arc::new(setup_data)); + // Load each file in parallel. Note that FS access is not necessarily parallel, but + // deserialization is (and it's not insignificant, as setup keys are large). + // Note: `collect` is important, because iterators are lazy and otherwise we won't actually + // spawn threads. + let handles: Vec<_> = circuit_ids + .into_iter() + .map(|prover_setup_metadata| { + let keystore = keystore.clone(); + let prover_setup_metadata = prover_setup_metadata.clone(); + tokio::task::spawn_blocking(move || { + let key = setup_metadata_to_setup_data_key(&prover_setup_metadata); + let setup_data = keystore + .load_gpu_setup_data_for_circuit_type(key.clone()) + .context("load_gpu_setup_data_for_circuit_type()")?; + anyhow::Ok((key, Arc::new(setup_data))) + }) + }) + .collect(); + for handle in futures::future::join_all(handles).await { + let (key, setup_data) = handle.context("Key loading future panicked")??; + cache.insert(key, setup_data); } SetupLoadMode::FromMemory(cache) } }) } + + #[cfg(test)] + mod tests { + use zksync_types::basic_fri_types::AggregationRound; + + use super::*; + + #[tokio::test] + async fn test_load_setup_data_cache() { + tracing_subscriber::fmt::try_init().ok(); + + let keystore = Keystore::locate(); + let mode = SetupLoadModeConfig::FromMemory; + let specialized_group_id = 0; + let ids: Vec<_> = AggregationRound::ALL_ROUNDS + .into_iter() + .flat_map(|r| r.circuit_ids()) + .collect(); + if !keystore.is_setup_data_present(&setup_metadata_to_setup_data_key(&ids[0])) { + // We don't want this test to fail on envs where setup keys are not present. + return; + } + + let start = Instant::now(); + let _cache = load_setup_data_cache(&keystore, mode, specialized_group_id, &ids) + .await + .expect("Unable to load keys"); + tracing::info!("Cache load time: {:?}", start.elapsed()); + } + } } diff --git a/prover/crates/bin/prover_fri/src/main.rs b/prover/crates/bin/prover_fri/src/main.rs index 8191653efec6..cbba8d0ddb4f 100644 --- a/prover/crates/bin/prover_fri/src/main.rs +++ b/prover/crates/bin/prover_fri/src/main.rs @@ -231,8 +231,14 @@ async fn get_prover_tasks( let keystore = Keystore::locate().with_setup_path(Some(prover_config.setup_data_path.clone().into())); - let setup_load_mode = gpu_prover::load_setup_data_cache(&keystore, &prover_config) - .context("load_setup_data_cache()")?; + let setup_load_mode = gpu_prover::load_setup_data_cache( + &keystore, + prover_config.setup_load_mode, + prover_config.specialized_group_id, + &circuit_ids_for_round_to_be_proven, + ) + .await + .context("load_setup_data_cache()")?; let witness_vector_queue = FixedSizeQueue::new(prover_config.queue_capacity); let shared_witness_vector_queue = Arc::new(Mutex::new(witness_vector_queue)); let consumer = shared_witness_vector_queue.clone();