Skip to content

Commit

Permalink
feat(prover): Optimize setup keys loading (#2847)
Browse files Browse the repository at this point in the history
## What ❔

- Loads setup keys to memory in parallel (for GPU prover only).
- Refactors a bunch of related code for simplicity.

## Why ❔

- Locally I've observed load time going from ~30s to ~12s, so ~60%
improvement for prover start time.
- Readability & maintainability.

## Checklist

<!-- Check your PR fulfills the following items. -->
<!-- For draft PRs check the boxes as you complete them. -->

- [ ] PR title corresponds to the body of PR (we generate changelog
entries from PRs).
- [ ] Tests for the changes have been added / updated.
- [ ] Documentation comments have been added / updated.
- [ ] Code has been formatted via `zk fmt` and `zk lint`.
  • Loading branch information
popzxc authored and itegulov committed Sep 13, 2024
1 parent 28eb685 commit 190ed8a
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 153 deletions.
59 changes: 33 additions & 26 deletions core/lib/basic_types/src/basic_fri_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,29 @@ impl AggregationRound {
AggregationRound::Scheduler => None,
}
}

/// Returns all the circuit IDs that correspond to a particular
/// aggregation round.
///
/// For example, in aggregation round 0, the circuit ids should be 1 to 15 + 255 (EIP4844).
/// In aggregation round 1, the circuit ids should be 3 to 18.
/// In aggregation round 2, the circuit ids should be 2.
/// In aggregation round 3, the circuit ids should be 255.
/// In aggregation round 4, the circuit ids should be 1.
pub fn circuit_ids(self) -> Vec<CircuitIdRoundTuple> {
match self {
AggregationRound::BasicCircuits => (1..=15)
.chain(once(255))
.map(|circuit_id| CircuitIdRoundTuple::new(circuit_id, self as u8))
.collect(),
AggregationRound::LeafAggregation => (3..=18)
.map(|circuit_id| CircuitIdRoundTuple::new(circuit_id, self as u8))
.collect(),
AggregationRound::NodeAggregation => vec![CircuitIdRoundTuple::new(2, self as u8)],
AggregationRound::RecursionTip => vec![CircuitIdRoundTuple::new(255, self as u8)],
AggregationRound::Scheduler => vec![CircuitIdRoundTuple::new(1, self as u8)],
}
}
}

impl std::fmt::Display for AggregationRound {
Expand Down Expand Up @@ -265,33 +288,17 @@ impl CircuitProverStats {

impl Default for CircuitProverStats {
fn default() -> Self {
let mut stats = HashMap::new();
for circuit in (1..=15).chain(once(255)) {
stats.insert(
CircuitIdRoundTuple::new(circuit, 0),
JobCountStatistics::default(),
);
}
for circuit in 3..=18 {
stats.insert(
CircuitIdRoundTuple::new(circuit, 1),
JobCountStatistics::default(),
);
}
stats.insert(
CircuitIdRoundTuple::new(2, 2),
JobCountStatistics::default(),
);
stats.insert(
CircuitIdRoundTuple::new(255, 3),
JobCountStatistics::default(),
);
stats.insert(
CircuitIdRoundTuple::new(1, 4),
JobCountStatistics::default(),
);
let circuits_prover_stats = AggregationRound::ALL_ROUNDS
.into_iter()
.flat_map(|round| {
let circuit_ids = round.circuit_ids();
circuit_ids.into_iter().map(|circuit_id_round_tuple| {
(circuit_id_round_tuple, JobCountStatistics::default())
})
})
.collect();
Self {
circuits_prover_stats: stats,
circuits_prover_stats,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/lib/config/src/configs/fri_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::Deserialize;

use crate::ObjectStoreConfig;

#[derive(Debug, Deserialize, Clone, PartialEq)]
#[derive(Debug, Deserialize, Clone, Copy, PartialEq)]
pub enum SetupLoadMode {
FromDisk,
FromMemory,
Expand Down
137 changes: 36 additions & 101 deletions core/lib/config/src/configs/fri_prover_group.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashSet;

use serde::Deserialize;
use zksync_basic_types::basic_fri_types::CircuitIdRoundTuple;
use zksync_basic_types::basic_fri_types::{AggregationRound, CircuitIdRoundTuple};

/// Configuration for the grouping of specialized provers.
#[derive(Debug, Deserialize, Clone, PartialEq)]
Expand Down Expand Up @@ -81,6 +81,7 @@ impl FriProverGroupConfig {
.flatten()
.collect()
}

/// check all_circuit ids present exactly once
/// and For each aggregation round, check that the circuit ids are in the correct range.
/// For example, in aggregation round 0, the circuit ids should be 1 to 15 + 255 (EIP4844).
Expand All @@ -89,7 +90,6 @@ impl FriProverGroupConfig {
/// In aggregation round 3, the circuit ids should be 255.
/// In aggregation round 4, the circuit ids should be 1.
pub fn validate(&self) -> anyhow::Result<()> {
let mut rounds: Vec<Vec<CircuitIdRoundTuple>> = vec![Vec::new(); 5];
let groups = [
&self.group_0,
&self.group_1,
Expand All @@ -107,110 +107,45 @@ impl FriProverGroupConfig {
&self.group_13,
&self.group_14,
];
for group in groups {
for circuit_round in group {
let round = match rounds.get_mut(circuit_round.aggregation_round as usize) {
Some(round) => round,
None => anyhow::bail!(
"Invalid aggregation round {}.",
circuit_round.aggregation_round
),
};
round.push(circuit_round.clone());
}
}

for (round, round_data) in rounds.iter().enumerate() {
let circuit_ids: Vec<u8> = round_data.iter().map(|x| x.circuit_id).collect();
let unique_circuit_ids: HashSet<u8> = circuit_ids.iter().copied().collect();
let duplicates: HashSet<u8> = circuit_ids
.iter()
.filter(|id| circuit_ids.iter().filter(|x| x == id).count() > 1)
.copied()
.collect();
let mut expected_circuit_ids: HashSet<_> = AggregationRound::ALL_ROUNDS
.into_iter()
.flat_map(|r| r.circuit_ids())
.collect();

let (missing_ids, not_in_range, expected_circuits_description) = match round {
0 => {
let mut expected_range: Vec<_> = (1..=15).collect();
expected_range.push(255);
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();

let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit IDs 1 to 15 and 255")
}
1 => {
let expected_range: Vec<_> = (3..=18).collect();
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();
let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit IDs 3 to 18")
}
2 => {
let expected_range: Vec<_> = vec![2];
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();
let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit ID 2")
let mut provided_circuit_ids = HashSet::new();
for (group_id, group) in groups.iter().enumerate() {
for circuit_id_round in group.iter() {
// Make sure that it's a known circuit.
if !expected_circuit_ids.contains(circuit_id_round) {
anyhow::bail!(
"Group {} contains unexpected circuit id: {:?}",
group_id,
circuit_id_round
);
}
3 => {
let expected_range: Vec<_> = vec![255];
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();
let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit ID 255")
}
4 => {
let expected_range: Vec<_> = vec![1];
let missing_ids: Vec<_> = expected_range
.iter()
.copied()
.filter(|id| !circuit_ids.contains(id))
.collect();
let not_in_range: Vec<_> = circuit_ids
.iter()
.filter(|&id| !expected_range.contains(id))
.collect();
(missing_ids, not_in_range, "circuit ID 1")
}
_ => {
anyhow::bail!("Unknown round {}", round);
// Remove this circuit from the expected set: later we will check that all circuits
// are present.
expected_circuit_ids.remove(circuit_id_round);

// Make sure that the circuit is not duplicated.
if provided_circuit_ids.contains(circuit_id_round) {
anyhow::bail!(
"Group {} contains duplicate circuit id: {:?}",
group_id,
circuit_id_round
);
}
};
if !missing_ids.is_empty() {
anyhow::bail!("Circuit IDs for round {round} are missing: {missing_ids:?}");
}
if circuit_ids.len() != unique_circuit_ids.len() {
anyhow::bail!("Circuit IDs: {duplicates:?} should be unique for round {round}.",);
}
if !not_in_range.is_empty() {
anyhow::bail!("Aggregation round {round} should only contain {expected_circuits_description}. Ids out of range: {not_in_range:?}");
provided_circuit_ids.insert(circuit_id_round.clone());
}
}
// All the circuit IDs should have been removed from the expected set.
if !expected_circuit_ids.is_empty() {
anyhow::bail!(
"Some circuit ids are missing from the groups: {:?}",
expected_circuit_ids
);
}

Ok(())
}
}
1 change: 1 addition & 0 deletions prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ tempfile = "3"
tokio = "1"
toml_edit = "0.14.4"
tracing = "0.1"
tracing-subscriber = { version = "0.3" }
tracing-subscriber = "0.3"
vise = "0.2.0"

# Proving dependencies
Expand Down
3 changes: 3 additions & 0 deletions prover/crates/bin/prover_fri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ reqwest = { workspace = true, features = ["blocking"] }
regex.workspace = true
clap = { workspace = true, features = ["derive"] }

[dev-dependencies]
tracing-subscriber.workspace = true

[features]
default = []
gpu = ["shivini", "zksync_prover_keystore/gpu"]
90 changes: 68 additions & 22 deletions prover/crates/bin/prover_fri/src/gpu_prover_job_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ pub mod gpu_prover {
ProverContextConfig,
};
use tokio::task::JoinHandle;
use zksync_config::configs::{fri_prover_group::FriProverGroupConfig, FriProverConfig};
use zksync_env_config::FromEnv;
use zksync_config::configs::fri_prover::SetupLoadMode as SetupLoadModeConfig;
use zksync_config::configs::FriProverConfig;
use zksync_object_store::ObjectStore;
use zksync_prover_dal::{ConnectionPool, ProverDal};
use zksync_prover_fri_types::{
Expand Down Expand Up @@ -341,38 +341,84 @@ pub mod gpu_prover {
}
}

pub fn load_setup_data_cache(
#[tracing::instrument(skip_all, fields(setup_load_mode = ?setup_load_mode, specialized_group_id = %specialized_group_id))]
pub async fn load_setup_data_cache(
keystore: &Keystore,
config: &FriProverConfig,
setup_load_mode: SetupLoadModeConfig,
specialized_group_id: u8,
circuit_ids: &[CircuitIdRoundTuple],
) -> anyhow::Result<SetupLoadMode> {
Ok(match config.setup_load_mode {
zksync_config::configs::fri_prover::SetupLoadMode::FromDisk => SetupLoadMode::FromDisk,
zksync_config::configs::fri_prover::SetupLoadMode::FromMemory => {
Ok(match setup_load_mode {
SetupLoadModeConfig::FromDisk => SetupLoadMode::FromDisk,
SetupLoadModeConfig::FromMemory => {
anyhow::ensure!(
!circuit_ids.is_empty(),
"Circuit IDs must be provided when using FromMemory mode"
);
let mut cache = HashMap::new();
tracing::info!(
"Loading setup data cache for group {}",
&config.specialized_group_id
&specialized_group_id
);
let prover_setup_metadata_list = FriProverGroupConfig::from_env()
.context("FriProverGroupConfig::from_env()")?
.get_circuit_ids_for_group_id(config.specialized_group_id)
.context(
"At least one circuit should be configured for group when running in FromMemory mode",
)?;
tracing::info!(
"for group {} configured setup metadata are {:?}",
&config.specialized_group_id,
prover_setup_metadata_list
&specialized_group_id,
circuit_ids
);
for prover_setup_metadata in prover_setup_metadata_list {
let key = setup_metadata_to_setup_data_key(&prover_setup_metadata);
let setup_data = keystore
.load_gpu_setup_data_for_circuit_type(key.clone())
.context("load_gpu_setup_data_for_circuit_type()")?;
cache.insert(key, Arc::new(setup_data));
// Load each file in parallel. Note that FS access is not necessarily parallel, but
// deserialization is (and it's not insignificant, as setup keys are large).
// Note: `collect` is important, because iterators are lazy and otherwise we won't actually
// spawn threads.
let handles: Vec<_> = circuit_ids
.into_iter()
.map(|prover_setup_metadata| {
let keystore = keystore.clone();
let prover_setup_metadata = prover_setup_metadata.clone();
tokio::task::spawn_blocking(move || {
let key = setup_metadata_to_setup_data_key(&prover_setup_metadata);
let setup_data = keystore
.load_gpu_setup_data_for_circuit_type(key.clone())
.context("load_gpu_setup_data_for_circuit_type()")?;
anyhow::Ok((key, Arc::new(setup_data)))
})
})
.collect();
for handle in futures::future::join_all(handles).await {
let (key, setup_data) = handle.context("Key loading future panicked")??;
cache.insert(key, setup_data);
}
SetupLoadMode::FromMemory(cache)
}
})
}

#[cfg(test)]
mod tests {
use zksync_types::basic_fri_types::AggregationRound;

use super::*;

#[tokio::test]
async fn test_load_setup_data_cache() {
tracing_subscriber::fmt::try_init().ok();

let keystore = Keystore::locate();
let mode = SetupLoadModeConfig::FromMemory;
let specialized_group_id = 0;
let ids: Vec<_> = AggregationRound::ALL_ROUNDS
.into_iter()
.flat_map(|r| r.circuit_ids())
.collect();
if !keystore.is_setup_data_present(&setup_metadata_to_setup_data_key(&ids[0])) {
// We don't want this test to fail on envs where setup keys are not present.
return;
}

let start = Instant::now();
let _cache = load_setup_data_cache(&keystore, mode, specialized_group_id, &ids)
.await
.expect("Unable to load keys");
tracing::info!("Cache load time: {:?}", start.elapsed());
}
}
}
Loading

0 comments on commit 190ed8a

Please sign in to comment.