Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prover): Add support for scaling WGs and compressor #3179

Merged
merged 10 commits into from
Oct 29, 2024
19 changes: 13 additions & 6 deletions core/lib/basic_types/src/prover_dal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ pub struct ExtendedJobCountStatistics {
pub successful: usize,
}

#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct JobCountStatistics {
pub queued: usize,
pub in_progress: usize,
}

impl Add for ExtendedJobCountStatistics {
type Output = ExtendedJobCountStatistics;

Expand All @@ -47,6 +41,19 @@ impl Add for ExtendedJobCountStatistics {
}
}

#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct JobCountStatistics {
pub queued: usize,
pub in_progress: usize,
}

impl JobCountStatistics {
/// all returns sum of queued and in_progress.
pub fn all(&self) -> usize {
self.queued + self.in_progress
}
}

#[derive(Debug)]
pub struct StuckJobs {
pub id: u64,
Expand Down
43 changes: 43 additions & 0 deletions core/lib/config/src/configs/prover_autoscaler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ pub struct ProverAutoscalerScalerConfig {
/// Duration after which pending pod considered long pending.
#[serde(default = "ProverAutoscalerScalerConfig::default_long_pending_duration")]
pub long_pending_duration: Duration,
/// List of simple autoscaler targets.
pub scaler_targets: Vec<ScalerTarget>,
}

#[derive(
Expand Down Expand Up @@ -93,6 +95,41 @@ pub enum Gpu {
A100,
}

// TODO: generate this enum by QueueReport from https://github.com/matter-labs/zksync-era/blob/main/prover/crates/bin/prover_job_monitor/src/autoscaler_queue_reporter.rs#L23
// and remove allowing of non_camel_case_types by generating field name parser.
#[derive(Debug, Display, PartialEq, Eq, Hash, Clone, Deserialize, EnumString, Default)]
#[allow(non_camel_case_types)]
pub enum QueueReportFields {
#[strum(ascii_case_insensitive)]
basic_witness_jobs,
#[strum(ascii_case_insensitive)]
leaf_witness_jobs,
#[strum(ascii_case_insensitive)]
node_witness_jobs,
#[strum(ascii_case_insensitive)]
recursion_tip_witness_jobs,
#[strum(ascii_case_insensitive)]
scheduler_witness_jobs,
#[strum(ascii_case_insensitive)]
proof_compressor_jobs,
#[default]
#[strum(ascii_case_insensitive)]
prover_jobs,
}

/// ScalerTarget can be configured to autoscale any of services for which queue is reported by
/// prover-job-monitor, except of provers. Provers need special treatment due to GPU requirement.
#[derive(Debug, Clone, PartialEq, Deserialize, Default)]
pub struct ScalerTarget {
pub queue_report_field: QueueReportFields,
pub pod_name_prefix: String,
/// Max replicas per cluster.
pub max_replicas: HashMap<String, usize>,
/// The queue will be divided by the speed and rounded up to get number of replicas.
#[serde(default = "ScalerTarget::default_speed")]
yorik marked this conversation as resolved.
Show resolved Hide resolved
pub speed: usize,
}

impl ProverAutoscalerConfig {
/// Default graceful shutdown timeout -- 5 seconds
pub fn default_graceful_shutdown_timeout() -> Duration {
Expand Down Expand Up @@ -126,3 +163,9 @@ impl ProverAutoscalerScalerConfig {
Duration::minutes(10)
}
}

impl ScalerTarget {
pub fn default_speed() -> usize {
1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,28 @@ message MinProver {
optional uint32 min = 2; // required
}

message MaxReplica {
optional string cluster = 1; // required
optional uint64 max = 2; // required
}

message ScalerTarget {
optional string queue_report_field = 1; // required
optional string pod_name_prefix = 2; // required
repeated MaxReplica max_replicas = 3; // required at least one
optional uint64 speed = 4; // optional
}

message ProverAutoscalerScalerConfig {
optional uint32 prometheus_port = 1; // required
optional std.Duration scaler_run_interval = 2; // optional
optional string prover_job_monitor_url = 3; // required
repeated string agents = 4; // required at least one
repeated ProtocolVersion protocol_versions = 5; // repeated at least one
repeated ProtocolVersion protocol_versions = 5; // required at least one
repeated ClusterPriority cluster_priorities = 6; // optional
repeated ProverSpeed prover_speed = 7; // optional
optional uint32 long_pending_duration_s = 8; // optional
repeated MaxProver max_provers = 9; // optional
repeated MinProver min_provers = 10; // optional
repeated ScalerTarget scaler_targets = 11; // optional
}
61 changes: 61 additions & 0 deletions core/lib/protobuf_config/src/prover_autoscaler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ impl ProtoRepr for proto::ProverAutoscalerScalerConfig {
.map(|(i, e)| e.read().context(i))
.collect::<Result<_, _>>()
.context("min_provers")?,
scaler_targets: self
.scaler_targets
.iter()
.enumerate()
.map(|(i, x)| x.read().context(i).unwrap())
.collect::<Vec<_>>(),
})
}

Expand Down Expand Up @@ -151,6 +157,7 @@ impl ProtoRepr for proto::ProverAutoscalerScalerConfig {
.iter()
.map(|(k, v)| proto::MinProver::build(&(k.clone(), *v)))
.collect(),
scaler_targets: this.scaler_targets.iter().map(ProtoRepr::build).collect(),
}
}
}
Expand Down Expand Up @@ -238,3 +245,57 @@ impl ProtoRepr for proto::MinProver {
}
}
}

impl ProtoRepr for proto::MaxReplica {
type Type = (String, usize);
fn read(&self) -> anyhow::Result<Self::Type> {
Ok((
required(&self.cluster).context("cluster")?.parse()?,
*required(&self.max).context("max")? as usize,
))
}
fn build(this: &Self::Type) -> Self {
Self {
cluster: Some(this.0.to_string()),
max: Some(this.1 as u64),
}
}
}

impl ProtoRepr for proto::ScalerTarget {
type Type = configs::prover_autoscaler::ScalerTarget;
fn read(&self) -> anyhow::Result<Self::Type> {
Ok(Self::Type {
queue_report_field: required(&self.queue_report_field)
.and_then(|x| Ok((*x).parse()?))
.context("queue_report_field")?,
pod_name_prefix: required(&self.pod_name_prefix)
.context("pod_name_prefix")?
.clone(),
max_replicas: self
.max_replicas
.iter()
.enumerate()
.map(|(i, e)| e.read().context(i))
.collect::<Result<_, _>>()
.context("max_replicas")?,
speed: match self.speed {
Some(x) => x as usize,
None => Self::Type::default_speed(),
},
})
}

fn build(this: &Self::Type) -> Self {
Self {
queue_report_field: Some(this.queue_report_field.to_string()),
pod_name_prefix: Some(this.pod_name_prefix.clone()),
max_replicas: this
.max_replicas
.iter()
.map(|(k, v)| proto::MaxReplica::build(&(k.clone(), *v)))
.collect(),
speed: Some(this.speed as u64),
}
}
}
2 changes: 2 additions & 0 deletions prover/crates/bin/prover_autoscaler/src/cluster_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub struct Namespace {
#[serde(serialize_with = "ordered_map")]
pub deployments: HashMap<String, Deployment>,
pub pods: HashMap<String, Pod>,
#[serde(default)]
pub scale_errors: Vec<ScaleEvent>,
}

Expand All @@ -64,4 +65,5 @@ pub enum PodStatus {
Pending,
LongPending,
NeedToMove,
Failed,
}
49 changes: 38 additions & 11 deletions prover/crates/bin/prover_autoscaler/src/global/queuer.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,76 @@
use std::collections::HashMap;
use std::{collections::HashMap, ops::Deref};

use anyhow::{Context, Ok};
use reqwest::Method;
use zksync_prover_job_monitor::autoscaler_queue_reporter::VersionedQueueReport;
use zksync_config::configs::prover_autoscaler::QueueReportFields;
use zksync_prover_job_monitor::autoscaler_queue_reporter::{QueueReport, VersionedQueueReport};
use zksync_utils::http_with_retries::send_request_with_retries;

use crate::metrics::{AUTOSCALER_METRICS, DEFAULT_ERROR_CODE};

const MAX_RETRIES: usize = 5;

#[derive(Debug)]
pub struct Queue {
pub queue: HashMap<String, u64>,
pub struct Queue(HashMap<(String, QueueReportFields), u64>);

impl Deref for Queue {
type Target = HashMap<(String, QueueReportFields), u64>;
fn deref(&self) -> &Self::Target {
&self.0
}
}

#[derive(Default)]
pub struct Queuer {
pub prover_job_monitor_url: String,
}

fn target_to_queue(target: &QueueReportFields, report: &QueueReport) -> u64 {
let res = match target {
QueueReportFields::basic_witness_jobs => report.basic_witness_jobs.all(),
QueueReportFields::leaf_witness_jobs => report.leaf_witness_jobs.all(),
QueueReportFields::node_witness_jobs => report.node_witness_jobs.all(),
QueueReportFields::recursion_tip_witness_jobs => report.recursion_tip_witness_jobs.all(),
QueueReportFields::scheduler_witness_jobs => report.scheduler_witness_jobs.all(),
QueueReportFields::proof_compressor_jobs => report.proof_compressor_jobs.all(),
QueueReportFields::prover_jobs => report.prover_jobs.all(),
};
res as u64
}

impl Queuer {
pub fn new(pjm_url: String) -> Self {
Self {
prover_job_monitor_url: pjm_url,
}
}

pub async fn get_queue(&self) -> anyhow::Result<Queue> {
/// Requests queue report from prover-job-monitor and parse it into Queue HashMap for provided
/// list of jobs.
pub async fn get_queue(&self, jobs: &[QueueReportFields]) -> anyhow::Result<Queue> {
let url = &self.prover_job_monitor_url;
let response = send_request_with_retries(url, MAX_RETRIES, Method::GET, None, None).await;
let response = response.map_err(|err| {
AUTOSCALER_METRICS.calls[&(url.clone(), DEFAULT_ERROR_CODE)].inc();
anyhow::anyhow!("Failed fetching queue from url: {url}: {err:?}")
anyhow::anyhow!("Failed fetching queue from URL: {url}: {err:?}")
})?;

AUTOSCALER_METRICS.calls[&(url.clone(), response.status().as_u16())].inc();
let response = response
.json::<Vec<VersionedQueueReport>>()
.await
.context("Failed to read response as json")?;
Ok(Queue {
queue: response
Ok(Queue(
response
.iter()
.map(|x| (x.version.to_string(), x.report.prover_jobs.queued as u64))
.flat_map(|versioned_report| {
jobs.iter().map(move |j| {
EmilLuta marked this conversation as resolved.
Show resolved Hide resolved
(
(versioned_report.version.to_string(), j.clone()),
target_to_queue(j, &versioned_report.report),
)
})
})
.collect::<HashMap<_, _>>(),
})
))
}
}
Loading
Loading