Skip to content

Commit

Permalink
Address resolution (flyteorg#546)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored May 5, 2023
1 parent 1026231 commit 950ad15
Show file tree
Hide file tree
Showing 13 changed files with 632 additions and 15 deletions.
131 changes: 128 additions & 3 deletions dataproxy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ import (
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"time"

"github.com/flyteorg/flyteadmin/pkg/common"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/logger"

"github.com/flyteorg/flyteadmin/pkg/errors"
"google.golang.org/grpc/codes"

Expand Down Expand Up @@ -37,6 +44,7 @@ type Service struct {
dataStore *storage.DataStore
shardSelector ioutils.ShardSelector
nodeExecutionManager interfaces.NodeExecutionInterface
taskExecutionManager interfaces.TaskExecutionInterface
}

// CreateUploadLocation creates a temporary signed url to allow callers to upload content.
Expand Down Expand Up @@ -133,9 +141,17 @@ func (s Service) CreateDownloadLink(ctx context.Context, req *service.CreateDown
return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create a signed url. Error: %v", err)
}

u := []string{signedURLResp.URL.String()}
ts := timestamppb.New(time.Now().Add(req.ExpiresIn.AsDuration()))

//
return &service.CreateDownloadLinkResponse{
SignedUrl: []string{signedURLResp.URL.String()},
ExpiresAt: timestamppb.New(time.Now().Add(req.ExpiresIn.AsDuration())),
SignedUrl: u,
ExpiresAt: ts,
PreSignedUrls: &service.PreSignedURLs{
SignedUrl: []string{signedURLResp.URL.String()},
ExpiresAt: ts,
},
}, nil
}

Expand Down Expand Up @@ -231,9 +247,117 @@ func createStorageLocation(ctx context.Context, store *storage.DataStore,
return storagePath, nil
}

func (s Service) validateResolveArtifactRequest(req *service.GetDataRequest) error {
if len(req.GetFlyteUrl()) == 0 {
return fmt.Errorf("source is required. Provided empty string")
}
if !strings.HasPrefix(req.GetFlyteUrl(), "flyte://") {
return fmt.Errorf("request does not start with the correct prefix")
}

return nil
}

func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID core.NodeExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {
taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, admin.TaskExecutionListRequest{
NodeExecutionId: &nodeExecID,
Limit: 1,
Filters: fmt.Sprintf("eq(retry_attempt,%s)", strconv.Itoa(attempt)),
})
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to list task executions [%v]. Error: %v", nodeExecID, err)
}
if len(taskExecs.TaskExecutions) == 0 {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "no task executions were listed [%v]. Error: %v", nodeExecID, err)
}
taskExec := taskExecs.TaskExecutions[0]
return taskExec.Id, nil
}

func (s Service) GetData(ctx context.Context, req *service.GetDataRequest) (
*service.GetDataResponse, error) {

logger.Debugf(ctx, "resolving flyte url query: %s", req.GetFlyteUrl())
err := s.validateResolveArtifactRequest(req)
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to validate resolve artifact request. Error: %v", err)
}

nodeExecID, attempt, ioType, err := common.ParseFlyteURL(req.GetFlyteUrl())
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to parse artifact url Error: %v", err)
}

// Get the data location, then decide how/where to fetch it from
if attempt == nil {
resp, err := s.nodeExecutionManager.GetNodeExecutionData(ctx, admin.NodeExecutionGetDataRequest{
Id: &nodeExecID,
})
if err != nil {
return nil, err
}

var lm *core.LiteralMap
if ioType == common.ArtifactTypeI {
lm = resp.FullInputs
} else if ioType == common.ArtifactTypeO {
lm = resp.FullOutputs
} else {
// Assume deck, and create a download link request
dlRequest := service.CreateDownloadLinkRequest{
ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK,
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{NodeExecutionId: &nodeExecID},
}
resp, err := s.CreateDownloadLink(ctx, &dlRequest)
if err != nil {
return nil, err
}
return &service.GetDataResponse{
Data: &service.GetDataResponse_PreSignedUrls{
PreSignedUrls: resp.PreSignedUrls,
},
}, nil
}

return &service.GetDataResponse{
Data: &service.GetDataResponse_LiteralMap{
LiteralMap: lm,
},
}, nil
}
// Rest of the logic handles task attempt lookups
var lm *core.LiteralMap
taskExecID, err := s.GetTaskExecutionID(ctx, *attempt, nodeExecID)
if err != nil {
return nil, err
}

reqT := admin.TaskExecutionGetDataRequest{
Id: taskExecID,
}
resp, err := s.taskExecutionManager.GetTaskExecutionData(ctx, reqT)
if err != nil {
return nil, err
}

if ioType == common.ArtifactTypeI {
lm = resp.FullInputs
} else if ioType == common.ArtifactTypeO {
lm = resp.FullOutputs
} else {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "deck type cannot be specified with a retry attempt, just use the node instead")
}
return &service.GetDataResponse{
Data: &service.GetDataResponse_LiteralMap{
LiteralMap: lm,
},
}, nil
}

func NewService(cfg config.DataProxyConfig,
nodeExec interfaces.NodeExecutionInterface,
dataStore *storage.DataStore) (Service, error) {
dataStore *storage.DataStore,
taskExec interfaces.TaskExecutionInterface) (Service, error) {

// Context is not used in the constructor. Should ideally be removed.
selector, err := ioutils.NewBase36PrefixShardSelector(context.TODO())
Expand All @@ -246,5 +370,6 @@ func NewService(cfg config.DataProxyConfig,
dataStore: dataStore,
shardSelector: selector,
nodeExecutionManager: nodeExec,
taskExecutionManager: taskExec,
}, nil
}
169 changes: 163 additions & 6 deletions dataproxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/golang/protobuf/proto"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"

Expand All @@ -14,10 +15,10 @@ import (
commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks"
stdlibConfig "github.com/flyteorg/flytestdlib/config"

"google.golang.org/protobuf/types/known/durationpb"

"github.com/flyteorg/flyteadmin/pkg/errors"
"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"
"google.golang.org/protobuf/types/known/durationpb"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"

Expand All @@ -32,9 +33,10 @@ func TestNewService(t *testing.T) {
assert.NoError(t, err)

nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
s, err := NewService(config.DataProxyConfig{
Upload: config.DataProxyUploadConfig{},
}, nodeExecutionManager, dataStore)
}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)
assert.NotNil(t, s)
}
Expand All @@ -57,7 +59,8 @@ func TestCreateUploadLocation(t *testing.T) {
dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore)
taskExecutionManager := &mocks.MockTaskExecutionManager{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)
t.Run("No project/domain", func(t *testing.T) {
_, err = s.CreateUploadLocation(context.Background(), &service.CreateUploadLocationRequest{})
Expand Down Expand Up @@ -92,8 +95,9 @@ func TestCreateDownloadLink(t *testing.T) {
},
}, nil
})
taskExecutionManager := &mocks.MockTaskExecutionManager{}

s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore)
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

t.Run("Invalid expiry", func(t *testing.T) {
Expand Down Expand Up @@ -128,7 +132,8 @@ func TestCreateDownloadLink(t *testing.T) {
func TestCreateDownloadLocation(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore)
taskExecutionManager := &mocks.MockTaskExecutionManager{}
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

t.Run("Invalid expiry", func(t *testing.T) {
Expand Down Expand Up @@ -161,3 +166,155 @@ func TestCreateDownloadLocation(t *testing.T) {
assert.NoError(t, err)
})
}

func TestService_GetData(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

inputsLM := &core.LiteralMap{
Literals: map[string]*core.Literal{
"input": {
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_StringValue{
StringValue: "hello",
},
},
},
},
},
},
},
}
outputsLM := &core.LiteralMap{
Literals: map[string]*core.Literal{
"output": {
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_StringValue{
StringValue: "world",
},
},
},
},
},
},
},
}

nodeExecutionManager.SetGetNodeExecutionDataFunc(
func(ctx context.Context, request admin.NodeExecutionGetDataRequest) (*admin.NodeExecutionGetDataResponse, error) {
return &admin.NodeExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
}, nil
},
)
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: []*admin.TaskExecution{
{
Id: &core.TaskExecutionIdentifier{
TaskId: &core.Identifier{
ResourceType: core.ResourceType_TASK,
Project: "proj",
Domain: "dev",
Name: "task",
Version: "v1",
},
NodeExecutionId: &core.NodeExecutionIdentifier{
NodeId: "n0",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "proj",
Domain: "dev",
Name: "wfexecid",
},
},
RetryAttempt: 5,
},
},
},
}, nil
})
taskExecutionManager.SetGetTaskExecutionDataCallback(func(ctx context.Context, request admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
return &admin.TaskExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
}, nil
})

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
res, err := s.GetData(context.Background(), &service.GetDataRequest{
FlyteUrl: "flyte://v1/proj/dev/wfexecid/n0-d0/i",
})
assert.NoError(t, err)
assert.True(t, proto.Equal(inputsLM, res.GetLiteralMap()))
assert.Nil(t, res.GetPreSignedUrls())
})

t.Run("get a working set of urls with a retry attempt", func(t *testing.T) {
res, err := s.GetData(context.Background(), &service.GetDataRequest{
FlyteUrl: "flyte://v1/proj/dev/wfexecid/n0-d0/5/o",
})
assert.NoError(t, err)
assert.True(t, proto.Equal(outputsLM, res.GetLiteralMap()))
assert.Nil(t, res.GetPreSignedUrls())
})

t.Run("Bad URL", func(t *testing.T) {
_, err = s.GetData(context.Background(), &service.GetDataRequest{
FlyteUrl: "flyte://v3/blah/lorem/m0-fdj",
})
assert.Error(t, err)
})
}

func TestService_Error(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return nil, errors.NewFlyteAdminErrorf(1, "not found")
})
nodeExecID := core.NodeExecutionIdentifier{
NodeId: "n0",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "proj",
Domain: "dev",
Name: "wfexecid",
},
}
_, err := s.GetTaskExecutionID(context.Background(), 0, nodeExecID)
assert.Error(t, err, "failed to list")
})

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: nil,
Token: "",
}, nil
})
nodeExecID := core.NodeExecutionIdentifier{
NodeId: "n0",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "proj",
Domain: "dev",
Name: "wfexecid",
},
}
_, err := s.GetTaskExecutionID(context.Background(), 0, nodeExecID)
assert.Error(t, err, "no task executions")
})
}
Loading

0 comments on commit 950ad15

Please sign in to comment.