Skip to content

Commit

Permalink
Make returning execution data via signed URL optional (flyteorg#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrina Rogan authored Nov 5, 2021
1 parent 2c5d940 commit fc01c41
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 45 deletions.
5 changes: 4 additions & 1 deletion flyteadmin/pkg/manager/impl/testutils/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ func GetApplicationConfigWithDefaultDomains() runtimeInterfaces.ApplicationConfi
Name: "domain",
},
})
config.SetRemoteDataConfig(runtimeInterfaces.RemoteDataConfig{Scheme: common.Local})
config.SetRemoteDataConfig(runtimeInterfaces.RemoteDataConfig{
Scheme: common.Local, SignedURL: runtimeInterfaces.SignedURL{
Enabled: true,
}})
return &config
}
14 changes: 9 additions & 5 deletions flyteadmin/pkg/manager/impl/util/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,18 @@ func GetInputs(ctx context.Context, urlData dataInterfaces.RemoteURLInterface,
if len(inputURI) == 0 {
return nil, nil, nil
}
inputsURLBlob, err := urlData.Get(ctx, inputURI)
if err != nil {
return nil, nil, err
var inputsURLBlob admin.UrlBlob
var err error
if remoteDataConfig.SignedURL.Enabled {
inputsURLBlob, err = urlData.Get(ctx, inputURI)
if err != nil {
return nil, nil, err
}
}

var fullInputs core.LiteralMap
if shouldFetchData(remoteDataConfig, inputsURLBlob) {
err := storageClient.ReadProtobuf(ctx, storage.DataReference(inputURI), &fullInputs)
err = storageClient.ReadProtobuf(ctx, storage.DataReference(inputURI), &fullInputs)
if err != nil {
// If we fail to read the protobuf from the remote store, we shouldn't fail the request altogether.
// Instead we return the signed URL blob so that the client can use that to fetch the input data.
Expand Down Expand Up @@ -90,7 +94,7 @@ func GetOutputs(ctx context.Context, urlData dataInterfaces.RemoteURLInterface,
return nil, nil, nil
}
var outputsURLBlob admin.UrlBlob
if len(closure.GetOutputUri()) > 0 {
if len(closure.GetOutputUri()) > 0 && remoteDataConfig.SignedURL.Enabled {
var err error
outputsURLBlob, err = urlData.Get(ctx, closure.GetOutputUri())
if err != nil {
Expand Down
90 changes: 60 additions & 30 deletions flyteadmin/pkg/manager/impl/util/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ func TestGetInputs(t *testing.T) {
assert.Equal(t, inputsURI, uri)
return expectedURLBlob, nil
}
remoteDataConfig := interfaces.RemoteDataConfig{}
remoteDataConfig.MaxSizeInBytes = 2000
remoteDataConfig := interfaces.RemoteDataConfig{
MaxSizeInBytes: 2000,
}

mockStorage := commonMocks.GetMockStorageClient()
mockStorage.ComposedProtobufStore.(*commonMocks.TestDataStore).ReadProtobufCb = func(
Expand All @@ -135,46 +136,75 @@ func TestGetInputs(t *testing.T) {
return nil
}

fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, inputsURI)
assert.NoError(t, err)
assert.True(t, proto.Equal(fullInputs, testLiteralMap))
assert.True(t, proto.Equal(inputURLBlob, &expectedURLBlob))
t.Run("should sign URL", func(t *testing.T) {
remoteDataConfig.SignedURL = interfaces.SignedURL{
Enabled: true,
}
fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, inputsURI)
assert.NoError(t, err)
assert.True(t, proto.Equal(fullInputs, testLiteralMap))
assert.True(t, proto.Equal(inputURLBlob, &expectedURLBlob))
})
t.Run("should not sign URL", func(t *testing.T) {
remoteDataConfig.SignedURL = interfaces.SignedURL{
Enabled: false,
}
fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, inputsURI)
assert.NoError(t, err)
assert.True(t, proto.Equal(fullInputs, testLiteralMap))
assert.Empty(t, inputURLBlob)
})

}

func TestGetOutputs(t *testing.T) {
t.Run("offloaded outputs", func(t *testing.T) {
expectedURLBlob := admin.UrlBlob{
Url: "s3://foo/signed/outputs.pb",
Bytes: 1000,
}
expectedURLBlob := admin.UrlBlob{
Url: "s3://foo/signed/outputs.pb",
Bytes: 1000,
}

mockRemoteURL := urlMocks.NewMockRemoteURL()
mockRemoteURL.(*urlMocks.MockRemoteURL).GetCallback = func(ctx context.Context, uri string) (admin.UrlBlob, error) {
assert.Equal(t, testOutputsURI, uri)
return expectedURLBlob, nil
}
remoteDataConfig := interfaces.RemoteDataConfig{}
remoteDataConfig.MaxSizeInBytes = 2000
mockRemoteURL := urlMocks.NewMockRemoteURL()
mockRemoteURL.(*urlMocks.MockRemoteURL).GetCallback = func(ctx context.Context, uri string) (admin.UrlBlob, error) {
assert.Equal(t, testOutputsURI, uri)
return expectedURLBlob, nil
}

mockStorage := commonMocks.GetMockStorageClient()
mockStorage.ComposedProtobufStore.(*commonMocks.TestDataStore).ReadProtobufCb = func(
ctx context.Context, reference storage.DataReference, msg proto.Message) error {
assert.Equal(t, testOutputsURI, reference.String())
marshalled, _ := proto.Marshal(testLiteralMap)
_ = proto.Unmarshal(marshalled, msg)
return nil
}
closure := &admin.NodeExecutionClosure{
OutputResult: &admin.NodeExecutionClosure_OutputUri{
OutputUri: testOutputsURI,
},
remoteDataConfig := interfaces.RemoteDataConfig{
MaxSizeInBytes: 2000,
}
mockStorage := commonMocks.GetMockStorageClient()
mockStorage.ComposedProtobufStore.(*commonMocks.TestDataStore).ReadProtobufCb = func(
ctx context.Context, reference storage.DataReference, msg proto.Message) error {
assert.Equal(t, testOutputsURI, reference.String())
marshalled, _ := proto.Marshal(testLiteralMap)
_ = proto.Unmarshal(marshalled, msg)
return nil
}
closure := &admin.NodeExecutionClosure{
OutputResult: &admin.NodeExecutionClosure_OutputUri{
OutputUri: testOutputsURI,
},
}
t.Run("offloaded outputs with signed URL", func(t *testing.T) {
remoteDataConfig.SignedURL = interfaces.SignedURL{
Enabled: true,
}

fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure)
assert.NoError(t, err)
assert.True(t, proto.Equal(fullOutputs, testLiteralMap))
assert.True(t, proto.Equal(outputURLBlob, &expectedURLBlob))
})
t.Run("offloaded outputs without signed URL", func(t *testing.T) {
remoteDataConfig.SignedURL = interfaces.SignedURL{
Enabled: false,
}

fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure)
assert.NoError(t, err)
assert.True(t, proto.Equal(fullOutputs, testLiteralMap))
assert.Empty(t, outputURLBlob)
})
t.Run("inline outputs", func(t *testing.T) {
mockRemoteURL := urlMocks.NewMockRemoteURL()
mockRemoteURL.(*urlMocks.MockRemoteURL).GetCallback = func(ctx context.Context, uri string) (admin.UrlBlob, error) {
Expand Down
3 changes: 3 additions & 0 deletions flyteadmin/pkg/runtime/application_config_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ var schedulerConfig = config.MustRegisterSection(scheduler, &interfaces.Schedule
var remoteDataConfig = config.MustRegisterSection(remoteData, &interfaces.RemoteDataConfig{
Scheme: common.None,
MaxSizeInBytes: 2 * MB,
SignedURL: interfaces.SignedURL{
Enabled: false,
},
})
var notificationsConfig = config.MustRegisterSection(notifications, &interfaces.NotificationsConfig{
Type: common.Local,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ func (s *SchedulerConfig) GetReconnectDelaySeconds() int {

// Configuration specific to setting up signed urls.
type SignedURL struct {
// Whether signed urls should even be returned with GetExecutionData, GetNodeExecutionData and GetTaskExecutionData
// response objects.
Enabled bool `json:"enabled" pflag:",Whether signed urls should even be returned with GetExecutionData, GetNodeExecutionData and GetTaskExecutionData response objects."`
// The amount of time for which a signed URL is valid.
DurationMinutes int `json:"durationMinutes"`
// The principal that signs the URL. This is only applicable to GCS URL.
Expand Down
49 changes: 40 additions & 9 deletions flyteadmin/tests/task_execution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package tests
import (
"context"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -324,8 +323,24 @@ func TestGetTaskExecutionData(t *testing.T) {
if err != nil {
t.Fatalf("Failed to construct data reference [%s]. Error: %v", taskExecInputURI, err)
}
dataToStore := "task execution input data"
err = store.WriteRaw(ctx, inputRef, int64(len(dataToStore)), storage.Options{}, strings.NewReader(dataToStore))
taskInputs := core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": {
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_StringValue{
StringValue: "foo",
},
},
},
},
},
},
},
}
err = store.WriteProtobuf(ctx, inputRef, storage.Options{}, &taskInputs)
if err != nil {
t.Fatalf("Failed to write data. Error: %v", err)
}
Expand All @@ -335,8 +350,24 @@ func TestGetTaskExecutionData(t *testing.T) {
t.Fatalf("Failed to construct data reference. Error: %v", err)
}

dataToStore = "task execution output data"
err = store.WriteRaw(ctx, outputRef, int64(len(dataToStore)), storage.Options{}, strings.NewReader(dataToStore))
taskOutputs := core.LiteralMap{
Literals: map[string]*core.Literal{
"bar": {
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_StringValue{
StringValue: "bar",
},
},
},
},
},
},
},
}
err = store.WriteProtobuf(ctx, outputRef, storage.Options{}, &taskOutputs)
if err != nil {
t.Fatalf("Failed to write data. Error: %v", err)
}
Expand Down Expand Up @@ -375,8 +406,8 @@ func TestGetTaskExecutionData(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, resp)

assert.NotEmpty(t, resp.Inputs.Url)
assert.Equal(t, int64(25), resp.Inputs.Bytes)
assert.NotEmpty(t, resp.Outputs.Url)
assert.Equal(t, int64(26), resp.Outputs.Bytes)
assert.Empty(t, resp.Inputs)
assert.NotEmpty(t, resp.FullInputs)
assert.Empty(t, resp.Outputs)
assert.NotEmpty(t, resp.FullOutputs)
}

0 comments on commit fc01c41

Please sign in to comment.