From 114e00485579c307635d3c3371b50fdecfc5db99 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Thu, 5 Sep 2024 17:11:33 -0700 Subject: [PATCH] Added literal offloading checks across hetrogeneous tasks --- flytepropeller/pkg/controller/controller.go | 2 +- .../pkg/controller/nodes/executor.go | 36 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/flytepropeller/pkg/controller/controller.go b/flytepropeller/pkg/controller/controller.go index e23b0c29c22..39047e811d8 100644 --- a/flytepropeller/pkg/controller/controller.go +++ b/flytepropeller/pkg/controller/controller.go @@ -443,7 +443,7 @@ func New(ctx context.Context, cfg *config.Config, kubeClientset kubernetes.Inter nodeExecutor, err := nodes.NewExecutor(ctx, cfg.NodeConfig, store, controller.enqueueWorkflowForNodeUpdates, eventSink, launchPlanActor, launchPlanActor, storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, - catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) + catalogClient, recoveryClient, cfg.LiteralOffloadingConfig, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) if err != nil { return nil, errors.Wrapf(err, "Failed to create Controller.") } diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 47c91edc513..85d8ecb22c9 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -491,6 +491,7 @@ type nodeExecutor struct { defaultExecutionDeadline time.Duration enqueueWorkflow v1alpha1.EnqueueWorkflow eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig interruptibleFailureThreshold int32 maxNodeRetriesForSystemFailures uint32 metrics *nodeMetrics @@ -764,6 +765,10 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur } if nodeInputs != nil { + p := c.checkOffloadingCompat(ctx, nCtx, nodeInputs.Literals, node) + if p != handler.PhaseInfoUndefined { + return p, nil + } inputsFile := v1alpha1.GetInputsFile(dataDir) if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { c.metrics.InputsWriteFailure.Inc(ctx) @@ -790,6 +795,34 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur return handler.PhaseInfoNotReady("predecessor node not yet complete"), nil } +// checkOffloadingCompat checks if the upstream and downstream nodes are compatible with the literal offloading feature +func (c *nodeExecutor) checkOffloadingCompat(ctx context.Context, nCtx interfaces.NodeExecutionContext, inputLiterals map[string]*core.Literal, node v1alpha1.ExecutableNode) handler.PhaseInfo { + isOffloadLiteral := false + for _, val := range inputLiterals { + if val != nil && val.GetOffloadedMetadata() != nil { + isOffloadLiteral = true + break + } + } + switch node.GetKind() { + case v1alpha1.NodeKindTask: + taskID := *node.GetTaskID() + taskNode, err := nCtx.ExecutionContext().GetTask(taskID) + if err != nil { + return handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "GetTaskIDFailure", err.Error(), nil) + } + runtimeData := taskNode.CoreTask().GetMetadata().GetRuntime() + if isOffloadLiteral && (c.literalOffloadingConfig.Enabled && !c.literalOffloadingConfig.IsSupportedSDKVersion(runtimeData.GetType().String(), runtimeData.GetVersion())) { + logger.Debugf(ctx, "literal offloading : sdk version check failed for task [%s]", taskID) + return handler.PhaseInfoFailure(core.ExecutionError_USER, "LiteralOffloadingNotSupported", "Literal offloading is not supported for this task", nil) + } + break + default: + logger.Warnf(ctx, "literal offloading : skipping sdk version check for node kind '%s'", node.GetKind()) + } + return handler.PhaseInfoUndefined +} + func isTimeoutExpired(queuedAt *metav1.Time, timeout time.Duration) bool { if !queuedAt.IsZero() && timeout != 0 { deadline := queuedAt.Add(timeout) @@ -1417,7 +1450,7 @@ func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructur func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, - catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, + catalogClient catalog.Client, recoveryClient recovery.Client, literalOffloadingConfig config.LiteralOffloadingConfig, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, nodeHandlerFactory interfaces.HandlerFactory, scope promutils.Scope) (interfaces.Node, error) { // TODO we may want to make this configurable. @@ -1469,6 +1502,7 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora defaultExecutionDeadline: nodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.Duration, enqueueWorkflow: enQWorkflow, eventConfig: eventConfig, + literalOffloadingConfig: literalOffloadingConfig, interruptibleFailureThreshold: nodeConfig.InterruptibleFailureThreshold, maxNodeRetriesForSystemFailures: uint32(nodeConfig.MaxNodeRetriesOnSystemFailures), metrics: metrics,