Skip to content

Commit

Permalink
Implement CreateDownloadLink (flyteorg#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
EngHabu authored Nov 17, 2022
1 parent 1398328 commit ec98637
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 34 deletions.
145 changes: 116 additions & 29 deletions flyteadmin/dataproxy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@ import (
"encoding/base64"
"fmt"
"net/url"
"reflect"
"time"

"github.com/flyteorg/flyteadmin/pkg/errors"
"google.golang.org/grpc/codes"

"github.com/flyteorg/flyteadmin/pkg/manager/interfaces"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"

"google.golang.org/protobuf/types/known/durationpb"
Expand All @@ -26,30 +33,31 @@ import (
type Service struct {
service.DataProxyServiceServer

cfg config.DataProxyConfig
dataStore *storage.DataStore
shardSelector ioutils.ShardSelector
cfg config.DataProxyConfig
dataStore *storage.DataStore
shardSelector ioutils.ShardSelector
nodeExecutionManager interfaces.NodeExecutionInterface
}

// CreateUploadLocation creates a temporary signed url to allow callers to upload content.
func (s Service) CreateUploadLocation(ctx context.Context, req *service.CreateUploadLocationRequest) (
*service.CreateUploadLocationResponse, error) {

if len(req.Project) == 0 || len(req.Domain) == 0 {
return nil, fmt.Errorf("prjoect and domain are required parameters")
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "project and domain are required parameters")
}

if len(req.ContentMd5) == 0 {
return nil, fmt.Errorf("content_md5 is a required parameter")
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "content_md5 is a required parameter")
}

if expiresIn := req.ExpiresIn; expiresIn != nil {
if !expiresIn.IsValid() {
return nil, fmt.Errorf("expiresIn [%v] is invalid", expiresIn)
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "expiresIn [%v] is invalid", expiresIn)
}

if expiresIn.AsDuration() > s.cfg.Upload.MaxExpiresIn.Duration {
return nil, fmt.Errorf("expiresIn [%v] cannot exceed max allowed expiration [%v]",
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "expiresIn [%v] cannot exceed max allowed expiration [%v]",
expiresIn.AsDuration().String(), s.cfg.Upload.MaxExpiresIn.String())
}
} else {
Expand All @@ -66,7 +74,7 @@ func (s Service) CreateUploadLocation(ctx context.Context, req *service.CreateUp
storagePath, err := createStorageLocation(ctx, s.dataStore, s.cfg.Upload,
req.Project, req.Domain, urlSafeMd5, req.Filename)
if err != nil {
return nil, err
return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create shardedStorageLocation, Error: %v", err)
}

resp, err := s.dataStore.CreateSignedURL(ctx, storagePath, storage.SignedURLProperties{
Expand All @@ -76,7 +84,7 @@ func (s Service) CreateUploadLocation(ctx context.Context, req *service.CreateUp
})

if err != nil {
return nil, fmt.Errorf("failed to create a signed url. Error: %w", err)
return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create a signed url. Error: %v", err)
}

return &service.CreateUploadLocationResponse{
Expand All @@ -86,12 +94,57 @@ func (s Service) CreateUploadLocation(ctx context.Context, req *service.CreateUp
}, nil
}

// CreateDownloadLink retrieves the requested artifact type for a given execution (wf, node, task) as a signed url(s).
func (s Service) CreateDownloadLink(ctx context.Context, req *service.CreateDownloadLinkRequest) (
resp *service.CreateDownloadLinkResponse, err error) {
if req, err = s.validateCreateDownloadLinkRequest(req); err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "error while validating request. Error: %v", err)
}

// Lookup task, node, workflow execution
var nativeURL string
if nodeExecutionIDEnvelope, casted := req.GetSource().(*service.CreateDownloadLinkRequest_NodeExecutionId); casted {
node, err := s.nodeExecutionManager.GetNodeExecution(ctx, admin.NodeExecutionGetRequest{
Id: nodeExecutionIDEnvelope.NodeExecutionId,
})

if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to find node execution [%v]. Error: %v", nodeExecutionIDEnvelope.NodeExecutionId, err)
}

switch req.GetArtifactType() {
case service.ArtifactType_ARTIFACT_TYPE_DECK:
nativeURL = node.Closure.DeckUri
}
} else {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "unsupported source [%v]", reflect.TypeOf(req.GetSource()))
}

if len(nativeURL) == 0 {
return nil, errors.NewFlyteAdminErrorf(codes.Internal, "no deckUrl found for request [%+v]", req)
}

signedURLResp, err := s.dataStore.CreateSignedURL(ctx, storage.DataReference(nativeURL), storage.SignedURLProperties{
Scope: stow.ClientMethodGet,
ExpiresIn: req.ExpiresIn.AsDuration(),
})

if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create a signed url. Error: %v", err)
}

return &service.CreateDownloadLinkResponse{
SignedUrl: []string{signedURLResp.URL.String()},
ExpiresAt: timestamppb.New(time.Now().Add(req.ExpiresIn.AsDuration())),
}, nil
}

// CreateDownloadLocation creates a temporary signed url to allow callers to download content.
func (s Service) CreateDownloadLocation(ctx context.Context, req *service.CreateDownloadLocationRequest) (
*service.CreateDownloadLocationResponse, error) {

if err := s.validateCreateDownloadLocationRequest(req); err != nil {
return nil, err
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "error while validating request: %v", err)
}

resp, err := s.dataStore.CreateSignedURL(ctx, storage.DataReference(req.NativeUrl), storage.SignedURLProperties{
Expand All @@ -100,7 +153,7 @@ func (s Service) CreateDownloadLocation(ctx context.Context, req *service.Create
})

if err != nil {
return nil, fmt.Errorf("failed to create a signed url. Error: %w", err)
return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create a signed url. Error: %v", err)
}

return &service.CreateDownloadLocationResponse{
Expand All @@ -110,22 +163,13 @@ func (s Service) CreateDownloadLocation(ctx context.Context, req *service.Create
}

func (s Service) validateCreateDownloadLocationRequest(req *service.CreateDownloadLocationRequest) error {
if expiresIn := req.ExpiresIn; expiresIn != nil {
if !expiresIn.IsValid() {
return fmt.Errorf("expiresIn [%v] is invalid", expiresIn)
}

if expiresIn.AsDuration() < 0 {
return fmt.Errorf("expiresIn [%v] should not less than 0",
expiresIn.AsDuration().String())
} else if expiresIn.AsDuration() > s.cfg.Download.MaxExpiresIn.Duration {
return fmt.Errorf("expiresIn [%v] cannot exceed max allowed expiration [%v]",
expiresIn.AsDuration().String(), s.cfg.Download.MaxExpiresIn.String())
}
} else {
req.ExpiresIn = durationpb.New(s.cfg.Download.MaxExpiresIn.Duration)
validatedExpiresIn, err := validateDuration(req.ExpiresIn, s.cfg.Download.MaxExpiresIn.Duration)
if err != nil {
return fmt.Errorf("expiresIn is invalid. Error: %w", err)
}

req.ExpiresIn = validatedExpiresIn

if _, err := url.Parse(req.NativeUrl); err != nil {
return fmt.Errorf("failed to parse native_url [%v]",
req.NativeUrl)
Expand All @@ -134,6 +178,45 @@ func (s Service) validateCreateDownloadLocationRequest(req *service.CreateDownlo
return nil
}

func validateDuration(input *durationpb.Duration, maxAllowed time.Duration) (*durationpb.Duration, error) {
if input == nil {
return durationpb.New(maxAllowed), nil
}

if !input.IsValid() {
return nil, fmt.Errorf("input duration [%v] is invalid", input)
}

if input.AsDuration() < 0 {
return nil, fmt.Errorf("input duration [%v] should not less than 0",
input.AsDuration().String())
} else if input.AsDuration() > maxAllowed {
return nil, fmt.Errorf("input duration [%v] cannot exceed max allowed expiration [%v]",
input.AsDuration(), maxAllowed)
}

return input, nil
}

func (s Service) validateCreateDownloadLinkRequest(req *service.CreateDownloadLinkRequest) (*service.CreateDownloadLinkRequest, error) {
validatedExpiresIn, err := validateDuration(req.ExpiresIn, s.cfg.Download.MaxExpiresIn.Duration)
if err != nil {
return nil, fmt.Errorf("expiresIn is invalid. Error: %w", err)
}

req.ExpiresIn = validatedExpiresIn

if req.GetArtifactType() == service.ArtifactType_ARTIFACT_TYPE_UNDEFINED {
return nil, fmt.Errorf("invalid artifact type [%v]", req.GetArtifactType())
}

if req.GetSource() == nil {
return nil, fmt.Errorf("source is required. Provided nil")
}

return req, nil
}

// createStorageLocation creates a location in storage destination to maximize read/write performance in most
// block stores. The final location should look something like: s3://<my bucket>/<file name>
func createStorageLocation(ctx context.Context, store *storage.DataStore,
Expand All @@ -148,16 +231,20 @@ func createStorageLocation(ctx context.Context, store *storage.DataStore,
return storagePath, nil
}

func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore) (Service, error) {
func NewService(cfg config.DataProxyConfig,
nodeExec interfaces.NodeExecutionInterface,
dataStore *storage.DataStore) (Service, error) {

// Context is not used in the constructor. Should ideally be removed.
selector, err := ioutils.NewBase36PrefixShardSelector(context.TODO())
if err != nil {
return Service{}, err
}

return Service{
cfg: cfg,
dataStore: dataStore,
shardSelector: selector,
cfg: cfg,
dataStore: dataStore,
shardSelector: selector,
nodeExecutionManager: nodeExec,
}, nil
}
59 changes: 56 additions & 3 deletions flyteadmin/dataproxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ import (
"testing"
"time"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"

"github.com/flyteorg/flyteadmin/pkg/manager/mocks"

commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks"
stdlibConfig "github.com/flyteorg/flytestdlib/config"

Expand All @@ -24,9 +30,11 @@ import (
func TestNewService(t *testing.T) {
dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

nodeExecutionManager := &mocks.MockNodeExecutionManager{}
s, err := NewService(config.DataProxyConfig{
Upload: config.DataProxyUploadConfig{},
}, dataStore)
}, nodeExecutionManager, dataStore)
assert.NoError(t, err)
assert.NotNil(t, s)
}
Expand All @@ -48,7 +56,8 @@ func Test_createStorageLocation(t *testing.T) {
func TestCreateUploadLocation(t *testing.T) {
dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)
s, err := NewService(config.DataProxyConfig{}, dataStore)
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore)
assert.NoError(t, err)
t.Run("No project/domain", func(t *testing.T) {
_, err = s.CreateUploadLocation(context.Background(), &service.CreateUploadLocationRequest{})
Expand All @@ -73,9 +82,53 @@ func TestCreateUploadLocation(t *testing.T) {
})
}

func TestCreateDownloadLink(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
nodeExecutionManager.SetGetNodeExecutionFunc(func(ctx context.Context, request admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) {
return &admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
DeckUri: "s3://something/something",
},
}, nil
})

s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore)
assert.NoError(t, err)

t.Run("Invalid expiry", func(t *testing.T) {
_, err = s.CreateDownloadLink(context.Background(), &service.CreateDownloadLinkRequest{
ExpiresIn: durationpb.New(-time.Hour),
})
assert.Error(t, err)
})

t.Run("valid config", func(t *testing.T) {
_, err = s.CreateDownloadLink(context.Background(), &service.CreateDownloadLinkRequest{
ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK,
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{
NodeExecutionId: &core.NodeExecutionIdentifier{},
},
ExpiresIn: durationpb.New(time.Hour),
})
assert.NoError(t, err)
})

t.Run("use default ExpiresIn", func(t *testing.T) {
_, err = s.CreateDownloadLink(context.Background(), &service.CreateDownloadLinkRequest{
ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK,
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{
NodeExecutionId: &core.NodeExecutionIdentifier{},
},
})
assert.NoError(t, err)
})
}

func TestCreateDownloadLocation(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, dataStore)
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore)
assert.NoError(t, err)

t.Run("Invalid expiry", func(t *testing.T) {
Expand Down
5 changes: 3 additions & 2 deletions flyteadmin/pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,14 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
}

configuration := runtime2.NewConfigurationProvider()
service.RegisterAdminServiceServer(grpcServer, adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, scope.NewSubScope("admin")))
adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, scope.NewSubScope("admin"))
service.RegisterAdminServiceServer(grpcServer, adminServer)
if cfg.Security.UseAuth {
service.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService())
service.RegisterIdentityServiceServer(grpcServer, authCtx.IdentityService())
}

dataProxySvc, err := dataproxy.NewService(cfg.DataProxy, dataStorageClient)
dataProxySvc, err := dataproxy.NewService(cfg.DataProxy, adminServer.NodeExecutionManager, dataStorageClient)
if err != nil {
return nil, fmt.Errorf("failed to initialize dataProxy service. Error: %w", err)
}
Expand Down

0 comments on commit ec98637

Please sign in to comment.