diff --git a/core/lib/config/src/configs/fri_prover.rs b/core/lib/config/src/configs/fri_prover.rs index 99e3d354536e..5cd25450531a 100644 --- a/core/lib/config/src/configs/fri_prover.rs +++ b/core/lib/config/src/configs/fri_prover.rs @@ -10,6 +10,18 @@ pub enum SetupLoadMode { FromMemory, } +/// Kind of cloud environment prover subsystem runs in. +/// +/// Currently will only affect how the prover zone is chosen. +#[derive(Debug, Default, Deserialize, Clone, Copy, PartialEq, Eq)] +pub enum CloudType { + /// Assumes that the prover runs in GCP. + #[default] + GCP, + /// Assumes that the prover runs locally. + Local, +} + /// Configuration for the fri prover application #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct FriProverConfig { @@ -28,6 +40,8 @@ pub struct FriProverConfig { pub shall_save_to_public_bucket: bool, pub prover_object_store: Option, pub public_object_store: Option, + #[serde(default)] + pub cloud_type: CloudType, } impl FriProverConfig { diff --git a/core/lib/config/src/testonly.rs b/core/lib/config/src/testonly.rs index a5e51131c3a8..e105c3282639 100644 --- a/core/lib/config/src/testonly.rs +++ b/core/lib/config/src/testonly.rs @@ -438,6 +438,16 @@ impl Distribution for EncodeDist { } } +impl Distribution for EncodeDist { + fn sample(&self, rng: &mut R) -> configs::fri_prover::CloudType { + type T = configs::fri_prover::CloudType; + match rng.gen_range(0..1) { + 0 => T::GCP, + _ => T::Local, + } + } +} + impl Distribution for EncodeDist { fn sample(&self, rng: &mut R) -> configs::FriProverConfig { configs::FriProverConfig { @@ -454,6 +464,7 @@ impl Distribution for EncodeDist { availability_check_interval_in_secs: self.sample(rng), prover_object_store: self.sample(rng), public_object_store: self.sample(rng), + cloud_type: self.sample(rng), } } } diff --git a/core/lib/env_config/src/fri_prover.rs b/core/lib/env_config/src/fri_prover.rs index 96069d6514ea..bdcf5291ee05 100644 --- a/core/lib/env_config/src/fri_prover.rs +++ b/core/lib/env_config/src/fri_prover.rs @@ -18,7 +18,10 @@ impl FromEnv for FriProverConfig { #[cfg(test)] mod tests { use zksync_config::{ - configs::{fri_prover::SetupLoadMode, object_store::ObjectStoreMode}, + configs::{ + fri_prover::{CloudType, SetupLoadMode}, + object_store::ObjectStoreMode, + }, ObjectStoreConfig, }; @@ -57,6 +60,7 @@ mod tests { local_mirror_path: None, }), availability_check_interval_in_secs: Some(1_800), + cloud_type: CloudType::GCP, } } diff --git a/core/lib/protobuf_config/src/proto/config/prover.proto b/core/lib/protobuf_config/src/proto/config/prover.proto index c50ebdde4eef..80d45f40bbcb 100644 --- a/core/lib/protobuf_config/src/proto/config/prover.proto +++ b/core/lib/protobuf_config/src/proto/config/prover.proto @@ -21,6 +21,11 @@ enum SetupLoadMode { FROM_MEMORY = 1; } +enum CloudType { + GCP = 0; + LOCAL = 1; +} + message Prover { optional string setup_data_path = 1; // required; fs path? optional uint32 prometheus_port = 2; // required; u16 @@ -35,6 +40,7 @@ message Prover { optional bool shall_save_to_public_bucket = 13; // required optional config.object_store.ObjectStore public_object_store = 22; optional config.object_store.ObjectStore prover_object_store = 23; + optional CloudType cloud_type = 24; // optional reserved 5, 6, 9; reserved "base_layer_circuit_ids_to_be_verified", "recursive_layer_circuit_ids_to_be_verified", "witness_vector_generator_thread_count"; } diff --git a/core/lib/protobuf_config/src/prover.rs b/core/lib/protobuf_config/src/prover.rs index 50782ab8e968..e1c31ee1fccd 100644 --- a/core/lib/protobuf_config/src/prover.rs +++ b/core/lib/protobuf_config/src/prover.rs @@ -292,6 +292,24 @@ impl proto::SetupLoadMode { } } +impl proto::CloudType { + fn new(x: &configs::fri_prover::CloudType) -> Self { + use configs::fri_prover::CloudType as From; + match x { + From::GCP => Self::Gcp, + From::Local => Self::Local, + } + } + + fn parse(&self) -> configs::fri_prover::CloudType { + use configs::fri_prover::CloudType as To; + match self { + Self::Gcp => To::GCP, + Self::Local => To::Local, + } + } +} + impl ProtoRepr for proto::Prover { type Type = configs::FriProverConfig; fn read(&self) -> anyhow::Result { @@ -338,6 +356,13 @@ impl ProtoRepr for proto::Prover { .context("shall_save_to_public_bucket")?, public_object_store, prover_object_store, + cloud_type: self + .cloud_type + .map(proto::CloudType::try_from) + .transpose() + .context("cloud_type")? + .map(|x| x.parse()) + .unwrap_or_default(), }) } @@ -356,6 +381,7 @@ impl ProtoRepr for proto::Prover { shall_save_to_public_bucket: Some(this.shall_save_to_public_bucket), prover_object_store: this.prover_object_store.as_ref().map(ProtoRepr::build), public_object_store: this.public_object_store.as_ref().map(ProtoRepr::build), + cloud_type: Some(proto::CloudType::new(&this.cloud_type).into()), } } } diff --git a/prover/proof_fri_compressor/Cargo.toml b/prover/proof_fri_compressor/Cargo.toml index 14fc44d5a3b2..0c01a40874f2 100644 --- a/prover/proof_fri_compressor/Cargo.toml +++ b/prover/proof_fri_compressor/Cargo.toml @@ -41,5 +41,6 @@ serde = { workspace = true, features = ["derive"] } wrapper_prover = { workspace = true, optional = true } [features] +default = [] gpu = ["wrapper_prover"] diff --git a/prover/prover_fri/src/gpu_prover_availability_checker.rs b/prover/prover_fri/src/gpu_prover_availability_checker.rs index 4b51b26e5d38..6e154ba553a9 100644 --- a/prover/prover_fri/src/gpu_prover_availability_checker.rs +++ b/prover/prover_fri/src/gpu_prover_availability_checker.rs @@ -4,6 +4,7 @@ pub mod availability_checker { use tokio::sync::Notify; use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; + use zksync_prover_fri_utils::region_fetcher::Zone; use zksync_types::prover_dal::{GpuProverInstanceStatus, SocketAddress}; use crate::metrics::{KillingReason, METRICS}; @@ -12,7 +13,7 @@ pub mod availability_checker { /// If the prover instance is not found in the database or marked as dead, the availability checker will shut down the prover. pub struct AvailabilityChecker { address: SocketAddress, - zone: String, + zone: Zone, polling_interval: Duration, pool: ConnectionPool, } @@ -20,7 +21,7 @@ pub mod availability_checker { impl AvailabilityChecker { pub fn new( address: SocketAddress, - zone: String, + zone: Zone, polling_interval_secs: u32, pool: ConnectionPool, ) -> Self { @@ -46,7 +47,7 @@ pub mod availability_checker { .await .unwrap() .fri_gpu_prover_queue_dal() - .get_prover_instance_status(self.address.clone(), self.zone.clone()) + .get_prover_instance_status(self.address.clone(), self.zone.to_string()) .await; // If the prover instance is not found in the database or marked as dead, we should shut down the prover diff --git a/prover/prover_fri/src/gpu_prover_job_processor.rs b/prover/prover_fri/src/gpu_prover_job_processor.rs index cbd363e9b4f4..6148ca3e0aed 100644 --- a/prover/prover_fri/src/gpu_prover_job_processor.rs +++ b/prover/prover_fri/src/gpu_prover_job_processor.rs @@ -28,6 +28,7 @@ pub mod gpu_prover { }, CircuitWrapper, FriProofWrapper, ProverServiceDataKey, WitnessVectorArtifacts, }; + use zksync_prover_fri_utils::region_fetcher::Zone; use zksync_queued_job_processor::{async_trait, JobProcessor}; use zksync_types::{ basic_fri_types::CircuitIdRoundTuple, protocol_version::ProtocolSemanticVersion, @@ -64,7 +65,7 @@ pub mod gpu_prover { witness_vector_queue: SharedWitnessVectorQueue, prover_context: ProverContext, address: SocketAddress, - zone: String, + zone: Zone, protocol_version: ProtocolSemanticVersion, } @@ -79,7 +80,7 @@ pub mod gpu_prover { circuit_ids_for_round_to_be_proven: Vec, witness_vector_queue: SharedWitnessVectorQueue, address: SocketAddress, - zone: String, + zone: Zone, protocol_version: ProtocolSemanticVersion, ) -> Self { Prover { @@ -230,7 +231,7 @@ pub mod gpu_prover { .fri_gpu_prover_queue_dal() .update_prover_instance_from_full_to_available( self.address.clone(), - self.zone.clone(), + self.zone.to_string(), ) .await; } diff --git a/prover/prover_fri/src/main.rs b/prover/prover_fri/src/main.rs index dfab8648d74c..e4b2fd5a6709 100644 --- a/prover/prover_fri/src/main.rs +++ b/prover/prover_fri/src/main.rs @@ -16,7 +16,10 @@ use zksync_env_config::FromEnv; use zksync_object_store::{ObjectStore, ObjectStoreFactory}; use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION; -use zksync_prover_fri_utils::{get_all_circuit_id_round_tuples_for, region_fetcher::get_zone}; +use zksync_prover_fri_utils::{ + get_all_circuit_id_round_tuples_for, + region_fetcher::{RegionFetcher, Zone}, +}; use zksync_queued_job_processor::JobProcessor; use zksync_types::{ basic_fri_types::CircuitIdRoundTuple, @@ -32,24 +35,20 @@ mod prover_job_processor; mod socket_listener; mod utils; -async fn graceful_shutdown(port: u16) -> anyhow::Result> { +async fn graceful_shutdown(zone: Zone, port: u16) -> anyhow::Result> { let database_secrets = DatabaseSecrets::from_env().context("DatabaseSecrets::from_env()")?; let pool = ConnectionPool::::singleton(database_secrets.prover_url()?) .build() .await .context("failed to build a connection pool")?; let host = local_ip().context("Failed obtaining local IP address")?; - let zone_url = &FriProverConfig::from_env() - .context("FriProverConfig::from_env()")? - .zone_read_url; - let zone = get_zone(zone_url).await.context("get_zone()")?; let address = SocketAddress { host, port }; Ok(async move { pool.connection() .await .unwrap() .fri_gpu_prover_queue_dal() - .update_prover_instance_status(address, GpuProverInstanceStatus::Dead, zone) + .update_prover_instance_status(address, GpuProverInstanceStatus::Dead, zone.to_string()) .await }) } @@ -107,6 +106,13 @@ async fn main() -> anyhow::Result<()> { }) .context("Error setting Ctrl+C handler")?; + let zone = RegionFetcher::new( + prover_config.cloud_type, + prover_config.zone_read_url.clone(), + ) + .get_zone() + .await?; + let (stop_sender, stop_receiver) = tokio::sync::watch::channel(false); let prover_object_store_config = prover_config .prover_object_store @@ -156,6 +162,7 @@ async fn main() -> anyhow::Result<()> { let prover_tasks = get_prover_tasks( prover_config, + zone.clone(), stop_receiver.clone(), object_store_factory, public_blob_store, @@ -174,7 +181,7 @@ async fn main() -> anyhow::Result<()> { tokio::select! { _ = tasks.wait_single() => { if cfg!(feature = "gpu") { - graceful_shutdown(port) + graceful_shutdown(zone, port) .await .context("failed to prepare graceful shutdown future")? .await; @@ -194,6 +201,7 @@ async fn main() -> anyhow::Result<()> { #[cfg(not(feature = "gpu"))] async fn get_prover_tasks( prover_config: FriProverConfig, + _zone: Zone, stop_receiver: Receiver, store_factory: ObjectStoreFactory, public_blob_store: Option>, @@ -228,6 +236,7 @@ async fn get_prover_tasks( #[cfg(feature = "gpu")] async fn get_prover_tasks( prover_config: FriProverConfig, + zone: Zone, stop_receiver: Receiver, store_factory: ObjectStoreFactory, public_blob_store: Option>, @@ -246,9 +255,6 @@ async fn get_prover_tasks( let shared_witness_vector_queue = Arc::new(Mutex::new(witness_vector_queue)); let consumer = shared_witness_vector_queue.clone(); - let zone = get_zone(&prover_config.zone_read_url) - .await - .context("get_zone()")?; let local_ip = local_ip().context("Failed obtaining local IP address")?; let address = SocketAddress { host: local_ip, diff --git a/prover/prover_fri/src/socket_listener.rs b/prover/prover_fri/src/socket_listener.rs index 5e857e651bcf..e65471409e1e 100644 --- a/prover/prover_fri/src/socket_listener.rs +++ b/prover/prover_fri/src/socket_listener.rs @@ -11,6 +11,7 @@ pub mod gpu_socket_listener { use zksync_object_store::bincode; use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; use zksync_prover_fri_types::WitnessVectorArtifacts; + use zksync_prover_fri_utils::region_fetcher::Zone; use zksync_types::{ protocol_version::ProtocolSemanticVersion, prover_dal::{GpuProverInstanceStatus, SocketAddress}, @@ -26,7 +27,7 @@ pub mod gpu_socket_listener { queue: SharedWitnessVectorQueue, pool: ConnectionPool, specialized_prover_group_id: u8, - zone: String, + zone: Zone, protocol_version: ProtocolSemanticVersion, } @@ -36,7 +37,7 @@ pub mod gpu_socket_listener { queue: SharedWitnessVectorQueue, pool: ConnectionPool, specialized_prover_group_id: u8, - zone: String, + zone: Zone, protocol_version: ProtocolSemanticVersion, ) -> Self { Self { @@ -68,7 +69,7 @@ pub mod gpu_socket_listener { .insert_prover_instance( self.address.clone(), self.specialized_prover_group_id, - self.zone.clone(), + self.zone.to_string(), self.protocol_version, ) .await; @@ -154,7 +155,7 @@ pub mod gpu_socket_listener { .await .unwrap() .fri_gpu_prover_queue_dal() - .update_prover_instance_status(self.address.clone(), status, self.zone.clone()) + .update_prover_instance_status(self.address.clone(), status, self.zone.to_string()) .await; tracing::info!( "Marked prover as {:?} after {:?}", diff --git a/prover/prover_fri_utils/src/region_fetcher.rs b/prover/prover_fri_utils/src/region_fetcher.rs index cae211c26cbe..c73e83d531b4 100644 --- a/prover/prover_fri_utils/src/region_fetcher.rs +++ b/prover/prover_fri_utils/src/region_fetcher.rs @@ -1,51 +1,98 @@ +use core::fmt; + use anyhow::Context; use regex::Regex; use reqwest::{ header::{HeaderMap, HeaderValue}, Method, }; +use zksync_config::configs::fri_prover::CloudType; use zksync_utils::http_with_retries::send_request_with_retries; -pub async fn get_zone(zone_url: &str) -> anyhow::Result { - let data = fetch_from_url(zone_url).await.context("fetch_from_url()")?; - parse_zone(&data).context("parse_zone") +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RegionFetcher { + cloud_type: CloudType, + zone_url: String, +} + +impl RegionFetcher { + pub fn new(cloud_type: CloudType, zone_url: String) -> Self { + Self { + cloud_type, + zone_url, + } + } + + pub async fn get_zone(&self) -> anyhow::Result { + match self.cloud_type { + CloudType::GCP => GcpZoneFetcher::get_zone(&self.zone_url).await, + CloudType::Local => Ok(Zone("local".to_string())), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Zone(String); + +impl fmt::Display for Zone { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } } -async fn fetch_from_url(url: &str) -> anyhow::Result { - let mut headers = HeaderMap::new(); - headers.insert("Metadata-Flavor", HeaderValue::from_static("Google")); - let response = send_request_with_retries(url, 5, Method::GET, Some(headers), None).await; - response - .map_err(|err| anyhow::anyhow!("Failed fetching response from url: {url}: {err:?}"))? - .text() - .await - .context("Failed to read response as text") +impl Zone { + pub fn new(zone: T) -> Self { + Self(zone.to_string()) + } } -fn parse_zone(data: &str) -> anyhow::Result { - // Statically provided Regex should always compile. - let re = Regex::new(r"^projects/\d+/zones/(\w+-\w+-\w+)$").unwrap(); - if let Some(caps) = re.captures(data) { - let zone = &caps[1]; - return Ok(zone.to_string()); +#[derive(Debug, Clone, Copy)] +struct GcpZoneFetcher; + +impl GcpZoneFetcher { + pub async fn get_zone(zone_url: &str) -> anyhow::Result { + let data = Self::fetch_from_url(zone_url) + .await + .context("fetch_from_url()")?; + Self::parse_zone(&data).context("parse_zone") + } + + async fn fetch_from_url(url: &str) -> anyhow::Result { + let mut headers = HeaderMap::new(); + headers.insert("Metadata-Flavor", HeaderValue::from_static("Google")); + let response = send_request_with_retries(url, 5, Method::GET, Some(headers), None).await; + response + .map_err(|err| anyhow::anyhow!("Failed fetching response from url: {url}: {err:?}"))? + .text() + .await + .context("Failed to read response as text") + } + + fn parse_zone(data: &str) -> anyhow::Result { + // Statically provided Regex should always compile. + let re = Regex::new(r"^projects/\d+/zones/(\w+-\w+-\w+)$").unwrap(); + if let Some(caps) = re.captures(data) { + let zone = &caps[1]; + return Ok(Zone(zone.to_string())); + } + anyhow::bail!("failed to extract zone from: {data}"); } - anyhow::bail!("failed to extract zone from: {data}"); } #[cfg(test)] mod tests { - use crate::region_fetcher::parse_zone; + use super::*; #[test] fn test_parse_zone() { let data = "projects/295056426491/zones/us-central1-a"; - let zone = parse_zone(data).unwrap(); - assert_eq!(zone, "us-central1-a"); + let zone = GcpZoneFetcher::parse_zone(data).unwrap(); + assert_eq!(zone, Zone::new("us-central1-a")); } #[test] fn test_parse_zone_panic() { let data = "invalid data"; - assert!(parse_zone(data).is_err()); + assert!(GcpZoneFetcher::parse_zone(data).is_err()); } } diff --git a/prover/witness_vector_generator/src/generator.rs b/prover/witness_vector_generator/src/generator.rs index d2b13beccd61..5574f0f1578d 100644 --- a/prover/witness_vector_generator/src/generator.rs +++ b/prover/witness_vector_generator/src/generator.rs @@ -15,7 +15,7 @@ use zksync_prover_fri_types::{ WitnessVectorArtifacts, }; use zksync_prover_fri_utils::{ - fetch_next_circuit, get_numeric_circuit_id, socket_utils::send_assembly, + fetch_next_circuit, get_numeric_circuit_id, region_fetcher::Zone, socket_utils::send_assembly, }; use zksync_queued_job_processor::JobProcessor; use zksync_types::{ @@ -30,7 +30,7 @@ pub struct WitnessVectorGenerator { object_store: Arc, pool: ConnectionPool, circuit_ids_for_round_to_be_proven: Vec, - zone: String, + zone: Zone, config: FriWitnessVectorGeneratorConfig, protocol_version: ProtocolSemanticVersion, max_attempts: u32, @@ -43,7 +43,7 @@ impl WitnessVectorGenerator { object_store: Arc, prover_connection_pool: ConnectionPool, circuit_ids_for_round_to_be_proven: Vec, - zone: String, + zone: Zone, config: FriWitnessVectorGeneratorConfig, protocol_version: ProtocolSemanticVersion, max_attempts: u32, @@ -167,7 +167,7 @@ impl JobProcessor for WitnessVectorGenerator { .lock_available_prover( self.config.max_prover_reservation_duration(), self.config.specialized_group_id, - self.zone.clone(), + self.zone.to_string(), self.protocol_version, ) .await; @@ -179,7 +179,8 @@ impl JobProcessor for WitnessVectorGenerator { now.elapsed() ); let result = send_assembly(job_id, &serialized, &address); - handle_send_result(&result, job_id, &address, &self.pool, self.zone.clone()).await; + handle_send_result(&result, job_id, &address, &self.pool, self.zone.to_string()) + .await; if result.is_ok() { METRICS.prover_waiting_time[&circuit_type].observe(now.elapsed()); diff --git a/prover/witness_vector_generator/src/main.rs b/prover/witness_vector_generator/src/main.rs index cb61be4227c9..58db6d6d5eb4 100644 --- a/prover/witness_vector_generator/src/main.rs +++ b/prover/witness_vector_generator/src/main.rs @@ -11,7 +11,7 @@ use zksync_env_config::object_store::ProverObjectStoreConfig; use zksync_object_store::ObjectStoreFactory; use zksync_prover_dal::ConnectionPool; use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION; -use zksync_prover_fri_utils::{get_all_circuit_id_round_tuples_for, region_fetcher::get_zone}; +use zksync_prover_fri_utils::{get_all_circuit_id_round_tuples_for, region_fetcher::RegionFetcher}; use zksync_queued_job_processor::JobProcessor; use zksync_utils::wait_for_tasks::ManagedTasks; use zksync_vlog::prometheus::PrometheusExporterConfig; @@ -95,9 +95,14 @@ async fn main() -> anyhow::Result<()> { .unwrap_or_default(); let circuit_ids_for_round_to_be_proven = get_all_circuit_id_round_tuples_for(circuit_ids_for_round_to_be_proven); - let fri_prover_config = general_config.prover_config.context("prover config")?; - let zone_url = &fri_prover_config.zone_read_url; - let zone = get_zone(zone_url).await.context("get_zone()")?; + let prover_config = general_config.prover_config.context("prover config")?; + let zone = RegionFetcher::new( + prover_config.cloud_type, + prover_config.zone_read_url.clone(), + ) + .get_zone() + .await + .context("get_zone()")?; let protocol_version = PROVER_PROTOCOL_SEMANTIC_VERSION; @@ -108,8 +113,8 @@ async fn main() -> anyhow::Result<()> { zone.clone(), config, protocol_version, - fri_prover_config.max_attempts, - Some(fri_prover_config.setup_data_path.clone()), + prover_config.max_attempts, + Some(prover_config.setup_data_path.clone()), ); let (stop_sender, stop_receiver) = watch::channel(false);