diff --git a/.github/workflows/ci-common-reusable.yml b/.github/workflows/ci-common-reusable.yml index 2e5d36feebff..3d28df592e98 100644 --- a/.github/workflows/ci-common-reusable.yml +++ b/.github/workflows/ci-common-reusable.yml @@ -22,6 +22,7 @@ jobs: echo "SCCACHE_GCS_SERVICE_ACCOUNT=gha-ci-runners@matterlabs-infra.iam.gserviceaccount.com" >> .env echo "SCCACHE_GCS_RW_MODE=READ_WRITE" >> .env echo "RUSTC_WRAPPER=sccache" >> .env + echo "RUSTFLAGS=--cfg=no_cuda" >> .env - name: Start services run: | diff --git a/.github/workflows/ci-core-lint-reusable.yml b/.github/workflows/ci-core-lint-reusable.yml index 404f0966b405..85e4be3ff5e3 100644 --- a/.github/workflows/ci-core-lint-reusable.yml +++ b/.github/workflows/ci-core-lint-reusable.yml @@ -19,6 +19,7 @@ jobs: echo "SCCACHE_GCS_SERVICE_ACCOUNT=gha-ci-runners@matterlabs-infra.iam.gserviceaccount.com" >> .env echo "SCCACHE_GCS_RW_MODE=READ_WRITE" >> .env echo "RUSTC_WRAPPER=sccache" >> .env + echo "RUSTFLAGS=--cfg=no_cuda" >> .env echo "prover_url=postgres://postgres:notsecurepassword@localhost:5432/zksync_local_prover" >> $GITHUB_ENV echo "core_url=postgres://postgres:notsecurepassword@localhost:5432/zksync_local" >> $GITHUB_ENV diff --git a/.github/workflows/ci-prover-reusable.yml b/.github/workflows/ci-prover-reusable.yml index 367a86c5f40f..6fa987b1cecf 100644 --- a/.github/workflows/ci-prover-reusable.yml +++ b/.github/workflows/ci-prover-reusable.yml @@ -57,6 +57,7 @@ jobs: echo "SCCACHE_GCS_SERVICE_ACCOUNT=gha-ci-runners@matterlabs-infra.iam.gserviceaccount.com" >> .env echo "SCCACHE_GCS_RW_MODE=READ_WRITE" >> .env echo "RUSTC_WRAPPER=sccache" >> .env + echo "RUSTFLAGS=--cfg=no_cuda" >> .env - name: Start services run: | diff --git a/core/lib/basic_types/src/prover_dal.rs b/core/lib/basic_types/src/prover_dal.rs index 7eb671448608..9764e6675c76 100644 --- a/core/lib/basic_types/src/prover_dal.rs +++ b/core/lib/basic_types/src/prover_dal.rs @@ -8,7 +8,7 @@ use crate::{ basic_fri_types::AggregationRound, protocol_version::ProtocolVersionId, L1BatchNumber, }; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct FriProverJobMetadata { pub id: u32, pub block_number: L1BatchNumber, diff --git a/docker/witness-generator/Dockerfile b/docker/witness-generator/Dockerfile index 2eebe07515e4..a9871be98b43 100644 --- a/docker/witness-generator/Dockerfile +++ b/docker/witness-generator/Dockerfile @@ -1,7 +1,7 @@ FROM matterlabs/zksync-build-base:latest AS builder ARG DEBIAN_FRONTEND=noninteractive -ARG RUST_FLAGS="" +ARG RUST_FLAGS="--cfg=no_cuda" ENV RUSTFLAGS=${RUST_FLAGS} WORKDIR /usr/src/zksync diff --git a/docker/witness-vector-generator/Dockerfile b/docker/witness-vector-generator/Dockerfile index e315f670101a..ce53beeb5fdc 100644 --- a/docker/witness-vector-generator/Dockerfile +++ b/docker/witness-vector-generator/Dockerfile @@ -1,6 +1,8 @@ FROM matterlabs/zksync-build-base:latest as builder ARG DEBIAN_FRONTEND=noninteractive +ARG RUST_FLAGS="--cfg=no_cuda" +ENV RUSTFLAGS=${RUST_FLAGS} WORKDIR /usr/src/zksync COPY . . diff --git a/docs/guides/setup-dev.md b/docs/guides/setup-dev.md index 10eb329628c1..7781e65e5bfb 100644 --- a/docs/guides/setup-dev.md +++ b/docs/guides/setup-dev.md @@ -48,6 +48,10 @@ cargo install sqlx-cli --version 0.8.1 # Foundry curl -L https://foundry.paradigm.xyz | bash foundryup --branch master + +# Non GPU setup, can be skipped if the machine has a GPU configured for provers +echo "export RUSTFLAGS='--cfg=no_cuda'" >> ~/.bashrc + # You will need to reload your `*rc` file here # Clone the repo to the desired location @@ -237,6 +241,28 @@ Go to the zksync folder and run `nix develop`. After it finishes, you are in a s [Foundry](https://book.getfoundry.sh/getting-started/installation) can be utilized for deploying smart contracts. For commands related to deployment, you can pass flags for Foundry integration. +## Non-GPU setup + +Circuit Prover requires a GPU (& CUDA bindings) to run. If you still want to be able to build everything locally on +non-GPU setup, you'll need to change your rustflags. + +For a single run, it's enough to export it on the shell: + +``` +export RUSTFLAGS='--cfg=no_cuda' +``` + +For persistent runs, you can either echo it in your ~/.rc file (discouraged), or configure it for your taste in +`config.toml`. + +For project level configuration, edit `/path/to/zksync/.cargo/config.toml`. For global cargo setup, +`~/.cargo/config.toml`. Add the following: + +```toml +[build] +rustflags = ["--cfg=no_cuda"] +``` + ## Environment Edit the lines below and add them to your shell profile file (e.g. `~/.bash_profile`, `~/.zshrc`): diff --git a/prover/Cargo.lock b/prover/Cargo.lock index d29f0110f217..7b46584fe5bc 100644 --- a/prover/Cargo.lock +++ b/prover/Cargo.lock @@ -7344,6 +7344,33 @@ dependencies = [ "zksync_pairing", ] +[[package]] +name = "zksync_circuit_prover" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "bincode", + "clap 4.5.4", + "shivini", + "tokio", + "tokio-util", + "tracing", + "vise", + "zkevm_test_harness", + "zksync_config", + "zksync_core_leftovers", + "zksync_env_config", + "zksync_object_store", + "zksync_prover_dal", + "zksync_prover_fri_types", + "zksync_prover_fri_utils", + "zksync_prover_keystore", + "zksync_queued_job_processor", + "zksync_types", + "zksync_utils", +] + [[package]] name = "zksync_concurrency" version = "0.1.1" @@ -7948,6 +7975,7 @@ dependencies = [ "anyhow", "bincode", "circuit_definitions", + "futures 0.3.30", "hex", "md5", "once_cell", @@ -7955,6 +7983,7 @@ dependencies = [ "serde_json", "sha3 0.10.8", "shivini", + "tokio", "tracing", "zkevm_test_harness", "zksync_basic_types", diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 624661adc8dc..c29f14500781 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -50,6 +50,7 @@ structopt = "0.3.26" strum = { version = "0.26" } tempfile = "3" tokio = "1" +tokio-util = "0.7.11" toml_edit = "0.14.4" tracing = "0.1" tracing-subscriber = "0.3" diff --git a/prover/crates/bin/circuit_prover/Cargo.toml b/prover/crates/bin/circuit_prover/Cargo.toml new file mode 100644 index 000000000000..a5751a4cd9a6 --- /dev/null +++ b/prover/crates/bin/circuit_prover/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "zksync_circuit_prover" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true + +[dependencies] +tokio = { workspace = true, features = ["macros", "time"] } +tokio-util.workspace = true +anyhow.workspace = true +async-trait.workspace = true +tracing.workspace = true +bincode.workspace = true +clap = { workspace = true, features = ["derive"] } + +zksync_config.workspace = true +zksync_object_store.workspace = true +zksync_prover_dal.workspace = true +zksync_prover_fri_types.workspace = true +zksync_prover_fri_utils.workspace = true +zksync_queued_job_processor.workspace = true +zksync_types.workspace = true +zksync_prover_keystore = { workspace = true, features = ["gpu"] } +zksync_env_config.workspace = true +zksync_core_leftovers.workspace = true +zksync_utils.workspace = true + +vise.workspace = true +shivini = { workspace = true, features = [ + "circuit_definitions", + "zksync", +] } +zkevm_test_harness.workspace = true diff --git a/prover/crates/bin/circuit_prover/src/backoff.rs b/prover/crates/bin/circuit_prover/src/backoff.rs new file mode 100644 index 000000000000..6ddb3d94be35 --- /dev/null +++ b/prover/crates/bin/circuit_prover/src/backoff.rs @@ -0,0 +1,39 @@ +use std::{ops::Mul, time::Duration}; + +/// Backoff - convenience structure that takes care of backoff timings. +#[derive(Debug, Clone)] +pub struct Backoff { + base_delay: Duration, + current_delay: Duration, + max_delay: Duration, +} + +impl Backoff { + /// The delay multiplication coefficient. + // Currently it's hardcoded, but could be provided in the constructor. + const DELAY_MULTIPLIER: u32 = 2; + + /// Create a backoff with base_delay (first delay) and max_delay (maximum delay possible). + pub fn new(base_delay: Duration, max_delay: Duration) -> Self { + Backoff { + base_delay, + current_delay: base_delay, + max_delay, + } + } + + /// Get current delay, handling future delays if needed + pub fn delay(&mut self) -> Duration { + let delay = self.current_delay; + self.current_delay = self + .current_delay + .mul(Self::DELAY_MULTIPLIER) + .min(self.max_delay); + delay + } + + /// Reset the backoff time for to base delay + pub fn reset(&mut self) { + self.current_delay = self.base_delay; + } +} diff --git a/prover/crates/bin/circuit_prover/src/circuit_prover.rs b/prover/crates/bin/circuit_prover/src/circuit_prover.rs new file mode 100644 index 000000000000..1a5f8aa0d974 --- /dev/null +++ b/prover/crates/bin/circuit_prover/src/circuit_prover.rs @@ -0,0 +1,397 @@ +use std::{sync::Arc, time::Instant}; + +use anyhow::Context; +use shivini::{ + gpu_proof_config::GpuProofConfig, gpu_prove_from_external_witness_data, ProverContext, + ProverContextConfig, +}; +use tokio::{sync::mpsc::Receiver, task::JoinHandle}; +use tokio_util::sync::CancellationToken; +use zkevm_test_harness::prover_utils::{verify_base_layer_proof, verify_recursion_layer_proof}; +use zksync_object_store::ObjectStore; +use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; +use zksync_prover_fri_types::{ + circuit_definitions::{ + base_layer_proof_config, + boojum::{ + cs::implementations::{pow::NoPow, witness::WitnessVec}, + field::goldilocks::GoldilocksField, + worker::Worker, + }, + circuit_definitions::{ + base_layer::ZkSyncBaseLayerProof, recursion_layer::ZkSyncRecursionLayerProof, + }, + recursion_layer_proof_config, + }, + CircuitWrapper, FriProofWrapper, ProverArtifacts, WitnessVectorArtifactsTemp, +}; +use zksync_prover_keystore::GoldilocksGpuProverSetupData; +use zksync_types::protocol_version::ProtocolSemanticVersion; +use zksync_utils::panic_extractor::try_extract_panic_message; + +use crate::{ + metrics::CIRCUIT_PROVER_METRICS, + types::{DefaultTranscript, DefaultTreeHasher, Proof, VerificationKey}, + SetupDataCache, +}; + +/// In charge of proving circuits, given a Witness Vector source. +/// Both job runner & job executor. +#[derive(Debug)] +pub struct CircuitProver { + connection_pool: ConnectionPool, + object_store: Arc, + protocol_version: ProtocolSemanticVersion, + /// Witness Vector source receiver + receiver: Receiver, + /// Setup Data used for proving & proof verification + setup_data_cache: SetupDataCache, +} + +impl CircuitProver { + pub fn new( + connection_pool: ConnectionPool, + object_store: Arc, + protocol_version: ProtocolSemanticVersion, + receiver: Receiver, + max_allocation: Option, + setup_data_cache: SetupDataCache, + ) -> anyhow::Result<(Self, ProverContext)> { + // VRAM allocation + let prover_context = match max_allocation { + Some(max_allocation) => ProverContext::create_with_config( + ProverContextConfig::default().with_maximum_device_allocation(max_allocation), + ) + .context("failed initializing fixed gpu prover context")?, + None => ProverContext::create().context("failed initializing gpu prover context")?, + }; + Ok(( + Self { + connection_pool, + object_store, + protocol_version, + receiver, + setup_data_cache, + }, + prover_context, + )) + } + + /// Continuously polls `receiver` for Witness Vectors and proves them. + /// All job executions are persisted. + pub async fn run(mut self, cancellation_token: CancellationToken) -> anyhow::Result<()> { + while !cancellation_token.is_cancelled() { + let time = Instant::now(); + + let artifact = self + .receiver + .recv() + .await + .context("no Witness Vector Generators are available")?; + tracing::info!( + "Circuit Prover received job {:?} after: {:?}", + artifact.prover_job.job_id, + time.elapsed() + ); + CIRCUIT_PROVER_METRICS.job_wait_time.observe(time.elapsed()); + + self.prove(artifact, cancellation_token.clone()) + .await + .context("failed to prove circuit proof")?; + } + tracing::info!("Circuit Prover shut down."); + Ok(()) + } + + /// Proves a job, with persistence of execution. + async fn prove( + &self, + artifact: WitnessVectorArtifactsTemp, + cancellation_token: CancellationToken, + ) -> anyhow::Result<()> { + let time = Instant::now(); + let block_number = artifact.prover_job.block_number; + let job_id = artifact.prover_job.job_id; + let job_start_time = artifact.time; + let setup_data_key = artifact.prover_job.setup_data_key.crypto_setup_key(); + let setup_data = self + .setup_data_cache + .get(&setup_data_key) + .context(format!( + "failed to get setup data for key {setup_data_key:?}" + ))? + .clone(); + let task = tokio::task::spawn_blocking(move || { + let _span = tracing::info_span!("prove_circuit_proof", %block_number).entered(); + Self::prove_circuit_proof(artifact, setup_data).context("failed to prove circuit") + }); + + self.finish_task( + job_id, + time, + job_start_time, + task, + cancellation_token.clone(), + ) + .await?; + tracing::info!( + "Circuit Prover finished job {:?} in: {:?}", + job_id, + time.elapsed() + ); + CIRCUIT_PROVER_METRICS + .job_finished_time + .observe(time.elapsed()); + CIRCUIT_PROVER_METRICS + .full_proving_time + .observe(job_start_time.elapsed()); + Ok(()) + } + + /// Proves a job using crypto primitives (proof generation & proof verification). + #[tracing::instrument( + name = "Prover::prove_circuit_proof", + skip_all, + fields(l1_batch = % witness_vector_artifacts.prover_job.block_number) + )] + pub fn prove_circuit_proof( + witness_vector_artifacts: WitnessVectorArtifactsTemp, + setup_data: Arc, + ) -> anyhow::Result { + let time = Instant::now(); + let WitnessVectorArtifactsTemp { + witness_vector, + prover_job, + .. + } = witness_vector_artifacts; + + let job_id = prover_job.job_id; + let circuit_wrapper = prover_job.circuit_wrapper; + let block_number = prover_job.block_number; + + let (proof, circuit_id) = + Self::generate_proof(&circuit_wrapper, witness_vector, &setup_data) + .context(format!("failed to generate proof for job id {job_id}"))?; + + Self::verify_proof(&circuit_wrapper, &proof, &setup_data.vk).context(format!( + "failed to verify proof with job_id {job_id}, circuit_id: {circuit_id}" + ))?; + + let proof_wrapper = match &circuit_wrapper { + CircuitWrapper::Base(_) => { + FriProofWrapper::Base(ZkSyncBaseLayerProof::from_inner(circuit_id, proof)) + } + CircuitWrapper::Recursive(_) => { + FriProofWrapper::Recursive(ZkSyncRecursionLayerProof::from_inner(circuit_id, proof)) + } + CircuitWrapper::BasePartial(_) => { + return Self::partial_proof_error(); + } + }; + CIRCUIT_PROVER_METRICS + .crypto_primitives_time + .observe(time.elapsed()); + Ok(ProverArtifacts::new(block_number, proof_wrapper)) + } + + /// Generates a proof from crypto primitives. + fn generate_proof( + circuit_wrapper: &CircuitWrapper, + witness_vector: WitnessVec, + setup_data: &Arc, + ) -> anyhow::Result<(Proof, u8)> { + let time = Instant::now(); + + let worker = Worker::new(); + + let (gpu_proof_config, proof_config, circuit_id) = match circuit_wrapper { + CircuitWrapper::Base(circuit) => ( + GpuProofConfig::from_base_layer_circuit(circuit), + base_layer_proof_config(), + circuit.numeric_circuit_type(), + ), + CircuitWrapper::Recursive(circuit) => ( + GpuProofConfig::from_recursive_layer_circuit(circuit), + recursion_layer_proof_config(), + circuit.numeric_circuit_type(), + ), + CircuitWrapper::BasePartial(_) => { + return Self::partial_proof_error(); + } + }; + + let proof = + gpu_prove_from_external_witness_data::( + &gpu_proof_config, + &witness_vector, + proof_config, + &setup_data.setup, + &setup_data.vk, + (), + &worker, + ) + .context("crypto primitive: failed to generate proof")?; + CIRCUIT_PROVER_METRICS + .generate_proof_time + .observe(time.elapsed()); + Ok((proof.into(), circuit_id)) + } + + /// Verifies a proof from crypto primitives + fn verify_proof( + circuit_wrapper: &CircuitWrapper, + proof: &Proof, + verification_key: &VerificationKey, + ) -> anyhow::Result<()> { + let time = Instant::now(); + + let is_valid = match circuit_wrapper { + CircuitWrapper::Base(base_circuit) => { + verify_base_layer_proof::(base_circuit, proof, verification_key) + } + CircuitWrapper::Recursive(recursive_circuit) => { + verify_recursion_layer_proof::(recursive_circuit, proof, verification_key) + } + CircuitWrapper::BasePartial(_) => { + return Self::partial_proof_error(); + } + }; + + CIRCUIT_PROVER_METRICS + .verify_proof_time + .observe(time.elapsed()); + + if !is_valid { + return Err(anyhow::anyhow!("crypto primitive: failed to verify proof")); + } + Ok(()) + } + + /// This code path should never trigger. All proofs are hydrated during Witness Vector Generator. + /// If this triggers, it means that proof hydration in Witness Vector Generator was not done -- logic bug. + fn partial_proof_error() -> anyhow::Result { + Err(anyhow::anyhow!("received unexpected dehydrated proof")) + } + + /// Runs task to completion and persists result. + /// NOTE: Task may be cancelled mid-flight. + async fn finish_task( + &self, + job_id: u32, + time: Instant, + job_start_time: Instant, + task: JoinHandle>, + cancellation_token: CancellationToken, + ) -> anyhow::Result<()> { + tokio::select! { + _ = cancellation_token.cancelled() => { + tracing::info!("Stop signal received, shutting down Circuit Prover..."); + return Ok(()) + } + result = task => { + let error_message = match result { + Ok(Ok(prover_artifact)) => { + tracing::info!("Circuit Prover executed job {:?} in: {:?}", job_id, time.elapsed()); + CIRCUIT_PROVER_METRICS.execution_time.observe(time.elapsed()); + self + .save_result(job_id, job_start_time, prover_artifact) + .await.context("failed to save result")?; + return Ok(()) + } + Ok(Err(error)) => error.to_string(), + Err(error) => try_extract_panic_message(error), + }; + tracing::error!( + "Circuit Prover failed on job {:?} with error {:?}", + job_id, + error_message + ); + + self.save_failure(job_id, error_message).await.context("failed to save failure")?; + } + } + + Ok(()) + } + + /// Persists proof generated. + /// Job metadata is saved to database, whilst artifacts go to object store. + async fn save_result( + &self, + job_id: u32, + job_start_time: Instant, + artifacts: ProverArtifacts, + ) -> anyhow::Result<()> { + let time = Instant::now(); + let mut connection = self + .connection_pool + .connection() + .await + .context("failed to get db connection")?; + let proof = artifacts.proof_wrapper; + + let (_circuit_type, is_scheduler_proof) = match &proof { + FriProofWrapper::Base(base) => (base.numeric_circuit_type(), false), + FriProofWrapper::Recursive(recursive_circuit) => match recursive_circuit { + ZkSyncRecursionLayerProof::SchedulerCircuit(_) => { + (recursive_circuit.numeric_circuit_type(), true) + } + _ => (recursive_circuit.numeric_circuit_type(), false), + }, + }; + + let upload_time = Instant::now(); + let blob_url = self + .object_store + .put(job_id, &proof) + .await + .context("failed to upload to object store")?; + CIRCUIT_PROVER_METRICS + .artifact_upload_time + .observe(upload_time.elapsed()); + + let mut transaction = connection + .start_transaction() + .await + .context("failed to start db transaction")?; + transaction + .fri_prover_jobs_dal() + .save_proof(job_id, job_start_time.elapsed(), &blob_url) + .await; + if is_scheduler_proof { + transaction + .fri_proof_compressor_dal() + .insert_proof_compression_job( + artifacts.block_number, + &blob_url, + self.protocol_version, + ) + .await; + } + transaction + .commit() + .await + .context("failed to commit db transaction")?; + + tracing::info!( + "Circuit Prover saved job {:?} after {:?}", + job_id, + time.elapsed() + ); + CIRCUIT_PROVER_METRICS.save_time.observe(time.elapsed()); + + Ok(()) + } + + /// Persists job execution error to database. + async fn save_failure(&self, job_id: u32, error: String) -> anyhow::Result<()> { + self.connection_pool + .connection() + .await + .context("failed to get db connection")? + .fri_prover_jobs_dal() + .save_proof_error(job_id, error) + .await; + Ok(()) + } +} diff --git a/prover/crates/bin/circuit_prover/src/lib.rs b/prover/crates/bin/circuit_prover/src/lib.rs new file mode 100644 index 000000000000..7d7ce1d96686 --- /dev/null +++ b/prover/crates/bin/circuit_prover/src/lib.rs @@ -0,0 +1,13 @@ +#![allow(incomplete_features)] // We have to use generic const exprs. +#![feature(generic_const_exprs)] +pub use backoff::Backoff; +pub use circuit_prover::CircuitProver; +pub use metrics::PROVER_BINARY_METRICS; +pub use types::{FinalizationHintsCache, SetupDataCache}; +pub use witness_vector_generator::WitnessVectorGenerator; + +mod backoff; +mod circuit_prover; +mod metrics; +mod types; +mod witness_vector_generator; diff --git a/prover/crates/bin/circuit_prover/src/main.rs b/prover/crates/bin/circuit_prover/src/main.rs new file mode 100644 index 000000000000..e26f29ca995d --- /dev/null +++ b/prover/crates/bin/circuit_prover/src/main.rs @@ -0,0 +1,201 @@ +use std::{ + path::PathBuf, + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::Context as _; +use clap::Parser; +use tokio_util::sync::CancellationToken; +use zksync_circuit_prover::{ + Backoff, CircuitProver, FinalizationHintsCache, SetupDataCache, WitnessVectorGenerator, + PROVER_BINARY_METRICS, +}; +use zksync_config::{ + configs::{FriProverConfig, ObservabilityConfig}, + ObjectStoreConfig, +}; +use zksync_core_leftovers::temp_config_store::{load_database_secrets, load_general_config}; +use zksync_object_store::{ObjectStore, ObjectStoreFactory}; +use zksync_prover_dal::{ConnectionPool, Prover}; +use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION; +use zksync_prover_keystore::keystore::Keystore; +use zksync_utils::wait_for_tasks::ManagedTasks; + +#[derive(Debug, Parser)] +#[command(author = "Matter Labs", version)] +struct Cli { + #[arg(long)] + pub(crate) config_path: Option, + #[arg(long)] + pub(crate) secrets_path: Option, + /// Number of WVG jobs to run in parallel. + /// Default value is 1. + #[arg(long, default_value_t = 1)] + pub(crate) witness_vector_generator_count: usize, + /// Max VRAM to allocate. Useful if you want to limit the size of VRAM used. + /// None corresponds to allocating all available VRAM. + #[arg(long)] + pub(crate) max_allocation: Option, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let time = Instant::now(); + let opt = Cli::parse(); + + let (observability_config, prover_config, object_store_config) = load_configs(opt.config_path)?; + + let _observability_guard = observability_config + .install() + .context("failed to install observability")?; + + let wvg_count = opt.witness_vector_generator_count as u32; + + let (connection_pool, object_store, setup_data_cache, hints) = load_resources( + opt.secrets_path, + object_store_config, + prover_config.setup_data_path.into(), + wvg_count, + ) + .await + .context("failed to load configs")?; + + PROVER_BINARY_METRICS.start_up.observe(time.elapsed()); + + let cancellation_token = CancellationToken::new(); + let backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(30)); + + let mut tasks = vec![]; + + let (sender, receiver) = tokio::sync::mpsc::channel(5); + + tracing::info!("Starting {wvg_count} Witness Vector Generators."); + + for _ in 0..wvg_count { + let wvg = WitnessVectorGenerator::new( + object_store.clone(), + connection_pool.clone(), + PROVER_PROTOCOL_SEMANTIC_VERSION, + sender.clone(), + hints.clone(), + ); + tasks.push(tokio::spawn( + wvg.run(cancellation_token.clone(), backoff.clone()), + )); + } + + // NOTE: Prover Context is the way VRAM is allocated. If it is dropped, the claim on VRAM allocation is dropped as well. + // It has to be kept until prover dies. Whilst it may be kept in prover struct, during cancellation, prover can `drop`, but the thread doing the processing can still be alive. + // This setup prevents segmentation faults and other nasty behavior during shutdown. + let (prover, _prover_context) = CircuitProver::new( + connection_pool, + object_store, + PROVER_PROTOCOL_SEMANTIC_VERSION, + receiver, + opt.max_allocation, + setup_data_cache, + ) + .context("failed to create circuit prover")?; + tasks.push(tokio::spawn(prover.run(cancellation_token.clone()))); + + let mut tasks = ManagedTasks::new(tasks); + tokio::select! { + _ = tasks.wait_single() => {}, + result = tokio::signal::ctrl_c() => { + match result { + Ok(_) => { + tracing::info!("Stop signal received, shutting down..."); + cancellation_token.cancel(); + }, + Err(_err) => { + tracing::error!("failed to set up ctrl c listener"); + } + } + } + } + PROVER_BINARY_METRICS.run_time.observe(time.elapsed()); + tasks.complete(Duration::from_secs(5)).await; + + Ok(()) +} + +/// Loads configs necessary for proving. +/// - observability config - for observability setup +/// - prover config - necessary for setup data +/// - object store config - for retrieving artifacts for WVG & CP +fn load_configs( + config_path: Option, +) -> anyhow::Result<(ObservabilityConfig, FriProverConfig, ObjectStoreConfig)> { + tracing::info!("loading configs..."); + let general_config = + load_general_config(config_path).context("failed loading general config")?; + let observability_config = general_config + .observability + .context("failed loading observability config")?; + let prover_config = general_config + .prover_config + .context("failed loading prover config")?; + let object_store_config = prover_config + .prover_object_store + .clone() + .context("failed loading prover object store config")?; + tracing::info!("Loaded configs."); + Ok((observability_config, prover_config, object_store_config)) +} + +/// Loads resources necessary for proving. +/// - connection pool - necessary to pick & store jobs from database +/// - object store - necessary for loading and storing artifacts to object store +/// - setup data - necessary for circuit proving +/// - finalization hints - necessary for generating witness vectors +async fn load_resources( + secrets_path: Option, + object_store_config: ObjectStoreConfig, + setup_data_path: PathBuf, + wvg_count: u32, +) -> anyhow::Result<( + ConnectionPool, + Arc, + SetupDataCache, + FinalizationHintsCache, +)> { + let database_secrets = + load_database_secrets(secrets_path).context("failed to load database secrets")?; + let database_url = database_secrets + .prover_url + .context("no prover DB URl present")?; + + // 1 connection for the prover and one for each vector generator + let max_connections = 1 + wvg_count; + let connection_pool = ConnectionPool::::builder(database_url, max_connections) + .build() + .await + .context("failed to build connection pool")?; + + let object_store = ObjectStoreFactory::new(object_store_config) + .create_store() + .await + .context("failed to create object store")?; + + tracing::info!("Loading mappings from disk..."); + + let keystore = Keystore::locate().with_setup_path(Some(setup_data_path)); + let setup_data_cache = keystore + .load_all_setup_key_mapping() + .await + .context("failed to load setup key mapping")?; + let finalization_hints = keystore + .load_all_finalization_hints_mapping() + .await + .context("failed to load finalization hints mapping")?; + + tracing::info!("Loaded mappings from disk."); + + Ok(( + connection_pool, + object_store, + setup_data_cache, + finalization_hints, + )) +} diff --git a/prover/crates/bin/circuit_prover/src/metrics.rs b/prover/crates/bin/circuit_prover/src/metrics.rs new file mode 100644 index 000000000000..e9f445914795 --- /dev/null +++ b/prover/crates/bin/circuit_prover/src/metrics.rs @@ -0,0 +1,80 @@ +use std::time::Duration; + +use vise::{Buckets, Histogram, Metrics}; + +#[derive(Debug, Metrics)] +#[metrics(prefix = "prover_binary")] +pub struct ProverBinaryMetrics { + /// How long does it take for prover to load data before it can produce proofs? + #[metrics(buckets = Buckets::LATENCIES)] + pub start_up: Histogram, + /// How long has the prover been running? + #[metrics(buckets = Buckets::LATENCIES)] + pub run_time: Histogram, +} + +#[vise::register] +pub static PROVER_BINARY_METRICS: vise::Global = vise::Global::new(); + +#[derive(Debug, Metrics)] +#[metrics(prefix = "witness_vector_generator")] +pub struct WitnessVectorGeneratorMetrics { + /// How long does witness vector generator waits before a job is available? + #[metrics(buckets = Buckets::LATENCIES)] + pub job_wait_time: Histogram, + /// How long does it take to load object store artifacts for a witness vector job? + #[metrics(buckets = Buckets::LATENCIES)] + pub artifact_download_time: Histogram, + /// How long does the crypto witness generation primitive take? + #[metrics(buckets = Buckets::LATENCIES)] + pub crypto_primitive_time: Histogram, + /// How long does it take for a job to be executed, from the moment it's loaded? + #[metrics(buckets = Buckets::LATENCIES)] + pub execution_time: Histogram, + /// How long does it take to send a job to prover? + /// This is relevant because prover queue can apply back-pressure. + #[metrics(buckets = Buckets::LATENCIES)] + pub send_time: Histogram, + /// How long does it take for a job to be considered finished, from the moment it's been loaded? + #[metrics(buckets = Buckets::LATENCIES)] + pub job_finished_time: Histogram, +} + +#[vise::register] +pub static WITNESS_VECTOR_GENERATOR_METRICS: vise::Global = + vise::Global::new(); + +#[derive(Debug, Metrics)] +#[metrics(prefix = "circuit_prover")] +pub struct CircuitProverMetrics { + /// How long does circuit prover wait before a job is available? + #[metrics(buckets = Buckets::LATENCIES)] + pub job_wait_time: Histogram, + /// How long does the crypto primitives (proof generation & verification) take? + #[metrics(buckets = Buckets::LATENCIES)] + pub crypto_primitives_time: Histogram, + /// How long does proof generation (crypto primitive) take? + #[metrics(buckets = Buckets::LATENCIES)] + pub generate_proof_time: Histogram, + /// How long does verify proof (crypto primitive) take? + #[metrics(buckets = Buckets::LATENCIES)] + pub verify_proof_time: Histogram, + /// How long does it take for a job to be executed, from the moment it's loaded? + #[metrics(buckets = Buckets::LATENCIES)] + pub execution_time: Histogram, + /// How long does it take to upload proof to object store? + #[metrics(buckets = Buckets::LATENCIES)] + pub artifact_upload_time: Histogram, + /// How long does it take to save a job? + #[metrics(buckets = Buckets::LATENCIES)] + pub save_time: Histogram, + /// How long does it take for a job to be considered finished, from the moment it's been loaded? + #[metrics(buckets = Buckets::LATENCIES)] + pub job_finished_time: Histogram, + /// How long does it take a job to go from witness generation to having the proof saved? + #[metrics(buckets = Buckets::LATENCIES)] + pub full_proving_time: Histogram, +} + +#[vise::register] +pub static CIRCUIT_PROVER_METRICS: vise::Global = vise::Global::new(); diff --git a/prover/crates/bin/circuit_prover/src/types.rs b/prover/crates/bin/circuit_prover/src/types.rs new file mode 100644 index 000000000000..52cdd48b6b50 --- /dev/null +++ b/prover/crates/bin/circuit_prover/src/types.rs @@ -0,0 +1,31 @@ +use std::{collections::HashMap, sync::Arc}; + +use zksync_prover_fri_types::{ + circuit_definitions::boojum::{ + algebraic_props::{ + round_function::AbsorptionModeOverwrite, sponge::GoldilocksPoseidon2Sponge, + }, + cs::implementations::{ + proof::Proof as CryptoProof, setup::FinalizationHintsForProver, + transcript::GoldilocksPoisedon2Transcript, + verifier::VerificationKey as CryptoVerificationKey, + }, + field::goldilocks::{GoldilocksExt2, GoldilocksField}, + }, + ProverServiceDataKey, +}; +use zksync_prover_keystore::GoldilocksGpuProverSetupData; + +// prover types +pub type DefaultTranscript = GoldilocksPoisedon2Transcript; +pub type DefaultTreeHasher = GoldilocksPoseidon2Sponge; + +type F = GoldilocksField; +type H = GoldilocksPoseidon2Sponge; +type Ext = GoldilocksExt2; +pub type Proof = CryptoProof; +pub type VerificationKey = CryptoVerificationKey; + +// cache types +pub type SetupDataCache = HashMap>; +pub type FinalizationHintsCache = HashMap>; diff --git a/prover/crates/bin/circuit_prover/src/witness_vector_generator.rs b/prover/crates/bin/circuit_prover/src/witness_vector_generator.rs new file mode 100644 index 000000000000..cb2d2a256df9 --- /dev/null +++ b/prover/crates/bin/circuit_prover/src/witness_vector_generator.rs @@ -0,0 +1,345 @@ +use std::{collections::HashMap, sync::Arc, time::Instant}; + +use anyhow::Context; +use tokio::{sync::mpsc::Sender, task::JoinHandle}; +use tokio_util::sync::CancellationToken; +use zksync_object_store::ObjectStore; +use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; +use zksync_prover_fri_types::{ + circuit_definitions::{ + boojum::{ + cs::implementations::setup::FinalizationHintsForProver, + field::goldilocks::GoldilocksField, + gadgets::queue::full_state_queue::FullStateCircuitQueueRawWitness, + }, + circuit_definitions::base_layer::ZkSyncBaseLayerCircuit, + }, + get_current_pod_name, + keys::RamPermutationQueueWitnessKey, + CircuitAuxData, CircuitWrapper, ProverJob, ProverServiceDataKey, RamPermutationQueueWitness, + WitnessVectorArtifactsTemp, +}; +use zksync_types::{protocol_version::ProtocolSemanticVersion, L1BatchNumber}; +use zksync_utils::panic_extractor::try_extract_panic_message; + +use crate::{metrics::WITNESS_VECTOR_GENERATOR_METRICS, Backoff, FinalizationHintsCache}; + +/// In charge of generating Witness Vectors and sending them to Circuit Prover. +/// Both job runner & job executor. +#[derive(Debug)] +pub struct WitnessVectorGenerator { + object_store: Arc, + connection_pool: ConnectionPool, + protocol_version: ProtocolSemanticVersion, + /// Finalization Hints used for Witness Vector generation + finalization_hints_cache: FinalizationHintsCache, + /// Witness Vector sender for Circuit Prover + sender: Sender, + pod_name: String, +} + +impl WitnessVectorGenerator { + pub fn new( + object_store: Arc, + connection_pool: ConnectionPool, + protocol_version: ProtocolSemanticVersion, + sender: Sender, + finalization_hints: HashMap>, + ) -> Self { + Self { + object_store, + connection_pool, + protocol_version, + finalization_hints_cache: finalization_hints, + sender, + pod_name: get_current_pod_name(), + } + } + + /// Continuously polls database for new prover jobs and generates witness vectors for them. + /// All job executions are persisted. + pub async fn run( + self, + cancellation_token: CancellationToken, + mut backoff: Backoff, + ) -> anyhow::Result<()> { + let mut get_job_timer = Instant::now(); + while !cancellation_token.is_cancelled() { + if let Some(prover_job) = self + .get_job() + .await + .context("failed to get next witness generation job")? + { + tracing::info!( + "Witness Vector Generator received job {:?} after: {:?}", + prover_job.job_id, + get_job_timer.elapsed() + ); + WITNESS_VECTOR_GENERATOR_METRICS + .job_wait_time + .observe(get_job_timer.elapsed()); + if let e @ Err(_) = self.generate(prover_job, cancellation_token.clone()).await { + // this means that the witness vector receiver is closed, no need to report the error, just return + if cancellation_token.is_cancelled() { + return Ok(()); + } + e.context("failed to generate witness")? + } + + // waiting for a job timer starts as soon as the other is finished + get_job_timer = Instant::now(); + backoff.reset(); + continue; + }; + self.backoff(&mut backoff, cancellation_token.clone()).await; + } + tracing::info!("Witness Vector Generator shut down."); + Ok(()) + } + + /// Retrieves a prover job from database, loads artifacts from object store and hydrates them. + async fn get_job(&self) -> anyhow::Result> { + let mut connection = self + .connection_pool + .connection() + .await + .context("failed to get db connection")?; + let prover_job_metadata = match connection + .fri_prover_jobs_dal() + .get_job(self.protocol_version, &self.pod_name) + .await + { + None => return Ok(None), + Some(job) => job, + }; + + let time = Instant::now(); + let circuit_wrapper = self + .object_store + .get(prover_job_metadata.into()) + .await + .context("failed to get circuit_wrapper from object store")?; + let artifact = match circuit_wrapper { + a @ CircuitWrapper::Base(_) => a, + a @ CircuitWrapper::Recursive(_) => a, + CircuitWrapper::BasePartial((circuit, aux_data)) => self + .fill_witness(circuit, aux_data, prover_job_metadata.block_number) + .await + .context("failed to fill witness")?, + }; + WITNESS_VECTOR_GENERATOR_METRICS + .artifact_download_time + .observe(time.elapsed()); + + let setup_data_key = ProverServiceDataKey { + circuit_id: prover_job_metadata.circuit_id, + round: prover_job_metadata.aggregation_round, + } + .crypto_setup_key(); + let prover_job = ProverJob::new( + prover_job_metadata.block_number, + prover_job_metadata.id, + artifact, + setup_data_key, + ); + Ok(Some(prover_job)) + } + + /// Prover artifact hydration. + async fn fill_witness( + &self, + circuit: ZkSyncBaseLayerCircuit, + aux_data: CircuitAuxData, + l1_batch_number: L1BatchNumber, + ) -> anyhow::Result { + if let ZkSyncBaseLayerCircuit::RAMPermutation(circuit_instance) = circuit { + let sorted_witness_key = RamPermutationQueueWitnessKey { + block_number: l1_batch_number, + circuit_subsequence_number: aux_data.circuit_subsequence_number as usize, + is_sorted: true, + }; + let sorted_witness: RamPermutationQueueWitness = self + .object_store + .get(sorted_witness_key) + .await + .context("failed to load sorted witness key")?; + + let unsorted_witness_key = RamPermutationQueueWitnessKey { + block_number: l1_batch_number, + circuit_subsequence_number: aux_data.circuit_subsequence_number as usize, + is_sorted: false, + }; + let unsorted_witness: RamPermutationQueueWitness = self + .object_store + .get(unsorted_witness_key) + .await + .context("failed to load unsorted witness key")?; + + let mut witness = circuit_instance.witness.take().unwrap(); + witness.unsorted_queue_witness = FullStateCircuitQueueRawWitness { + elements: unsorted_witness.witness.into(), + }; + witness.sorted_queue_witness = FullStateCircuitQueueRawWitness { + elements: sorted_witness.witness.into(), + }; + circuit_instance.witness.store(Some(witness)); + + return Ok(CircuitWrapper::Base( + ZkSyncBaseLayerCircuit::RAMPermutation(circuit_instance), + )); + } + Err(anyhow::anyhow!( + "unexpected circuit received with partial witness, expected RAM permutation, got {:?}", + circuit.short_description() + )) + } + + /// Generates witness vector, with persistence of execution. + async fn generate( + &self, + prover_job: ProverJob, + cancellation_token: CancellationToken, + ) -> anyhow::Result<()> { + let start_time = Instant::now(); + let finalization_hints = self + .finalization_hints_cache + .get(&prover_job.setup_data_key) + .context(format!( + "failed to get finalization hints for key {:?}", + &prover_job.setup_data_key + ))? + .clone(); + let job_id = prover_job.job_id; + let task = tokio::task::spawn_blocking(move || { + let block_number = prover_job.block_number; + let _span = tracing::info_span!("witness_vector_generator", %block_number).entered(); + Self::generate_witness_vector(prover_job, finalization_hints) + }); + + self.finish_task(job_id, start_time, task, cancellation_token.clone()) + .await?; + + tracing::info!( + "Witness Vector Generator finished job {:?} in: {:?}", + job_id, + start_time.elapsed() + ); + WITNESS_VECTOR_GENERATOR_METRICS + .job_finished_time + .observe(start_time.elapsed()); + Ok(()) + } + + /// Generates witness vector using crypto primitives. + #[tracing::instrument( + skip_all, + fields(l1_batch = % prover_job.block_number) + )] + pub fn generate_witness_vector( + prover_job: ProverJob, + finalization_hints: Arc, + ) -> anyhow::Result { + let time = Instant::now(); + let cs = match prover_job.circuit_wrapper.clone() { + CircuitWrapper::Base(base_circuit) => { + base_circuit.synthesis::(&finalization_hints) + } + CircuitWrapper::Recursive(recursive_circuit) => { + recursive_circuit.synthesis::(&finalization_hints) + } + // circuit must be hydrated during `get_job` + CircuitWrapper::BasePartial(_) => { + return Err(anyhow::anyhow!("received unexpected dehydrated proof")); + } + }; + WITNESS_VECTOR_GENERATOR_METRICS + .crypto_primitive_time + .observe(time.elapsed()); + Ok(WitnessVectorArtifactsTemp::new( + cs.witness.unwrap(), + prover_job, + time, + )) + } + + /// Runs task to completion and persists result. + /// NOTE: Task may be cancelled mid-flight. + async fn finish_task( + &self, + job_id: u32, + time: Instant, + task: JoinHandle>, + cancellation_token: CancellationToken, + ) -> anyhow::Result<()> { + tokio::select! { + _ = cancellation_token.cancelled() => { + tracing::info!("Stop signal received, shutting down Witness Vector Generator..."); + return Ok(()) + } + result = task => { + let error_message = match result { + Ok(Ok(witness_vector)) => { + tracing::info!("Witness Vector Generator executed job {:?} in: {:?}", job_id, time.elapsed()); + WITNESS_VECTOR_GENERATOR_METRICS.execution_time.observe(time.elapsed()); + self + .save_result(witness_vector, job_id) + .await + .context("failed to save result")?; + return Ok(()) + } + Ok(Err(error)) => error.to_string(), + Err(error) => try_extract_panic_message(error), + }; + tracing::error!("Witness Vector Generator failed on job {job_id:?} with error {error_message:?}"); + + self.save_failure(job_id, error_message).await.context("failed to save failure")?; + } + } + + Ok(()) + } + + /// Sends proof to Circuit Prover. + async fn save_result( + &self, + artifacts: WitnessVectorArtifactsTemp, + job_id: u32, + ) -> anyhow::Result<()> { + let time = Instant::now(); + self.sender + .send(artifacts) + .await + .context("failed to send witness vector to prover")?; + tracing::info!( + "Witness Vector Generator sent job {:?} after {:?}", + job_id, + time.elapsed() + ); + WITNESS_VECTOR_GENERATOR_METRICS + .send_time + .observe(time.elapsed()); + Ok(()) + } + + /// Persists job execution error to database + async fn save_failure(&self, job_id: u32, error: String) -> anyhow::Result<()> { + self.connection_pool + .connection() + .await + .context("failed to get db connection")? + .fri_prover_jobs_dal() + .save_proof_error(job_id, error) + .await; + Ok(()) + } + + /// Backs off, whilst being cancellation aware. + async fn backoff(&self, backoff: &mut Backoff, cancellation_token: CancellationToken) { + let backoff_duration = backoff.delay(); + tracing::info!("Backing off for {:?}...", backoff_duration); + // Error here corresponds to a timeout w/o receiving task cancel; we're OK with this. + tokio::time::timeout(backoff_duration, cancellation_token.cancelled()) + .await + .ok(); + } +} diff --git a/prover/crates/bin/prover_fri/src/prover_job_processor.rs b/prover/crates/bin/prover_fri/src/prover_job_processor.rs index bbfb1d5a8322..5e8740d1b728 100644 --- a/prover/crates/bin/prover_fri/src/prover_job_processor.rs +++ b/prover/crates/bin/prover_fri/src/prover_job_processor.rs @@ -90,7 +90,7 @@ impl Prover { let started_at = Instant::now(); let artifact: GoldilocksProverSetupData = self .keystore - .load_cpu_setup_data_for_circuit_type(key.clone()) + .load_cpu_setup_data_for_circuit_type(key) .context("get_cpu_setup_data_for_circuit_type()")?; METRICS.gpu_setup_data_load_time[&key.circuit_id.to_string()] .observe(started_at.elapsed()); @@ -226,7 +226,7 @@ impl JobProcessor for Prover { _started_at: Instant, ) -> JoinHandle> { let config = Arc::clone(&self.config); - let setup_data = self.get_setup_data(job.setup_data_key.clone()); + let setup_data = self.get_setup_data(job.setup_data_key); tokio::task::spawn_blocking(move || { let block_number = job.block_number; let _span = tracing::info_span!("cpu_prove", %block_number).entered(); @@ -307,7 +307,7 @@ pub fn load_setup_data_cache( for prover_setup_metadata in prover_setup_metadata_list { let key = setup_metadata_to_setup_data_key(&prover_setup_metadata); let setup_data = keystore - .load_cpu_setup_data_for_circuit_type(key.clone()) + .load_cpu_setup_data_for_circuit_type(key) .context("get_cpu_setup_data_for_circuit_type()")?; cache.insert(key, Arc::new(setup_data)); } diff --git a/prover/crates/bin/witness_vector_generator/src/generator.rs b/prover/crates/bin/witness_vector_generator/src/generator.rs index 6695905c07e3..646dd8ffda78 100644 --- a/prover/crates/bin/witness_vector_generator/src/generator.rs +++ b/prover/crates/bin/witness_vector_generator/src/generator.rs @@ -70,7 +70,7 @@ impl WitnessVectorGenerator { keystore: &Keystore, ) -> anyhow::Result { let finalization_hints = keystore - .load_finalization_hints(job.setup_data_key.clone()) + .load_finalization_hints(job.setup_data_key) .context("get_finalization_hints()")?; let cs = match job.circuit_wrapper.clone() { CircuitWrapper::Base(base_circuit) => { diff --git a/prover/crates/lib/keystore/Cargo.toml b/prover/crates/lib/keystore/Cargo.toml index 617030754f8b..4d9addc26bc0 100644 --- a/prover/crates/lib/keystore/Cargo.toml +++ b/prover/crates/lib/keystore/Cargo.toml @@ -27,6 +27,8 @@ once_cell.workspace = true md5.workspace = true sha3.workspace = true hex.workspace = true +tokio.workspace = true +futures = { workspace = true, features = ["compat"] } [features] default = [] diff --git a/prover/crates/lib/keystore/src/keystore.rs b/prover/crates/lib/keystore/src/keystore.rs index 28ce989287cc..6225943e3cd7 100644 --- a/prover/crates/lib/keystore/src/keystore.rs +++ b/prover/crates/lib/keystore/src/keystore.rs @@ -1,7 +1,9 @@ use std::{ + collections::HashMap, fs::{self, File}, io::Read, path::{Path, PathBuf}, + sync::Arc, }; use anyhow::Context as _; @@ -14,7 +16,7 @@ use circuit_definitions::{ }, zkevm_circuits::scheduler::aux::BaseLayerCircuitType, }; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use zkevm_test_harness::data_source::{in_memory_data_source::InMemoryDataSource, SetupDataSource}; use zksync_basic_types::basic_fri_types::AggregationRound; use zksync_prover_fri_types::ProverServiceDataKey; @@ -24,6 +26,7 @@ use zksync_utils::env::Workspace; use crate::GoldilocksGpuProverSetupData; use crate::{GoldilocksProverSetupData, VkCommitments}; +#[derive(Debug, Clone, Copy)] pub enum ProverServiceDataType { VerificationKey, SetupData, @@ -209,7 +212,7 @@ impl Keystore { key: ProverServiceDataKey, hint: &FinalizationHintsForProver, ) -> anyhow::Result<()> { - let filepath = self.get_file_path(key.clone(), ProverServiceDataType::FinalizationHints); + let filepath = self.get_file_path(key, ProverServiceDataType::FinalizationHints); tracing::info!("saving finalization hints for {:?} to: {:?}", key, filepath); let serialized = @@ -267,7 +270,7 @@ impl Keystore { &self, key: ProverServiceDataKey, ) -> anyhow::Result { - let filepath = self.get_file_path(key.clone(), ProverServiceDataType::SetupData); + let filepath = self.get_file_path(key, ProverServiceDataType::SetupData); let mut file = File::open(filepath.clone()) .with_context(|| format!("Failed reading setup-data from path: {filepath:?}"))?; @@ -286,7 +289,7 @@ impl Keystore { &self, key: ProverServiceDataKey, ) -> anyhow::Result { - let filepath = self.get_file_path(key.clone(), ProverServiceDataType::SetupData); + let filepath = self.get_file_path(key, ProverServiceDataType::SetupData); let mut file = File::open(filepath.clone()) .with_context(|| format!("Failed reading setup-data from path: {filepath:?}"))?; @@ -301,7 +304,7 @@ impl Keystore { } pub fn is_setup_data_present(&self, key: &ProverServiceDataKey) -> bool { - Path::new(&self.get_file_path(key.clone(), ProverServiceDataType::SetupData)).exists() + Path::new(&self.get_file_path(*key, ProverServiceDataType::SetupData)).exists() } pub fn save_setup_data_for_circuit_type( @@ -309,7 +312,7 @@ impl Keystore { key: ProverServiceDataKey, serialized_setup_data: &Vec, ) -> anyhow::Result<()> { - let filepath = self.get_file_path(key.clone(), ProverServiceDataType::SetupData); + let filepath = self.get_file_path(key, ProverServiceDataType::SetupData); tracing::info!("saving {:?} setup data to: {:?}", key, filepath); std::fs::write(filepath.clone(), serialized_setup_data) .with_context(|| format!("Failed saving setup-data at path: {filepath:?}")) @@ -465,4 +468,49 @@ impl Keystore { pub fn save_commitments(&self, commitments: &VkCommitments) -> anyhow::Result<()> { Self::save_json_pretty(self.get_base_path().join("commitments.json"), &commitments) } + + /// Async loads mapping of all circuits to setup key, if successful + pub async fn load_all_setup_key_mapping( + &self, + ) -> anyhow::Result>> { + self.load_key_mapping(ProverServiceDataType::SetupData) + .await + } + + /// Async loads mapping of all circuits to finalization hints, if successful + pub async fn load_all_finalization_hints_mapping( + &self, + ) -> anyhow::Result>> { + self.load_key_mapping(ProverServiceDataType::FinalizationHints) + .await + } + + /// Async function that loads mapping from disk. + /// Whilst IO is not parallelizable, ser/de is. + async fn load_key_mapping( + &self, + data_type: ProverServiceDataType, + ) -> anyhow::Result>> { + let mut mapping: HashMap> = HashMap::new(); + + // Load each file in parallel. Note that FS access is not necessarily parallel, but + // deserialization is. For larger files, it makes a big difference. + // Note: `collect` is important, because iterators are lazy, and otherwise we won't actually + // spawn threads. + let handles: Vec<_> = ProverServiceDataKey::all() + .into_iter() + .map(|key| { + let filepath = self.get_file_path(key, data_type); + tokio::task::spawn_blocking(move || { + let data = Self::load_bincode_from_file(filepath)?; + anyhow::Ok((key, Arc::new(data))) + }) + }) + .collect(); + for handle in futures::future::join_all(handles).await { + let (key, setup_data) = handle.context("future loading key panicked")??; + mapping.insert(key, setup_data); + } + Ok(mapping) + } } diff --git a/prover/crates/lib/keystore/src/setup_data_generator.rs b/prover/crates/lib/keystore/src/setup_data_generator.rs index e69184ee9364..c4790d67feaa 100644 --- a/prover/crates/lib/keystore/src/setup_data_generator.rs +++ b/prover/crates/lib/keystore/src/setup_data_generator.rs @@ -33,7 +33,7 @@ pub fn generate_setup_data_common( let (finalization, vk) = if circuit.is_base_layer() { ( - Some(keystore.load_finalization_hints(circuit.clone())?), + Some(keystore.load_finalization_hints(circuit)?), data_source .get_base_layer_vk(circuit.circuit_id) .unwrap() @@ -41,7 +41,7 @@ pub fn generate_setup_data_common( ) } else { ( - Some(keystore.load_finalization_hints(circuit.clone())?), + Some(keystore.load_finalization_hints(circuit)?), data_source .get_recursion_layer_vk(circuit.circuit_id) .unwrap() @@ -86,7 +86,7 @@ pub trait SetupDataGenerator { ); return Ok("Skipped".to_string()); } - let serialized = self.generate_setup_data(circuit.clone())?; + let serialized = self.generate_setup_data(circuit)?; let digest = md5::compute(&serialized); if !dry_run { @@ -109,7 +109,7 @@ pub trait SetupDataGenerator { .iter() .map(|circuit| { let digest = self - .generate_and_write_setup_data(circuit.clone(), dry_run, recompute_if_missing) + .generate_and_write_setup_data(*circuit, dry_run, recompute_if_missing) .context(circuit.name()) .unwrap(); (circuit.name(), digest) diff --git a/prover/crates/lib/prover_dal/.sqlx/query-7d20c0bf35625185c1f6c675aa8fcddbb47c5e9965443f118f8edd7d562734a2.json b/prover/crates/lib/prover_dal/.sqlx/query-7d20c0bf35625185c1f6c675aa8fcddbb47c5e9965443f118f8edd7d562734a2.json new file mode 100644 index 000000000000..140b8f126750 --- /dev/null +++ b/prover/crates/lib/prover_dal/.sqlx/query-7d20c0bf35625185c1f6c675aa8fcddbb47c5e9965443f118f8edd7d562734a2.json @@ -0,0 +1,60 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE prover_jobs_fri\n SET\n status = 'in_progress',\n attempts = attempts + 1,\n updated_at = NOW(),\n processing_started_at = NOW(),\n picked_by = $3\n WHERE\n id = (\n SELECT\n id\n FROM\n prover_jobs_fri\n WHERE\n status = 'queued'\n AND protocol_version = $1\n AND protocol_version_patch = $2\n ORDER BY\n l1_batch_number ASC,\n aggregation_round ASC,\n circuit_id ASC,\n id ASC\n LIMIT\n 1\n FOR UPDATE\n SKIP LOCKED\n )\n RETURNING\n prover_jobs_fri.id,\n prover_jobs_fri.l1_batch_number,\n prover_jobs_fri.circuit_id,\n prover_jobs_fri.aggregation_round,\n prover_jobs_fri.sequence_number,\n prover_jobs_fri.depth,\n prover_jobs_fri.is_node_final_proof\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "l1_batch_number", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "circuit_id", + "type_info": "Int2" + }, + { + "ordinal": 3, + "name": "aggregation_round", + "type_info": "Int2" + }, + { + "ordinal": 4, + "name": "sequence_number", + "type_info": "Int4" + }, + { + "ordinal": 5, + "name": "depth", + "type_info": "Int4" + }, + { + "ordinal": 6, + "name": "is_node_final_proof", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int4", + "Int4", + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "7d20c0bf35625185c1f6c675aa8fcddbb47c5e9965443f118f8edd7d562734a2" +} diff --git a/prover/crates/lib/prover_dal/src/fri_prover_dal.rs b/prover/crates/lib/prover_dal/src/fri_prover_dal.rs index 4e68154290da..e8bf5523693f 100644 --- a/prover/crates/lib/prover_dal/src/fri_prover_dal.rs +++ b/prover/crates/lib/prover_dal/src/fri_prover_dal.rs @@ -49,6 +49,78 @@ impl FriProverDal<'_, '_> { drop(latency); } + /// Retrieves the next prover job to be proven. Called by WVGs. + /// + /// Prover jobs must be thought of as ordered. + /// Prover must prioritize proving such jobs that will make the chain move forward the fastest. + /// Current ordering: + /// - pick the lowest batch + /// - within the lowest batch, look at the lowest aggregation level (move up the proof tree) + /// - pick the same type of circuit for as long as possible, this maximizes GPU cache reuse + /// + /// NOTE: Most of this function is a duplicate of `get_next_job()`. Get next job will be deleted together with old prover. + pub async fn get_job( + &mut self, + protocol_version: ProtocolSemanticVersion, + picked_by: &str, + ) -> Option { + sqlx::query!( + r#" + UPDATE prover_jobs_fri + SET + status = 'in_progress', + attempts = attempts + 1, + updated_at = NOW(), + processing_started_at = NOW(), + picked_by = $3 + WHERE + id = ( + SELECT + id + FROM + prover_jobs_fri + WHERE + status = 'queued' + AND protocol_version = $1 + AND protocol_version_patch = $2 + ORDER BY + l1_batch_number ASC, + aggregation_round ASC, + circuit_id ASC, + id ASC + LIMIT + 1 + FOR UPDATE + SKIP LOCKED + ) + RETURNING + prover_jobs_fri.id, + prover_jobs_fri.l1_batch_number, + prover_jobs_fri.circuit_id, + prover_jobs_fri.aggregation_round, + prover_jobs_fri.sequence_number, + prover_jobs_fri.depth, + prover_jobs_fri.is_node_final_proof + "#, + protocol_version.minor as i32, + protocol_version.patch.0 as i32, + picked_by, + ) + .fetch_optional(self.storage.conn()) + .await + .expect("failed to get prover job") + .map(|row| FriProverJobMetadata { + id: row.id as u32, + block_number: L1BatchNumber(row.l1_batch_number as u32), + circuit_id: row.circuit_id as u8, + aggregation_round: AggregationRound::try_from(i32::from(row.aggregation_round)) + .unwrap(), + sequence_number: row.sequence_number as usize, + depth: row.depth as u16, + is_node_final_proof: row.is_node_final_proof, + }) + } + pub async fn get_next_job( &mut self, protocol_version: ProtocolSemanticVersion, diff --git a/prover/crates/lib/prover_fri_types/src/keys.rs b/prover/crates/lib/prover_fri_types/src/keys.rs index 2948fc5f84ed..26aa679b4a94 100644 --- a/prover/crates/lib/prover_fri_types/src/keys.rs +++ b/prover/crates/lib/prover_fri_types/src/keys.rs @@ -1,6 +1,8 @@ //! Different key types for object store. -use zksync_types::{basic_fri_types::AggregationRound, L1BatchNumber}; +use zksync_types::{ + basic_fri_types::AggregationRound, prover_dal::FriProverJobMetadata, L1BatchNumber, +}; /// Storage key for a [AggregationWrapper`]. #[derive(Debug, Clone, Copy)] @@ -27,6 +29,18 @@ pub struct FriCircuitKey { pub depth: u16, } +impl From for FriCircuitKey { + fn from(prover_job_metadata: FriProverJobMetadata) -> Self { + FriCircuitKey { + block_number: prover_job_metadata.block_number, + sequence_number: prover_job_metadata.sequence_number, + circuit_id: prover_job_metadata.circuit_id, + aggregation_round: prover_job_metadata.aggregation_round, + depth: prover_job_metadata.depth, + } + } +} + /// Storage key for a [`ZkSyncCircuit`]. #[derive(Debug, Clone, Copy)] pub struct CircuitKey<'a> { diff --git a/prover/crates/lib/prover_fri_types/src/lib.rs b/prover/crates/lib/prover_fri_types/src/lib.rs index c14bc1905639..4a8a1b3e4064 100644 --- a/prover/crates/lib/prover_fri_types/src/lib.rs +++ b/prover/crates/lib/prover_fri_types/src/lib.rs @@ -1,4 +1,4 @@ -use std::env; +use std::{env, time::Instant}; pub use circuit_definitions; use circuit_definitions::{ @@ -66,7 +66,7 @@ impl StoredObject for CircuitWrapper { serialize_using_bincode!(); } -#[derive(Clone, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub enum FriProofWrapper { Base(ZkSyncBaseLayerProof), Recursive(ZkSyncRecursionLayerProof), @@ -98,6 +98,45 @@ impl WitnessVectorArtifacts { } } +/// This structure exists for the transition period between old prover and new prover. +/// We want the 2 codebases to coexist, without impacting each other. +/// Once old prover is deleted, this struct will be renamed to `WitnessVectorArtifacts`. +pub struct WitnessVectorArtifactsTemp { + pub witness_vector: WitnessVec, + pub prover_job: ProverJob, + pub time: Instant, +} + +impl WitnessVectorArtifactsTemp { + pub fn new( + witness_vector: WitnessVec, + prover_job: ProverJob, + time: Instant, + ) -> Self { + Self { + witness_vector, + prover_job, + time, + } + } +} + +/// Data structure containing the proof generated by the circuit prover. +#[derive(Debug)] +pub struct ProverArtifacts { + pub block_number: L1BatchNumber, + pub proof_wrapper: FriProofWrapper, +} + +impl ProverArtifacts { + pub fn new(block_number: L1BatchNumber, proof_wrapper: FriProofWrapper) -> Self { + Self { + block_number, + proof_wrapper, + } + } +} + #[derive(Clone, serde::Serialize, serde::Deserialize)] pub struct ProverJob { pub block_number: L1BatchNumber, @@ -122,12 +161,30 @@ impl ProverJob { } } -#[derive(Debug, Clone, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)] pub struct ProverServiceDataKey { pub circuit_id: u8, pub round: AggregationRound, } +impl ProverServiceDataKey { + /// Returns the crypto version of the setup key. + /// + /// Setup key is overloaded in our system. On one hand, it is used as identifier for figuring out which type of proofs are ready. + /// On the other hand, it is also a setup key from prover perspective. + /// The 2 overlap on all aggregation rounds, but NodeAggregation. + /// There's only 1 node key and that belongs to circuit 2. + pub fn crypto_setup_key(self) -> Self { + if let AggregationRound::NodeAggregation = self.round { + return Self { + circuit_id: 2, + round: self.round, + }; + } + self + } +} + fn get_round_for_recursive_circuit_type(circuit_type: u8) -> AggregationRound { match circuit_type { circuit_type if circuit_type == ZkSyncRecursionLayerStorageType::SchedulerCircuit as u8 => { @@ -186,6 +243,12 @@ impl ProverServiceDataKey { } } + pub fn all() -> Vec { + let mut keys = Self::all_boojum(); + keys.push(Self::snark()); + keys + } + pub fn is_base_layer(&self) -> bool { self.round == AggregationRound::BasicCircuits }