From f1005f282656cbdaf936b4021c67ae3dc062e2ab Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Tue, 2 Jul 2024 17:32:26 -0500 Subject: [PATCH] Fix fasttask backlog timeouts (#353) ## Overview The PR addresses two separate issues (1) the fasttask plugin can assign tasks to a worker that is at full capacity (parallelism + backlog_length). The worker than transparently drops these tasks and the plugin failsover to another worker and (2) the fasttask plugin has no notion of `PhaseVersion` which FlytePropeller uses to determine if any updates have occurred and consequently it needs to store state to etcd. This means that the `LastAccessedAt` field on the fasttask plugin state increment will never be persisted. Therefore, all backlogged tasks will fail after the grace period occurs regardless of whether updates are sent by the worker heartbeat or not. The former is addressed by making the worker backlog_length a suggestion, similar to how `max-parallelism` is applied within FlytePropeller. That is, the fasttask plugin will attempt to only assign tasks to full worker capacity (ie. parallelism + backlog_length), but if it assigns more (race condition) then the worker will backlog them. The latter is fixed by adding a `PhaseVersion` field on the fasttask plugin state that is incremented with each worker task status heartbeat. ## Test Plan This has been tested locally against a variety of backlog scenarios (ex. differing lengths, timeouts, etc). ## Rollout Plan (if applicable) This may be rolled out to all tenants immediately. ## Upstream Changes Should this change be upstreamed to OSS (flyteorg/flyte)? If not, please uncheck this box, which is used for auditing. Note, it is the responsibility of each developer to actually upstream their changes. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F). - [ ] To be upstreamed to OSS ## Issue https://linear.app/unionai/issue/COR-1455/execution-frequently-fails-due-to-missing-task-status-reporting ## Checklist * [x] Added tests * [ ] Ran a deploy dry run and shared the terraform plan * [x] Added logging and metrics * [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list) * [ ] Updated documentation --- fasttask/plugin/plugin.go | 31 ++++++++++++++++++---------- fasttask/plugin/plugin_test.go | 15 ++++++++++++++ fasttask/worker/bridge/src/bridge.rs | 14 ++----------- fasttask/worker/bridge/src/cli.rs | 2 +- fasttask/worker/bridge/src/task.rs | 20 +++++++----------- 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/fasttask/plugin/plugin.go b/fasttask/plugin/plugin.go index a749a4a806..c303165bdf 100644 --- a/fasttask/plugin/plugin.go +++ b/fasttask/plugin/plugin.go @@ -66,6 +66,7 @@ func newPluginMetrics(scope promutils.Scope) pluginMetrics { // State maintains the current status of the task execution. type State struct { SubmissionPhase SubmissionPhase + PhaseVersion uint32 WorkerID string LastUpdated time.Time } @@ -293,11 +294,12 @@ func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (co if len(workerID) > 0 { pluginState.SubmissionPhase = Submitted + pluginState.PhaseVersion = core.DefaultPhaseVersion pluginState.WorkerID = workerID pluginState.LastUpdated = time.Now() - phaseInfo = core.PhaseInfoQueued(time.Now(), core.DefaultPhaseVersion, fmt.Sprintf("task offered to worker %s", workerID)) + phaseInfo = core.PhaseInfoQueued(time.Now(), pluginState.PhaseVersion, fmt.Sprintf("task offered to worker %s", workerID)) } else { if pluginState.LastUpdated.IsZero() { pluginState.LastUpdated = time.Now() @@ -352,7 +354,8 @@ func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (co phaseInfo = core.PhaseInfoSystemFailure("unknown", fmt.Sprintf("all workers have failed for queue %s\n%s", queueID, messageCollector.Summary(maxErrorMessageLength)), nil) } else { - phaseInfo = core.PhaseInfoWaitingForResourcesInfo(time.Now(), core.DefaultPhaseVersion, "no workers available", nil) + pluginState.PhaseVersion = core.DefaultPhaseVersion + phaseInfo = core.PhaseInfoWaitingForResourcesInfo(time.Now(), pluginState.PhaseVersion, "no workers available", nil) } } case Submitted: @@ -360,14 +363,19 @@ func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (co phase, reason, err := p.fastTaskService.CheckStatus(ctx, taskID, fastTaskEnvironment.GetQueueId(), pluginState.WorkerID) now := time.Now() - if err != nil && !errors.Is(err, statusUpdateNotFoundError) && !errors.Is(err, taskContextNotFoundError) { - return core.UnknownTransition, err - } else if errors.Is(err, statusUpdateNotFoundError) && now.Sub(pluginState.LastUpdated) > GetConfig().GracePeriodStatusNotFound.Duration { - // if task has not been updated within the grace period we should abort - logger.Infof(ctx, "Task status update not reported within grace period for queue %s and worker %s", queueID, pluginState.WorkerID) - p.metrics.statusUpdateNotFoundTimeout.Inc() - - return core.DoTransition(core.PhaseInfoSystemRetryableFailure("unknown", fmt.Sprintf("task status update not reported within grace period for queue %s and worker %s", queueID, pluginState.WorkerID), nil)), nil + if err != nil { + if errors.Is(err, statusUpdateNotFoundError) && now.Sub(pluginState.LastUpdated) > GetConfig().GracePeriodStatusNotFound.Duration { + // if task has not been updated within the grace period we should abort + logger.Errorf(ctx, "Task status update not reported within grace period for queue %s and worker %s", queueID, pluginState.WorkerID) + p.metrics.statusUpdateNotFoundTimeout.Inc() + + return core.DoTransition(core.PhaseInfoSystemRetryableFailure("unknown", + fmt.Sprintf("task status update not reported within grace period for queue %s and worker %s", queueID, pluginState.WorkerID), nil)), nil + } else if errors.Is(err, statusUpdateNotFoundError) || errors.Is(err, taskContextNotFoundError) { + phaseInfo = core.PhaseInfoRunning(pluginState.PhaseVersion, nil) + } else { + return core.UnknownTransition, err + } } else if phase == core.PhaseSuccess { taskTemplate, err := tCtx.TaskReader().Read(ctx) if err != nil { @@ -387,8 +395,9 @@ func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (co } else if phase == core.PhaseRetryableFailure { return core.DoTransition(core.PhaseInfoRetryableFailure("unknown", reason, nil)), nil } else { + pluginState.PhaseVersion++ pluginState.LastUpdated = now - phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion, nil) + phaseInfo = core.PhaseInfoRunning(pluginState.PhaseVersion, nil) } } diff --git a/fasttask/plugin/plugin_test.go b/fasttask/plugin/plugin_test.go index 10a8d58e4a..bcc9b3db58 100644 --- a/fasttask/plugin/plugin_test.go +++ b/fasttask/plugin/plugin_test.go @@ -532,6 +532,7 @@ func TestHandleRunning(t *testing.T) { taskStatusReason string checkStatusError error expectedPhase core.Phase + expectedPhaseVersion uint32 expectedReason string expectedError error expectedLastUpdatedInc bool @@ -543,10 +544,23 @@ func TestHandleRunning(t *testing.T) { taskStatusReason: "", checkStatusError: nil, expectedPhase: core.PhaseRunning, + expectedPhaseVersion: 1, expectedReason: "", expectedError: nil, expectedLastUpdatedInc: true, }, + { + name: "RunningStatusNotFound", + lastUpdated: time.Now().Add(-5 * time.Second), + taskStatusPhase: core.PhaseRunning, + taskStatusReason: "", + checkStatusError: statusUpdateNotFoundError, + expectedPhase: core.PhaseRunning, + expectedPhaseVersion: 0, + expectedReason: "", + expectedError: nil, + expectedLastUpdatedInc: false, + }, { name: "RetryableFailure", lastUpdated: time.Now().Add(-5 * time.Second), @@ -639,6 +653,7 @@ func TestHandleRunning(t *testing.T) { transition, err := plugin.Handle(ctx, tCtx) assert.Equal(t, test.expectedError, err) assert.Equal(t, test.expectedPhase, transition.Info().Phase()) + assert.Equal(t, test.expectedPhaseVersion, arrayNodeStateOutput.PhaseVersion) if test.expectedLastUpdatedInc { assert.True(t, arrayNodeStateOutput.LastUpdated.After(test.lastUpdated)) diff --git a/fasttask/worker/bridge/src/bridge.rs b/fasttask/worker/bridge/src/bridge.rs index df6c48f342..cadecfa393 100644 --- a/fasttask/worker/bridge/src/bridge.rs +++ b/fasttask/worker/bridge/src/bridge.rs @@ -88,13 +88,7 @@ pub async fn run(args: BridgeArgs) -> Result<(), Box> { let task_statuses: Arc>> = Arc::new(RwLock::new(vec![])); let heartbeat_bool = Arc::new(Mutex::new(AsyncBool::new())); - let (backlog_tx, backlog_rx) = match args.backlog_length { - 0 => (None, None), - x => { - let (tx, rx) = async_channel::bounded(x); - (Some(tx), Some(rx)) - } - }; + let (backlog_tx, backlog_rx) = async_channel::unbounded(); // build executors let (build_executor_tx, build_executor_rx) = async_channel::unbounded(); @@ -203,17 +197,13 @@ pub async fn run(args: BridgeArgs) -> Result<(), Box> { // periodically send heartbeat let _ = heartbeater.trigger().await; - let backlogged = match backlog_rx_clone { - Some(ref rx) => rx.len() as i32, - None => 0, - }; let mut heartbeat_request = HeartbeatRequest { worker_id: worker_id_clone.clone(), queue_id: queue_id_clone.clone(), capacity: Some(Capacity { execution_count: parallelism_clone - (executor_rx_clone.len() as i32), execution_limit: parallelism_clone, - backlog_count: backlogged, + backlog_count: backlog_rx_clone.len() as i32, backlog_limit: backlog_length_clone, }), task_statuses: vec!(), diff --git a/fasttask/worker/bridge/src/cli.rs b/fasttask/worker/bridge/src/cli.rs index f679a01090..886bf25c73 100644 --- a/fasttask/worker/bridge/src/cli.rs +++ b/fasttask/worker/bridge/src/cli.rs @@ -64,7 +64,7 @@ pub struct BridgeArgs { long, value_name = "BACKLOG_LENGTH", default_value = "5", - help = "number of tasks to buffer before dropping assignments" + help = "suggested number of tasks to buffer for future execution, the actual number may be higher" )] pub backlog_length: usize, #[arg( diff --git a/fasttask/worker/bridge/src/task.rs b/fasttask/worker/bridge/src/task.rs index e2357d9899..c8d2b8bef4 100644 --- a/fasttask/worker/bridge/src/task.rs +++ b/fasttask/worker/bridge/src/task.rs @@ -6,7 +6,7 @@ use crate::common::{Executor, Response, Task}; use crate::common::{TaskContext, FAILED, QUEUED, RUNNING}; use crate::pb::fasttask::TaskStatus; -use async_channel::{self, Receiver, Sender, TryRecvError, TrySendError}; +use async_channel::{self, Receiver, Sender, TryRecvError}; use futures::sink::SinkExt; use futures::stream::StreamExt; use tracing::{debug, info, warn}; @@ -20,8 +20,8 @@ pub async fn execute( task_status_tx: Sender, task_status_report_interval_seconds: u64, last_ack_grace_period_seconds: u64, - backlog_tx: Option>, - backlog_rx: Option>, + backlog_tx: Sender<()>, + backlog_rx: Receiver<()>, executor_tx: Sender, executor_rx: Receiver, build_executor_tx: Sender<()>, @@ -60,7 +60,6 @@ pub async fn execute( // if backlogged we wait until we can execute let (mut phase, mut reason) = (QUEUED, "".to_string()); if backlogged { - let backlog_rx = backlog_rx.unwrap(); executor = match wait_in_backlog( task_contexts.clone(), &kill_rx, @@ -156,7 +155,7 @@ pub async fn execute( async fn is_executable( executor_rx: &Receiver, - backlog_tx: &Option>, + backlog_tx: &Sender<()>, ) -> Result<(Option, bool), String> { match executor_rx.try_recv() { Ok(executor) => return Ok((Some(executor), false)), @@ -164,15 +163,12 @@ async fn is_executable( Err(TryRecvError::Empty) => {} } - if let Some(backlog_tx) = backlog_tx { - match backlog_tx.try_send(()) { - Ok(_) => return Ok((None, true)), - Err(TrySendError::Closed(e)) => return Err(format!("backlog_tx is closed: {:?}", e)), - Err(TrySendError::Full(_)) => {} - } + match backlog_tx.send(()).await { + Ok(_) => {} + Err(e) => return Err(format!("failed to send to backlog_tx: {:?}", e)), } - Ok((None, false)) + Ok((None, true)) } async fn report_terminal_status(