Skip to content

Commit

Permalink
Rework region fetcher
Browse files Browse the repository at this point in the history
  • Loading branch information
popzxc committed Jul 22, 2024
1 parent bf1f3e9 commit a6b909d
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 61 deletions.
14 changes: 14 additions & 0 deletions core/lib/config/src/configs/fri_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ pub enum SetupLoadMode {
FromMemory,
}

/// Kind of cloud environment prover subsystem runs in.
///
/// Currently will only affect how the prover zone is chosen.
#[derive(Debug, Default, Deserialize, Clone, Copy, PartialEq, Eq)]
pub enum CloudType {
/// Assumes that the prover runs in GCP.
#[default]
GCP,
/// Assumes that the prover runs locally.
Local,
}

/// Configuration for the fri prover application
#[derive(Debug, Deserialize, Clone, PartialEq)]
pub struct FriProverConfig {
Expand All @@ -28,6 +40,8 @@ pub struct FriProverConfig {
pub shall_save_to_public_bucket: bool,
pub prover_object_store: Option<ObjectStoreConfig>,
pub public_object_store: Option<ObjectStoreConfig>,
#[serde(default)]
pub cloud_type: CloudType,
}

impl FriProverConfig {
Expand Down
11 changes: 11 additions & 0 deletions core/lib/config/src/testonly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,16 @@ impl Distribution<configs::fri_prover::SetupLoadMode> for EncodeDist {
}
}

impl Distribution<configs::fri_prover::CloudType> for EncodeDist {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> configs::fri_prover::CloudType {
type T = configs::fri_prover::CloudType;
match rng.gen_range(0..1) {
0 => T::GCP,
_ => T::Local,
}
}
}

impl Distribution<configs::FriProverConfig> for EncodeDist {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> configs::FriProverConfig {
configs::FriProverConfig {
Expand All @@ -454,6 +464,7 @@ impl Distribution<configs::FriProverConfig> for EncodeDist {
availability_check_interval_in_secs: self.sample(rng),
prover_object_store: self.sample(rng),
public_object_store: self.sample(rng),
cloud_type: self.sample(rng),
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion core/lib/env_config/src/fri_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ impl FromEnv for FriProverConfig {
#[cfg(test)]
mod tests {
use zksync_config::{
configs::{fri_prover::SetupLoadMode, object_store::ObjectStoreMode},
configs::{
fri_prover::{CloudType, SetupLoadMode},
object_store::ObjectStoreMode,
},
ObjectStoreConfig,
};

Expand Down Expand Up @@ -57,6 +60,7 @@ mod tests {
local_mirror_path: None,
}),
availability_check_interval_in_secs: Some(1_800),
cloud_type: CloudType::GCP,
}
}

Expand Down
6 changes: 6 additions & 0 deletions core/lib/protobuf_config/src/proto/config/prover.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ enum SetupLoadMode {
FROM_MEMORY = 1;
}

enum CloudType {
GCP = 0;
LOCAL = 1;
}

message Prover {
optional string setup_data_path = 1; // required; fs path?
optional uint32 prometheus_port = 2; // required; u16
Expand All @@ -35,6 +40,7 @@ message Prover {
optional bool shall_save_to_public_bucket = 13; // required
optional config.object_store.ObjectStore public_object_store = 22;
optional config.object_store.ObjectStore prover_object_store = 23;
optional CloudType cloud_type = 24; // optional
reserved 5, 6, 9; reserved "base_layer_circuit_ids_to_be_verified", "recursive_layer_circuit_ids_to_be_verified", "witness_vector_generator_thread_count";
}

Expand Down
26 changes: 26 additions & 0 deletions core/lib/protobuf_config/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,24 @@ impl proto::SetupLoadMode {
}
}

impl proto::CloudType {
fn new(x: &configs::fri_prover::CloudType) -> Self {
use configs::fri_prover::CloudType as From;
match x {
From::GCP => Self::Gcp,
From::Local => Self::Local,
}
}

fn parse(&self) -> configs::fri_prover::CloudType {
use configs::fri_prover::CloudType as To;
match self {
Self::Gcp => To::GCP,
Self::Local => To::Local,
}
}
}

impl ProtoRepr for proto::Prover {
type Type = configs::FriProverConfig;
fn read(&self) -> anyhow::Result<Self::Type> {
Expand Down Expand Up @@ -338,6 +356,13 @@ impl ProtoRepr for proto::Prover {
.context("shall_save_to_public_bucket")?,
public_object_store,
prover_object_store,
cloud_type: self
.cloud_type
.map(proto::CloudType::try_from)
.transpose()
.context("cloud_type")?
.map(|x| x.parse())
.unwrap_or_default(),
})
}

Expand All @@ -356,6 +381,7 @@ impl ProtoRepr for proto::Prover {
shall_save_to_public_bucket: Some(this.shall_save_to_public_bucket),
prover_object_store: this.prover_object_store.as_ref().map(ProtoRepr::build),
public_object_store: this.public_object_store.as_ref().map(ProtoRepr::build),
cloud_type: Some(proto::CloudType::new(&this.cloud_type).into()),
}
}
}
1 change: 1 addition & 0 deletions prover/proof_fri_compressor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,6 @@ serde = { workspace = true, features = ["derive"] }
wrapper_prover = { workspace = true, optional = true }

[features]
default = ["gpu"]
gpu = ["wrapper_prover"]

7 changes: 4 additions & 3 deletions prover/prover_fri/src/gpu_prover_availability_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod availability_checker {

use tokio::sync::Notify;
use zksync_prover_dal::{ConnectionPool, Prover, ProverDal};
use zksync_prover_fri_utils::region_fetcher::Zone;
use zksync_types::prover_dal::{GpuProverInstanceStatus, SocketAddress};

use crate::metrics::{KillingReason, METRICS};
Expand All @@ -12,15 +13,15 @@ pub mod availability_checker {
/// If the prover instance is not found in the database or marked as dead, the availability checker will shut down the prover.
pub struct AvailabilityChecker {
address: SocketAddress,
zone: String,
zone: Zone,
polling_interval: Duration,
pool: ConnectionPool<Prover>,
}

impl AvailabilityChecker {
pub fn new(
address: SocketAddress,
zone: String,
zone: Zone,
polling_interval_secs: u32,
pool: ConnectionPool<Prover>,
) -> Self {
Expand All @@ -46,7 +47,7 @@ pub mod availability_checker {
.await
.unwrap()
.fri_gpu_prover_queue_dal()
.get_prover_instance_status(self.address.clone(), self.zone.clone())
.get_prover_instance_status(self.address.clone(), self.zone.to_string())
.await;

// If the prover instance is not found in the database or marked as dead, we should shut down the prover
Expand Down
7 changes: 4 additions & 3 deletions prover/prover_fri/src/gpu_prover_job_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub mod gpu_prover {
},
CircuitWrapper, FriProofWrapper, ProverServiceDataKey, WitnessVectorArtifacts,
};
use zksync_prover_fri_utils::region_fetcher::Zone;
use zksync_queued_job_processor::{async_trait, JobProcessor};
use zksync_types::{
basic_fri_types::CircuitIdRoundTuple, protocol_version::ProtocolSemanticVersion,
Expand Down Expand Up @@ -64,7 +65,7 @@ pub mod gpu_prover {
witness_vector_queue: SharedWitnessVectorQueue,
prover_context: ProverContext,
address: SocketAddress,
zone: String,
zone: Zone,
protocol_version: ProtocolSemanticVersion,
}

Expand All @@ -79,7 +80,7 @@ pub mod gpu_prover {
circuit_ids_for_round_to_be_proven: Vec<CircuitIdRoundTuple>,
witness_vector_queue: SharedWitnessVectorQueue,
address: SocketAddress,
zone: String,
zone: Zone,
protocol_version: ProtocolSemanticVersion,
) -> Self {
Prover {
Expand Down Expand Up @@ -230,7 +231,7 @@ pub mod gpu_prover {
.fri_gpu_prover_queue_dal()
.update_prover_instance_from_full_to_available(
self.address.clone(),
self.zone.clone(),
self.zone.to_string(),
)
.await;
}
Expand Down
27 changes: 16 additions & 11 deletions prover/prover_fri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ use zksync_env_config::FromEnv;
use zksync_object_store::{ObjectStore, ObjectStoreFactory};
use zksync_prover_dal::{ConnectionPool, Prover, ProverDal};
use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION;
use zksync_prover_fri_utils::{get_all_circuit_id_round_tuples_for, region_fetcher::get_zone};
use zksync_prover_fri_utils::{
get_all_circuit_id_round_tuples_for,
region_fetcher::{RegionFetcher, Zone},
};
use zksync_queued_job_processor::JobProcessor;
use zksync_types::{
basic_fri_types::CircuitIdRoundTuple,
Expand All @@ -32,24 +35,20 @@ mod prover_job_processor;
mod socket_listener;
mod utils;

async fn graceful_shutdown(port: u16) -> anyhow::Result<impl Future<Output = ()>> {
async fn graceful_shutdown(zone: Zone, port: u16) -> anyhow::Result<impl Future<Output = ()>> {
let database_secrets = DatabaseSecrets::from_env().context("DatabaseSecrets::from_env()")?;
let pool = ConnectionPool::<Prover>::singleton(database_secrets.prover_url()?)
.build()
.await
.context("failed to build a connection pool")?;
let host = local_ip().context("Failed obtaining local IP address")?;
let zone_url = &FriProverConfig::from_env()
.context("FriProverConfig::from_env()")?
.zone_read_url;
let zone = get_zone(zone_url).await.context("get_zone()")?;
let address = SocketAddress { host, port };
Ok(async move {
pool.connection()
.await
.unwrap()
.fri_gpu_prover_queue_dal()
.update_prover_instance_status(address, GpuProverInstanceStatus::Dead, zone)
.update_prover_instance_status(address, GpuProverInstanceStatus::Dead, zone.to_string())
.await
})
}
Expand Down Expand Up @@ -107,6 +106,13 @@ async fn main() -> anyhow::Result<()> {
})
.context("Error setting Ctrl+C handler")?;

let zone = RegionFetcher::new(
prover_config.cloud_type,
prover_config.zone_read_url.clone(),
)
.get_zone()
.await?;

let (stop_sender, stop_receiver) = tokio::sync::watch::channel(false);
let prover_object_store_config = prover_config
.prover_object_store
Expand Down Expand Up @@ -156,6 +162,7 @@ async fn main() -> anyhow::Result<()> {

let prover_tasks = get_prover_tasks(
prover_config,
zone.clone(),
stop_receiver.clone(),
object_store_factory,
public_blob_store,
Expand All @@ -174,7 +181,7 @@ async fn main() -> anyhow::Result<()> {
tokio::select! {
_ = tasks.wait_single() => {
if cfg!(feature = "gpu") {
graceful_shutdown(port)
graceful_shutdown(zone, port)
.await
.context("failed to prepare graceful shutdown future")?
.await;
Expand Down Expand Up @@ -228,6 +235,7 @@ async fn get_prover_tasks(
#[cfg(feature = "gpu")]
async fn get_prover_tasks(
prover_config: FriProverConfig,
zone: Zone,
stop_receiver: Receiver<bool>,
store_factory: ObjectStoreFactory,
public_blob_store: Option<Arc<dyn ObjectStore>>,
Expand All @@ -246,9 +254,6 @@ async fn get_prover_tasks(
let shared_witness_vector_queue = Arc::new(Mutex::new(witness_vector_queue));
let consumer = shared_witness_vector_queue.clone();

let zone = get_zone(&prover_config.zone_read_url)
.await
.context("get_zone()")?;
let local_ip = local_ip().context("Failed obtaining local IP address")?;
let address = SocketAddress {
host: local_ip,
Expand Down
9 changes: 5 additions & 4 deletions prover/prover_fri/src/socket_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod gpu_socket_listener {
use zksync_object_store::bincode;
use zksync_prover_dal::{ConnectionPool, Prover, ProverDal};
use zksync_prover_fri_types::WitnessVectorArtifacts;
use zksync_prover_fri_utils::region_fetcher::Zone;
use zksync_types::{
protocol_version::ProtocolSemanticVersion,
prover_dal::{GpuProverInstanceStatus, SocketAddress},
Expand All @@ -26,7 +27,7 @@ pub mod gpu_socket_listener {
queue: SharedWitnessVectorQueue,
pool: ConnectionPool<Prover>,
specialized_prover_group_id: u8,
zone: String,
zone: Zone,
protocol_version: ProtocolSemanticVersion,
}

Expand All @@ -36,7 +37,7 @@ pub mod gpu_socket_listener {
queue: SharedWitnessVectorQueue,
pool: ConnectionPool<Prover>,
specialized_prover_group_id: u8,
zone: String,
zone: Zone,
protocol_version: ProtocolSemanticVersion,
) -> Self {
Self {
Expand Down Expand Up @@ -68,7 +69,7 @@ pub mod gpu_socket_listener {
.insert_prover_instance(
self.address.clone(),
self.specialized_prover_group_id,
self.zone.clone(),
self.zone.to_string(),
self.protocol_version,
)
.await;
Expand Down Expand Up @@ -154,7 +155,7 @@ pub mod gpu_socket_listener {
.await
.unwrap()
.fri_gpu_prover_queue_dal()
.update_prover_instance_status(self.address.clone(), status, self.zone.clone())
.update_prover_instance_status(self.address.clone(), status, self.zone.to_string())
.await;
tracing::info!(
"Marked prover as {:?} after {:?}",
Expand Down
Loading

0 comments on commit a6b909d

Please sign in to comment.