diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 07787d6e7..b1f269989 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -595,7 +595,7 @@ func (in *NodeStatus) GetOrCreateArrayNodeStatus() MutableArrayNodeStatus { } func (in *NodeStatus) UpdatePhase(p NodePhase, occurredAt metav1.Time, reason string, err *core.ExecutionError) { - if in.Phase == p { + if in.Phase == p && in.Message == reason { // We will not update the phase multiple times. This prevents the comparison from returning false positive return } diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go b/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go index 1fbf4391d..698b437eb 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status_test.go @@ -263,7 +263,8 @@ func TestNodeStatus_UpdatePhase(t *testing.T) { t.Run("identical-phase", func(t *testing.T) { p := NodePhaseQueued ns := NodeStatus{ - Phase: p, + Phase: p, + Message: queued, } msg := queued ns.UpdatePhase(p, n, msg, nil) diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 57dcb9ba6..3aedbed64 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -25,11 +25,11 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/catalog" errors3 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/factory" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" "github.com/flyteorg/flytepropeller/pkg/controller/workflow" "github.com/flyteorg/flytepropeller/pkg/controller/workflowstore" leader "github.com/flyteorg/flytepropeller/pkg/leaderelection" diff --git a/pkg/controller/nodes/array/handler.go b/pkg/controller/nodes/array/handler.go index 19641cb93..a8cc13c85 100644 --- a/pkg/controller/nodes/array/handler.go +++ b/pkg/controller/nodes/array/handler.go @@ -295,6 +295,18 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu retryAttempt := subNodeStatus.GetAttempts() + // fastcache will not emit task events for cache hits. we need to manually detect a + // transition to `SUCCEEDED` and add an `ExternalResourceInfo` for it. + if cacheStatus == idlcore.CatalogCacheStatus_CACHE_HIT && len(arrayEventRecorder.TaskEvents()) == 0 { + externalResources = append(externalResources, &event.ExternalResourceInfo{ + ExternalId: buildSubNodeID(nCtx, i, retryAttempt), + Index: uint32(i), + RetryAttempt: retryAttempt, + Phase: idlcore.TaskExecution_SUCCEEDED, + CacheStatus: cacheStatus, + }) + } + for _, taskExecutionEvent := range arrayEventRecorder.TaskEvents() { for _, log := range taskExecutionEvent.Logs { log.Name = fmt.Sprintf("%s-%d", log.Name, i) @@ -543,19 +555,17 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter inputReader := newStaticInputReader(nCtx.InputReader(), inputLiteralMap) - // if node has not yet started we automatically set to NodePhaseQueued to skip input resolution - if nodePhase == v1alpha1.NodePhaseNotYetStarted { - // TODO - to supprt fastcache we'll need to override the bindings to BindingScalars for the input resolution on the nCtx - // that way resolution is just reading a literal ... but does this still write a file then?!? - nodePhase = v1alpha1.NodePhaseQueued - } - // wrap node lookup subNodeSpec := *arrayNode.GetSubNodeSpec() subNodeID := fmt.Sprintf("%s-n%d", nCtx.NodeID(), subNodeIndex) subNodeSpec.ID = subNodeID subNodeSpec.Name = subNodeID + // mock the input bindings for the subNode to nil to bypass input resolution in the + // `nodeExecutor.preExecute` function. this is required because this function is the entrypoint + // for initial cache lookups. an alternative solution would be to mock the datastore to bypass + // writing the inputFile. + subNodeSpec.InputBindings = nil // TODO - if we want to support more plugin types we need to figure out the best way to store plugin state // currently just mocking based on node phase -> which works for all k8s plugins diff --git a/pkg/controller/nodes/array/handler_test.go b/pkg/controller/nodes/array/handler_test.go index f3e6f8bd1..3a5f84965 100644 --- a/pkg/controller/nodes/array/handler_test.go +++ b/pkg/controller/nodes/array/handler_test.go @@ -13,13 +13,13 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/config" execmocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/catalog" gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" recoverymocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" pluginmocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" diff --git a/pkg/controller/nodes/array/node_lookup.go b/pkg/controller/nodes/array/node_lookup.go index 061b323af..b3b8dd03a 100644 --- a/pkg/controller/nodes/array/node_lookup.go +++ b/pkg/controller/nodes/array/node_lookup.go @@ -14,6 +14,14 @@ type arrayNodeLookup struct { subNodeStatus *v1alpha1.NodeStatus } +func (a *arrayNodeLookup) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + if id == a.subNodeID { + return nil, nil + } + + return a.NodeLookup.ToNode(id) +} + func (a *arrayNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { if nodeID == a.subNodeID { return a.subNodeSpec, true diff --git a/pkg/controller/nodes/cache.go b/pkg/controller/nodes/cache.go new file mode 100644 index 000000000..fbae02b16 --- /dev/null +++ b/pkg/controller/nodes/cache.go @@ -0,0 +1,234 @@ +package nodes + +import ( + "context" + "strconv" + "time" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/encoding" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" + nodeserrors "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/storage" + + "github.com/pkg/errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// computeCatalogReservationOwnerID constructs a unique identifier which includes the nodes +// parent information, node ID, and retry attempt number. This is used to uniquely identify a task +// when the cache reservation API to serialize cached executions. +func computeCatalogReservationOwnerID(nCtx interfaces.NodeExecutionContext) (string, error) { + currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID()) + if err != nil { + return "", err + } + + ownerID, err := encoding.FixedLengthUniqueIDForParts(task.IDMaxLength, + []string{nCtx.NodeExecutionMetadata().GetOwnerID().Name, currentNodeUniqueID, strconv.Itoa(int(nCtx.CurrentAttempt()))}) + if err != nil { + return "", err + } + + return ownerID, nil +} + +// updatePhaseCacheInfo adds the cache and catalog reservation metadata to the PhaseInfo. This +// ensures this information is reported in events and available within FlyteAdmin. +func updatePhaseCacheInfo(phaseInfo handler.PhaseInfo, cacheStatus *catalog.Status, reservationStatus *core.CatalogReservation_Status) handler.PhaseInfo { + if cacheStatus == nil && reservationStatus == nil { + return phaseInfo + } + + info := phaseInfo.GetInfo() + if info == nil { + info = &handler.ExecutionInfo{} + } + + if info.TaskNodeInfo == nil { + info.TaskNodeInfo = &handler.TaskNodeInfo{} + } + + if info.TaskNodeInfo.TaskNodeMetadata == nil { + info.TaskNodeInfo.TaskNodeMetadata = &event.TaskNodeMetadata{} + } + + if cacheStatus != nil { + info.TaskNodeInfo.TaskNodeMetadata.CacheStatus = cacheStatus.GetCacheStatus() + info.TaskNodeInfo.TaskNodeMetadata.CatalogKey = cacheStatus.GetMetadata() + } + + if reservationStatus != nil { + info.TaskNodeInfo.TaskNodeMetadata.ReservationStatus = *reservationStatus + } + + return phaseInfo.WithInfo(info) +} + +// CheckCatalogCache uses the handler and contexts to check if cached outputs for the current node +// exist. If the exist, this function also copies the outputs to this node. +func (n *nodeExecutor) CheckCatalogCache(ctx context.Context, nCtx interfaces.NodeExecutionContext, cacheHandler interfaces.CacheableNodeHandler) (catalog.Entry, error) { + catalogKey, err := cacheHandler.GetCatalogKey(ctx, nCtx) + if err != nil { + return catalog.Entry{}, errors.Wrapf(err, "failed to initialize the catalogKey") + } + + entry, err := n.catalog.Get(ctx, catalogKey) + if err != nil { + causeErr := errors.Cause(err) + if taskStatus, ok := status.FromError(causeErr); ok && taskStatus.Code() == codes.NotFound { + n.metrics.catalogMissCount.Inc(ctx) + logger.Infof(ctx, "Catalog CacheMiss: Artifact not found in Catalog. Executing Task.") + return catalog.NewCatalogEntry(nil, catalog.NewStatus(core.CatalogCacheStatus_CACHE_MISS, nil)), nil + } + + n.metrics.catalogGetFailureCount.Inc(ctx) + logger.Errorf(ctx, "Catalog Failure: memoization check failed. err: %v", err.Error()) + return catalog.Entry{}, errors.Wrapf(err, "Failed to check Catalog for previous results") + } + + if entry.GetStatus().GetCacheStatus() != core.CatalogCacheStatus_CACHE_HIT { + logger.Errorf(ctx, "No CacheHIT and no Error received. Illegal state, Cache State: %s", entry.GetStatus().GetCacheStatus().String()) + // TODO should this be an error? + return entry, nil + } + + logger.Infof(ctx, "Catalog CacheHit: for task [%s/%s/%s/%s]", catalogKey.Identifier.Project, + catalogKey.Identifier.Domain, catalogKey.Identifier.Name, catalogKey.Identifier.Version) + n.metrics.catalogHitCount.Inc(ctx) + + iface := catalogKey.TypedInterface + if iface.Outputs != nil && len(iface.Outputs.Variables) > 0 { + // copy cached outputs to node outputs + o, ee, err := entry.GetOutputs().Read(ctx) + if err != nil { + logger.Errorf(ctx, "failed to read from catalog, err: %s", err.Error()) + return catalog.Entry{}, err + } else if ee != nil { + logger.Errorf(ctx, "got execution error from catalog output reader? This should not happen, err: %s", ee.String()) + return catalog.Entry{}, nodeserrors.Errorf(nodeserrors.IllegalStateError, nCtx.NodeID(), "execution error from a cache output, bad state: %s", ee.String()) + } + + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + if err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, o); err != nil { + logger.Errorf(ctx, "failed to write cached value to datastore, err: %s", err.Error()) + return catalog.Entry{}, err + } + } + + return entry, nil +} + +// GetOrExtendCatalogReservation attempts to acquire an artifact reservation if the task is +// cachable and cache serializable. If the reservation already exists for this owner, the +// reservation is extended. +func (n *nodeExecutor) GetOrExtendCatalogReservation(ctx context.Context, nCtx interfaces.NodeExecutionContext, + cacheHandler interfaces.CacheableNodeHandler, heartbeatInterval time.Duration) (catalog.ReservationEntry, error) { + + catalogKey, err := cacheHandler.GetCatalogKey(ctx, nCtx) + if err != nil { + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED), + errors.Wrapf(err, "failed to initialize the catalogKey") + } + + ownerID, err := computeCatalogReservationOwnerID(nCtx) + if err != nil { + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED), + errors.Wrapf(err, "failed to initialize the cache reservation ownerID") + } + + reservation, err := n.catalog.GetOrExtendReservation(ctx, catalogKey, ownerID, heartbeatInterval) + if err != nil { + n.metrics.reservationGetFailureCount.Inc(ctx) + logger.Errorf(ctx, "Catalog Failure: reservation get or extend failed. err: %v", err.Error()) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err + } + + var status core.CatalogReservation_Status + if reservation.OwnerId == ownerID { + status = core.CatalogReservation_RESERVATION_ACQUIRED + } else { + status = core.CatalogReservation_RESERVATION_EXISTS + } + + n.metrics.reservationGetSuccessCount.Inc(ctx) + return catalog.NewReservationEntry(reservation.ExpiresAt.AsTime(), + reservation.HeartbeatInterval.AsDuration(), reservation.OwnerId, status), nil +} + +// ReleaseCatalogReservation attempts to release an artifact reservation if the task is cachable +// and cache serializable. If the reservation does not exist for this owner (e.x. it never existed +// or has been acquired by another owner) this call is still successful. +func (n *nodeExecutor) ReleaseCatalogReservation(ctx context.Context, nCtx interfaces.NodeExecutionContext, + cacheHandler interfaces.CacheableNodeHandler) (catalog.ReservationEntry, error) { + + catalogKey, err := cacheHandler.GetCatalogKey(ctx, nCtx) + if err != nil { + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED), + errors.Wrapf(err, "failed to initialize the catalogKey") + } + + ownerID, err := computeCatalogReservationOwnerID(nCtx) + if err != nil { + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED), + errors.Wrapf(err, "failed to initialize the cache reservation ownerID") + } + + err = n.catalog.ReleaseReservation(ctx, catalogKey, ownerID) + if err != nil { + n.metrics.reservationReleaseFailureCount.Inc(ctx) + logger.Errorf(ctx, "Catalog Failure: release reservation failed. err: %v", err.Error()) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err + } + + n.metrics.reservationReleaseSuccessCount.Inc(ctx) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_RELEASED), nil +} + +// WriteCatalogCache relays the outputs of this node to the cache. This allows future executions +// to reuse these data to avoid recomputation. +func (n *nodeExecutor) WriteCatalogCache(ctx context.Context, nCtx interfaces.NodeExecutionContext, cacheHandler interfaces.CacheableNodeHandler) (catalog.Status, error) { + catalogKey, err := cacheHandler.GetCatalogKey(ctx, nCtx) + if err != nil { + return catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil), errors.Wrapf(err, "failed to initialize the catalogKey") + } + + iface := catalogKey.TypedInterface + if iface.Outputs != nil && len(iface.Outputs.Variables) == 0 { + return catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil), nil + } + + logger.Infof(ctx, "Catalog CacheEnabled. recording execution [%s/%s/%s/%s]", catalogKey.Identifier.Project, + catalogKey.Identifier.Domain, catalogKey.Identifier.Name, catalogKey.Identifier.Version) + + outputPaths := ioutils.NewReadOnlyOutputFilePaths(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetOutputDir()) + outputReader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, nCtx.MaxDatasetSizeBytes()) + metadata := catalog.Metadata{ + TaskExecutionIdentifier: task.GetTaskExecutionIdentifier(nCtx), + } + + // ignores discovery write failures + status, err := n.catalog.Put(ctx, catalogKey, outputReader, metadata) + if err != nil { + n.metrics.catalogPutFailureCount.Inc(ctx) + logger.Errorf(ctx, "Failed to write results to catalog for Task [%v]. Error: %v", catalogKey.Identifier, err) + return catalog.NewStatus(core.CatalogCacheStatus_CACHE_PUT_FAILURE, status.GetMetadata()), nil + } + + n.metrics.catalogPutSuccessCount.Inc(ctx) + logger.Infof(ctx, "Successfully cached results to catalog - Task [%v]", catalogKey.Identifier) + return status, nil +} diff --git a/pkg/controller/nodes/cache_test.go b/pkg/controller/nodes/cache_test.go new file mode 100644 index 000000000..6d9c4fce9 --- /dev/null +++ b/pkg/controller/nodes/cache_test.go @@ -0,0 +1,448 @@ +package nodes + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/datacatalog" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + executorsmocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + interfacesmocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" + catalogmocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + + "k8s.io/apimachinery/pkg/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var ( + currentAttempt = uint32(0) + nodeID = "baz" + nodeOutputDir = storage.DataReference("output_directory") + parentUniqueID = "bar" + parentCurrentAttempt = uint32(1) + uniqueID = "foo" +) + +type mockTaskReader struct { + taskTemplate *core.TaskTemplate +} + +func (t mockTaskReader) Read(ctx context.Context) (*core.TaskTemplate, error) { + return t.taskTemplate, nil +} +func (t mockTaskReader) GetTaskType() v1alpha1.TaskType { return "" } +func (t mockTaskReader) GetTaskID() *core.Identifier { return nil } + +func setupCacheableNodeExecutionContext(dataStore *storage.DataStore, taskTemplate *core.TaskTemplate) *nodeExecContext { + mockNode := &mocks.ExecutableNode{} + mockNode.OnGetIDMatch(mock.Anything).Return(nodeID) + + mockNodeStatus := &mocks.ExecutableNodeStatus{} + mockNodeStatus.OnGetAttemptsMatch().Return(currentAttempt) + mockNodeStatus.OnGetOutputDir().Return(nodeOutputDir) + + mockParentInfo := &executorsmocks.ImmutableParentInfo{} + mockParentInfo.OnCurrentAttemptMatch().Return(parentCurrentAttempt) + mockParentInfo.OnGetUniqueIDMatch().Return(uniqueID) + + mockExecutionContext := &executorsmocks.ExecutionContext{} + mockExecutionContext.OnGetParentInfoMatch(mock.Anything).Return(mockParentInfo) + + mockNodeExecutionMetadata := &interfacesmocks.NodeExecutionMetadata{} + mockNodeExecutionMetadata.OnGetOwnerID().Return( + types.NamespacedName{ + Name: parentUniqueID, + }, + ) + mockNodeExecutionMetadata.OnGetNodeExecutionIDMatch().Return( + &core.NodeExecutionIdentifier{ + NodeId: nodeID, + }, + ) + + var taskReader interfaces.TaskReader + if taskTemplate != nil { + taskReader = mockTaskReader{ + taskTemplate: taskTemplate, + } + } + + return &nodeExecContext{ + ic: mockExecutionContext, + md: mockNodeExecutionMetadata, + node: mockNode, + nodeStatus: mockNodeStatus, + store: dataStore, + tr: taskReader, + } +} + +func TestComputeCatalogReservationOwnerID(t *testing.T) { + nCtx := setupCacheableNodeExecutionContext(nil, nil) + + ownerID, err := computeCatalogReservationOwnerID(nCtx) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("%s-%s-%d-%s-%d", parentUniqueID, uniqueID, parentCurrentAttempt, nodeID, currentAttempt), ownerID) +} + +func TestUpdatePhaseCacheInfo(t *testing.T) { + cacheStatus := catalog.NewStatus(core.CatalogCacheStatus_CACHE_MISS, nil) + reservationStatus := core.CatalogReservation_RESERVATION_EXISTS + + tests := []struct { + name string + cacheStatus *catalog.Status + reservationStatus *core.CatalogReservation_Status + }{ + {"BothEmpty", nil, nil}, + {"CacheStatusOnly", &cacheStatus, nil}, + {"ReservationStatusOnly", nil, &reservationStatus}, + {"BothPopulated", &cacheStatus, &reservationStatus}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + phaseInfo := handler.PhaseInfoUndefined + phaseInfo = updatePhaseCacheInfo(phaseInfo, test.cacheStatus, test.reservationStatus) + + // do not create ExecutionInfo object if neither cacheStatus or reservationStatus exists + if test.cacheStatus == nil && test.reservationStatus == nil { + assert.Nil(t, phaseInfo.GetInfo()) + } + + // ensure cache and reservation status' are being set correctly + if test.cacheStatus != nil { + assert.Equal(t, cacheStatus.GetCacheStatus(), phaseInfo.GetInfo().TaskNodeInfo.TaskNodeMetadata.CacheStatus) + } + + if test.reservationStatus != nil { + assert.Equal(t, reservationStatus, phaseInfo.GetInfo().TaskNodeInfo.TaskNodeMetadata.ReservationStatus) + } + }) + } +} + +func TestCheckCatalogCache(t *testing.T) { + tests := []struct { + name string + cacheEntry catalog.Entry + cacheError error + catalogKey catalog.Key + expectedCacheStatus core.CatalogCacheStatus + assertOutputFile bool + outputFileExists bool + }{ + { + "CacheMiss", + catalog.Entry{}, + status.Error(codes.NotFound, ""), + catalog.Key{}, + core.CatalogCacheStatus_CACHE_MISS, + false, + false, + }, + { + "CacheHitWithOutputs", + catalog.NewCatalogEntry( + ioutils.NewInMemoryOutputReader(&core.LiteralMap{}, nil, nil), + catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, nil), + ), + nil, + catalog.Key{ + TypedInterface: core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "foo": nil, + }, + }, + }, + }, + core.CatalogCacheStatus_CACHE_HIT, + true, + true, + }, + { + "CacheHitWithoutOutputs", + catalog.NewCatalogEntry( + nil, + catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, nil), + ), + nil, + catalog.Key{}, + core.CatalogCacheStatus_CACHE_HIT, + true, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testScope := promutils.NewTestScope() + metrics := &nodeMetrics{ + catalogHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", testScope), + catalogMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", testScope), + catalogSkipCount: labeled.NewCounter("discovery_skip_count", "Task cached skipped in Discovery", testScope), + catalogPutSuccessCount: labeled.NewCounter("discovery_put_success_count", "Discovery Put success count", testScope), + catalogPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", testScope), + catalogGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", testScope), + reservationGetFailureCount: labeled.NewCounter("reservation_get_failure_count", "Reservation GetOrExtend failure count", testScope), + reservationGetSuccessCount: labeled.NewCounter("reservation_get_success_count", "Reservation GetOrExtend success count", testScope), + reservationReleaseFailureCount: labeled.NewCounter("reservation_release_failure_count", "Reservation Release failure count", testScope), + reservationReleaseSuccessCount: labeled.NewCounter("reservation_release_success_count", "Reservation Release success count", testScope), + } + + cacheableHandler := &interfacesmocks.CacheableNodeHandler{} + cacheableHandler.OnGetCatalogKeyMatch(mock.Anything, mock.Anything).Return(test.catalogKey, nil) + + catalogClient := &catalogmocks.Client{} + catalogClient.OnGetMatch(mock.Anything, mock.Anything).Return(test.cacheEntry, test.cacheError) + + dataStore, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + testScope.NewSubScope("data_store"), + ) + assert.NoError(t, err) + + nodeExecutor := &nodeExecutor{ + catalog: catalogClient, + metrics: metrics, + } + nCtx := setupCacheableNodeExecutionContext(dataStore, nil) + + // execute catalog cache check + cacheEntry, err := nodeExecutor.CheckCatalogCache(context.TODO(), nCtx, cacheableHandler) + assert.NoError(t, err) + + // validate the result cache entry status + assert.Equal(t, test.expectedCacheStatus, cacheEntry.GetStatus().GetCacheStatus()) + + if test.assertOutputFile { + // assert the outputs file exists + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + metadata, err := nCtx.DataStore().Head(context.TODO(), outputFile) + assert.NoError(t, err) + assert.Equal(t, test.outputFileExists, metadata.Exists()) + } + }) + } +} + +func TestGetOrExtendCatalogReservation(t *testing.T) { + tests := []struct { + name string + reservationOwnerID string + expectedReservationStatus core.CatalogReservation_Status + }{ + { + "Acquired", + "bar-foo-1-baz-0", + core.CatalogReservation_RESERVATION_ACQUIRED, + }, + { + "Exists", + "some-other-owner", + core.CatalogReservation_RESERVATION_EXISTS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testScope := promutils.NewTestScope() + metrics := &nodeMetrics{ + catalogHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", testScope), + catalogMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", testScope), + catalogSkipCount: labeled.NewCounter("discovery_skip_count", "Task cached skipped in Discovery", testScope), + catalogPutSuccessCount: labeled.NewCounter("discovery_put_success_count", "Discovery Put success count", testScope), + catalogPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", testScope), + catalogGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", testScope), + reservationGetFailureCount: labeled.NewCounter("reservation_get_failure_count", "Reservation GetOrExtend failure count", testScope), + reservationGetSuccessCount: labeled.NewCounter("reservation_get_success_count", "Reservation GetOrExtend success count", testScope), + reservationReleaseFailureCount: labeled.NewCounter("reservation_release_failure_count", "Reservation Release failure count", testScope), + reservationReleaseSuccessCount: labeled.NewCounter("reservation_release_success_count", "Reservation Release success count", testScope), + } + + cacheableHandler := &interfacesmocks.CacheableNodeHandler{} + cacheableHandler.OnGetCatalogKeyMatch(mock.Anything, mock.Anything).Return(catalog.Key{}, nil) + + catalogClient := &catalogmocks.Client{} + catalogClient.OnGetOrExtendReservationMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return( + &datacatalog.Reservation{ + OwnerId: test.reservationOwnerID, + }, + nil, + ) + + nodeExecutor := &nodeExecutor{ + catalog: catalogClient, + metrics: metrics, + } + nCtx := setupCacheableNodeExecutionContext(nil, &core.TaskTemplate{}) + + // execute catalog cache check + reservationEntry, err := nodeExecutor.GetOrExtendCatalogReservation(context.TODO(), nCtx, cacheableHandler, time.Second*30) + assert.NoError(t, err) + + // validate the result cache entry status + assert.Equal(t, test.expectedReservationStatus, reservationEntry.GetStatus()) + }) + } +} + +func TestReleaseCatalogReservation(t *testing.T) { + tests := []struct { + name string + releaseError error + expectedReservationStatus core.CatalogReservation_Status + }{ + { + "Success", + nil, + core.CatalogReservation_RESERVATION_RELEASED, + }, + { + "Failure", + errors.New("failed to release"), + core.CatalogReservation_RESERVATION_FAILURE, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testScope := promutils.NewTestScope() + metrics := &nodeMetrics{ + catalogHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", testScope), + catalogMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", testScope), + catalogSkipCount: labeled.NewCounter("discovery_skip_count", "Task cached skipped in Discovery", testScope), + catalogPutSuccessCount: labeled.NewCounter("discovery_put_success_count", "Discovery Put success count", testScope), + catalogPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", testScope), + catalogGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", testScope), + reservationGetFailureCount: labeled.NewCounter("reservation_get_failure_count", "Reservation GetOrExtend failure count", testScope), + reservationGetSuccessCount: labeled.NewCounter("reservation_get_success_count", "Reservation GetOrExtend success count", testScope), + reservationReleaseFailureCount: labeled.NewCounter("reservation_release_failure_count", "Reservation Release failure count", testScope), + reservationReleaseSuccessCount: labeled.NewCounter("reservation_release_success_count", "Reservation Release success count", testScope), + } + + cacheableHandler := &interfacesmocks.CacheableNodeHandler{} + cacheableHandler.OnGetCatalogKeyMatch(mock.Anything, mock.Anything).Return(catalog.Key{}, nil) + + catalogClient := &catalogmocks.Client{} + catalogClient.OnReleaseReservationMatch(mock.Anything, mock.Anything, mock.Anything).Return(test.releaseError) + + nodeExecutor := &nodeExecutor{ + catalog: catalogClient, + metrics: metrics, + } + nCtx := setupCacheableNodeExecutionContext(nil, &core.TaskTemplate{}) + + // execute catalog cache check + reservationEntry, err := nodeExecutor.ReleaseCatalogReservation(context.TODO(), nCtx, cacheableHandler) + if test.releaseError == nil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + + // validate the result cache entry status + assert.Equal(t, test.expectedReservationStatus, reservationEntry.GetStatus()) + }) + } +} + +func TestWriteCatalogCache(t *testing.T) { + tests := []struct { + name string + cacheStatus catalog.Status + cacheError error + catalogKey catalog.Key + expectedCacheStatus core.CatalogCacheStatus + }{ + { + "NoOutputs", + catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil), + nil, + catalog.Key{}, + core.CatalogCacheStatus_CACHE_DISABLED, + }, + { + "OutputsExist", + catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, nil), + nil, + catalog.Key{ + TypedInterface: core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "foo": nil, + }, + }, + }, + }, + core.CatalogCacheStatus_CACHE_POPULATED, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testScope := promutils.NewTestScope() + metrics := &nodeMetrics{ + catalogHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", testScope), + catalogMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", testScope), + catalogSkipCount: labeled.NewCounter("discovery_skip_count", "Task cached skipped in Discovery", testScope), + catalogPutSuccessCount: labeled.NewCounter("discovery_put_success_count", "Discovery Put success count", testScope), + catalogPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", testScope), + catalogGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", testScope), + reservationGetFailureCount: labeled.NewCounter("reservation_get_failure_count", "Reservation GetOrExtend failure count", testScope), + reservationGetSuccessCount: labeled.NewCounter("reservation_get_success_count", "Reservation GetOrExtend success count", testScope), + reservationReleaseFailureCount: labeled.NewCounter("reservation_release_failure_count", "Reservation Release failure count", testScope), + reservationReleaseSuccessCount: labeled.NewCounter("reservation_release_success_count", "Reservation Release success count", testScope), + } + + cacheableHandler := &interfacesmocks.CacheableNodeHandler{} + cacheableHandler.OnGetCatalogKeyMatch(mock.Anything, mock.Anything).Return(test.catalogKey, nil) + + catalogClient := &catalogmocks.Client{} + catalogClient.OnPutMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(test.cacheStatus, nil) + + dataStore, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + testScope.NewSubScope("data_store"), + ) + assert.NoError(t, err) + + nodeExecutor := &nodeExecutor{ + catalog: catalogClient, + metrics: metrics, + } + nCtx := setupCacheableNodeExecutionContext(dataStore, &core.TaskTemplate{}) + + // execute catalog cache check + cacheStatus, err := nodeExecutor.WriteCatalogCache(context.TODO(), nCtx, cacheableHandler) + assert.NoError(t, err) + + // validate the result cache entry status + assert.Equal(t, test.expectedCacheStatus, cacheStatus.GetCacheStatus()) + }) + } +} diff --git a/pkg/controller/nodes/task/catalog/config.go b/pkg/controller/nodes/catalog/config.go similarity index 96% rename from pkg/controller/nodes/task/catalog/config.go rename to pkg/controller/nodes/catalog/config.go index 20663cc1e..a77ba6f74 100644 --- a/pkg/controller/nodes/task/catalog/config.go +++ b/pkg/controller/nodes/catalog/config.go @@ -8,7 +8,7 @@ import ( "github.com/flyteorg/flytestdlib/config" "google.golang.org/grpc" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog/datacatalog" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/catalog/datacatalog" ) //go:generate pflags Config --default-var defaultConfig diff --git a/pkg/controller/nodes/task/catalog/config_flags.go b/pkg/controller/nodes/catalog/config_flags.go similarity index 100% rename from pkg/controller/nodes/task/catalog/config_flags.go rename to pkg/controller/nodes/catalog/config_flags.go diff --git a/pkg/controller/nodes/task/catalog/config_flags_test.go b/pkg/controller/nodes/catalog/config_flags_test.go similarity index 100% rename from pkg/controller/nodes/task/catalog/config_flags_test.go rename to pkg/controller/nodes/catalog/config_flags_test.go diff --git a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go b/pkg/controller/nodes/catalog/datacatalog/datacatalog.go similarity index 100% rename from pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go rename to pkg/controller/nodes/catalog/datacatalog/datacatalog.go diff --git a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog_test.go b/pkg/controller/nodes/catalog/datacatalog/datacatalog_test.go similarity index 100% rename from pkg/controller/nodes/task/catalog/datacatalog/datacatalog_test.go rename to pkg/controller/nodes/catalog/datacatalog/datacatalog_test.go diff --git a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go b/pkg/controller/nodes/catalog/datacatalog/transformer.go similarity index 100% rename from pkg/controller/nodes/task/catalog/datacatalog/transformer.go rename to pkg/controller/nodes/catalog/datacatalog/transformer.go diff --git a/pkg/controller/nodes/task/catalog/datacatalog/transformer_test.go b/pkg/controller/nodes/catalog/datacatalog/transformer_test.go similarity index 100% rename from pkg/controller/nodes/task/catalog/datacatalog/transformer_test.go rename to pkg/controller/nodes/catalog/datacatalog/transformer_test.go diff --git a/pkg/controller/nodes/task/catalog/noop_catalog.go b/pkg/controller/nodes/catalog/noop_catalog.go similarity index 100% rename from pkg/controller/nodes/task/catalog/noop_catalog.go rename to pkg/controller/nodes/catalog/noop_catalog.go diff --git a/pkg/controller/nodes/dynamic/handler.go b/pkg/controller/nodes/dynamic/handler.go index 369b26ee6..d2c0eb38c 100644 --- a/pkg/controller/nodes/dynamic/handler.go +++ b/pkg/controller/nodes/dynamic/handler.go @@ -7,7 +7,6 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -31,10 +30,10 @@ import ( const dynamicNodeID = "dynamic-node" type TaskNodeHandler interface { - interfaces.NodeHandler - ValidateOutputAndCacheAdd(ctx context.Context, nodeID v1alpha1.NodeID, i io.InputReader, + interfaces.CacheableNodeHandler + ValidateOutput(ctx context.Context, nodeID v1alpha1.NodeID, i io.InputReader, r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, - tr ioutils.SimpleTaskReader, m catalog.Metadata) (catalog.Status, *io.ExecutionError, error) + tr ioutils.SimpleTaskReader) (*io.ExecutionError, error) } type metrics struct { @@ -145,12 +144,9 @@ func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, n // These outputPaths only reads the output metadata. So the sandbox is completely optional here and hence it is nil. // The sandbox creation as it uses hashing can be expensive and we skip that expense. outputPaths := ioutils.NewReadOnlyOutputFilePaths(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetOutputDir()) - execID := task.GetTaskExecutionIdentifier(nCtx) outputReader := ioutils.NewRemoteFileOutputReader(ctx, nCtx.DataStore(), outputPaths, nCtx.MaxDatasetSizeBytes()) - status, ee, err := d.TaskNodeHandler.ValidateOutputAndCacheAdd(ctx, nCtx.NodeID(), nCtx.InputReader(), - outputReader, nil, nCtx.ExecutionContext().GetExecutionConfig(), nCtx.TaskReader(), catalog.Metadata{ - TaskExecutionIdentifier: execID, - }) + ee, err := d.TaskNodeHandler.ValidateOutput(ctx, nCtx.NodeID(), nCtx.InputReader(), + outputReader, nil, nCtx.ExecutionContext().GetExecutionConfig(), nCtx.TaskReader()) if err != nil { return handler.UnknownTransition, prevState, err @@ -163,10 +159,9 @@ func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, n return trns.WithInfo(handler.PhaseInfoFailureErr(ee.ExecutionError, trns.Info().GetInfo())), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil } - taskNodeInfoMetadata := &event.TaskNodeMetadata{CacheStatus: status.GetCacheStatus(), CatalogKey: status.GetMetadata()} + trns = trns.WithInfo(trns.Info().WithInfo(&handler.ExecutionInfo{ - OutputInfo: trns.Info().GetInfo().OutputInfo, - TaskNodeInfo: &handler.TaskNodeInfo{TaskNodeMetadata: taskNodeInfoMetadata}, + OutputInfo: trns.Info().GetInfo().OutputInfo, })) } diff --git a/pkg/controller/nodes/dynamic/handler_test.go b/pkg/controller/nodes/dynamic/handler_test.go index ae0bb6912..980e7601a 100644 --- a/pkg/controller/nodes/dynamic/handler_test.go +++ b/pkg/controller/nodes/dynamic/handler_test.go @@ -7,7 +7,6 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/config" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" @@ -25,7 +24,6 @@ import ( lpMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" executorMocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" @@ -520,45 +518,18 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { return nCtx } - validCachePopulatedStatus := catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, &core.CatalogMetadata{ - DatasetId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Project: "project", - Domain: "domain", - Name: "name", - Version: "version", - }, - ArtifactTag: &core.CatalogArtifactTag{Name: "name", ArtifactId: "id"}, - }) execInfoOutputOnly := &handler.ExecutionInfo{ OutputInfo: &handler.OutputInfo{ OutputURI: "output-dir/outputs.pb", }, } - execInfoWithTaskNodeMeta := &handler.ExecutionInfo{ - OutputInfo: &handler.OutputInfo{ - OutputURI: "output-dir/outputs.pb", - }, - TaskNodeInfo: &handler.TaskNodeInfo{ - TaskNodeMetadata: &event.TaskNodeMetadata{ - CacheStatus: validCachePopulatedStatus.GetCacheStatus(), - CatalogKey: &core.CatalogMetadata{ - DatasetId: validCachePopulatedStatus.GetMetadata().DatasetId, - ArtifactTag: validCachePopulatedStatus.GetMetadata().ArtifactTag, - SourceExecution: validCachePopulatedStatus.GetMetadata().SourceExecution, - }, - ReservationStatus: core.CatalogReservation_RESERVATION_DISABLED, - }, - }, - } type args struct { - s interfaces.NodeStatus - isErr bool - dj *core.DynamicJobSpec - validErr *io.ExecutionError - validCacheStatus *catalog.Status - generateOutputs bool + s interfaces.NodeStatus + isErr bool + dj *core.DynamicJobSpec + validErr *io.ExecutionError + generateOutputs bool } type want struct { p handler.EPhase @@ -573,7 +544,7 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { }{ {"error", args{isErr: true, dj: createDynamicJobSpec()}, want{isErr: true}}, {"success", args{s: interfaces.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"complete", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true, validCacheStatus: &validCachePopulatedStatus}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting, info: execInfoWithTaskNodeMeta}}, + {"complete", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting, info: execInfoOutputOnly}}, {"complete-no-outputs", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, {"complete-valid-error-retryable", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, {"complete-valid-error", args{s: interfaces.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, @@ -596,17 +567,9 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} if tt.args.validErr != nil { - h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil), tt.args.validErr, nil) + h.OnValidateOutputMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.validErr, nil) } else { - var validCacheStatus catalog.Status - if tt.args.validCacheStatus == nil { - validCacheStatus = catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, &core.CatalogMetadata{ - ArtifactTag: &core.CatalogArtifactTag{Name: "name", ArtifactId: "id"}, - }) - } else { - validCacheStatus = *tt.args.validCacheStatus - } - h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(validCacheStatus, nil, nil) + h.OnValidateOutputMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil) } n := &nodeMocks.Node{} if tt.args.isErr { @@ -795,9 +758,9 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { mockLPLauncher := &lpMocks.Reader{} h := &mocks.TaskNodeHandler{} if tt.args.validErr != nil { - h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil), tt.args.validErr, nil) + h.OnValidateOutputMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.validErr, nil) } else { - h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, &core.CatalogMetadata{ArtifactTag: &core.CatalogArtifactTag{Name: "name", ArtifactId: "id"}}), nil, nil) + h.OnValidateOutputMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil) } n := &nodeMocks.Node{} if tt.args.isErr { diff --git a/pkg/controller/nodes/dynamic/mocks/task_node_handler.go b/pkg/controller/nodes/dynamic/mocks/task_node_handler.go index e8d8cc6d7..eeba97048 100644 --- a/pkg/controller/nodes/dynamic/mocks/task_node_handler.go +++ b/pkg/controller/nodes/dynamic/mocks/task_node_handler.go @@ -121,6 +121,45 @@ func (_m *TaskNodeHandler) FinalizeRequired() bool { return r0 } +type TaskNodeHandler_GetCatalogKey struct { + *mock.Call +} + +func (_m TaskNodeHandler_GetCatalogKey) Return(_a0 catalog.Key, _a1 error) *TaskNodeHandler_GetCatalogKey { + return &TaskNodeHandler_GetCatalogKey{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *TaskNodeHandler) OnGetCatalogKey(ctx context.Context, executionContext interfaces.NodeExecutionContext) *TaskNodeHandler_GetCatalogKey { + c_call := _m.On("GetCatalogKey", ctx, executionContext) + return &TaskNodeHandler_GetCatalogKey{Call: c_call} +} + +func (_m *TaskNodeHandler) OnGetCatalogKeyMatch(matchers ...interface{}) *TaskNodeHandler_GetCatalogKey { + c_call := _m.On("GetCatalogKey", matchers...) + return &TaskNodeHandler_GetCatalogKey{Call: c_call} +} + +// GetCatalogKey provides a mock function with given fields: ctx, executionContext +func (_m *TaskNodeHandler) GetCatalogKey(ctx context.Context, executionContext interfaces.NodeExecutionContext) (catalog.Key, error) { + ret := _m.Called(ctx, executionContext) + + var r0 catalog.Key + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) catalog.Key); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Get(0).(catalog.Key) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r1 = rf(ctx, executionContext) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + type TaskNodeHandler_Handle struct { *mock.Call } @@ -160,6 +199,52 @@ func (_m *TaskNodeHandler) Handle(ctx context.Context, executionContext interfac return r0, r1 } +type TaskNodeHandler_IsCacheable struct { + *mock.Call +} + +func (_m TaskNodeHandler_IsCacheable) Return(_a0 bool, _a1 bool, _a2 error) *TaskNodeHandler_IsCacheable { + return &TaskNodeHandler_IsCacheable{Call: _m.Call.Return(_a0, _a1, _a2)} +} + +func (_m *TaskNodeHandler) OnIsCacheable(ctx context.Context, executionContext interfaces.NodeExecutionContext) *TaskNodeHandler_IsCacheable { + c_call := _m.On("IsCacheable", ctx, executionContext) + return &TaskNodeHandler_IsCacheable{Call: c_call} +} + +func (_m *TaskNodeHandler) OnIsCacheableMatch(matchers ...interface{}) *TaskNodeHandler_IsCacheable { + c_call := _m.On("IsCacheable", matchers...) + return &TaskNodeHandler_IsCacheable{Call: c_call} +} + +// IsCacheable provides a mock function with given fields: ctx, executionContext +func (_m *TaskNodeHandler) IsCacheable(ctx context.Context, executionContext interfaces.NodeExecutionContext) (bool, bool, error) { + ret := _m.Called(ctx, executionContext) + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) bool); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 bool + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) bool); ok { + r1 = rf(ctx, executionContext) + } else { + r1 = ret.Get(1).(bool) + } + + var r2 error + if rf, ok := ret.Get(2).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r2 = rf(ctx, executionContext) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + type TaskNodeHandler_Setup struct { *mock.Call } @@ -192,50 +277,43 @@ func (_m *TaskNodeHandler) Setup(ctx context.Context, setupContext interfaces.Se return r0 } -type TaskNodeHandler_ValidateOutputAndCacheAdd struct { +type TaskNodeHandler_ValidateOutput struct { *mock.Call } -func (_m TaskNodeHandler_ValidateOutputAndCacheAdd) Return(_a0 catalog.Status, _a1 *io.ExecutionError, _a2 error) *TaskNodeHandler_ValidateOutputAndCacheAdd { - return &TaskNodeHandler_ValidateOutputAndCacheAdd{Call: _m.Call.Return(_a0, _a1, _a2)} +func (_m TaskNodeHandler_ValidateOutput) Return(_a0 *io.ExecutionError, _a1 error) *TaskNodeHandler_ValidateOutput { + return &TaskNodeHandler_ValidateOutput{Call: _m.Call.Return(_a0, _a1)} } -func (_m *TaskNodeHandler) OnValidateOutputAndCacheAdd(ctx context.Context, nodeID string, i io.InputReader, r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, tr ioutils.SimpleTaskReader, m catalog.Metadata) *TaskNodeHandler_ValidateOutputAndCacheAdd { - c_call := _m.On("ValidateOutputAndCacheAdd", ctx, nodeID, i, r, outputCommitter, executionConfig, tr, m) - return &TaskNodeHandler_ValidateOutputAndCacheAdd{Call: c_call} +func (_m *TaskNodeHandler) OnValidateOutput(ctx context.Context, nodeID string, i io.InputReader, r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, tr ioutils.SimpleTaskReader) *TaskNodeHandler_ValidateOutput { + c_call := _m.On("ValidateOutput", ctx, nodeID, i, r, outputCommitter, executionConfig, tr) + return &TaskNodeHandler_ValidateOutput{Call: c_call} } -func (_m *TaskNodeHandler) OnValidateOutputAndCacheAddMatch(matchers ...interface{}) *TaskNodeHandler_ValidateOutputAndCacheAdd { - c_call := _m.On("ValidateOutputAndCacheAdd", matchers...) - return &TaskNodeHandler_ValidateOutputAndCacheAdd{Call: c_call} +func (_m *TaskNodeHandler) OnValidateOutputMatch(matchers ...interface{}) *TaskNodeHandler_ValidateOutput { + c_call := _m.On("ValidateOutput", matchers...) + return &TaskNodeHandler_ValidateOutput{Call: c_call} } -// ValidateOutputAndCacheAdd provides a mock function with given fields: ctx, nodeID, i, r, outputCommitter, executionConfig, tr, m -func (_m *TaskNodeHandler) ValidateOutputAndCacheAdd(ctx context.Context, nodeID string, i io.InputReader, r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, tr ioutils.SimpleTaskReader, m catalog.Metadata) (catalog.Status, *io.ExecutionError, error) { - ret := _m.Called(ctx, nodeID, i, r, outputCommitter, executionConfig, tr, m) +// ValidateOutput provides a mock function with given fields: ctx, nodeID, i, r, outputCommitter, executionConfig, tr +func (_m *TaskNodeHandler) ValidateOutput(ctx context.Context, nodeID string, i io.InputReader, r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, tr ioutils.SimpleTaskReader) (*io.ExecutionError, error) { + ret := _m.Called(ctx, nodeID, i, r, outputCommitter, executionConfig, tr) - var r0 catalog.Status - if rf, ok := ret.Get(0).(func(context.Context, string, io.InputReader, io.OutputReader, io.OutputWriter, v1alpha1.ExecutionConfig, ioutils.SimpleTaskReader, catalog.Metadata) catalog.Status); ok { - r0 = rf(ctx, nodeID, i, r, outputCommitter, executionConfig, tr, m) + var r0 *io.ExecutionError + if rf, ok := ret.Get(0).(func(context.Context, string, io.InputReader, io.OutputReader, io.OutputWriter, v1alpha1.ExecutionConfig, ioutils.SimpleTaskReader) *io.ExecutionError); ok { + r0 = rf(ctx, nodeID, i, r, outputCommitter, executionConfig, tr) } else { - r0 = ret.Get(0).(catalog.Status) - } - - var r1 *io.ExecutionError - if rf, ok := ret.Get(1).(func(context.Context, string, io.InputReader, io.OutputReader, io.OutputWriter, v1alpha1.ExecutionConfig, ioutils.SimpleTaskReader, catalog.Metadata) *io.ExecutionError); ok { - r1 = rf(ctx, nodeID, i, r, outputCommitter, executionConfig, tr, m) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(*io.ExecutionError) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*io.ExecutionError) } } - var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, io.InputReader, io.OutputReader, io.OutputWriter, v1alpha1.ExecutionConfig, ioutils.SimpleTaskReader, catalog.Metadata) error); ok { - r2 = rf(ctx, nodeID, i, r, outputCommitter, executionConfig, tr, m) + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, io.InputReader, io.OutputReader, io.OutputWriter, v1alpha1.ExecutionConfig, ioutils.SimpleTaskReader) error); ok { + r1 = rf(ctx, nodeID, i, r, outputCommitter, executionConfig, tr) } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 23fbafc46..011acd0bb 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -57,6 +57,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +const cacheSerializedReason = "waiting on serialized cache" + type nodeMetrics struct { Scope promutils.Scope FailureDuration labeled.StopWatch @@ -83,6 +85,17 @@ type nodeMetrics struct { QueuingLatency labeled.StopWatch NodeExecutionTime labeled.StopWatch NodeInputGatherLatency labeled.StopWatch + + catalogPutFailureCount labeled.Counter + catalogGetFailureCount labeled.Counter + catalogPutSuccessCount labeled.Counter + catalogMissCount labeled.Counter + catalogHitCount labeled.Counter + catalogSkipCount labeled.Counter + reservationGetSuccessCount labeled.Counter + reservationGetFailureCount labeled.Counter + reservationReleaseSuccessCount labeled.Counter + reservationReleaseFailureCount labeled.Counter } // recursiveNodeExector implements the executors.Node interfaces and is the starting point for @@ -464,6 +477,7 @@ func (c *recursiveNodeExecutor) WithNodeExecutionContextBuilder(nCtxBuilder inte // nodeExecutor implements the NodeExecutor interface and is responsible for executing a single node. type nodeExecutor struct { + catalog catalog.Client clusterID string defaultActiveDeadline time.Duration defaultDataSandbox storage.DataReference @@ -498,40 +512,6 @@ func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, dag executor } } -/*func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *event.NodeExecutionEvent) error { - if nodeEvent == nil { - return fmt.Errorf("event recording attempt of Nil Node execution event") - } - - if nodeEvent.Id == nil { - return fmt.Errorf("event recording attempt of with nil node Event ID") - } - - logger.Infof(ctx, "Recording NodeEvent [%s] phase[%s]", nodeEvent.GetId().String(), nodeEvent.Phase.String()) - err := c.nodeRecorder.RecordNodeEvent(ctx, nodeEvent, c.eventConfig) - if err != nil { - if nodeEvent.GetId().NodeId == v1alpha1.EndNodeID { - return nil - } - - if eventsErr.IsAlreadyExists(err) { - logger.Infof(ctx, "Node event phase: %s, nodeId %s already exist", - nodeEvent.Phase.String(), nodeEvent.GetId().NodeId) - return nil - } else if eventsErr.IsEventAlreadyInTerminalStateError(err) { - if IsTerminalNodePhase(nodeEvent.Phase) { - // Event was trying to record a different terminal phase for an already terminal event. ignoring. - logger.Infof(ctx, "Node event phase: %s, nodeId %s already in terminal phase. err: %s", - nodeEvent.Phase.String(), nodeEvent.GetId().NodeId, err.Error()) - return nil - } - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return errors.Wrapf(errors.IllegalStateError, nodeEvent.Id.NodeId, err, "phase mis-match mismatch between propeller and control plane; Trying to record Node p: %s", nodeEvent.Phase) - } - } - return err -}*/ - func (c *nodeExecutor) recoverInputs(ctx context.Context, nCtx interfaces.NodeExecutionContext, recovered *admin.NodeExecution, recoveredData *admin.NodeExecutionGetDataResponse) (*core.LiteralMap, error) { @@ -939,7 +919,9 @@ func (c *nodeExecutor) Finalize(ctx context.Context, h interfaces.NodeHandler, n return h.Finalize(ctx, nCtx) } -func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, _ interfaces.NodeHandler) (interfaces.NodeStatus, error) { +func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { + nodeStatus := nCtx.NodeStatus() + logger.Debugf(ctx, "Node not yet started, running pre-execute") defer logger.Debugf(ctx, "Node pre-execute completed") occurredAt := time.Now() @@ -957,12 +939,54 @@ func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executor return interfaces.NodeStatusPending, nil } + var cacheStatus *catalog.Status + if cacheHandler, ok := h.(interfaces.CacheableNodeHandler); ok { + cacheable, _, err := cacheHandler.IsCacheable(ctx, nCtx) + if err != nil { + logger.Errorf(ctx, "failed to determine if node is cacheable with err '%s'", err.Error()) + return interfaces.NodeStatusUndefined, err + } else if cacheable && nCtx.ExecutionContext().GetExecutionConfig().OverwriteCache { + logger.Info(ctx, "execution config forced cache skip, not checking catalog") + status := catalog.NewStatus(core.CatalogCacheStatus_CACHE_SKIPPED, nil) + cacheStatus = &status + c.metrics.catalogSkipCount.Inc(ctx) + } else if cacheable { + entry, err := c.CheckCatalogCache(ctx, nCtx, cacheHandler) + if err != nil { + logger.Errorf(ctx, "failed to check the catalog cache with err '%s'", err.Error()) + return interfaces.NodeStatusUndefined, err + } + + status := entry.GetStatus() + cacheStatus = &status + if entry.GetStatus().GetCacheStatus() == core.CatalogCacheStatus_CACHE_HIT { + // if cache hit we immediately transition the node to successful + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + p = handler.PhaseInfoSuccess(&handler.ExecutionInfo{ + OutputInfo: &handler.OutputInfo{ + OutputURI: outputFile, + }, + TaskNodeInfo: &handler.TaskNodeInfo{ + TaskNodeMetadata: &event.TaskNodeMetadata{ + CacheStatus: entry.GetStatus().GetCacheStatus(), + CatalogKey: entry.GetStatus().GetMetadata(), + }, + }, + }) + } + } + } + np, err := ToNodePhase(p.GetPhase()) if err != nil { return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") } - nodeStatus := nCtx.NodeStatus() + if np == v1alpha1.NodePhaseSucceeding && (!h.FinalizeRequired() || (cacheStatus != nil && cacheStatus.GetCacheStatus() == core.CatalogCacheStatus_CACHE_HIT)) { + logger.Infof(ctx, "Finalize not required, moving node to Succeeded") + np = v1alpha1.NodePhaseSucceeded + } + if np != nodeStatus.GetPhase() { // assert np == Queued! logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) @@ -999,25 +1023,136 @@ func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executor func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx interfaces.NodeExecutionContext, h interfaces.NodeHandler) (interfaces.NodeStatus, error) { nodeStatus := nCtx.NodeStatus() currentPhase := nodeStatus.GetPhase() + p := handler.PhaseInfoUndefined // case v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning: logger.Debugf(ctx, "node executing, current phase [%s]", currentPhase) defer logger.Debugf(ctx, "node execution completed") + var cacheStatus *catalog.Status + var catalogReservationStatus *core.CatalogReservation_Status + cacheHandler, cacheHandlerOk := h.(interfaces.CacheableNodeHandler) + if cacheHandlerOk { + // if node is cacheable we attempt to check the cache if in queued phase or get / extend a + // catalog reservation + _, cacheSerializable, err := cacheHandler.IsCacheable(ctx, nCtx) + if err != nil { + logger.Errorf(ctx, "failed to determine if node is cacheable with err '%s'", err.Error()) + return interfaces.NodeStatusUndefined, err + } + + if cacheSerializable && nCtx.ExecutionContext().GetExecutionConfig().OverwriteCache { + status := catalog.NewStatus(core.CatalogCacheStatus_CACHE_SKIPPED, nil) + cacheStatus = &status + } else if cacheSerializable && currentPhase == v1alpha1.NodePhaseQueued && nodeStatus.GetMessage() == cacheSerializedReason { + // since we already check the cache before transitioning to Phase Queued we only need to check it again if + // the cache is serialized and that causes the node to stay in the Queued phase. the easiest way to detect + // this is verifying the NodeStatus Reason is what we set it during cache serialization. + + entry, err := c.CheckCatalogCache(ctx, nCtx, cacheHandler) + if err != nil { + logger.Errorf(ctx, "failed to check the catalog cache with err '%s'", err.Error()) + return interfaces.NodeStatusUndefined, err + } + + status := entry.GetStatus() + cacheStatus = &status + if entry.GetStatus().GetCacheStatus() == core.CatalogCacheStatus_CACHE_HIT { + // if cache hit we immediately transition the node to successful + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + p = handler.PhaseInfoSuccess(&handler.ExecutionInfo{ + OutputInfo: &handler.OutputInfo{ + OutputURI: outputFile, + }, + TaskNodeInfo: &handler.TaskNodeInfo{ + TaskNodeMetadata: &event.TaskNodeMetadata{ + CacheStatus: entry.GetStatus().GetCacheStatus(), + CatalogKey: entry.GetStatus().GetMetadata(), + }, + }, + }) + } + } + + if cacheSerializable && !nCtx.ExecutionContext().GetExecutionConfig().OverwriteCache && + (cacheStatus == nil || (cacheStatus.GetCacheStatus() != core.CatalogCacheStatus_CACHE_HIT)) { + + entry, err := c.GetOrExtendCatalogReservation(ctx, nCtx, cacheHandler, config.GetConfig().WorkflowReEval.Duration) + if err != nil { + logger.Errorf(ctx, "failed to check for catalog reservation with err '%s'", err.Error()) + return interfaces.NodeStatusUndefined, err + } + + status := entry.GetStatus() + catalogReservationStatus = &status + if status == core.CatalogReservation_RESERVATION_ACQUIRED && currentPhase == v1alpha1.NodePhaseQueued { + logger.Infof(ctx, "acquired cache reservation") + } else if status == core.CatalogReservation_RESERVATION_EXISTS { + // if reservation is held by another owner we stay in the queued phase + p = handler.PhaseInfoQueued(cacheSerializedReason, nil) + } + } + } + // Since we reset node status inside execute for retryable failure, we use lastAttemptStartTime to carry that information // across execute which is used to emit metrics lastAttemptStartTime := nodeStatus.GetLastAttemptStartedAt() - p, err := c.execute(ctx, h, nCtx, nodeStatus) - if err != nil { - logger.Errorf(ctx, "failed Execute for node. Error: %s", err.Error()) - return interfaces.NodeStatusUndefined, err + // we execute the node if: + // (1) caching is disabled (ie. cacheStatus == nil) + // (2) there was no cache hit and the cache is not blocked by a cache reservation + // (3) the node is already running, this covers the scenario where the node held the cache + // reservation, but it expired and was captured by a different node + if currentPhase != v1alpha1.NodePhaseQueued || + ((cacheStatus == nil || cacheStatus.GetCacheStatus() != core.CatalogCacheStatus_CACHE_HIT) && + (catalogReservationStatus == nil || *catalogReservationStatus != core.CatalogReservation_RESERVATION_EXISTS)) { + + var err error + p, err = c.execute(ctx, h, nCtx, nodeStatus) + if err != nil { + logger.Errorf(ctx, "failed Execute for node. Error: %s", err.Error()) + return interfaces.NodeStatusUndefined, err + } } if p.GetPhase() == handler.EPhaseUndefined { return interfaces.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") } + if p.GetPhase() == handler.EPhaseSuccess && cacheHandlerOk { + // if node is cacheable we attempt to write outputs to the cache and release catalog reservation + cacheable, cacheSerializable, err := cacheHandler.IsCacheable(ctx, nCtx) + if err != nil { + logger.Errorf(ctx, "failed to determine if node is cacheable with err '%s'", err.Error()) + return interfaces.NodeStatusUndefined, err + } + + if cacheable && (cacheStatus == nil || cacheStatus.GetCacheStatus() != core.CatalogCacheStatus_CACHE_HIT) { + status, err := c.WriteCatalogCache(ctx, nCtx, cacheHandler) + if err != nil { + // ignore failure to write to catalog + logger.Warnf(ctx, "failed to write to the catalog cache with err '%s'", err.Error()) + } + + cacheStatus = &status + } + + if cacheSerializable && (cacheStatus == nil || cacheStatus.GetCacheStatus() != core.CatalogCacheStatus_CACHE_HIT) { + entry, err := c.ReleaseCatalogReservation(ctx, nCtx, cacheHandler) + if err != nil { + // ignore failure to release the catalog reservation + logger.Warnf(ctx, "failed to write to the catalog cache with err '%s'", err.Error()) + } else { + status := entry.GetStatus() + catalogReservationStatus = &status + } + } + } + + // update phase info with catalog cache and reservation information while maintaining all + // other metadata + p = updatePhaseCacheInfo(p, cacheStatus, catalogReservationStatus) + np, err := ToNodePhase(p.GetPhase()) if err != nil { return interfaces.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") @@ -1068,11 +1203,12 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx inter finalStatus = interfaces.NodeStatusTimedOut } - if np == v1alpha1.NodePhaseSucceeding && !h.FinalizeRequired() { + if np == v1alpha1.NodePhaseSucceeding && (!h.FinalizeRequired() || (cacheStatus != nil && cacheStatus.GetCacheStatus() == core.CatalogCacheStatus_CACHE_HIT)) { logger.Infof(ctx, "Finalize not required, moving node to Succeeded") np = v1alpha1.NodePhaseSucceeded finalStatus = interfaces.NodeStatusSuccess } + if np == v1alpha1.NodePhaseRecovered { logger.Infof(ctx, "Finalize not required, moving node to Recovered") finalStatus = interfaces.NodeStatusRecovered @@ -1253,29 +1389,40 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora nodeScope := scope.NewSubScope("node") metrics := &nodeMetrics{ - Scope: nodeScope, - FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - RecoveryDuration: labeled.NewStopWatch("recovery_duration", "Indicates the total execution time of a recovered workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - UserErrorDuration: labeled.NewStopWatch("user_error_duration", "Indicates the total execution time before user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - SystemErrorDuration: labeled.NewStopWatch("system_error_duration", "Indicates the total execution time before system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - UnknownErrorDuration: labeled.NewStopWatch("unknown_error_duration", "Indicates the total execution time before unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - PermanentUserErrorDuration: labeled.NewStopWatch("perma_user_error_duration", "Indicates the total execution time before non recoverable user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - PermanentSystemErrorDuration: labeled.NewStopWatch("perma_system_error_duration", "Indicates the total execution time before non recoverable system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - PermanentUnknownErrorDuration: labeled.NewStopWatch("perma_unknown_error_duration", "Indicates the total execution time before non recoverable unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), - TimedOutFailure: labeled.NewCounter("timeout_fail", "Indicates failure due to timeout", nodeScope), - InterruptedThresholdHit: labeled.NewCounter("interrupted_threshold", "Indicates the node interruptible disabled because it hit max failure count", nodeScope), - InterruptibleNodesRunning: labeled.NewCounter("interruptible_nodes_running", "number of interruptible nodes running", nodeScope), - InterruptibleNodesTerminated: labeled.NewCounter("interruptible_nodes_terminated", "number of interruptible nodes finished running", nodeScope), - ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), - TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), - NodeExecutionTime: labeled.NewStopWatch("node_exec_latency", "Measures the time taken to execute one node, a node can be complex so it may encompass sub-node latency.", time.Microsecond, nodeScope, labeled.EmitUnlabeledMetric), - NodeInputGatherLatency: labeled.NewStopWatch("node_input_latency", "Measures the latency to aggregate inputs and check readiness of a node", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + Scope: nodeScope, + FailureDuration: labeled.NewStopWatch("failure_duration", "Indicates the total execution time of a failed workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + SuccessDuration: labeled.NewStopWatch("success_duration", "Indicates the total execution time of a successful workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + RecoveryDuration: labeled.NewStopWatch("recovery_duration", "Indicates the total execution time of a recovered workflow.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + UserErrorDuration: labeled.NewStopWatch("user_error_duration", "Indicates the total execution time before user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + SystemErrorDuration: labeled.NewStopWatch("system_error_duration", "Indicates the total execution time before system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + UnknownErrorDuration: labeled.NewStopWatch("unknown_error_duration", "Indicates the total execution time before unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + PermanentUserErrorDuration: labeled.NewStopWatch("perma_user_error_duration", "Indicates the total execution time before non recoverable user error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + PermanentSystemErrorDuration: labeled.NewStopWatch("perma_system_error_duration", "Indicates the total execution time before non recoverable system error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + PermanentUnknownErrorDuration: labeled.NewStopWatch("perma_unknown_error_duration", "Indicates the total execution time before non recoverable unknown error", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + InputsWriteFailure: labeled.NewCounter("inputs_write_fail", "Indicates failure in writing node inputs to metastore", nodeScope), + TimedOutFailure: labeled.NewCounter("timeout_fail", "Indicates failure due to timeout", nodeScope), + InterruptedThresholdHit: labeled.NewCounter("interrupted_threshold", "Indicates the node interruptible disabled because it hit max failure count", nodeScope), + InterruptibleNodesRunning: labeled.NewCounter("interruptible_nodes_running", "number of interruptible nodes running", nodeScope), + InterruptibleNodesTerminated: labeled.NewCounter("interruptible_nodes_terminated", "number of interruptible nodes finished running", nodeScope), + ResolutionFailure: labeled.NewCounter("input_resolve_fail", "Indicates failure in resolving node inputs", nodeScope), + TransitionLatency: labeled.NewStopWatch("transition_latency", "Measures the latency between the last parent node stoppedAt time and current node's queued time.", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + QueuingLatency: labeled.NewStopWatch("queueing_latency", "Measures the latency between the time a node's been queued to the time the handler reported the executable moved to running state", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + NodeExecutionTime: labeled.NewStopWatch("node_exec_latency", "Measures the time taken to execute one node, a node can be complex so it may encompass sub-node latency.", time.Microsecond, nodeScope, labeled.EmitUnlabeledMetric), + NodeInputGatherLatency: labeled.NewStopWatch("node_input_latency", "Measures the latency to aggregate inputs and check readiness of a node", time.Millisecond, nodeScope, labeled.EmitUnlabeledMetric), + catalogHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", scope), + catalogMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", scope), + catalogSkipCount: labeled.NewCounter("discovery_skip_count", "Task cached skipped in Discovery", scope), + catalogPutSuccessCount: labeled.NewCounter("discovery_put_success_count", "Discovery Put success count", scope), + catalogPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", scope), + catalogGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", scope), + reservationGetFailureCount: labeled.NewCounter("reservation_get_failure_count", "Reservation GetOrExtend failure count", scope), + reservationGetSuccessCount: labeled.NewCounter("reservation_get_success_count", "Reservation GetOrExtend success count", scope), + reservationReleaseFailureCount: labeled.NewCounter("reservation_release_failure_count", "Reservation Release failure count", scope), + reservationReleaseSuccessCount: labeled.NewCounter("reservation_release_success_count", "Reservation Release success count", scope), } nodeExecutor := &nodeExecutor{ + catalog: catalogClient, clusterID: clusterID, defaultActiveDeadline: nodeConfig.DefaultDeadlines.DefaultNodeActiveDeadline.Duration, defaultDataSandbox: defaultRawOutputPrefix, @@ -1302,7 +1449,5 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora store: store, metrics: metrics, } - /*nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, recoveryClient, eventConfig, clusterID, signalClient, nodeScope) - exec.nodeHandlerFactory = nodeHandlerFactory*/ return exec, err } diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index 71b7c3df7..3a474b39e 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -11,8 +11,11 @@ import ( "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/datacatalog" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + pluginscatalog "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" + catalogmocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog/mocks" mocks3 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flytepropeller/events" @@ -23,13 +26,13 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" mocks4 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/catalog" gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" nodemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" recoveryMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" flyteassert "github.com/flyteorg/flytepropeller/pkg/utils/assert" "github.com/flyteorg/flytestdlib/contextutils" @@ -2549,3 +2552,216 @@ func (e existsMetadata) Size() int64 { func (e existsMetadata) Etag() string { return "" } + +func TestNodeExecutor_RecursiveNodeHandler_Cache(t *testing.T) { + currentNodeID := "node-0" + downstreamNodeID := "node-1" + taskID := taskID + + createMockWorkflow := func(currentNodePhase, downstreamNodePhase v1alpha1.NodePhase, dataStore *storage.DataStore) *v1alpha1.FlyteWorkflow { + return &v1alpha1.FlyteWorkflow{ + Tasks: map[v1alpha1.TaskID]*v1alpha1.TaskSpec{ + taskID: { + TaskTemplate: &core.TaskTemplate{}, + }, + }, + Status: v1alpha1.WorkflowStatus{ + NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ + currentNodeID: &v1alpha1.NodeStatus{ + Phase: currentNodePhase, + Message: cacheSerializedReason, + }, + downstreamNodeID: &v1alpha1.NodeStatus{ + Phase: downstreamNodePhase, + }, + }, + DataDir: "data", + }, + WorkflowSpec: &v1alpha1.WorkflowSpec{ + ID: "wf", + Nodes: map[v1alpha1.NodeID]*v1alpha1.NodeSpec{ + currentNodeID: &v1alpha1.NodeSpec{ + ID: currentNodeID, + TaskRef: &taskID, + Kind: v1alpha1.NodeKindTask, + }, + downstreamNodeID: &v1alpha1.NodeSpec{ + ID: downstreamNodeID, + TaskRef: &taskID, + Kind: v1alpha1.NodeKindTask, + }, + }, + Connections: v1alpha1.Connections{ + Upstream: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + downstreamNodeID: {currentNodeID}, + }, + Downstream: map[v1alpha1.NodeID][]v1alpha1.NodeID{ + currentNodeID: {downstreamNodeID}, + }, + }, + }, + DataReferenceConstructor: dataStore, + RawOutputDataConfig: v1alpha1.RawOutputDataConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""}, + }, + } + } + + setupNodeExecutor := func(t *testing.T, catalogClient pluginscatalog.Client, dataStore *storage.DataStore, mockHandler interfaces.CacheableNodeHandler, testScope promutils.Scope) interfaces.Node { + ctx := context.TODO() + + // create mocks + adminClient := launchplan.NewFailFastLaunchPlanExecutor() + enqueueWorkflow := func(workflowID v1alpha1.WorkflowID) {} + eventConfig := &config.EventConfig{ + RawOutputPolicy: config.RawOutputPolicyReference, + } + fakeKubeClient := mocks4.NewFakeKubeClient() + maxDatasetSize := int64(10) + mockEventSink := eventMocks.NewMockEventSink() + nodeConfig := config.GetConfig().NodeConfig + rawOutputPrefix := storage.DataReference("s3://bucket/") + recoveryClient := &recoveryMocks.Client{} + testClusterID := "cluster1" + + // initialize node executor + mockHandlerFactory := &nodemocks.HandlerFactory{} + mockHandlerFactory.OnGetHandler(v1alpha1.NodeKindTask).Return(mockHandler, nil) + nodeExecutor, err := NewExecutor(ctx, nodeConfig, dataStore, enqueueWorkflow, mockEventSink, + adminClient, adminClient, maxDatasetSize, rawOutputPrefix, fakeKubeClient, catalogClient, + recoveryClient, eventConfig, testClusterID, signalClient, mockHandlerFactory, testScope) + assert.NoError(t, err) + + return nodeExecutor + } + + tests := []struct { + name string + cacheable bool + cacheStatus core.CatalogCacheStatus + cacheSerializable bool + cacheReservationOwnerID string + currentNodePhase v1alpha1.NodePhase + nextNodePhase v1alpha1.NodePhase + currentDownstreamNodePhase v1alpha1.NodePhase + nextDownstreamNodePhase v1alpha1.NodePhase + }{ + { + "NotYetStarted->CacheMiss->Queued", + true, + core.CatalogCacheStatus_CACHE_MISS, + false, + "", + v1alpha1.NodePhaseNotYetStarted, + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseNotYetStarted, + v1alpha1.NodePhaseNotYetStarted, + }, + { + "NotYetStarted->CacheHit->Success", + true, + core.CatalogCacheStatus_CACHE_HIT, + false, + "", + v1alpha1.NodePhaseNotYetStarted, + v1alpha1.NodePhaseSucceeded, + v1alpha1.NodePhaseNotYetStarted, + v1alpha1.NodePhaseNotYetStarted, + }, + { + "Queued->CacheHit->Success", + true, + core.CatalogCacheStatus_CACHE_HIT, + true, + "another-node", + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseSucceeded, + v1alpha1.NodePhaseNotYetStarted, + v1alpha1.NodePhaseNotYetStarted, + }, + { + "Queued->CacheMiss->Queued", + true, + core.CatalogCacheStatus_CACHE_MISS, + true, + "another-node", + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseNotYetStarted, + v1alpha1.NodePhaseNotYetStarted, + }, + { + "Queued->ReservationAcquired->Running", + true, + core.CatalogCacheStatus_CACHE_MISS, + true, + fmt.Sprintf("%s-%d", currentNodeID, 0), + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseNotYetStarted, + v1alpha1.NodePhaseNotYetStarted, + }, + { + "Running->ReservationExists->Running", + true, + core.CatalogCacheStatus_CACHE_MISS, + true, + "another-node", + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseRunning, + v1alpha1.NodePhaseNotYetStarted, + v1alpha1.NodePhaseNotYetStarted, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testScope := promutils.NewTestScope() + + dataStore := createInmemoryDataStore(t, testScope.NewSubScope("data_store")) + mockWorkflow := createMockWorkflow(test.currentNodePhase, test.currentDownstreamNodePhase, dataStore) + + // retrieve current node references + currentNodeSpec, ok := mockWorkflow.WorkflowSpec.Nodes[currentNodeID] + assert.Equal(t, true, ok) + + currentNodeStatus, ok := mockWorkflow.Status.NodeStatus[currentNodeID] + assert.Equal(t, true, ok) + + downstreamNodeStatus, ok := mockWorkflow.Status.NodeStatus[downstreamNodeID] + assert.Equal(t, true, ok) + + // initialize nodeExecutor + catalogClient := &catalogmocks.Client{} + catalogClient.OnGetMatch(mock.Anything, mock.Anything). + Return(pluginscatalog.NewCatalogEntry(nil, pluginscatalog.NewStatus(test.cacheStatus, nil)), nil) + catalogClient.OnGetOrExtendReservationMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(&datacatalog.Reservation{OwnerId: test.cacheReservationOwnerID}, nil) + + mockHandler := &nodemocks.CacheableNodeHandler{} + mockHandler.OnIsCacheableMatch( + mock.Anything, + mock.MatchedBy(func(nCtx interfaces.NodeExecutionContext) bool { return nCtx.NodeID() == currentNodeID }), + ).Return(test.cacheable, test.cacheSerializable, nil) + mockHandler.OnIsCacheableMatch( + mock.Anything, + mock.MatchedBy(func(nCtx interfaces.NodeExecutionContext) bool { return nCtx.NodeID() == downstreamNodeID }), + ).Return(false, false, nil) + mockHandler.OnGetCatalogKeyMatch(mock.Anything, mock.Anything). + Return(pluginscatalog.Key{Identifier: core.Identifier{Name: currentNodeID}}, nil) + mockHandler.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil) + mockHandler.OnFinalizeRequiredMatch(mock.Anything).Return(false) + + nodeExecutor := setupNodeExecutor(t, catalogClient, dataStore, mockHandler, testScope.NewSubScope("node_executor")) + + execContext := executors.NewExecutionContext(mockWorkflow, mockWorkflow, mockWorkflow, nil, executors.InitializeControlFlow()) + + // execute RecursiveNodeHandler + _, err := nodeExecutor.RecursiveNodeHandler(context.Background(), execContext, mockWorkflow, mockWorkflow, currentNodeSpec) + assert.NoError(t, err) + + // validate node phase transitions + assert.Equal(t, test.nextNodePhase, currentNodeStatus.Phase) + assert.Equal(t, test.nextDownstreamNodePhase, downstreamNodeStatus.Phase) + }) + } +} diff --git a/pkg/controller/nodes/interfaces/handler.go b/pkg/controller/nodes/interfaces/handler.go index 16ef73274..5cf82a27a 100644 --- a/pkg/controller/nodes/interfaces/handler.go +++ b/pkg/controller/nodes/interfaces/handler.go @@ -3,6 +3,7 @@ package interfaces import ( "context" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytestdlib/promutils" @@ -36,6 +37,18 @@ type NodeHandler interface { Finalize(ctx context.Context, executionContext NodeExecutionContext) error } +// CacheableNodeHandler is a node that supports caching +type CacheableNodeHandler interface { + NodeHandler + + // GetCatalogKey returns the unique key for the node represented by the NodeExecutionContext + GetCatalogKey(ctx context.Context, executionContext NodeExecutionContext) (catalog.Key, error) + + // IsCacheable returns two booleans representing if the node represented by the + // NodeExecutionContext is cacheable and cache serializable respectively. + IsCacheable(ctx context.Context, executionContext NodeExecutionContext) (bool, bool, error) +} + type SetupContext interface { EnqueueOwner() func(string) OwnerKind() string diff --git a/pkg/controller/nodes/interfaces/mocks/cacheable_node_handler.go b/pkg/controller/nodes/interfaces/mocks/cacheable_node_handler.go new file mode 100644 index 000000000..39456010b --- /dev/null +++ b/pkg/controller/nodes/interfaces/mocks/cacheable_node_handler.go @@ -0,0 +1,272 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + catalog "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" + + handler "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + + interfaces "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + mock "github.com/stretchr/testify/mock" +) + +// CacheableNodeHandler is an autogenerated mock type for the CacheableNodeHandler type +type CacheableNodeHandler struct { + mock.Mock +} + +type CacheableNodeHandler_Abort struct { + *mock.Call +} + +func (_m CacheableNodeHandler_Abort) Return(_a0 error) *CacheableNodeHandler_Abort { + return &CacheableNodeHandler_Abort{Call: _m.Call.Return(_a0)} +} + +func (_m *CacheableNodeHandler) OnAbort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) *CacheableNodeHandler_Abort { + c_call := _m.On("Abort", ctx, executionContext, reason) + return &CacheableNodeHandler_Abort{Call: c_call} +} + +func (_m *CacheableNodeHandler) OnAbortMatch(matchers ...interface{}) *CacheableNodeHandler_Abort { + c_call := _m.On("Abort", matchers...) + return &CacheableNodeHandler_Abort{Call: c_call} +} + +// Abort provides a mock function with given fields: ctx, executionContext, reason +func (_m *CacheableNodeHandler) Abort(ctx context.Context, executionContext interfaces.NodeExecutionContext, reason string) error { + ret := _m.Called(ctx, executionContext, reason) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext, string) error); ok { + r0 = rf(ctx, executionContext, reason) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type CacheableNodeHandler_Finalize struct { + *mock.Call +} + +func (_m CacheableNodeHandler_Finalize) Return(_a0 error) *CacheableNodeHandler_Finalize { + return &CacheableNodeHandler_Finalize{Call: _m.Call.Return(_a0)} +} + +func (_m *CacheableNodeHandler) OnFinalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) *CacheableNodeHandler_Finalize { + c_call := _m.On("Finalize", ctx, executionContext) + return &CacheableNodeHandler_Finalize{Call: c_call} +} + +func (_m *CacheableNodeHandler) OnFinalizeMatch(matchers ...interface{}) *CacheableNodeHandler_Finalize { + c_call := _m.On("Finalize", matchers...) + return &CacheableNodeHandler_Finalize{Call: c_call} +} + +// Finalize provides a mock function with given fields: ctx, executionContext +func (_m *CacheableNodeHandler) Finalize(ctx context.Context, executionContext interfaces.NodeExecutionContext) error { + ret := _m.Called(ctx, executionContext) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type CacheableNodeHandler_FinalizeRequired struct { + *mock.Call +} + +func (_m CacheableNodeHandler_FinalizeRequired) Return(_a0 bool) *CacheableNodeHandler_FinalizeRequired { + return &CacheableNodeHandler_FinalizeRequired{Call: _m.Call.Return(_a0)} +} + +func (_m *CacheableNodeHandler) OnFinalizeRequired() *CacheableNodeHandler_FinalizeRequired { + c_call := _m.On("FinalizeRequired") + return &CacheableNodeHandler_FinalizeRequired{Call: c_call} +} + +func (_m *CacheableNodeHandler) OnFinalizeRequiredMatch(matchers ...interface{}) *CacheableNodeHandler_FinalizeRequired { + c_call := _m.On("FinalizeRequired", matchers...) + return &CacheableNodeHandler_FinalizeRequired{Call: c_call} +} + +// FinalizeRequired provides a mock function with given fields: +func (_m *CacheableNodeHandler) FinalizeRequired() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type CacheableNodeHandler_GetCatalogKey struct { + *mock.Call +} + +func (_m CacheableNodeHandler_GetCatalogKey) Return(_a0 catalog.Key, _a1 error) *CacheableNodeHandler_GetCatalogKey { + return &CacheableNodeHandler_GetCatalogKey{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *CacheableNodeHandler) OnGetCatalogKey(ctx context.Context, executionContext interfaces.NodeExecutionContext) *CacheableNodeHandler_GetCatalogKey { + c_call := _m.On("GetCatalogKey", ctx, executionContext) + return &CacheableNodeHandler_GetCatalogKey{Call: c_call} +} + +func (_m *CacheableNodeHandler) OnGetCatalogKeyMatch(matchers ...interface{}) *CacheableNodeHandler_GetCatalogKey { + c_call := _m.On("GetCatalogKey", matchers...) + return &CacheableNodeHandler_GetCatalogKey{Call: c_call} +} + +// GetCatalogKey provides a mock function with given fields: ctx, executionContext +func (_m *CacheableNodeHandler) GetCatalogKey(ctx context.Context, executionContext interfaces.NodeExecutionContext) (catalog.Key, error) { + ret := _m.Called(ctx, executionContext) + + var r0 catalog.Key + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) catalog.Key); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Get(0).(catalog.Key) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r1 = rf(ctx, executionContext) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type CacheableNodeHandler_Handle struct { + *mock.Call +} + +func (_m CacheableNodeHandler_Handle) Return(_a0 handler.Transition, _a1 error) *CacheableNodeHandler_Handle { + return &CacheableNodeHandler_Handle{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *CacheableNodeHandler) OnHandle(ctx context.Context, executionContext interfaces.NodeExecutionContext) *CacheableNodeHandler_Handle { + c_call := _m.On("Handle", ctx, executionContext) + return &CacheableNodeHandler_Handle{Call: c_call} +} + +func (_m *CacheableNodeHandler) OnHandleMatch(matchers ...interface{}) *CacheableNodeHandler_Handle { + c_call := _m.On("Handle", matchers...) + return &CacheableNodeHandler_Handle{Call: c_call} +} + +// Handle provides a mock function with given fields: ctx, executionContext +func (_m *CacheableNodeHandler) Handle(ctx context.Context, executionContext interfaces.NodeExecutionContext) (handler.Transition, error) { + ret := _m.Called(ctx, executionContext) + + var r0 handler.Transition + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) handler.Transition); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Get(0).(handler.Transition) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r1 = rf(ctx, executionContext) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type CacheableNodeHandler_IsCacheable struct { + *mock.Call +} + +func (_m CacheableNodeHandler_IsCacheable) Return(_a0 bool, _a1 bool, _a2 error) *CacheableNodeHandler_IsCacheable { + return &CacheableNodeHandler_IsCacheable{Call: _m.Call.Return(_a0, _a1, _a2)} +} + +func (_m *CacheableNodeHandler) OnIsCacheable(ctx context.Context, executionContext interfaces.NodeExecutionContext) *CacheableNodeHandler_IsCacheable { + c_call := _m.On("IsCacheable", ctx, executionContext) + return &CacheableNodeHandler_IsCacheable{Call: c_call} +} + +func (_m *CacheableNodeHandler) OnIsCacheableMatch(matchers ...interface{}) *CacheableNodeHandler_IsCacheable { + c_call := _m.On("IsCacheable", matchers...) + return &CacheableNodeHandler_IsCacheable{Call: c_call} +} + +// IsCacheable provides a mock function with given fields: ctx, executionContext +func (_m *CacheableNodeHandler) IsCacheable(ctx context.Context, executionContext interfaces.NodeExecutionContext) (bool, bool, error) { + ret := _m.Called(ctx, executionContext) + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, interfaces.NodeExecutionContext) bool); ok { + r0 = rf(ctx, executionContext) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 bool + if rf, ok := ret.Get(1).(func(context.Context, interfaces.NodeExecutionContext) bool); ok { + r1 = rf(ctx, executionContext) + } else { + r1 = ret.Get(1).(bool) + } + + var r2 error + if rf, ok := ret.Get(2).(func(context.Context, interfaces.NodeExecutionContext) error); ok { + r2 = rf(ctx, executionContext) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +type CacheableNodeHandler_Setup struct { + *mock.Call +} + +func (_m CacheableNodeHandler_Setup) Return(_a0 error) *CacheableNodeHandler_Setup { + return &CacheableNodeHandler_Setup{Call: _m.Call.Return(_a0)} +} + +func (_m *CacheableNodeHandler) OnSetup(ctx context.Context, setupContext interfaces.SetupContext) *CacheableNodeHandler_Setup { + c_call := _m.On("Setup", ctx, setupContext) + return &CacheableNodeHandler_Setup{Call: c_call} +} + +func (_m *CacheableNodeHandler) OnSetupMatch(matchers ...interface{}) *CacheableNodeHandler_Setup { + c_call := _m.On("Setup", matchers...) + return &CacheableNodeHandler_Setup{Call: c_call} +} + +// Setup provides a mock function with given fields: ctx, setupContext +func (_m *CacheableNodeHandler) Setup(ctx context.Context, setupContext interfaces.SetupContext) error { + ret := _m.Called(ctx, setupContext) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interfaces.SetupContext) error); ok { + r0 = rf(ctx, setupContext) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/controller/nodes/interfaces/node_exec_context.go b/pkg/controller/nodes/interfaces/node_exec_context.go index 3b8afc384..fa9d32388 100644 --- a/pkg/controller/nodes/interfaces/node_exec_context.go +++ b/pkg/controller/nodes/interfaces/node_exec_context.go @@ -54,7 +54,6 @@ type NodeExecutionContext interface { DataStore() *storage.DataStore InputReader() io.InputReader - //EventsRecorder() events.TaskEventRecorder EventsRecorder() EventRecorder NodeID() v1alpha1.NodeID Node() v1alpha1.ExecutableNode diff --git a/pkg/controller/nodes/task/cache.go b/pkg/controller/nodes/task/cache.go new file mode 100644 index 000000000..479a49c5b --- /dev/null +++ b/pkg/controller/nodes/task/cache.go @@ -0,0 +1,67 @@ +package task + +import ( + "context" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + + errors2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces" + + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/logger" +) + +func (t *Handler) GetCatalogKey(ctx context.Context, nCtx interfaces.NodeExecutionContext) (catalog.Key, error) { + // read task template + taskTemplatePath, err := ioutils.GetTaskTemplatePath(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetDataDir()) + if err != nil { + return catalog.Key{}, err + } + + taskReader := ioutils.NewLazyUploadingTaskReader(nCtx.TaskReader(), taskTemplatePath, nCtx.DataStore()) + taskTemplate, err := taskReader.Read(ctx) + if err != nil { + logger.Errorf(ctx, "failed to read TaskTemplate, error :%s", err.Error()) + return catalog.Key{}, err + } + + return catalog.Key{ + Identifier: *taskTemplate.Id, + CacheVersion: taskTemplate.Metadata.DiscoveryVersion, + TypedInterface: *taskTemplate.Interface, + InputReader: nCtx.InputReader(), + }, nil +} + +func (t *Handler) IsCacheable(ctx context.Context, nCtx interfaces.NodeExecutionContext) (bool, bool, error) { + // check if plugin has caching disabled + ttype := nCtx.TaskReader().GetTaskType() + ctx = contextutils.WithTaskType(ctx, ttype) + p, err := t.ResolvePlugin(ctx, ttype, nCtx.ExecutionContext().GetExecutionConfig()) + if err != nil { + return false, false, errors2.Wrapf(errors2.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") + } + + checkCatalog := !p.GetProperties().DisableNodeLevelCaching + if !checkCatalog { + logger.Infof(ctx, "Node level caching is disabled. Skipping catalog read.") + return false, false, nil + } + + // read task template + taskTemplatePath, err := ioutils.GetTaskTemplatePath(ctx, nCtx.DataStore(), nCtx.NodeStatus().GetDataDir()) + if err != nil { + return false, false, err + } + + taskReader := ioutils.NewLazyUploadingTaskReader(nCtx.TaskReader(), taskTemplatePath, nCtx.DataStore()) + taskTemplate, err := taskReader.Read(ctx) + if err != nil { + logger.Errorf(ctx, "failed to read TaskTemplate, error :%s", err.Error()) + return false, false, err + } + + return taskTemplate.Metadata.Discoverable, taskTemplate.Metadata.Discoverable && taskTemplate.Metadata.CacheSerializable, nil +} diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index c3b8833fa..1cffebb80 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -44,20 +44,10 @@ import ( const pluginContextKey = contextutils.Key("plugin") type metrics struct { - pluginPanics labeled.Counter - unsupportedTaskType labeled.Counter - catalogPutFailureCount labeled.Counter - catalogGetFailureCount labeled.Counter - catalogPutSuccessCount labeled.Counter - catalogMissCount labeled.Counter - catalogHitCount labeled.Counter - catalogSkipCount labeled.Counter - pluginExecutionLatency labeled.StopWatch - pluginQueueLatency labeled.StopWatch - reservationGetSuccessCount labeled.Counter - reservationGetFailureCount labeled.Counter - reservationReleaseSuccessCount labeled.Counter - reservationReleaseFailureCount labeled.Counter + pluginPanics labeled.Counter + unsupportedTaskType labeled.Counter + pluginExecutionLatency labeled.StopWatch + pluginQueueLatency labeled.StopWatch // TODO We should have a metric to capture custom state size scope promutils.Scope @@ -487,11 +477,8 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta // ------------------------------------- logger.Debugf(ctx, "Task success detected, calling on Task success") outputCommitter := ioutils.NewRemoteFileOutputWriter(ctx, tCtx.DataStore(), tCtx.OutputWriter()) - execID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID() - cacheStatus, ee, err := t.ValidateOutputAndCacheAdd(ctx, tCtx.NodeID(), tCtx.InputReader(), tCtx.ow.GetReader(), - outputCommitter, tCtx.ExecutionContext().GetExecutionConfig(), tCtx.tr, catalog.Metadata{ - TaskExecutionIdentifier: &execID, - }) + ee, err := t.ValidateOutput(ctx, tCtx.NodeID(), tCtx.InputReader(), tCtx.ow.GetReader(), + outputCommitter, tCtx.ExecutionContext().GetExecutionConfig(), tCtx.tr) if err != nil { return nil, err } @@ -515,8 +502,6 @@ func (t Handler) invokePlugin(ctx context.Context, p pluginCore.Plugin, tCtx *ta } pluginTrns.ObserveSuccess(tCtx.ow.GetOutputPath(), deckURI, &event.TaskNodeMetadata{ - CacheStatus: cacheStatus.GetCacheStatus(), - CatalogKey: cacheStatus.GetMetadata(), CheckpointUri: tCtx.ow.GetCheckpointPrefix().String(), }) } @@ -540,11 +525,6 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex return handler.UnknownTransition, errors.Wrapf(errors.UnsupportedTaskTypeError, nCtx.NodeID(), err, "unable to resolve plugin") } - checkCatalog := !p.GetProperties().DisableNodeLevelCaching - if !checkCatalog { - logger.Infof(ctx, "Node level caching is disabled. Skipping catalog read.") - } - tCtx, err := t.newTaskExecutionContext(ctx, nCtx, p) if err != nil { return handler.UnknownTransition, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "unable to create Handler execution context") @@ -586,84 +566,6 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex } } } - // STEP 1: Check Cache - if (ts.PluginPhase == pluginCore.PhaseUndefined || ts.PluginPhase == pluginCore.PhaseWaitingForCache) && checkCatalog { - // This is assumed to be first time. we will check catalog and call handle - // If the cache should be skipped (requested by user for the execution), do not check datacatalog for any cached - // data, but instead always perform calculations again and overwrite the stored data after successful execution. - if nCtx.ExecutionContext().GetExecutionConfig().OverwriteCache { - logger.Info(ctx, "Execution config forced cache skip, not checking catalog") - pluginTrns.PopulateCacheInfo(catalog.NewCatalogEntry(nil, cacheSkipped)) - t.metrics.catalogSkipCount.Inc(ctx) - } else { - entry, err := t.CheckCatalogCache(ctx, tCtx.tr, nCtx.InputReader(), tCtx.ow) - if err != nil { - logger.Errorf(ctx, "failed to check catalog cache with error") - return handler.UnknownTransition, err - } - - if entry.GetStatus().GetCacheStatus() == core.CatalogCacheStatus_CACHE_HIT { - r := tCtx.ow.GetReader() - if r == nil { - return handler.UnknownTransition, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "failed to reader outputs from a CacheHIT. Unexpected!") - } - - // TODO @kumare this can be optimized, if we have paths then the reader could be pipelined to a sink - o, ee, err := r.Read(ctx) - if err != nil { - logger.Errorf(ctx, "failed to read from catalog, err: %s", err.Error()) - return handler.UnknownTransition, err - } - - if ee != nil { - logger.Errorf(ctx, "got execution error from catalog output reader? This should not happen, err: %s", ee.String()) - return handler.UnknownTransition, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "execution error from a cache output, bad state: %s", ee.String()) - } - - if err := nCtx.DataStore().WriteProtobuf(ctx, tCtx.ow.GetOutputPath(), storage.Options{}, o); err != nil { - logger.Errorf(ctx, "failed to write cached value to datastore, err: %s", err.Error()) - return handler.UnknownTransition, err - } - - pluginTrns.CacheHit(tCtx.ow.GetOutputPath(), nil, entry) - } else { - logger.Infof(ctx, "No CacheHIT. Status [%s]", entry.GetStatus().GetCacheStatus().String()) - pluginTrns.PopulateCacheInfo(entry) - } - } - } - - // Check catalog for cache reservation and acquire if none exists - if checkCatalog && (pluginTrns.execInfo.TaskNodeInfo == nil || pluginTrns.execInfo.TaskNodeInfo.TaskNodeMetadata.CacheStatus != core.CatalogCacheStatus_CACHE_HIT) { - ownerID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - reservation, err := t.GetOrExtendCatalogReservation(ctx, ownerID, controllerConfig.GetConfig().WorkflowReEval.Duration, tCtx.tr, nCtx.InputReader()) - if err != nil { - logger.Errorf(ctx, "failed to get or extend catalog reservation with error") - return handler.UnknownTransition, err - } - - pluginTrns.PopulateReservationInfo(reservation) - - if reservation.GetStatus() == core.CatalogReservation_RESERVATION_ACQUIRED && - (ts.PluginPhase == pluginCore.PhaseUndefined || ts.PluginPhase == pluginCore.PhaseWaitingForCache) { - logger.Infof(ctx, "Acquired cache reservation") - } - - // If we do not own the reservation then we transition to WaitingForCache phase. If we are - // already running (ie. in a phase other than PhaseUndefined or PhaseWaitingForCache) and - // somehow lost the reservation (ex. by expiration), continue to execute until completion. - if reservation.GetStatus() == core.CatalogReservation_RESERVATION_EXISTS { - if ts.PluginPhase == pluginCore.PhaseUndefined || ts.PluginPhase == pluginCore.PhaseWaitingForCache { - pluginTrns.ttype = handler.TransitionTypeEphemeral - pluginTrns.pInfo = pluginCore.PhaseInfoWaitingForCache(pluginCore.DefaultPhaseVersion, nil) - } - - if ts.PluginPhase == pluginCore.PhaseWaitingForCache { - logger.Debugf(ctx, "No state change for Task, previously observed same transition. Short circuiting.") - return pluginTrns.FinalTransition(ctx) - } - } - } occurredAt := time.Now() // STEP 2: If no cache-hit and not transitioning to PhaseWaitingForCache, then lets invoke the plugin and wait for a transition out of undefined @@ -771,6 +673,84 @@ func (t Handler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContex return pluginTrns.FinalTransition(ctx) } +func (t *Handler) ValidateOutput(ctx context.Context, nodeID v1alpha1.NodeID, i io.InputReader, + r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, + tr ioutils.SimpleTaskReader) (*io.ExecutionError, error) { + + tk, err := tr.Read(ctx) + if err != nil { + logger.Errorf(ctx, "Failed to read TaskTemplate, error :%s", err.Error()) + return nil, err + } + + iface := tk.Interface + outputsDeclared := iface != nil && iface.Outputs != nil && len(iface.Outputs.Variables) > 0 + + if r == nil { + if outputsDeclared { + // Whack! plugin did not return any outputs for this task + // Also When an error is observed, cache is automatically disabled + return &io.ExecutionError{ + ExecutionError: &core.ExecutionError{ + Code: "OutputsNotGenerated", + Message: "Output Reader was nil. Plugin/Platform problem.", + }, + IsRecoverable: true, + }, nil + } + return nil, nil + } + // Reader exists, we can check for error, even if this task may not have any outputs declared + y, err := r.IsError(ctx) + if err != nil { + return nil, err + } + if y { + taskErr, err := r.ReadError(ctx) + if err != nil { + return nil, err + } + + if taskErr.ExecutionError == nil { + taskErr.ExecutionError = &core.ExecutionError{Kind: core.ExecutionError_UNKNOWN, Code: "Unknown", Message: "Unknown"} + } + return &taskErr, nil + } + + // Do this if we have outputs declared for the Handler interface! + if !outputsDeclared { + return nil, nil + } + ok, err := r.Exists(ctx) + if err != nil { + logger.Errorf(ctx, "Failed to check if the output file exists. Error: %s", err.Error()) + return nil, err + } + + if !ok { + // Does not exist + return &io.ExecutionError{ + ExecutionError: &core.ExecutionError{ + Code: "OutputsNotFound", + Message: "Outputs not generated by task execution", + }, + IsRecoverable: true, + }, nil + } + + if !r.IsFile(ctx) { + // Read output and write to file + // No need to check for Execution Error here as we have done so above this block. + err = outputCommitter.Put(ctx, r) + if err != nil { + logger.Errorf(ctx, "Failed to commit output to remote location. Error: %v", err) + return nil, err + } + } + + return nil, nil +} + func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext, reason string) error { taskNodeState := nCtx.NodeStateReader().GetTaskNodeState() currentPhase := taskNodeState.PluginPhase @@ -860,13 +840,6 @@ func (t Handler) Finalize(ctx context.Context, nCtx interfaces.NodeExecutionCont } }() - // release catalog reservation (if exists) - ownerID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - _, err = t.ReleaseCatalogReservation(ctx, ownerID, tCtx.tr, tCtx.InputReader()) - if err != nil { - return errors.Wrapf(errors.CatalogCallFailed, nCtx.NodeID(), err, "failed to release reservation") - } - childCtx := context.WithValue(ctx, pluginContextKey, p.GetID()) err = p.Finalize(childCtx, tCtx) return @@ -891,21 +864,11 @@ func New(ctx context.Context, kubeClient executors.Client, client catalog.Client pluginsForType: make(map[pluginCore.TaskType]map[pluginID]pluginCore.Plugin), taskMetricsMap: make(map[MetricKey]*taskMetrics), metrics: &metrics{ - pluginPanics: labeled.NewCounter("plugin_panic", "Task plugin paniced when trying to execute a Handler.", scope), - unsupportedTaskType: labeled.NewCounter("unsupported_tasktype", "No Handler plugin configured for Handler type", scope), - catalogHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", scope), - catalogMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", scope), - catalogSkipCount: labeled.NewCounter("discovery_skip_count", "Task lookup skipped in Discovery", scope), - catalogPutSuccessCount: labeled.NewCounter("discovery_put_success_count", "Discovery Put success count", scope), - catalogPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", scope), - catalogGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", scope), - pluginExecutionLatency: labeled.NewStopWatch("plugin_exec_latency", "Time taken to invoke plugin for one round", time.Microsecond, scope), - pluginQueueLatency: labeled.NewStopWatch("plugin_queue_latency", "Time spent by plugin in queued phase", time.Microsecond, scope), - reservationGetFailureCount: labeled.NewCounter("reservation_get_failure_count", "Reservation GetOrExtend failure count", scope), - reservationGetSuccessCount: labeled.NewCounter("reservation_get_success_count", "Reservation GetOrExtend success count", scope), - reservationReleaseFailureCount: labeled.NewCounter("reservation_release_failure_count", "Reservation Release failure count", scope), - reservationReleaseSuccessCount: labeled.NewCounter("reservation_release_success_count", "Reservation Release success count", scope), - scope: scope, + pluginPanics: labeled.NewCounter("plugin_panic", "Task plugin paniced when trying to execute a Handler.", scope), + unsupportedTaskType: labeled.NewCounter("unsupported_tasktype", "No Handler plugin configured for Handler type", scope), + pluginExecutionLatency: labeled.NewStopWatch("plugin_exec_latency", "Time taken to invoke plugin for one round", time.Microsecond, scope), + pluginQueueLatency: labeled.NewStopWatch("plugin_queue_latency", "Time spent by plugin in queued phase", time.Microsecond, scope), + scope: scope, }, pluginScope: scope.NewSubScope("plugin"), kubeClient: kubeClient, diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index 0a009cb65..38506c4c7 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -16,7 +16,6 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/resourcemanager" @@ -25,7 +24,6 @@ import ( "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/datacatalog" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginCatalogMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog/mocks" @@ -765,518 +763,6 @@ func Test_task_Handle_NoCatalog(t *testing.T) { } } -func Test_task_Handle_Catalog(t *testing.T) { - - createNodeContext := func(recorder interfaces.EventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { - wfExecID := &core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - } - - nodeID := "n1" - - nm := &nodeMocks.NodeExecutionMetadata{} - nm.OnGetAnnotations().Return(map[string]string{}) - nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ - NodeId: nodeID, - ExecutionId: wfExecID, - }) - nm.OnGetK8sServiceAccount().Return("service-account") - nm.OnGetLabels().Return(map[string]string{}) - nm.OnGetNamespace().Return("namespace") - nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.OnGetOwnerReference().Return(v12.OwnerReference{ - Kind: "sample", - Name: "name", - }) - nm.OnIsInterruptible().Return(true) - - taskID := &core.Identifier{} - tk := &core.TaskTemplate{ - Id: taskID, - Type: "test", - Metadata: &core.TaskMetadata{ - Discoverable: true, - }, - Interface: &core.TypedInterface{ - Outputs: &core.VariableMap{ - Variables: map[string]*core.Variable{ - "x": { - Type: &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_BOOLEAN, - }, - }, - }, - }, - }, - }, - } - tr := &nodeMocks.TaskReader{} - tr.OnGetTaskID().Return(taskID) - tr.OnGetTaskType().Return(ttype) - tr.OnReadMatch(mock.Anything).Return(tk, nil) - - ns := &flyteMocks.ExecutableNodeStatus{} - ns.OnGetDataDir().Return(storage.DataReference("data-dir")) - ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) - - res := &v1.ResourceRequirements{} - n := &flyteMocks.ExecutableNode{} - ma := 5 - n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma}) - n.OnGetResources().Return(res) - - ir := &ioMocks.InputReader{} - ir.OnGetInputPath().Return(storage.DataReference("input")) - ir.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) - nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.OnNodeExecutionMetadata().Return(nm) - nCtx.OnNode().Return(n) - nCtx.OnInputReader().Return(ir) - ds, err := storage.NewDataStore( - &storage.Config{ - Type: storage.TypeMemory, - }, - promutils.NewTestScope(), - ) - assert.NoError(t, err) - nCtx.OnDataStore().Return(ds) - nCtx.OnCurrentAttempt().Return(uint32(1)) - nCtx.OnTaskReader().Return(tr) - nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) - nCtx.OnNodeStatus().Return(ns) - nCtx.OnNodeID().Return(nodeID) - nCtx.OnEventsRecorder().Return(recorder) - nCtx.OnEnqueueOwnerFunc().Return(nil) - - executionContext := &mocks.ExecutionContext{} - executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{OverwriteCache: overwriteCache}) - executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) - executionContext.OnGetParentInfo().Return(nil) - nCtx.OnExecutionContext().Return(executionContext) - - nCtx.OnRawOutputPrefix().Return("s3://sandbox/") - nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) - - st := bytes.NewBuffer([]byte{}) - cod := codex.GobStateCodec{} - assert.NoError(t, cod.Encode(&fakeplugins.NextPhaseState{ - Phase: pluginCore.PhaseSuccess, - OutputExists: true, - }, st)) - nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ - PluginState: st.Bytes(), - }) - nCtx.OnNodeStateReader().Return(nr) - nCtx.OnNodeStateWriter().Return(s) - return nCtx - } - - noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) - - type args struct { - catalogFetch bool - catalogFetchError bool - catalogWriteError bool - catalogSkip bool - } - type want struct { - handlerPhase handler.EPhase - wantErr bool - eventPhase core.TaskExecution_Phase - } - tests := []struct { - name string - args args - want want - }{ - { - "cache-hit", - args{ - catalogFetch: true, - catalogWriteError: true, - }, - want{ - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - { - "cache-err", - args{ - catalogFetchError: true, - catalogWriteError: true, - }, - want{ - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - { - "cache-write", - args{}, - want{ - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - { - "cache-write-err", - args{ - catalogWriteError: true, - }, - want{ - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - { - "cache-skip-hit", - args{ - catalogFetch: true, - catalogSkip: true, - }, - want{ - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - { - "cache-skip-miss", - args{ - catalogFetch: false, - catalogSkip: true, - }, - want{ - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - state := &taskNodeStateHolder{} - ev := &fakeBufferedEventRecorder{} - nCtx := createNodeContext(ev, "test", state, tt.args.catalogSkip) - c := &pluginCatalogMocks.Client{} - if tt.args.catalogFetch { - or := &ioMocks.OutputReader{} - or.OnDeckExistsMatch(mock.Anything).Return(true, nil) - or.OnReadMatch(mock.Anything).Return(&core.LiteralMap{}, nil, nil) - c.OnGetMatch(mock.Anything, mock.Anything).Return(catalog.NewCatalogEntry(or, catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, nil)), nil) - } else { - c.OnGetMatch(mock.Anything, mock.Anything).Return(catalog.NewFailedCatalogEntry(catalog.NewStatus(core.CatalogCacheStatus_CACHE_MISS, nil)), nil) - } - if tt.args.catalogFetchError { - c.OnGetMatch(mock.Anything, mock.Anything).Return(catalog.Entry{}, fmt.Errorf("failed to read from catalog")) - } - if tt.args.catalogWriteError { - c.OnPutMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.Status{}, fmt.Errorf("failed to write to catalog")) - c.OnUpdateMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.Status{}, fmt.Errorf("failed to write to catalog")) - } else { - c.OnPutMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, nil), nil) - c.OnUpdateMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, nil), nil) - } - tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), c, eventConfig, testClusterID, promutils.NewTestScope()) - assert.NoError(t, err) - tk.defaultPlugins = map[pluginCore.TaskType]pluginCore.Plugin{ - "test": fakeplugins.NewPhaseBasedPlugin(), - } - tk.catalog = c - tk.resourceManager = noopRm - got, err := tk.Handle(context.TODO(), nCtx) - if (err != nil) != tt.want.wantErr { - t.Errorf("Handler.Handle() error = %v, wantErr %v", err, tt.want.wantErr) - return - } - if err == nil { - assert.Equal(t, tt.want.handlerPhase.String(), got.Info().GetPhase().String()) - if assert.Equal(t, 1, len(ev.evs)) { - e := ev.evs[0] - assert.Equal(t, tt.want.eventPhase.String(), e.Phase.String()) - } - assert.Equal(t, pluginCore.PhaseSuccess.String(), state.s.PluginPhase.String()) - assert.Equal(t, uint32(0), state.s.PluginPhaseVersion) - if tt.args.catalogFetch { - if assert.NotNil(t, got.Info().GetInfo().TaskNodeInfo) { - assert.NotNil(t, got.Info().GetInfo().TaskNodeInfo.TaskNodeMetadata) - if tt.args.catalogSkip { - assert.Equal(t, core.CatalogCacheStatus_CACHE_POPULATED, got.Info().GetInfo().TaskNodeInfo.TaskNodeMetadata.CacheStatus) - } else { - assert.Equal(t, core.CatalogCacheStatus_CACHE_HIT, got.Info().GetInfo().TaskNodeInfo.TaskNodeMetadata.CacheStatus) - } - } - assert.NotNil(t, got.Info().GetInfo().OutputInfo) - s := storage.DataReference("/output-dir/outputs.pb") - assert.Equal(t, s, got.Info().GetInfo().OutputInfo.OutputURI) - r, err := nCtx.DataStore().Head(context.TODO(), s) - assert.NoError(t, err) - assert.Equal(t, !tt.args.catalogSkip, r.Exists()) - } - if tt.args.catalogSkip { - c.AssertNotCalled(t, "Put", mock.Anything, mock.Anything, mock.Anything, mock.Anything) - c.AssertCalled(t, "Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything) - } - } - }) - } -} - -func Test_task_Handle_Reservation(t *testing.T) { - - createNodeContext := func(recorder interfaces.EventRecorder, ttype string, s *taskNodeStateHolder, overwriteCache bool) *nodeMocks.NodeExecutionContext { - wfExecID := &core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - } - - nodeID := "n1" - - nm := &nodeMocks.NodeExecutionMetadata{} - nm.OnGetAnnotations().Return(map[string]string{}) - nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ - NodeId: nodeID, - ExecutionId: wfExecID, - }) - nm.OnGetK8sServiceAccount().Return("service-account") - nm.OnGetLabels().Return(map[string]string{}) - nm.OnGetNamespace().Return("namespace") - nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.OnGetOwnerReference().Return(v12.OwnerReference{ - Kind: "sample", - Name: "name", - }) - nm.OnIsInterruptible().Return(true) - - taskID := &core.Identifier{} - tk := &core.TaskTemplate{ - Id: taskID, - Type: "test", - Metadata: &core.TaskMetadata{ - Discoverable: true, - CacheSerializable: true, - }, - Interface: &core.TypedInterface{ - Outputs: &core.VariableMap{ - Variables: map[string]*core.Variable{ - "x": { - Type: &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_BOOLEAN, - }, - }, - }, - }, - }, - }, - } - tr := &nodeMocks.TaskReader{} - tr.OnGetTaskID().Return(taskID) - tr.OnGetTaskType().Return(ttype) - tr.OnReadMatch(mock.Anything).Return(tk, nil) - - ns := &flyteMocks.ExecutableNodeStatus{} - ns.OnGetDataDir().Return(storage.DataReference("data-dir")) - ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) - - res := &v1.ResourceRequirements{} - n := &flyteMocks.ExecutableNode{} - ma := 5 - n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma}) - n.OnGetResources().Return(res) - - ir := &ioMocks.InputReader{} - ir.OnGetInputPath().Return(storage.DataReference("input")) - ir.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) - nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.OnNodeExecutionMetadata().Return(nm) - nCtx.OnNode().Return(n) - nCtx.OnInputReader().Return(ir) - nCtx.OnInputReader().Return(ir) - ds, err := storage.NewDataStore( - &storage.Config{ - Type: storage.TypeMemory, - }, - promutils.NewTestScope(), - ) - assert.NoError(t, err) - nCtx.OnDataStore().Return(ds) - nCtx.OnCurrentAttempt().Return(uint32(1)) - nCtx.OnTaskReader().Return(tr) - nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) - nCtx.OnNodeStatus().Return(ns) - nCtx.OnNodeID().Return(nodeID) - nCtx.OnEventsRecorder().Return(recorder) - nCtx.OnEnqueueOwnerFunc().Return(nil) - - executionContext := &mocks.ExecutionContext{} - executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{OverwriteCache: overwriteCache}) - executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) - executionContext.OnGetParentInfo().Return(nil) - executionContext.OnIncrementParallelism().Return(1) - nCtx.OnExecutionContext().Return(executionContext) - - nCtx.OnRawOutputPrefix().Return("s3://sandbox/") - nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) - - nCtx.OnNodeStateWriter().Return(s) - return nCtx - } - - noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) - - type args struct { - catalogFetch bool - catalogSkip bool - pluginPhase pluginCore.Phase - ownerID string - } - type want struct { - pluginPhase pluginCore.Phase - handlerPhase handler.EPhase - eventPhase core.TaskExecution_Phase - } - tests := []struct { - name string - args args - want want - }{ - { - "reservation-create-or-update", - args{ - catalogFetch: false, - pluginPhase: pluginCore.PhaseUndefined, - ownerID: "name-n1-1", - }, - want{ - pluginPhase: pluginCore.PhaseSuccess, - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - { - "reservation-exists", - args{ - catalogFetch: false, - pluginPhase: pluginCore.PhaseUndefined, - ownerID: "nilOwner", - }, - want{ - pluginPhase: pluginCore.PhaseWaitingForCache, - handlerPhase: handler.EPhaseRunning, - eventPhase: core.TaskExecution_UNDEFINED, - }, - }, - { - "cache-hit", - args{ - catalogFetch: true, - pluginPhase: pluginCore.PhaseWaitingForCache, - }, - want{ - pluginPhase: pluginCore.PhaseSuccess, - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - { - "cache-skip-miss", - args{ - catalogFetch: false, - catalogSkip: true, - pluginPhase: pluginCore.PhaseUndefined, - ownerID: "name-n1-1", - }, - want{ - pluginPhase: pluginCore.PhaseSuccess, - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - { - "cache-skip-hit", - args{ - catalogFetch: true, - catalogSkip: true, - pluginPhase: pluginCore.PhaseWaitingForCache, - ownerID: "name-n1-1", - }, - want{ - pluginPhase: pluginCore.PhaseSuccess, - handlerPhase: handler.EPhaseSuccess, - eventPhase: core.TaskExecution_SUCCEEDED, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - state := &taskNodeStateHolder{} - ev := &fakeBufferedEventRecorder{} - nCtx := createNodeContext(ev, "test", state, tt.args.catalogSkip) - c := &pluginCatalogMocks.Client{} - nr := &nodeMocks.NodeStateReader{} - st := bytes.NewBuffer([]byte{}) - cod := codex.GobStateCodec{} - assert.NoError(t, cod.Encode(&fakeplugins.NextPhaseState{ - Phase: pluginCore.PhaseSuccess, - OutputExists: true, - }, st)) - nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ - PluginPhase: tt.args.pluginPhase, - PluginState: st.Bytes(), - }) - nCtx.OnNodeStateReader().Return(nr) - if tt.args.catalogFetch { - or := &ioMocks.OutputReader{} - or.OnDeckExistsMatch(mock.Anything).Return(true, nil) - or.OnReadMatch(mock.Anything).Return(&core.LiteralMap{}, nil, nil) - c.OnGetMatch(mock.Anything, mock.Anything).Return(catalog.NewCatalogEntry(or, catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, nil)), nil) - } else { - c.OnGetMatch(mock.Anything, mock.Anything).Return(catalog.NewFailedCatalogEntry(catalog.NewStatus(core.CatalogCacheStatus_CACHE_MISS, nil)), nil) - } - c.OnPutMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, nil), nil) - c.OnUpdateMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, nil), nil) - c.OnGetOrExtendReservationMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&datacatalog.Reservation{OwnerId: tt.args.ownerID}, nil) - tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), c, eventConfig, testClusterID, promutils.NewTestScope()) - assert.NoError(t, err) - tk.defaultPlugins = map[pluginCore.TaskType]pluginCore.Plugin{ - "test": fakeplugins.NewPhaseBasedPlugin(), - } - tk.catalog = c - tk.resourceManager = noopRm - got, err := tk.Handle(context.TODO(), nCtx) - if err != nil { - t.Errorf("Handler.Handle() error = %v", err) - return - } - if err == nil { - assert.Equal(t, tt.want.handlerPhase.String(), got.Info().GetPhase().String()) - if assert.Equal(t, 1, len(ev.evs)) { - e := ev.evs[0] - assert.Equal(t, tt.want.eventPhase.String(), e.Phase.String()) - } - assert.Equal(t, tt.want.pluginPhase.String(), state.s.PluginPhase.String()) - assert.Equal(t, uint32(0), state.s.PluginPhaseVersion) - // verify catalog.Put was called appropriately (overwrite param should be `true` if catalog cache is skipped) - // Put only gets called in the tests defined above that succeed and have an owner ID defined - if tt.want.pluginPhase == pluginCore.PhaseSuccess && len(tt.args.ownerID) > 0 { - if tt.args.catalogSkip { - c.AssertNotCalled(t, "Put", mock.Anything, mock.Anything, mock.Anything, mock.Anything) - c.AssertCalled(t, "Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything) - } else { - c.AssertCalled(t, "Put", mock.Anything, mock.Anything, mock.Anything, mock.Anything) - c.AssertNotCalled(t, "Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything) - } - } - } - }) - } -} - func Test_task_Abort(t *testing.T) { createNodeCtx := func(ev interfaces.EventRecorder) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ @@ -1603,7 +1089,7 @@ func Test_task_Abort_v1(t *testing.T) { func Test_task_Finalize(t *testing.T) { - createNodeContext := func(cacheSerializable bool) *nodeMocks.NodeExecutionContext { + createNodeContext := func() *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -1631,10 +1117,6 @@ func Test_task_Finalize(t *testing.T) { tk := &core.TaskTemplate{ Id: taskID, Type: "test", - Metadata: &core.TaskMetadata{ - CacheSerializable: cacheSerializable, - Discoverable: cacheSerializable, - }, Interface: &core.TypedInterface{ Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ @@ -1714,61 +1196,35 @@ func Test_task_Finalize(t *testing.T) { type fields struct { defaultPluginCallback func() pluginCore.Plugin } - type args struct { - releaseReservation bool - releaseReservationError bool - } tests := []struct { name string fields fields - args args wantErr bool finalize bool }{ {"no-plugin", fields{defaultPluginCallback: func() pluginCore.Plugin { return nil - }}, args{}, true, false}, - + }}, true, false}, {"finalize-fails", fields{defaultPluginCallback: func() pluginCore.Plugin { p := &pluginCoreMocks.Plugin{} p.On("GetID").Return("id") p.OnGetProperties().Return(pluginCore.PluginProperties{}) p.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) return p - }}, args{}, true, true}, + }}, true, true}, {"finalize-success", fields{defaultPluginCallback: func() pluginCore.Plugin { p := &pluginCoreMocks.Plugin{} p.On("GetID").Return("id") p.OnGetProperties().Return(pluginCore.PluginProperties{}) p.On("Finalize", mock.Anything, mock.Anything).Return(nil) return p - }}, args{}, false, true}, - {"release-reservation", fields{defaultPluginCallback: func() pluginCore.Plugin { - p := &pluginCoreMocks.Plugin{} - p.On("GetID").Return("id") - p.OnGetProperties().Return(pluginCore.PluginProperties{}) - p.On("Finalize", mock.Anything, mock.Anything).Return(nil) - return p - }}, args{releaseReservation: true}, false, true}, - {"release-reservation-error", fields{defaultPluginCallback: func() pluginCore.Plugin { - p := &pluginCoreMocks.Plugin{} - p.On("GetID").Return("id") - p.OnGetProperties().Return(pluginCore.PluginProperties{}) - p.On("Finalize", mock.Anything, mock.Anything).Return(nil) - return p - }}, args{releaseReservation: true, releaseReservationError: true}, true, false}, + }}, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - nCtx := createNodeContext(tt.args.releaseReservation) + nCtx := createNodeContext() catalog := &pluginCatalogMocks.Client{} - if tt.args.releaseReservationError { - catalog.OnReleaseReservationMatch(mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("failed to release reservation")) - } else { - catalog.OnReleaseReservationMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) - } - m := tt.fields.defaultPluginCallback() tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), catalog, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) diff --git a/pkg/controller/nodes/task/pre_post_execution.go b/pkg/controller/nodes/task/pre_post_execution.go deleted file mode 100644 index c571d4ed4..000000000 --- a/pkg/controller/nodes/task/pre_post_execution.go +++ /dev/null @@ -1,273 +0,0 @@ -package task - -import ( - "context" - "time" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" - pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" - "github.com/flyteorg/flytestdlib/logger" - "github.com/pkg/errors" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - errors2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" -) - -var ( - cacheDisabled = catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil) - cacheSkipped = catalog.NewStatus(core.CatalogCacheStatus_CACHE_SKIPPED, nil) -) - -func (t *Handler) CheckCatalogCache(ctx context.Context, tr pluginCore.TaskReader, inputReader io.InputReader, outputWriter io.OutputWriter) (catalog.Entry, error) { - tk, err := tr.Read(ctx) - if err != nil { - logger.Errorf(ctx, "Failed to read TaskTemplate, error :%s", err.Error()) - return catalog.Entry{}, err - } - - if tk.Metadata.Discoverable { - logger.Infof(ctx, "Catalog CacheEnabled: Looking up catalog Cache.") - key := catalog.Key{ - Identifier: *tk.Id, - CacheVersion: tk.Metadata.DiscoveryVersion, - TypedInterface: *tk.Interface, - InputReader: inputReader, - } - - resp, err := t.catalog.Get(ctx, key) - if err != nil { - causeErr := errors.Cause(err) - if taskStatus, ok := status.FromError(causeErr); ok && taskStatus.Code() == codes.NotFound { - t.metrics.catalogMissCount.Inc(ctx) - logger.Infof(ctx, "Catalog CacheMiss: Artifact not found in Catalog. Executing Task.") - return catalog.NewCatalogEntry(nil, catalog.NewStatus(core.CatalogCacheStatus_CACHE_MISS, nil)), nil - } - - t.metrics.catalogGetFailureCount.Inc(ctx) - logger.Errorf(ctx, "Catalog Failure: memoization check failed. err: %v", err.Error()) - return catalog.Entry{}, errors.Wrapf(err, "Failed to check Catalog for previous results") - } - - if resp.GetStatus().GetCacheStatus() != core.CatalogCacheStatus_CACHE_HIT { - logger.Errorf(ctx, "No CacheHIT and no Error received. Illegal state, Cache State: %s", resp.GetStatus().GetCacheStatus().String()) - // TODO should this be an error? - return resp, nil - } - - logger.Infof(ctx, "Catalog CacheHit: for task [%s/%s/%s/%s]", tk.Id.Project, tk.Id.Domain, tk.Id.Name, tk.Id.Version) - t.metrics.catalogHitCount.Inc(ctx) - if iface := tk.Interface; iface != nil && iface.Outputs != nil && len(iface.Outputs.Variables) > 0 { - if err := outputWriter.Put(ctx, resp.GetOutputs()); err != nil { - logger.Errorf(ctx, "failed to write data to Storage, err: %v", err.Error()) - return catalog.Entry{}, errors.Wrapf(err, "failed to copy cached results for task.") - } - } - // SetCached. - return resp, nil - } - logger.Infof(ctx, "Catalog CacheDisabled: for Task [%s/%s/%s/%s]", tk.Id.Project, tk.Id.Domain, tk.Id.Name, tk.Id.Version) - return catalog.NewCatalogEntry(nil, cacheDisabled), nil -} - -// GetOrExtendCatalogReservation attempts to acquire an artifact reservation if the task is -// cachable and cache serializable. If the reservation already exists for this owner, the -// reservation is extended. -func (t *Handler) GetOrExtendCatalogReservation(ctx context.Context, ownerID string, heartbeatInterval time.Duration, tr pluginCore.TaskReader, inputReader io.InputReader) (catalog.ReservationEntry, error) { - tk, err := tr.Read(ctx) - if err != nil { - logger.Errorf(ctx, "Failed to read TaskTemplate, error :%s", err.Error()) - return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err - } - - if tk.Metadata.Discoverable && tk.Metadata.CacheSerializable { - logger.Infof(ctx, "Catalog CacheSerializeEnabled: creating catalog reservation.") - key := catalog.Key{ - Identifier: *tk.Id, - CacheVersion: tk.Metadata.DiscoveryVersion, - TypedInterface: *tk.Interface, - InputReader: inputReader, - } - - reservation, err := t.catalog.GetOrExtendReservation(ctx, key, ownerID, heartbeatInterval) - if err != nil { - t.metrics.reservationGetFailureCount.Inc(ctx) - logger.Errorf(ctx, "Catalog Failure: reservation get or extend failed. err: %v", err.Error()) - return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err - } - - expiresAt := reservation.ExpiresAt.AsTime() - heartbeatInterval := reservation.HeartbeatInterval.AsDuration() - - var status core.CatalogReservation_Status - if reservation.OwnerId == ownerID { - status = core.CatalogReservation_RESERVATION_ACQUIRED - } else { - status = core.CatalogReservation_RESERVATION_EXISTS - } - - t.metrics.reservationGetSuccessCount.Inc(ctx) - return catalog.NewReservationEntry(expiresAt, heartbeatInterval, reservation.OwnerId, status), nil - } - logger.Infof(ctx, "Catalog CacheSerializeDisabled: for Task [%s/%s/%s/%s]", tk.Id.Project, tk.Id.Domain, tk.Id.Name, tk.Id.Version) - return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED), nil -} - -func (t *Handler) ValidateOutputAndCacheAdd(ctx context.Context, nodeID v1alpha1.NodeID, i io.InputReader, - r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, - tr ioutils.SimpleTaskReader, m catalog.Metadata) (catalog.Status, *io.ExecutionError, error) { - - tk, err := tr.Read(ctx) - if err != nil { - logger.Errorf(ctx, "Failed to read TaskTemplate, error :%s", err.Error()) - return cacheDisabled, nil, err - } - - iface := tk.Interface - outputsDeclared := iface != nil && iface.Outputs != nil && len(iface.Outputs.Variables) > 0 - - if r == nil { - if outputsDeclared { - // Whack! plugin did not return any outputs for this task - // Also When an error is observed, cache is automatically disabled - return cacheDisabled, &io.ExecutionError{ - ExecutionError: &core.ExecutionError{ - Code: "OutputsNotGenerated", - Message: "Output Reader was nil. Plugin/Platform problem.", - }, - IsRecoverable: true, - }, nil - } - return cacheDisabled, nil, nil - } - // Reader exists, we can check for error, even if this task may not have any outputs declared - y, err := r.IsError(ctx) - if err != nil { - return cacheDisabled, nil, err - } - if y { - taskErr, err := r.ReadError(ctx) - if err != nil { - return cacheDisabled, nil, err - } - - if taskErr.ExecutionError == nil { - taskErr.ExecutionError = &core.ExecutionError{Kind: core.ExecutionError_UNKNOWN, Code: "Unknown", Message: "Unknown"} - } - return cacheDisabled, &taskErr, nil - } - - // Do this if we have outputs declared for the Handler interface! - if !outputsDeclared { - return cacheDisabled, nil, nil - } - ok, err := r.Exists(ctx) - if err != nil { - logger.Errorf(ctx, "Failed to check if the output file exists. Error: %s", err.Error()) - return cacheDisabled, nil, err - } - - if !ok { - // Does not exist - return cacheDisabled, - &io.ExecutionError{ - ExecutionError: &core.ExecutionError{ - Code: "OutputsNotFound", - Message: "Outputs not generated by task execution", - }, - IsRecoverable: true, - }, nil - } - - if !r.IsFile(ctx) { - // Read output and write to file - // No need to check for Execution Error here as we have done so above this block. - err = outputCommitter.Put(ctx, r) - if err != nil { - logger.Errorf(ctx, "Failed to commit output to remote location. Error: %v", err) - return cacheDisabled, nil, err - } - } - - p, err := t.ResolvePlugin(ctx, tk.Type, executionConfig) - if err != nil { - return cacheDisabled, nil, errors2.Wrapf(errors2.UnsupportedTaskTypeError, nodeID, err, "unable to resolve plugin") - } - writeToCatalog := !p.GetProperties().DisableNodeLevelCaching - - if !tk.Metadata.Discoverable || !writeToCatalog { - if !writeToCatalog { - logger.Infof(ctx, "Node level caching is disabled. Skipping catalog write.") - } - return cacheDisabled, nil, nil - } - - cacheVersion := "0" - if tk.Metadata != nil { - cacheVersion = tk.Metadata.DiscoveryVersion - } - - key := catalog.Key{ - Identifier: *tk.Id, - CacheVersion: cacheVersion, - TypedInterface: *tk.Interface, - InputReader: i, - } - - logger.Infof(ctx, "Catalog CacheEnabled. recording execution [%s/%s/%s/%s]", tk.Id.Project, tk.Id.Domain, tk.Id.Name, tk.Id.Version) - // ignores discovery write failures - var s catalog.Status - if executionConfig.OverwriteCache { - // Overwrite existing artifact (will create instead of no existing data was found) - s, err = t.catalog.Update(ctx, key, r, m) - } else { - // Explicitly create new artifact - s, err = t.catalog.Put(ctx, key, r, m) - } - if err != nil { - t.metrics.catalogPutFailureCount.Inc(ctx) - logger.Errorf(ctx, "Failed to write results to catalog for Task [%v]. Error: %v", tk.GetId(), err) - return catalog.NewStatus(core.CatalogCacheStatus_CACHE_PUT_FAILURE, s.GetMetadata()), nil, nil - } - t.metrics.catalogPutSuccessCount.Inc(ctx) - logger.Infof(ctx, "Successfully cached results to catalog - Task [%v]", tk.GetId()) - return s, nil, nil -} - -// ReleaseCatalogReservation attempts to release an artifact reservation if the task is cachable -// and cache serializable. If the reservation does not exist for this owner (e.x. it never existed -// or has been acquired by another owner) this call is still successful. -func (t *Handler) ReleaseCatalogReservation(ctx context.Context, ownerID string, tr pluginCore.TaskReader, inputReader io.InputReader) (catalog.ReservationEntry, error) { - tk, err := tr.Read(ctx) - if err != nil { - logger.Errorf(ctx, "Failed to read TaskTemplate, error :%s", err.Error()) - return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err - } - - if tk.Metadata.Discoverable && tk.Metadata.CacheSerializable { - logger.Infof(ctx, "Catalog CacheSerializeEnabled: releasing catalog reservation.") - key := catalog.Key{ - Identifier: *tk.Id, - CacheVersion: tk.Metadata.DiscoveryVersion, - TypedInterface: *tk.Interface, - InputReader: inputReader, - } - - err := t.catalog.ReleaseReservation(ctx, key, ownerID) - if err != nil { - t.metrics.reservationReleaseFailureCount.Inc(ctx) - logger.Errorf(ctx, "Catalog Failure: release reservation failed. err: %v", err.Error()) - return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err - } - - t.metrics.reservationReleaseSuccessCount.Inc(ctx) - return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_RELEASED), nil - } - logger.Infof(ctx, "Catalog CacheSerializeDisabled: for Task [%s/%s/%s/%s]", tk.Id.Project, tk.Id.Domain, tk.Id.Name, tk.Id.Version) - return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED), nil -} diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index b8c06ce3d..6e09c103b 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -229,8 +229,8 @@ func ToK8sTime(t time.Time) v1.Time { } func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n interfaces.NodeStateReader, s v1alpha1.ExecutableNodeStatus) { - // We update the phase only if it is not already updated - if np != s.GetPhase() { + // We update the phase and / or reason only if they are not already updated + if np != s.GetPhase() || p.GetReason() != s.GetMessage() { s.UpdatePhase(np, ToK8sTime(p.GetOccurredAt()), p.GetReason(), p.GetErr()) } // Update TaskStatus diff --git a/pkg/controller/workflow/executor_test.go b/pkg/controller/workflow/executor_test.go index ddd553dd8..6633b11b4 100644 --- a/pkg/controller/workflow/executor_test.go +++ b/pkg/controller/workflow/executor_test.go @@ -32,10 +32,10 @@ import ( eventsErr "github.com/flyteorg/flytepropeller/events/errors" eventMocks "github.com/flyteorg/flytepropeller/events/mocks" mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/catalog" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/factory" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" nodemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/interfaces/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/fakeplugins" wfErrors "github.com/flyteorg/flytepropeller/pkg/controller/workflow/errors"