From dbad94c71097d4a3de387000deb1efa3d2f984f2 Mon Sep 17 00:00:00 2001 From: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Date: Mon, 29 Apr 2024 07:09:22 -0700 Subject: [PATCH] [BUG] populate source id metadata for launch plans and subworkflows (#240) * populate source id metadata for launch plans and subworkflows Signed-off-by: Paul Dittamo * revert lint Signed-off-by: Paul Dittamo * revert lint Signed-off-by: Paul Dittamo * revert lint Signed-off-by: Paul Dittamo --------- Signed-off-by: Paul Dittamo --- .../catalog/cacheservice/cache_client.go | 2 +- .../nodes/catalog/cacheservice/transformer.go | 39 ++++++---- .../catalog/cacheservice/transformer_test.go | 74 +++++++++++++++++++ .../pkg/controller/nodes/catalog/config.go | 2 +- .../nodes/catalog/datacatalog/datacatalog.go | 6 +- .../catalog/datacatalog/datacatalog_test.go | 5 +- .../nodes/catalog/datacatalog/transformer.go | 36 +++++---- .../catalog/datacatalog/transformer_test.go | 24 +++++- 8 files changed, 152 insertions(+), 36 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/catalog/cacheservice/cache_client.go b/flytepropeller/pkg/controller/nodes/catalog/cacheservice/cache_client.go index 19f5eca56f..40935be7ba 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/cacheservice/cache_client.go +++ b/flytepropeller/pkg/controller/nodes/catalog/cacheservice/cache_client.go @@ -221,7 +221,7 @@ func NewCacheClient(ctx context.Context, dataStore *storage.DataStore, endpoint opts = append(opts, authOpt...) } - grpcOptions := []grpcRetry.CallOption{ + grpcOptions := []grpcRetry.CallOption{ grpcRetry.WithBackoff(grpcRetry.BackoffExponentialWithJitter(time.Duration(backoffScalar)*time.Millisecond, backoffJitter)), grpcRetry.WithCodes(codes.DeadlineExceeded, codes.Unavailable, codes.Canceled), grpcRetry.WithMax(maxRetries), diff --git a/flytepropeller/pkg/controller/nodes/catalog/cacheservice/transformer.go b/flytepropeller/pkg/controller/nodes/catalog/cacheservice/transformer.go index 0aceb68219..57c1b97e97 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/cacheservice/transformer.go +++ b/flytepropeller/pkg/controller/nodes/catalog/cacheservice/transformer.go @@ -43,27 +43,40 @@ func GenerateCacheKey(ctx context.Context, key catalog.Key) (string, error) { return cacheKey, nil } -// GenerateCacheMetadata creates a cache metadata to identify the source of a cache entry +// GenerateCacheMetadata creates cache metadata to identify the source of a cache entry func GenerateCacheMetadata(key catalog.Key, metadata catalog.Metadata) *cacheservice.Metadata { - if metadata.TaskExecutionIdentifier == nil { + if metadata.TaskExecutionIdentifier != nil { return &cacheservice.Metadata{ SourceIdentifier: &key.Identifier, + KeyMap: &cacheservice.KeyMapMetadata{ + Values: map[string]string{ + datacatalog.TaskVersionKey: metadata.TaskExecutionIdentifier.TaskId.GetVersion(), + datacatalog.ExecProjectKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetExecutionId().GetProject(), + datacatalog.ExecDomainKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetExecutionId().GetDomain(), + datacatalog.ExecNameKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetExecutionId().GetName(), + datacatalog.ExecNodeIDKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetNodeId(), + datacatalog.ExecTaskAttemptKey: strconv.Itoa(int(metadata.TaskExecutionIdentifier.GetRetryAttempt())), + datacatalog.ExecOrgKey: metadata.TaskExecutionIdentifier.GetNodeExecutionId().GetExecutionId().GetOrg(), + }, + }, + } + } else if metadata.NodeExecutionIdentifier != nil { + return &cacheservice.Metadata{ + SourceIdentifier: &key.Identifier, + KeyMap: &cacheservice.KeyMapMetadata{ + Values: map[string]string{ + datacatalog.ExecProjectKey: metadata.NodeExecutionIdentifier.GetExecutionId().GetProject(), + datacatalog.ExecDomainKey: metadata.NodeExecutionIdentifier.GetExecutionId().GetDomain(), + datacatalog.ExecNameKey: metadata.NodeExecutionIdentifier.GetExecutionId().GetName(), + datacatalog.ExecNodeIDKey: metadata.NodeExecutionIdentifier.GetNodeId(), + datacatalog.ExecOrgKey: metadata.NodeExecutionIdentifier.GetExecutionId().GetOrg(), + }, + }, } } return &cacheservice.Metadata{ SourceIdentifier: &key.Identifier, - KeyMap: &cacheservice.KeyMapMetadata{ - Values: map[string]string{ - datacatalog.TaskVersionKey: metadata.TaskExecutionIdentifier.TaskId.GetVersion(), - datacatalog.ExecProjectKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetExecutionId().GetProject(), - datacatalog.ExecDomainKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetExecutionId().GetDomain(), - datacatalog.ExecNameKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetExecutionId().GetName(), - datacatalog.ExecNodeIDKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetNodeId(), - datacatalog.ExecTaskAttemptKey: strconv.Itoa(int(metadata.TaskExecutionIdentifier.GetRetryAttempt())), - datacatalog.ExecOrgKey: metadata.TaskExecutionIdentifier.GetNodeExecutionId().GetExecutionId().GetOrg(), - }, - }, } } diff --git a/flytepropeller/pkg/controller/nodes/catalog/cacheservice/transformer_test.go b/flytepropeller/pkg/controller/nodes/catalog/cacheservice/transformer_test.go index b9fbd8e4fd..ff7ba806a2 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/cacheservice/transformer_test.go +++ b/flytepropeller/pkg/controller/nodes/catalog/cacheservice/transformer_test.go @@ -2,6 +2,7 @@ package cacheservice import ( "context" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -11,6 +12,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/catalog/datacatalog" ) func TestGenerateCachedTaskKey(t *testing.T) { @@ -118,3 +120,75 @@ func TestGenerateCachedTaskKey(t *testing.T) { }) } } + +func TestGenerateCacheMetadata(t *testing.T) { + + tID := &core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Name: "x", + Project: "project", + Domain: "development", + Version: "ver", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "wf", + Project: "p1", + Domain: "d1", + }, + NodeId: "n", + }, + RetryAttempt: 1, + } + nID := &core.NodeExecutionIdentifier{ + NodeId: "n", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "wf", + Project: "p1", + Domain: "d1", + Org: "org", + }, + } + + testCases := []struct { + name string + key catalog.Key + metadata catalog.Metadata + expected map[string]string + }{ + { + name: "task execution identifier", + key: catalog.Key{}, + metadata: catalog.Metadata{TaskExecutionIdentifier: tID}, + expected: map[string]string{ + datacatalog.ExecTaskAttemptKey: strconv.Itoa(int(tID.RetryAttempt)), + datacatalog.ExecProjectKey: tID.NodeExecutionId.ExecutionId.Project, + datacatalog.ExecDomainKey: tID.NodeExecutionId.ExecutionId.Domain, + datacatalog.ExecNodeIDKey: tID.NodeExecutionId.NodeId, + datacatalog.ExecNameKey: tID.NodeExecutionId.ExecutionId.Name, + datacatalog.ExecOrgKey: tID.NodeExecutionId.ExecutionId.Org, + datacatalog.TaskVersionKey: tID.TaskId.Version, + }, + }, + { + name: "node execution identifier", + key: catalog.Key{}, + metadata: catalog.Metadata{NodeExecutionIdentifier: nID}, + expected: map[string]string{ + datacatalog.ExecNodeIDKey: nID.NodeId, + datacatalog.ExecDomainKey: nID.ExecutionId.Domain, + datacatalog.ExecNameKey: nID.ExecutionId.Name, + datacatalog.ExecProjectKey: nID.ExecutionId.Project, + datacatalog.ExecOrgKey: nID.ExecutionId.Org, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + meta := GenerateCacheMetadata(tc.key, tc.metadata) + assert.Equal(t, tc.expected, meta.KeyMap.Values) + }) + } +} diff --git a/flytepropeller/pkg/controller/nodes/catalog/config.go b/flytepropeller/pkg/controller/nodes/catalog/config.go index 3108df17a9..4c82404eca 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/config.go +++ b/flytepropeller/pkg/controller/nodes/catalog/config.go @@ -11,8 +11,8 @@ import ( "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/catalog/cacheservice" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/catalog/datacatalog" "github.com/flyteorg/flyte/flytestdlib/config" - "github.com/flyteorg/flyte/flytestdlib/storage" "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/storage" ) //go:generate pflags Config --default-var defaultConfig diff --git a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog.go b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog.go index be3c467a55..481eacac0a 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog.go +++ b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog.go @@ -224,7 +224,7 @@ func (m *CatalogClient) CreateArtifact(ctx context.Context, key catalog.Key, dat Id: string(uuid.NewUUID()), Dataset: datasetID, Data: artifactDataList, - Metadata: GetArtifactMetadataForSource(metadata.TaskExecutionIdentifier), + Metadata: GetArtifactMetadataForSource(metadata), } createArtifactRequest := &datacatalog.CreateArtifactRequest{Artifact: cachedArtifact} @@ -286,7 +286,7 @@ func (m *CatalogClient) UpdateArtifact(ctx context.Context, key catalog.Key, dat Dataset: datasetID, QueryHandle: &datacatalog.UpdateArtifactRequest_TagName{TagName: tagName}, Data: artifactDataList, - Metadata: GetArtifactMetadataForSource(metadata.TaskExecutionIdentifier), + Metadata: GetArtifactMetadataForSource(metadata), } resp, err := m.client.UpdateArtifact(ctx, updateArtifactRequest) if err != nil { @@ -300,7 +300,7 @@ func (m *CatalogClient) UpdateArtifact(ctx context.Context, key catalog.Key, dat ArtifactId: resp.GetArtifactId(), } - source, err := GetSourceFromMetadata(GetDatasetMetadataForSource(metadata.TaskExecutionIdentifier), GetArtifactMetadataForSource(metadata.TaskExecutionIdentifier), key.Identifier) + source, err := GetSourceFromMetadata(GetDatasetMetadataForSource(metadata.TaskExecutionIdentifier), GetArtifactMetadataForSource(metadata), key.Identifier) if err != nil { return catalog.Status{}, fmt.Errorf("failed to get source from metadata. Error: %w", err) } diff --git a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog_test.go b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog_test.go index 3a6ca16268..b30f077538 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog_test.go +++ b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog_test.go @@ -186,6 +186,9 @@ func TestCatalog_Get(t *testing.T) { Id: datasetID, Metadata: GetDatasetMetadataForSource(taskID), } + sampleMetadata := catalog.Metadata{ + TaskExecutionIdentifier: taskID, + } mockClient.On("GetDataset", ctx, @@ -199,7 +202,7 @@ func TestCatalog_Get(t *testing.T) { Id: "test-artifact", Dataset: sampleDataSet.Id, Data: []*datacatalog.ArtifactData{sampleArtifactData}, - Metadata: GetArtifactMetadataForSource(taskID), + Metadata: GetArtifactMetadataForSource(sampleMetadata), Tags: []*datacatalog.Tag{ { Name: "x", diff --git a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go index ef66b23037..c93d6ef09e 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go +++ b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go @@ -200,20 +200,30 @@ func GetDatasetMetadataForSource(taskExecutionID *core.TaskExecutionIdentifier) } } -func GetArtifactMetadataForSource(taskExecutionID *core.TaskExecutionIdentifier) *datacatalog.Metadata { - if taskExecutionID == nil { - return &datacatalog.Metadata{} - } - return &datacatalog.Metadata{ - KeyMap: map[string]string{ - ExecProjectKey: taskExecutionID.NodeExecutionId.GetExecutionId().GetProject(), - ExecDomainKey: taskExecutionID.NodeExecutionId.GetExecutionId().GetDomain(), - ExecNameKey: taskExecutionID.NodeExecutionId.GetExecutionId().GetName(), - ExecNodeIDKey: taskExecutionID.NodeExecutionId.GetNodeId(), - ExecTaskAttemptKey: strconv.Itoa(int(taskExecutionID.GetRetryAttempt())), - ExecOrgKey: taskExecutionID.GetNodeExecutionId().GetExecutionId().GetOrg(), - }, +func GetArtifactMetadataForSource(metadata catalog.Metadata) *datacatalog.Metadata { + if metadata.TaskExecutionIdentifier != nil { + return &datacatalog.Metadata{ + KeyMap: map[string]string{ + ExecProjectKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetExecutionId().GetProject(), + ExecDomainKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetExecutionId().GetDomain(), + ExecNameKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetExecutionId().GetName(), + ExecNodeIDKey: metadata.TaskExecutionIdentifier.NodeExecutionId.GetNodeId(), + ExecTaskAttemptKey: strconv.Itoa(int(metadata.TaskExecutionIdentifier.GetRetryAttempt())), + ExecOrgKey: metadata.TaskExecutionIdentifier.GetNodeExecutionId().GetExecutionId().GetOrg(), + }, + } + } else if metadata.NodeExecutionIdentifier != nil { + return &datacatalog.Metadata{ + KeyMap: map[string]string{ + ExecProjectKey: metadata.NodeExecutionIdentifier.GetExecutionId().GetProject(), + ExecDomainKey: metadata.NodeExecutionIdentifier.GetExecutionId().GetDomain(), + ExecNameKey: metadata.NodeExecutionIdentifier.GetExecutionId().GetName(), + ExecNodeIDKey: metadata.NodeExecutionIdentifier.GetNodeId(), + ExecOrgKey: metadata.NodeExecutionIdentifier.GetExecutionId().GetOrg(), + }, + } } + return &datacatalog.Metadata{} } // GetSourceFromMetadata returns the Source TaskExecutionIdentifier from the catalog metadata diff --git a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go index 3f9b44a448..f953e387ef 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go +++ b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go @@ -147,7 +147,7 @@ func TestGetOrDefault(t *testing.T) { func TestGetArtifactMetadataForSource(t *testing.T) { type args struct { - taskExecutionID *core.TaskExecutionIdentifier + metadata catalog.Metadata } tID := &core.TaskExecutionIdentifier{ @@ -168,6 +168,15 @@ func TestGetArtifactMetadataForSource(t *testing.T) { }, RetryAttempt: 1, } + nID := &core.NodeExecutionIdentifier{ + NodeId: "n", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "wf", + Project: "p1", + Domain: "d1", + Org: "org", + }, + } tests := []struct { name string @@ -175,7 +184,7 @@ func TestGetArtifactMetadataForSource(t *testing.T) { want map[string]string }{ {"nil TaskExec", args{}, nil}, - {"TaskExec", args{tID}, map[string]string{ + {"TaskExec", args{catalog.Metadata{TaskExecutionIdentifier: tID}}, map[string]string{ ExecTaskAttemptKey: strconv.Itoa(int(tID.RetryAttempt)), ExecProjectKey: tID.NodeExecutionId.ExecutionId.Project, ExecDomainKey: tID.NodeExecutionId.ExecutionId.Domain, @@ -183,10 +192,17 @@ func TestGetArtifactMetadataForSource(t *testing.T) { ExecNameKey: tID.NodeExecutionId.ExecutionId.Name, ExecOrgKey: tID.NodeExecutionId.ExecutionId.Org, }}, + {"NodeExec", args{catalog.Metadata{NodeExecutionIdentifier: nID}}, map[string]string{ + ExecNodeIDKey: nID.NodeId, + ExecDomainKey: nID.ExecutionId.Domain, + ExecNameKey: nID.ExecutionId.Name, + ExecProjectKey: nID.ExecutionId.Project, + ExecOrgKey: nID.ExecutionId.Org, + }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := GetArtifactMetadataForSource(tt.args.taskExecutionID); !reflect.DeepEqual(got.KeyMap, tt.want) { + if got := GetArtifactMetadataForSource(tt.args.metadata); !reflect.DeepEqual(got.KeyMap, tt.want) { t.Errorf("GetMetadataForSource() = %v, want %v", got.KeyMap, tt.want) } }) @@ -270,7 +286,7 @@ func TestGetSourceFromMetadata(t *testing.T) { RetryAttempt: 0, }}, // Completely available - {"latest", args{datasetMd: GetDatasetMetadataForSource(&tID).KeyMap, artifactMd: GetArtifactMetadataForSource(&tID).KeyMap, currentID: currentTaskID}, &tID}, + {"latest", args{datasetMd: GetDatasetMetadataForSource(&tID).KeyMap, artifactMd: GetArtifactMetadataForSource(catalog.Metadata{TaskExecutionIdentifier: &tID}).KeyMap, currentID: currentTaskID}, &tID}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {