Skip to content

Commit

Permalink
Tiny url improvements (flyteorg#565)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored May 23, 2023
1 parent ca0ebf1 commit 6567044
Show file tree
Hide file tree
Showing 4 changed files with 386 additions and 135 deletions.
142 changes: 98 additions & 44 deletions flyteadmin/dataproxy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,25 @@ func (s Service) validateResolveArtifactRequest(req *service.GetDataRequest) err
return nil
}

// GetCompleteTaskExecutionID returns the task execution identifier for the task execution with the Task ID filled in.
// The one coming from the node execution doesn't have this as this is not data encapsulated in the flyte url.
func (s Service) GetCompleteTaskExecutionID(ctx context.Context, taskExecID core.TaskExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {

taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, admin.TaskExecutionListRequest{
NodeExecutionId: taskExecID.GetNodeExecutionId(),
Limit: 1,
Filters: fmt.Sprintf("eq(retry_attempt,%s)", strconv.Itoa(int(taskExecID.RetryAttempt))),
})
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to list task executions [%v]. Error: %v", taskExecID, err)
}
if len(taskExecs.TaskExecutions) == 0 {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "no task executions were listed [%v]. Error: %v", taskExecID, err)
}
taskExec := taskExecs.TaskExecutions[0]
return taskExec.Id, nil
}

func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID core.NodeExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {
taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, admin.TaskExecutionListRequest{
NodeExecutionId: &nodeExecID,
Expand All @@ -274,66 +293,61 @@ func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID
return taskExec.Id, nil
}

func (s Service) GetData(ctx context.Context, req *service.GetDataRequest) (
func (s Service) GetDataFromNodeExecution(ctx context.Context, nodeExecID core.NodeExecutionIdentifier, ioType common.ArtifactType, name string) (
*service.GetDataResponse, error) {

logger.Debugf(ctx, "resolving flyte url query: %s", req.GetFlyteUrl())
err := s.validateResolveArtifactRequest(req)
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to validate resolve artifact request. Error: %v", err)
}

nodeExecID, attempt, ioType, err := common.ParseFlyteURL(req.GetFlyteUrl())
resp, err := s.nodeExecutionManager.GetNodeExecutionData(ctx, admin.NodeExecutionGetDataRequest{
Id: &nodeExecID,
})
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to parse artifact url Error: %v", err)
return nil, err
}

// Get the data location, then decide how/where to fetch it from
if attempt == nil {
resp, err := s.nodeExecutionManager.GetNodeExecutionData(ctx, admin.NodeExecutionGetDataRequest{
Id: &nodeExecID,
})
var lm *core.LiteralMap
if ioType == common.ArtifactTypeI {
lm = resp.FullInputs
} else if ioType == common.ArtifactTypeO {
lm = resp.FullOutputs
} else {
// Assume deck, and create a download link request
dlRequest := service.CreateDownloadLinkRequest{
ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK,
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{NodeExecutionId: &nodeExecID},
}
resp, err := s.CreateDownloadLink(ctx, &dlRequest)
if err != nil {
return nil, err
}
return &service.GetDataResponse{
Data: &service.GetDataResponse_PreSignedUrls{
PreSignedUrls: resp.PreSignedUrls,
},
}, nil
}

var lm *core.LiteralMap
if ioType == common.ArtifactTypeI {
lm = resp.FullInputs
} else if ioType == common.ArtifactTypeO {
lm = resp.FullOutputs
} else {
// Assume deck, and create a download link request
dlRequest := service.CreateDownloadLinkRequest{
ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK,
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{NodeExecutionId: &nodeExecID},
}
resp, err := s.CreateDownloadLink(ctx, &dlRequest)
if err != nil {
return nil, err
}
if name != "" {
if literal, ok := lm.Literals[name]; ok {
return &service.GetDataResponse{
Data: &service.GetDataResponse_PreSignedUrls{
PreSignedUrls: resp.PreSignedUrls,
Data: &service.GetDataResponse_Literal{
Literal: literal,
},
}, nil
}

return &service.GetDataResponse{
Data: &service.GetDataResponse_LiteralMap{
LiteralMap: lm,
},
}, nil
}
// Rest of the logic handles task attempt lookups
var lm *core.LiteralMap
taskExecID, err := s.GetTaskExecutionID(ctx, *attempt, nodeExecID)
if err != nil {
return nil, err
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "name [%v] not found in node execution [%v]", name, nodeExecID)
}
return &service.GetDataResponse{
Data: &service.GetDataResponse_LiteralMap{
LiteralMap: lm,
},
}, nil
}

func (s Service) GetDataFromTaskExecution(ctx context.Context, taskExecID core.TaskExecutionIdentifier, ioType common.ArtifactType, name string) (
*service.GetDataResponse, error) {

var lm *core.LiteralMap
reqT := admin.TaskExecutionGetDataRequest{
Id: taskExecID,
Id: &taskExecID,
}
resp, err := s.taskExecutionManager.GetTaskExecutionData(ctx, reqT)
if err != nil {
Expand All @@ -347,11 +361,51 @@ func (s Service) GetData(ctx context.Context, req *service.GetDataRequest) (
} else {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "deck type cannot be specified with a retry attempt, just use the node instead")
}

if name != "" {
if literal, ok := lm.Literals[name]; ok {
return &service.GetDataResponse{
Data: &service.GetDataResponse_Literal{
Literal: literal,
},
}, nil
}
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "name [%v] not found in task execution [%v]", name, taskExecID)
}

return &service.GetDataResponse{
Data: &service.GetDataResponse_LiteralMap{
LiteralMap: lm,
},
}, nil

}

func (s Service) GetData(ctx context.Context, req *service.GetDataRequest) (
*service.GetDataResponse, error) {

logger.Debugf(ctx, "resolving flyte url query: %s", req.GetFlyteUrl())
err := s.validateResolveArtifactRequest(req)
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to validate resolve artifact request. Error: %v", err)
}

execution, err := common.ParseFlyteURLToExecution(req.GetFlyteUrl())
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to parse artifact url Error: %v", err)
}

if execution.NodeExecID != nil {
return s.GetDataFromNodeExecution(ctx, *execution.NodeExecID, execution.IOType, execution.LiteralName)
} else if execution.PartialTaskExecID != nil {
taskExecID, err := s.GetCompleteTaskExecutionID(ctx, *execution.PartialTaskExecID)
if err != nil {
return nil, err
}
return s.GetDataFromTaskExecution(ctx, *taskExecID, execution.IOType, execution.LiteralName)
}

return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to parse get data request %v", req)
}

func NewService(cfg config.DataProxyConfig,
Expand Down
32 changes: 32 additions & 0 deletions flyteadmin/dataproxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,38 @@ func TestService_GetData(t *testing.T) {
})
assert.Error(t, err)
})

t.Run("get individual literal without retry attempt", func(t *testing.T) {
res, err := s.GetData(context.Background(), &service.GetDataRequest{
FlyteUrl: "flyte://v1/proj/dev/wfexecid/n0-d0/i/input",
})
assert.NoError(t, err)
assert.True(t, proto.Equal(inputsLM.GetLiterals()["input"], res.GetLiteral()))
assert.Nil(t, res.GetPreSignedUrls())
})

t.Run("get individual literal with a retry attempt", func(t *testing.T) {
res, err := s.GetData(context.Background(), &service.GetDataRequest{
FlyteUrl: "flyte://v1/proj/dev/wfexecid/n0-d0/5/o/output",
})
assert.NoError(t, err)
assert.True(t, proto.Equal(outputsLM.GetLiterals()["output"], res.GetLiteral()))
assert.Nil(t, res.GetPreSignedUrls())
})

t.Run("error requesting missing name without retry attempt", func(t *testing.T) {
_, err := s.GetData(context.Background(), &service.GetDataRequest{
FlyteUrl: "flyte://v1/proj/dev/wfexecid/n0-d0/i/o5",
})
assert.Error(t, err)
})

t.Run("error requesting missing name with a retry attempt", func(t *testing.T) {
_, err := s.GetData(context.Background(), &service.GetDataRequest{
FlyteUrl: "flyte://v1/proj/dev/wfexecid/n0-d0/5/o/o1",
})
assert.Error(t, err)
})
}

func TestService_Error(t *testing.T) {
Expand Down
88 changes: 72 additions & 16 deletions flyteadmin/pkg/common/flyte_url.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ const (
ArtifactTypeD // deck
)

var re = regexp.MustCompile("flyte://v1/(?P<project>[a-zA-Z0-9_-]+)/(?P<domain>[a-zA-Z0-9_-]+)/(?P<exec>[a-zA-Z0-9_-]+)/(?P<node>[a-zA-Z0-9_-]+)(?:/(?P<attempt>[0-9]+))?/(?P<artifactType>[iod])$")
var re = regexp.MustCompile("flyte://v1/(?P<project>[a-zA-Z0-9_-]+)/(?P<domain>[a-zA-Z0-9_-]+)/(?P<exec>[a-zA-Z0-9_-]+)/(?P<node>[a-zA-Z0-9_-]+)(?:/(?P<attempt>[0-9]+))?/(?P<ioType>[iod])$")

var reSpecificOutput = regexp.MustCompile("flyte://v1/(?P<project>[a-zA-Z0-9_-]+)/(?P<domain>[a-zA-Z0-9_-]+)/(?P<exec>[a-zA-Z0-9_-]+)/(?P<node>[a-zA-Z0-9_-]+)(?:/(?P<attempt>[0-9]+))?/(?P<ioType>[io])/(?P<literalName>[a-zA-Z0-9_-]+)$")

func MatchRegex(reg *regexp.Regexp, input string) map[string]string {
names := reg.SubexpNames()
Expand All @@ -37,36 +39,90 @@ func MatchRegex(reg *regexp.Regexp, input string) map[string]string {
return dict
}

func ParseFlyteURL(flyteURL string) (core.NodeExecutionIdentifier, *int, ArtifactType, error) {
// flyteURL is of the form flyte://v1/project/domain/execution_id/node_id/attempt/[iod]
type ParsedExecution struct {
// Returned when the user does not request a specific attempt
NodeExecID *core.NodeExecutionIdentifier

// This is returned in the case where the user requested a specific attempt. But the TaskID portion of this
// will be empty since that information is not encapsulated in the flyte url.
PartialTaskExecID *core.TaskExecutionIdentifier

// The name of the input or output in the literal map
LiteralName string

IOType ArtifactType
}

func tryMatches(flyteURL string) map[string]string {
var matches map[string]string

if matches = MatchRegex(re, flyteURL); len(matches) > 0 {
return matches
} else if matches = MatchRegex(reSpecificOutput, flyteURL); len(matches) > 0 {
return matches
}
return nil
}

func ParseFlyteURLToExecution(flyteURL string) (ParsedExecution, error) {
// flyteURL can be of the following forms
// flyte://v1/project/domain/execution_id/node_id/attempt/[iod]
// flyte://v1/project/domain/execution_id/node_id/attempt/[io]/output_name
// flyte://v1/project/domain/execution_id/node_id/[io]/output_name
// flyte://v1/project/domain/execution_id/node_id/[iod]

// where i stands for inputs.pb o for outputs.pb and d for the flyte deck
// If the retry attempt is missing, the io requested is assumed to be for the node instead of the task execution
matches := MatchRegex(re, flyteURL)

matches := tryMatches(flyteURL)
if matches == nil {
return ParsedExecution{}, fmt.Errorf("failed to parse [%s]", flyteURL)
}

proj := matches["project"]
domain := matches["domain"]
executionID := matches["exec"]
nodeID := matches["node"]
var attemptPtr *int // nil means node execution, not a task execution
if attempt := matches["attempt"]; len(attempt) > 0 {
a, err := strconv.Atoi(attempt)
if err != nil {
return core.NodeExecutionIdentifier{}, nil, ArtifactTypeUndefined, fmt.Errorf("failed to parse attempt [%v], %v", attempt, err)
}
attemptPtr = &a
}
ioType, err := ArtifactTypeString(matches["artifactType"])
ioType, err := ArtifactTypeString(matches["ioType"])
if err != nil {
return core.NodeExecutionIdentifier{}, nil, ArtifactTypeUndefined, err
return ParsedExecution{}, err
}
literalName := matches["literalName"] // may be empty

return core.NodeExecutionIdentifier{
// node and task level outputs
nodeExecID := core.NodeExecutionIdentifier{
NodeId: nodeID,
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: proj,
Domain: domain,
Name: executionID,
},
}, attemptPtr, ioType, nil
}

// if attempt is there, that means a task execution
if attempt := matches["attempt"]; len(attempt) > 0 {
a, err := strconv.Atoi(attempt)
if err != nil {
return ParsedExecution{}, fmt.Errorf("failed to parse attempt [%v], %v", attempt, err)
}
taskExecID := core.TaskExecutionIdentifier{
NodeExecutionId: &nodeExecID,
// checking for overflow here is probably unreasonable
RetryAttempt: uint32(a),
}
return ParsedExecution{
PartialTaskExecID: &taskExecID,
IOType: ioType,
LiteralName: literalName,
}, nil
}

return ParsedExecution{
NodeExecID: &nodeExecID,
IOType: ioType,
LiteralName: literalName,
}, nil

}

func FlyteURLsFromNodeExecutionID(nodeExecutionID core.NodeExecutionIdentifier, deck bool) *admin.FlyteURLs {
Expand Down
Loading

0 comments on commit 6567044

Please sign in to comment.