diff --git a/dataproxy/service.go b/dataproxy/service.go index 948c6a25c4..5ab9e32a83 100644 --- a/dataproxy/service.go +++ b/dataproxy/service.go @@ -7,8 +7,15 @@ import ( "fmt" "net/url" "reflect" + "strconv" + "strings" "time" + "github.com/flyteorg/flyteadmin/pkg/common" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flyteadmin/pkg/errors" "google.golang.org/grpc/codes" @@ -37,6 +44,7 @@ type Service struct { dataStore *storage.DataStore shardSelector ioutils.ShardSelector nodeExecutionManager interfaces.NodeExecutionInterface + taskExecutionManager interfaces.TaskExecutionInterface } // CreateUploadLocation creates a temporary signed url to allow callers to upload content. @@ -133,9 +141,17 @@ func (s Service) CreateDownloadLink(ctx context.Context, req *service.CreateDown return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create a signed url. Error: %v", err) } + u := []string{signedURLResp.URL.String()} + ts := timestamppb.New(time.Now().Add(req.ExpiresIn.AsDuration())) + + // return &service.CreateDownloadLinkResponse{ - SignedUrl: []string{signedURLResp.URL.String()}, - ExpiresAt: timestamppb.New(time.Now().Add(req.ExpiresIn.AsDuration())), + SignedUrl: u, + ExpiresAt: ts, + PreSignedUrls: &service.PreSignedURLs{ + SignedUrl: []string{signedURLResp.URL.String()}, + ExpiresAt: ts, + }, }, nil } @@ -231,9 +247,117 @@ func createStorageLocation(ctx context.Context, store *storage.DataStore, return storagePath, nil } +func (s Service) validateResolveArtifactRequest(req *service.GetDataRequest) error { + if len(req.GetFlyteUrl()) == 0 { + return fmt.Errorf("source is required. Provided empty string") + } + if !strings.HasPrefix(req.GetFlyteUrl(), "flyte://") { + return fmt.Errorf("request does not start with the correct prefix") + } + + return nil +} + +func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID core.NodeExecutionIdentifier) (*core.TaskExecutionIdentifier, error) { + taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, admin.TaskExecutionListRequest{ + NodeExecutionId: &nodeExecID, + Limit: 1, + Filters: fmt.Sprintf("eq(retry_attempt,%s)", strconv.Itoa(attempt)), + }) + if err != nil { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to list task executions [%v]. Error: %v", nodeExecID, err) + } + if len(taskExecs.TaskExecutions) == 0 { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "no task executions were listed [%v]. Error: %v", nodeExecID, err) + } + taskExec := taskExecs.TaskExecutions[0] + return taskExec.Id, nil +} + +func (s Service) GetData(ctx context.Context, req *service.GetDataRequest) ( + *service.GetDataResponse, error) { + + logger.Debugf(ctx, "resolving flyte url query: %s", req.GetFlyteUrl()) + err := s.validateResolveArtifactRequest(req) + if err != nil { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to validate resolve artifact request. Error: %v", err) + } + + nodeExecID, attempt, ioType, err := common.ParseFlyteURL(req.GetFlyteUrl()) + if err != nil { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to parse artifact url Error: %v", err) + } + + // Get the data location, then decide how/where to fetch it from + if attempt == nil { + resp, err := s.nodeExecutionManager.GetNodeExecutionData(ctx, admin.NodeExecutionGetDataRequest{ + Id: &nodeExecID, + }) + if err != nil { + return nil, err + } + + var lm *core.LiteralMap + if ioType == common.ArtifactTypeI { + lm = resp.FullInputs + } else if ioType == common.ArtifactTypeO { + lm = resp.FullOutputs + } else { + // Assume deck, and create a download link request + dlRequest := service.CreateDownloadLinkRequest{ + ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK, + Source: &service.CreateDownloadLinkRequest_NodeExecutionId{NodeExecutionId: &nodeExecID}, + } + resp, err := s.CreateDownloadLink(ctx, &dlRequest) + if err != nil { + return nil, err + } + return &service.GetDataResponse{ + Data: &service.GetDataResponse_PreSignedUrls{ + PreSignedUrls: resp.PreSignedUrls, + }, + }, nil + } + + return &service.GetDataResponse{ + Data: &service.GetDataResponse_LiteralMap{ + LiteralMap: lm, + }, + }, nil + } + // Rest of the logic handles task attempt lookups + var lm *core.LiteralMap + taskExecID, err := s.GetTaskExecutionID(ctx, *attempt, nodeExecID) + if err != nil { + return nil, err + } + + reqT := admin.TaskExecutionGetDataRequest{ + Id: taskExecID, + } + resp, err := s.taskExecutionManager.GetTaskExecutionData(ctx, reqT) + if err != nil { + return nil, err + } + + if ioType == common.ArtifactTypeI { + lm = resp.FullInputs + } else if ioType == common.ArtifactTypeO { + lm = resp.FullOutputs + } else { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "deck type cannot be specified with a retry attempt, just use the node instead") + } + return &service.GetDataResponse{ + Data: &service.GetDataResponse_LiteralMap{ + LiteralMap: lm, + }, + }, nil +} + func NewService(cfg config.DataProxyConfig, nodeExec interfaces.NodeExecutionInterface, - dataStore *storage.DataStore) (Service, error) { + dataStore *storage.DataStore, + taskExec interfaces.TaskExecutionInterface) (Service, error) { // Context is not used in the constructor. Should ideally be removed. selector, err := ioutils.NewBase36PrefixShardSelector(context.TODO()) @@ -246,5 +370,6 @@ func NewService(cfg config.DataProxyConfig, dataStore: dataStore, shardSelector: selector, nodeExecutionManager: nodeExec, + taskExecutionManager: taskExec, }, nil } diff --git a/dataproxy/service_test.go b/dataproxy/service_test.go index 261b0f0863..db1c0e61d4 100644 --- a/dataproxy/service_test.go +++ b/dataproxy/service_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/golang/protobuf/proto" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -14,10 +15,10 @@ import ( commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" stdlibConfig "github.com/flyteorg/flytestdlib/config" - "google.golang.org/protobuf/types/known/durationpb" - + "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" + "google.golang.org/protobuf/types/known/durationpb" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" @@ -32,9 +33,10 @@ func TestNewService(t *testing.T) { assert.NoError(t, err) nodeExecutionManager := &mocks.MockNodeExecutionManager{} + taskExecutionManager := &mocks.MockTaskExecutionManager{} s, err := NewService(config.DataProxyConfig{ Upload: config.DataProxyUploadConfig{}, - }, nodeExecutionManager, dataStore) + }, nodeExecutionManager, dataStore, taskExecutionManager) assert.NoError(t, err) assert.NotNil(t, s) } @@ -57,7 +59,8 @@ func TestCreateUploadLocation(t *testing.T) { dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) nodeExecutionManager := &mocks.MockNodeExecutionManager{} - s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore) + taskExecutionManager := &mocks.MockTaskExecutionManager{} + s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager) assert.NoError(t, err) t.Run("No project/domain", func(t *testing.T) { _, err = s.CreateUploadLocation(context.Background(), &service.CreateUploadLocationRequest{}) @@ -92,8 +95,9 @@ func TestCreateDownloadLink(t *testing.T) { }, }, nil }) + taskExecutionManager := &mocks.MockTaskExecutionManager{} - s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore) + s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager) assert.NoError(t, err) t.Run("Invalid expiry", func(t *testing.T) { @@ -128,7 +132,8 @@ func TestCreateDownloadLink(t *testing.T) { func TestCreateDownloadLocation(t *testing.T) { dataStore := commonMocks.GetMockStorageClient() nodeExecutionManager := &mocks.MockNodeExecutionManager{} - s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore) + taskExecutionManager := &mocks.MockTaskExecutionManager{} + s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager) assert.NoError(t, err) t.Run("Invalid expiry", func(t *testing.T) { @@ -161,3 +166,155 @@ func TestCreateDownloadLocation(t *testing.T) { assert.NoError(t, err) }) } + +func TestService_GetData(t *testing.T) { + dataStore := commonMocks.GetMockStorageClient() + nodeExecutionManager := &mocks.MockNodeExecutionManager{} + taskExecutionManager := &mocks.MockTaskExecutionManager{} + s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager) + assert.NoError(t, err) + + inputsLM := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "input": { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: "hello", + }, + }, + }, + }, + }, + }, + }, + } + outputsLM := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "output": { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: "world", + }, + }, + }, + }, + }, + }, + }, + } + + nodeExecutionManager.SetGetNodeExecutionDataFunc( + func(ctx context.Context, request admin.NodeExecutionGetDataRequest) (*admin.NodeExecutionGetDataResponse, error) { + return &admin.NodeExecutionGetDataResponse{ + FullInputs: inputsLM, + FullOutputs: outputsLM, + }, nil + }, + ) + taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) { + return &admin.TaskExecutionList{ + TaskExecutions: []*admin.TaskExecution{ + { + Id: &core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "proj", + Domain: "dev", + Name: "task", + Version: "v1", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "proj", + Domain: "dev", + Name: "wfexecid", + }, + }, + RetryAttempt: 5, + }, + }, + }, + }, nil + }) + taskExecutionManager.SetGetTaskExecutionDataCallback(func(ctx context.Context, request admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) { + return &admin.TaskExecutionGetDataResponse{ + FullInputs: inputsLM, + FullOutputs: outputsLM, + }, nil + }) + + t.Run("get a working set of urls without retry attempt", func(t *testing.T) { + res, err := s.GetData(context.Background(), &service.GetDataRequest{ + FlyteUrl: "flyte://v1/proj/dev/wfexecid/n0-d0/i", + }) + assert.NoError(t, err) + assert.True(t, proto.Equal(inputsLM, res.GetLiteralMap())) + assert.Nil(t, res.GetPreSignedUrls()) + }) + + t.Run("get a working set of urls with a retry attempt", func(t *testing.T) { + res, err := s.GetData(context.Background(), &service.GetDataRequest{ + FlyteUrl: "flyte://v1/proj/dev/wfexecid/n0-d0/5/o", + }) + assert.NoError(t, err) + assert.True(t, proto.Equal(outputsLM, res.GetLiteralMap())) + assert.Nil(t, res.GetPreSignedUrls()) + }) + + t.Run("Bad URL", func(t *testing.T) { + _, err = s.GetData(context.Background(), &service.GetDataRequest{ + FlyteUrl: "flyte://v3/blah/lorem/m0-fdj", + }) + assert.Error(t, err) + }) +} + +func TestService_Error(t *testing.T) { + dataStore := commonMocks.GetMockStorageClient() + nodeExecutionManager := &mocks.MockNodeExecutionManager{} + taskExecutionManager := &mocks.MockTaskExecutionManager{} + s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager) + assert.NoError(t, err) + + t.Run("get a working set of urls without retry attempt", func(t *testing.T) { + taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) { + return nil, errors.NewFlyteAdminErrorf(1, "not found") + }) + nodeExecID := core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "proj", + Domain: "dev", + Name: "wfexecid", + }, + } + _, err := s.GetTaskExecutionID(context.Background(), 0, nodeExecID) + assert.Error(t, err, "failed to list") + }) + + t.Run("get a working set of urls without retry attempt", func(t *testing.T) { + taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) { + return &admin.TaskExecutionList{ + TaskExecutions: nil, + Token: "", + }, nil + }) + nodeExecID := core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "proj", + Domain: "dev", + Name: "wfexecid", + }, + } + _, err := s.GetTaskExecutionID(context.Background(), 0, nodeExecID) + assert.Error(t, err, "no task executions") + }) +} diff --git a/flyteadmin_config.yaml b/flyteadmin_config.yaml index 964f83a818..e3d19f7326 100644 --- a/flyteadmin_config.yaml +++ b/flyteadmin_config.yaml @@ -114,7 +114,7 @@ externalEvents: eventTypes: all Logger: show-source: true - level: 6 + level: 5 storage: type: stow stow: @@ -129,7 +129,7 @@ storage: secret_key: miniostorage signedUrl: stowConfigOverride: - endpoint: http://localhost:30084 + endpoint: http://localhost:30002 cache: max_size_mbs: 10 target_gc_percent: 100 diff --git a/go.mod b/go.mod index 6710d33bbe..6d78ffdd1b 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/cloudevents/sdk-go/v2 v2.8.0 github.com/coreos/go-oidc v2.2.1+incompatible github.com/evanphx/json-patch v4.12.0+incompatible - github.com/flyteorg/flyteidl v1.3.14 + github.com/flyteorg/flyteidl v1.5.0 github.com/flyteorg/flyteplugins v1.0.40 github.com/flyteorg/flytepropeller v1.1.70 github.com/flyteorg/flytestdlib v1.0.15 diff --git a/go.sum b/go.sum index f2fe7c70b8..8e166ce641 100644 --- a/go.sum +++ b/go.sum @@ -312,8 +312,8 @@ github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.3.14 h1:o5M0g/r6pXTPu5PEurbYxbQmuOu3hqqsaI2M6uvK0N8= -github.com/flyteorg/flyteidl v1.3.14/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= +github.com/flyteorg/flyteidl v1.5.0 h1:vdaA5Cg9eqi5NMuASSod/AE7RXlHvzdWjSL9abDyd/M= +github.com/flyteorg/flyteidl v1.5.0/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I= github.com/flyteorg/flyteplugins v1.0.40 h1:RTsYingqmqr13qBbi4CB2ArXDHNHUOkAF+HTLJQiQ/s= github.com/flyteorg/flyteplugins v1.0.40/go.mod h1:qyUPqVspLcLGJpKxVwHDWf+kBpOGuItOxCaF6zAmDio= github.com/flyteorg/flytepropeller v1.1.70 h1:/d1qqz13rdVADM85ST70eerAdBstJJz9UUB/mNSZi0w= diff --git a/pkg/common/artifacttype_enumer.go b/pkg/common/artifacttype_enumer.go new file mode 100644 index 0000000000..6847e6b7dd --- /dev/null +++ b/pkg/common/artifacttype_enumer.go @@ -0,0 +1,51 @@ +// Code generated by "enumer --type=ArtifactType --trimprefix=ArtifactType -transform=snake"; DO NOT EDIT. + +package common + +import ( + "fmt" +) + +const _ArtifactTypeName = "undefinediod" + +var _ArtifactTypeIndex = [...]uint8{0, 9, 10, 11, 12} + +func (i ArtifactType) String() string { + if i < 0 || i >= ArtifactType(len(_ArtifactTypeIndex)-1) { + return fmt.Sprintf("ArtifactType(%d)", i) + } + return _ArtifactTypeName[_ArtifactTypeIndex[i]:_ArtifactTypeIndex[i+1]] +} + +var _ArtifactTypeValues = []ArtifactType{0, 1, 2, 3} + +var _ArtifactTypeNameToValueMap = map[string]ArtifactType{ + _ArtifactTypeName[0:9]: 0, + _ArtifactTypeName[9:10]: 1, + _ArtifactTypeName[10:11]: 2, + _ArtifactTypeName[11:12]: 3, +} + +// ArtifactTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func ArtifactTypeString(s string) (ArtifactType, error) { + if val, ok := _ArtifactTypeNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to ArtifactType values", s) +} + +// ArtifactTypeValues returns all values of the enum +func ArtifactTypeValues() []ArtifactType { + return _ArtifactTypeValues +} + +// IsAArtifactType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i ArtifactType) IsAArtifactType() bool { + for _, v := range _ArtifactTypeValues { + if i == v { + return true + } + } + return false +} diff --git a/pkg/common/flyte_url.go b/pkg/common/flyte_url.go new file mode 100644 index 0000000000..b3094d586d --- /dev/null +++ b/pkg/common/flyte_url.go @@ -0,0 +1,98 @@ +package common + +import ( + "fmt" + "regexp" + "strconv" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" +) + +// transform to snake case to make lower case +//go:generate enumer --type=ArtifactType --trimprefix=ArtifactType -transform=snake + +type ArtifactType int + +// The suffixes in these constants are used to match against the tail end of the flyte url, to keep tne flyte url simpler +const ( + ArtifactTypeUndefined ArtifactType = iota + ArtifactTypeI // inputs + ArtifactTypeO // outputs + ArtifactTypeD // deck +) + +var re = regexp.MustCompile("flyte://v1/(?P[a-zA-Z0-9_-]+)/(?P[a-zA-Z0-9_-]+)/(?P[a-zA-Z0-9_-]+)/(?P[a-zA-Z0-9_-]+)(?:/(?P[0-9]+))?/(?P[iod])$") + +func MatchRegex(reg *regexp.Regexp, input string) map[string]string { + names := reg.SubexpNames() + res := reg.FindAllStringSubmatch(input, -1) + if len(res) == 0 { + return nil + } + dict := make(map[string]string, len(names)) + for i := 1; i < len(res[0]); i++ { + dict[names[i]] = res[0][i] + } + return dict +} + +func ParseFlyteURL(flyteURL string) (core.NodeExecutionIdentifier, *int, ArtifactType, error) { + // flyteURL is of the form flyte://v1/project/domain/execution_id/node_id/attempt/[iod] + // where i stands for inputs.pb o for outputs.pb and d for the flyte deck + // If the retry attempt is missing, the io requested is assumed to be for the node instead of the task execution + matches := MatchRegex(re, flyteURL) + proj := matches["project"] + domain := matches["domain"] + executionID := matches["exec"] + nodeID := matches["node"] + var attemptPtr *int // nil means node execution, not a task execution + if attempt := matches["attempt"]; len(attempt) > 0 { + a, err := strconv.Atoi(attempt) + if err != nil { + return core.NodeExecutionIdentifier{}, nil, ArtifactTypeUndefined, fmt.Errorf("failed to parse attempt [%v], %v", attempt, err) + } + attemptPtr = &a + } + ioType, err := ArtifactTypeString(matches["artifactType"]) + if err != nil { + return core.NodeExecutionIdentifier{}, nil, ArtifactTypeUndefined, err + } + + return core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: proj, + Domain: domain, + Name: executionID, + }, + }, attemptPtr, ioType, nil +} + +func FlyteURLsFromNodeExecutionID(nodeExecutionID core.NodeExecutionIdentifier, deck bool) *admin.FlyteURLs { + base := fmt.Sprintf("flyte://v1/%s/%s/%s/%s", nodeExecutionID.ExecutionId.Project, + nodeExecutionID.ExecutionId.Domain, nodeExecutionID.ExecutionId.Name, nodeExecutionID.NodeId) + + res := &admin.FlyteURLs{ + Inputs: fmt.Sprintf("%s/%s", base, ArtifactTypeI), + Outputs: fmt.Sprintf("%s/%s", base, ArtifactTypeO), + } + if deck { + res.Deck = fmt.Sprintf("%s/%s", base, ArtifactTypeD) + } + return res +} + +func FlyteURLsFromTaskExecutionID(taskExecutionID core.TaskExecutionIdentifier, deck bool) *admin.FlyteURLs { + base := fmt.Sprintf("flyte://v1/%s/%s/%s/%s/%s", taskExecutionID.NodeExecutionId.ExecutionId.Project, + taskExecutionID.NodeExecutionId.ExecutionId.Domain, taskExecutionID.NodeExecutionId.ExecutionId.Name, taskExecutionID.NodeExecutionId.NodeId, strconv.Itoa(int(taskExecutionID.RetryAttempt))) + + res := &admin.FlyteURLs{ + Inputs: fmt.Sprintf("%s/%s", base, ArtifactTypeI), + Outputs: fmt.Sprintf("%s/%s", base, ArtifactTypeO), + } + if deck { + res.Deck = fmt.Sprintf("%s/%s", base, ArtifactTypeD) + } + return res +} diff --git a/pkg/common/flyte_url_test.go b/pkg/common/flyte_url_test.go new file mode 100644 index 0000000000..378860cf1d --- /dev/null +++ b/pkg/common/flyte_url_test.go @@ -0,0 +1,174 @@ +package common + +import ( + "testing" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" +) + +func TestParseFlyteUrl(t *testing.T) { + t.Run("valid", func(t *testing.T) { + ne, attempt, kind, err := ParseFlyteURL("flyte://v1/fs/dev/abc/n0/0/o") + assert.NoError(t, err) + assert.Equal(t, 0, *attempt) + assert.Equal(t, ArtifactTypeO, kind) + assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "fs", + Domain: "dev", + Name: "abc", + }, + }, &ne)) + ne, attempt, kind, err = ParseFlyteURL("flyte://v1/fs/dev/abc/n0/i") + assert.NoError(t, err) + assert.Nil(t, attempt) + assert.Equal(t, ArtifactTypeI, kind) + assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "fs", + Domain: "dev", + Name: "abc", + }, + }, &ne)) + + ne, attempt, kind, err = ParseFlyteURL("flyte://v1/fs/dev/abc/n0/d") + assert.NoError(t, err) + assert.Nil(t, attempt) + assert.Equal(t, ArtifactTypeD, kind) + assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "fs", + Domain: "dev", + Name: "abc", + }, + }, &ne)) + + ne, attempt, kind, err = ParseFlyteURL("flyte://v1/fs/dev/abc/n0-dn0-9-n0-n0/d") + assert.NoError(t, err) + assert.Nil(t, attempt) + assert.Equal(t, ArtifactTypeD, kind) + assert.True(t, proto.Equal(&core.NodeExecutionIdentifier{ + NodeId: "n0-dn0-9-n0-n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "fs", + Domain: "dev", + Name: "abc", + }, + }, &ne)) + }) + + t.Run("invalid", func(t *testing.T) { + // more than one character + _, attempt, kind, err := ParseFlyteURL("flyte://v1/fs/dev/abc/n0/0/od") + assert.Error(t, err) + assert.Nil(t, attempt) + assert.Equal(t, ArtifactTypeUndefined, kind) + + _, attempt, kind, err = ParseFlyteURL("flyte://v1/fs/dev/abc/n0/input") + assert.Error(t, err) + assert.Nil(t, attempt) + assert.Equal(t, ArtifactTypeUndefined, kind) + + // non integer for attempt + _, attempt, kind, err = ParseFlyteURL("flyte://v1/fs/dev/ab/n0/a/i") + assert.Error(t, err) + assert.Nil(t, attempt) + assert.Equal(t, ArtifactTypeUndefined, kind) + }) +} + +func TestFlyteURLsFromNodeExecutionID(t *testing.T) { + t.Run("with deck", func(t *testing.T) { + ne := core.NodeExecutionIdentifier{ + NodeId: "n0-dn0-n1", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "fs", + Domain: "dev", + Name: "abc", + }, + } + urls := FlyteURLsFromNodeExecutionID(ne, true) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0-dn0-n1/i", urls.GetInputs()) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0-dn0-n1/o", urls.GetOutputs()) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0-dn0-n1/d", urls.GetDeck()) + }) + + t.Run("without deck", func(t *testing.T) { + ne := core.NodeExecutionIdentifier{ + NodeId: "n0-dn0-n1", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "fs", + Domain: "dev", + Name: "abc", + }, + } + urls := FlyteURLsFromNodeExecutionID(ne, false) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0-dn0-n1/i", urls.GetInputs()) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0-dn0-n1/o", urls.GetOutputs()) + assert.Equal(t, "", urls.GetDeck()) + }) +} + +func TestFlyteURLsFromTaskExecutionID(t *testing.T) { + t.Run("with deck", func(t *testing.T) { + te := core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "fs", + Domain: "dev", + Name: "abc", + Version: "v1", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "fs", + Domain: "dev", + Name: "abc", + }, + }, + RetryAttempt: 1, + } + urls := FlyteURLsFromTaskExecutionID(te, true) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0/1/i", urls.GetInputs()) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0/1/o", urls.GetOutputs()) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0/1/d", urls.GetDeck()) + }) + + t.Run("without deck", func(t *testing.T) { + te := core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "fs", + Domain: "dev", + Name: "abc", + Version: "v1", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: "n0", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "fs", + Domain: "dev", + Name: "abc", + }, + }, + } + urls := FlyteURLsFromTaskExecutionID(te, false) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0/0/i", urls.GetInputs()) + assert.Equal(t, "flyte://v1/fs/dev/abc/n0/0/o", urls.GetOutputs()) + assert.Equal(t, "", urls.GetDeck()) + }) +} + +func TestMatchRegexDirectly(t *testing.T) { + result := MatchRegex(re, "flyte://v1/fs/dev/abc/n0-dn0-9-n0-n0/i") + assert.Equal(t, "", result["attempt"]) + + result = MatchRegex(re, "flyteff://v2/fs/dfdsaev/abc/n0-dn0-9-n0-n0/i") + assert.Nil(t, result) +} diff --git a/pkg/manager/impl/node_execution_manager.go b/pkg/manager/impl/node_execution_manager.go index ae5a7bb400..bcc4362dbd 100644 --- a/pkg/manager/impl/node_execution_manager.go +++ b/pkg/manager/impl/node_execution_manager.go @@ -521,6 +521,7 @@ func (m *NodeExecutionManager) GetNodeExecutionData( Outputs: outputURLBlob, FullInputs: inputs, FullOutputs: outputs, + FlyteUrls: common.FlyteURLsFromNodeExecutionID(*request.Id, nodeExecution.GetClosure() != nil && nodeExecution.GetClosure().GetDeckUri() != ""), } if len(nodeExecutionModel.DynamicWorkflowRemoteClosureReference) > 0 { diff --git a/pkg/manager/impl/node_execution_manager_test.go b/pkg/manager/impl/node_execution_manager_test.go index a1c43c36ba..1348803475 100644 --- a/pkg/manager/impl/node_execution_manager_test.go +++ b/pkg/manager/impl/node_execution_manager_test.go @@ -1314,5 +1314,10 @@ func TestGetNodeExecutionData(t *testing.T) { Id: dynamicWorkflowClosure.Primary.Template.Id, CompiledWorkflow: &dynamicWorkflowClosure, }, + FlyteUrls: &admin.FlyteURLs{ + Inputs: "flyte://v1/project/domain/name/node id/i", + Outputs: "flyte://v1/project/domain/name/node id/o", + Deck: "flyte://v1/project/domain/name/node id/d", + }, }, dataResponse)) } diff --git a/pkg/manager/impl/task_execution_manager.go b/pkg/manager/impl/task_execution_manager.go index 60825309b8..46967f264a 100644 --- a/pkg/manager/impl/task_execution_manager.go +++ b/pkg/manager/impl/task_execution_manager.go @@ -331,6 +331,7 @@ func (m *TaskExecutionManager) GetTaskExecutionData( Outputs: outputURLBlob, FullInputs: inputs, FullOutputs: outputs, + FlyteUrls: common.FlyteURLsFromTaskExecutionID(*request.Id, false), } m.metrics.TaskExecutionInputBytes.Observe(float64(response.Inputs.Bytes)) diff --git a/pkg/manager/impl/task_execution_manager_test.go b/pkg/manager/impl/task_execution_manager_test.go index cc59012bb7..4c190d0bff 100644 --- a/pkg/manager/impl/task_execution_manager_test.go +++ b/pkg/manager/impl/task_execution_manager_test.go @@ -958,5 +958,10 @@ func TestGetTaskExecutionData(t *testing.T) { }, FullInputs: fullInputs, FullOutputs: fullOutputs, + FlyteUrls: &admin.FlyteURLs{ + Inputs: "flyte://v1/project/domain/name/node-id/1/i", + Outputs: "flyte://v1/project/domain/name/node-id/1/o", + Deck: "", + }, }, dataResponse)) } diff --git a/pkg/server/service.go b/pkg/server/service.go index 4f1f58ffb5..1fe2f57c14 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -119,7 +119,7 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c service.RegisterIdentityServiceServer(grpcServer, authCtx.IdentityService()) } - dataProxySvc, err := dataproxy.NewService(cfg.DataProxy, adminServer.NodeExecutionManager, dataStorageClient) + dataProxySvc, err := dataproxy.NewService(cfg.DataProxy, adminServer.NodeExecutionManager, dataStorageClient, adminServer.TaskExecutionManager) if err != nil { return nil, fmt.Errorf("failed to initialize dataProxy service. Error: %w", err) }