diff --git a/flyteadmin/dataproxy/service.go b/flyteadmin/dataproxy/service.go index e4932af54..7a8d689ef 100644 --- a/flyteadmin/dataproxy/service.go +++ b/flyteadmin/dataproxy/service.go @@ -5,6 +5,7 @@ import ( "encoding/base32" "encoding/base64" "fmt" + "net/url" "strings" "time" @@ -86,6 +87,54 @@ func (s Service) CreateUploadLocation(ctx context.Context, req *service.CreateUp }, nil } +// CreateDownloadLocation creates a temporary signed url to allow callers to download content. +func (s Service) CreateDownloadLocation(ctx context.Context, req *service.CreateDownloadLocationRequest) ( + *service.CreateDownloadLocationResponse, error) { + + if err := s.validateCreateDownloadLocationRequest(req); err != nil { + return nil, err + } + + resp, err := s.dataStore.CreateSignedURL(ctx, storage.DataReference(req.NativeUrl), storage.SignedURLProperties{ + Scope: stow.ClientMethodGet, + ExpiresIn: req.ExpiresIn.AsDuration(), + }) + + if err != nil { + return nil, fmt.Errorf("failed to create a signed url. Error: %w", err) + } + + return &service.CreateDownloadLocationResponse{ + SignedUrl: resp.URL.String(), + ExpiresAt: timestamppb.New(time.Now().Add(req.ExpiresIn.AsDuration())), + }, nil +} + +func (s Service) validateCreateDownloadLocationRequest(req *service.CreateDownloadLocationRequest) error { + if expiresIn := req.ExpiresIn; expiresIn != nil { + if !expiresIn.IsValid() { + return fmt.Errorf("expiresIn [%v] is invalid", expiresIn) + } + + if expiresIn.AsDuration() < 0 { + return fmt.Errorf("expiresIn [%v] should not less than 0", + expiresIn.AsDuration().String()) + } else if expiresIn.AsDuration() > s.cfg.Download.MaxExpiresIn.Duration { + return fmt.Errorf("expiresIn [%v] cannot exceed max allowed expiration [%v]", + expiresIn.AsDuration().String(), s.cfg.Download.MaxExpiresIn.String()) + } + } else { + req.ExpiresIn = durationpb.New(s.cfg.Download.MaxExpiresIn.Duration) + } + + if _, err := url.Parse(req.NativeUrl); err != nil { + return fmt.Errorf("failed to parse native_url [%v]", + req.NativeUrl) + } + + return nil +} + // createShardedStorageLocation creates a location in storage destination to maximize read/write performance in most // block stores. The final location should look something like: s3://// func createShardedStorageLocation(ctx context.Context, shardSelector ioutils.ShardSelector, store *storage.DataStore, diff --git a/flyteadmin/dataproxy/service_test.go b/flyteadmin/dataproxy/service_test.go index ed5f09f92..074c052e9 100644 --- a/flyteadmin/dataproxy/service_test.go +++ b/flyteadmin/dataproxy/service_test.go @@ -5,6 +5,9 @@ import ( "testing" "time" + commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" + stdlibConfig "github.com/flyteorg/flytestdlib/config" + "google.golang.org/protobuf/types/known/durationpb" "github.com/flyteorg/flytestdlib/contextutils" @@ -72,3 +75,39 @@ func TestCreateUploadLocation(t *testing.T) { assert.Error(t, err) }) } + +func TestCreateDownloadLocation(t *testing.T) { + dataStore := commonMocks.GetMockStorageClient() + s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, dataStore) + assert.NoError(t, err) + + t.Run("Invalid expiry", func(t *testing.T) { + _, err = s.CreateDownloadLocation(context.Background(), &service.CreateDownloadLocationRequest{ + NativeUrl: "s3://bucket/key", + ExpiresIn: durationpb.New(-time.Hour), + }) + assert.Error(t, err) + }) + + t.Run("valid config", func(t *testing.T) { + _, err = s.CreateDownloadLocation(context.Background(), &service.CreateDownloadLocationRequest{ + NativeUrl: "s3://bucket/key", + ExpiresIn: durationpb.New(time.Hour), + }) + assert.NoError(t, err) + }) + + t.Run("use default ExpiresIn", func(t *testing.T) { + _, err = s.CreateDownloadLocation(context.Background(), &service.CreateDownloadLocationRequest{ + NativeUrl: "s3://bucket/key", + }) + assert.NoError(t, err) + }) + + t.Run("invalid URL", func(t *testing.T) { + _, err = s.CreateDownloadLocation(context.Background(), &service.CreateDownloadLocationRequest{ + NativeUrl: "bucket/key", + }) + assert.NoError(t, err) + }) +} diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index d759c0eff..9d62a1c73 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -13,7 +13,7 @@ require ( github.com/cloudevents/sdk-go/v2 v2.8.0 github.com/coreos/go-oidc v2.2.1+incompatible github.com/evanphx/json-patch v4.9.0+incompatible - github.com/flyteorg/flyteidl v1.1.0 + github.com/flyteorg/flyteidl v1.1.4 github.com/flyteorg/flyteplugins v1.0.0 github.com/flyteorg/flytepropeller v1.1.3 github.com/flyteorg/flytestdlib v1.0.2 diff --git a/flyteadmin/go.sum b/flyteadmin/go.sum index 866eee43f..9e6937269 100644 --- a/flyteadmin/go.sum +++ b/flyteadmin/go.sum @@ -380,8 +380,8 @@ github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4 github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flyteorg/flyteidl v1.0.0/go.mod h1:JW0z1ZaHS9zWvDAwSMIyGhsf+V4zrzBBgh5IuqzMFCM= -github.com/flyteorg/flyteidl v1.1.0 h1:f8tdMXOuorS/d+4Ut2QarfDbdCOriK0S+EnlQzrwz9E= -github.com/flyteorg/flyteidl v1.1.0/go.mod h1:JW0z1ZaHS9zWvDAwSMIyGhsf+V4zrzBBgh5IuqzMFCM= +github.com/flyteorg/flyteidl v1.1.4 h1:P6YgFYcmBxoLcTegv301i5oYKBCvjHGW3ujRT9s4dvI= +github.com/flyteorg/flyteidl v1.1.4/go.mod h1:f1tvw5CDjqmrzNxKpRYr6BdAhHL8f7Wp1Duxl0ZOV4g= github.com/flyteorg/flyteplugins v1.0.0 h1:77hUJjiIxBmQ9rd3+cXjSGnzOVAFrSzCd59aIaYFB/8= github.com/flyteorg/flyteplugins v1.0.0/go.mod h1:4Cpn+9RfanIieTTh2XsuL6zPYXtsR5UDe8YaEmXONT4= github.com/flyteorg/flytepropeller v1.1.3 h1:RuS/mkbEhjGyUy2XIs7sHOaio1BK8TUZMGKiIN0/pqE= diff --git a/flyteadmin/pkg/common/mocks/storage.go b/flyteadmin/pkg/common/mocks/storage.go index 9993e0d66..9edc5c1ed 100644 --- a/flyteadmin/pkg/common/mocks/storage.go +++ b/flyteadmin/pkg/common/mocks/storage.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/url" "strings" "github.com/flyteorg/flytestdlib/storage" @@ -44,7 +45,11 @@ func (t *TestDataStore) GetBaseContainerFQN(ctx context.Context) storage.DataRef } func (t *TestDataStore) CreateSignedURL(ctx context.Context, reference storage.DataReference, properties storage.SignedURLProperties) (storage.SignedURLResponse, error) { - return storage.SignedURLResponse{}, fmt.Errorf("unsupported") + signedURL, err := url.Parse(reference.String()) + if err != nil { + return storage.SignedURLResponse{}, err + } + return storage.SignedURLResponse{URL: *signedURL}, nil } // Retrieves a byte array from the Blob store or an error diff --git a/flyteadmin/pkg/config/config.go b/flyteadmin/pkg/config/config.go index 95d094a8b..b2388b9c1 100644 --- a/flyteadmin/pkg/config/config.go +++ b/flyteadmin/pkg/config/config.go @@ -29,7 +29,12 @@ type ServerConfig struct { } type DataProxyConfig struct { - Upload DataProxyUploadConfig `json:"upload" pflag:",Defines data proxy upload configuration."` + Upload DataProxyUploadConfig `json:"upload" pflag:",Defines data proxy upload configuration."` + Download DataProxyDownloadConfig `json:"download" pflag:",Defines data proxy download configuration."` +} + +type DataProxyDownloadConfig struct { + MaxExpiresIn config.Duration `json:"maxExpiresIn" pflag:",Maximum allowed expiration duration."` } type DataProxyUploadConfig struct { @@ -86,6 +91,9 @@ var defaultServerConfig = &ServerConfig{ MaxExpiresIn: config.Duration{Duration: time.Hour}, DefaultFileNameLength: 20, }, + Download: DataProxyDownloadConfig{ + MaxExpiresIn: config.Duration{Duration: time.Hour}, + }, }, } var serverConfig = config.MustRegisterSection(SectionKey, defaultServerConfig) diff --git a/flyteadmin/pkg/config/serverconfig_flags.go b/flyteadmin/pkg/config/serverconfig_flags.go index 5bb8a48a9..fa2ee145f 100755 --- a/flyteadmin/pkg/config/serverconfig_flags.go +++ b/flyteadmin/pkg/config/serverconfig_flags.go @@ -73,5 +73,6 @@ func (cfg ServerConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dataProxy.upload.maxExpiresIn"), defaultServerConfig.DataProxy.Upload.MaxExpiresIn.String(), "Maximum allowed expiration duration.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "dataProxy.upload.defaultFileNameLength"), defaultServerConfig.DataProxy.Upload.DefaultFileNameLength, "Default length for the generated file name if not provided in the request.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dataProxy.upload.storagePrefix"), defaultServerConfig.DataProxy.Upload.StoragePrefix, "Storage prefix to use for all upload requests.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dataProxy.download.maxExpiresIn"), defaultServerConfig.DataProxy.Download.MaxExpiresIn.String(), "Maximum allowed expiration duration.") return cmdFlags } diff --git a/flyteadmin/pkg/config/serverconfig_flags_test.go b/flyteadmin/pkg/config/serverconfig_flags_test.go index 699a656c5..bd3bb83e9 100755 --- a/flyteadmin/pkg/config/serverconfig_flags_test.go +++ b/flyteadmin/pkg/config/serverconfig_flags_test.go @@ -421,4 +421,18 @@ func TestServerConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_dataProxy.download.maxExpiresIn", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultServerConfig.DataProxy.Download.MaxExpiresIn.String() + + cmdFlags.Set("dataProxy.download.maxExpiresIn", testValue) + if vString, err := cmdFlags.GetString("dataProxy.download.maxExpiresIn"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vString), &actual.DataProxy.Download.MaxExpiresIn) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flyteadmin/pkg/manager/impl/node_execution_manager_test.go b/flyteadmin/pkg/manager/impl/node_execution_manager_test.go index 945612bca..7e243101e 100644 --- a/flyteadmin/pkg/manager/impl/node_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/node_execution_manager_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" + genModel "github.com/flyteorg/flyteadmin/pkg/repositories/gen/models" eventWriterMocks "github.com/flyteorg/flyteadmin/pkg/async/events/mocks" @@ -1192,8 +1194,9 @@ func TestGetNodeExecutionData(t *testing.T) { expectedClosure := admin.NodeExecutionClosure{ Phase: core.NodeExecution_SUCCEEDED, OutputResult: &admin.NodeExecutionClosure_OutputUri{ - OutputUri: "output uri", + OutputUri: util.OutputsFile, }, + DeckUri: util.DeckFile, } dynamicWorkflowClosureRef := "s3://my-s3-bucket/foo/bar/dynamic.pb" @@ -1233,7 +1236,7 @@ func TestGetNodeExecutionData(t *testing.T) { Url: "inputs", Bytes: 100, }, nil - } else if uri == "output uri" { + } else if uri == util.OutputsFile { return admin.UrlBlob{ Url: "outputs", Bytes: 200, @@ -1260,7 +1263,7 @@ func TestGetNodeExecutionData(t *testing.T) { marshalled, _ := proto.Marshal(fullInputs) _ = proto.Unmarshal(marshalled, msg) return nil - } else if reference.String() == "output uri" { + } else if reference.String() == util.OutputsFile { marshalled, _ := proto.Marshal(fullOutputs) _ = proto.Unmarshal(marshalled, msg) return nil diff --git a/flyteadmin/pkg/manager/impl/util/data.go b/flyteadmin/pkg/manager/impl/util/data.go index efa941e89..cd281b81e 100644 --- a/flyteadmin/pkg/manager/impl/util/data.go +++ b/flyteadmin/pkg/manager/impl/util/data.go @@ -13,6 +13,11 @@ import ( "github.com/golang/protobuf/proto" ) +const ( + OutputsFile = "outputs.pb" + DeckFile = "deck.html" +) + func shouldFetchData(config *runtimeInterfaces.RemoteDataConfig, urlBlob admin.UrlBlob) bool { return config.Scheme == common.Local || config.Scheme == common.None || config.MaxSizeInBytes == 0 || urlBlob.Bytes < config.MaxSizeInBytes diff --git a/flyteadmin/pkg/manager/impl/util/data_test.go b/flyteadmin/pkg/manager/impl/util/data_test.go index 2c3409a67..69f41567f 100644 --- a/flyteadmin/pkg/manager/impl/util/data_test.go +++ b/flyteadmin/pkg/manager/impl/util/data_test.go @@ -154,7 +154,6 @@ func TestGetInputs(t *testing.T) { assert.True(t, proto.Equal(fullInputs, testLiteralMap)) assert.Empty(t, inputURLBlob) }) - } func TestGetOutputs(t *testing.T) { diff --git a/flyteadmin/pkg/repositories/transformers/node_execution.go b/flyteadmin/pkg/repositories/transformers/node_execution.go index ebdcd02bb..c56a1c028 100644 --- a/flyteadmin/pkg/repositories/transformers/node_execution.go +++ b/flyteadmin/pkg/repositories/transformers/node_execution.go @@ -95,6 +95,8 @@ func addTerminalState( nodeExecutionModel.ErrorKind = &k nodeExecutionModel.ErrorCode = &request.Event.GetError().Code } + closure.DeckUri = request.Event.DeckUri + return nil } diff --git a/flyteadmin/pkg/server/service.go b/flyteadmin/pkg/server/service.go index 1d8019b7e..b998936df 100644 --- a/flyteadmin/pkg/server/service.go +++ b/flyteadmin/pkg/server/service.go @@ -210,6 +210,11 @@ func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig. return nil, errors.Wrap(err, "error registering identity service") } + err = service.RegisterDataProxyServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) + if err != nil { + return nil, errors.Wrap(err, "error registering data proxy service") + } + mux.Handle("/", gwmux) return mux, nil