diff --git a/pkg/manager/impl/testutils/config.go b/pkg/manager/impl/testutils/config.go index 672a6269e4..c2fe20139d 100644 --- a/pkg/manager/impl/testutils/config.go +++ b/pkg/manager/impl/testutils/config.go @@ -26,6 +26,9 @@ func GetApplicationConfigWithDefaultDomains() runtimeInterfaces.ApplicationConfi Name: "domain", }, }) - config.SetRemoteDataConfig(runtimeInterfaces.RemoteDataConfig{Scheme: common.Local}) + config.SetRemoteDataConfig(runtimeInterfaces.RemoteDataConfig{ + Scheme: common.Local, SignedURL: runtimeInterfaces.SignedURL{ + Enabled: true, + }}) return &config } diff --git a/pkg/manager/impl/util/data.go b/pkg/manager/impl/util/data.go index 235f2be855..5a9f5017c4 100644 --- a/pkg/manager/impl/util/data.go +++ b/pkg/manager/impl/util/data.go @@ -29,14 +29,18 @@ func GetInputs(ctx context.Context, urlData dataInterfaces.RemoteURLInterface, if len(inputURI) == 0 { return nil, nil, nil } - inputsURLBlob, err := urlData.Get(ctx, inputURI) - if err != nil { - return nil, nil, err + var inputsURLBlob admin.UrlBlob + var err error + if remoteDataConfig.SignedURL.Enabled { + inputsURLBlob, err = urlData.Get(ctx, inputURI) + if err != nil { + return nil, nil, err + } } var fullInputs core.LiteralMap if shouldFetchData(remoteDataConfig, inputsURLBlob) { - err := storageClient.ReadProtobuf(ctx, storage.DataReference(inputURI), &fullInputs) + err = storageClient.ReadProtobuf(ctx, storage.DataReference(inputURI), &fullInputs) if err != nil { // If we fail to read the protobuf from the remote store, we shouldn't fail the request altogether. // Instead we return the signed URL blob so that the client can use that to fetch the input data. @@ -90,7 +94,7 @@ func GetOutputs(ctx context.Context, urlData dataInterfaces.RemoteURLInterface, return nil, nil, nil } var outputsURLBlob admin.UrlBlob - if len(closure.GetOutputUri()) > 0 { + if len(closure.GetOutputUri()) > 0 && remoteDataConfig.SignedURL.Enabled { var err error outputsURLBlob, err = urlData.Get(ctx, closure.GetOutputUri()) if err != nil { diff --git a/pkg/manager/impl/util/data_test.go b/pkg/manager/impl/util/data_test.go index b010caa5fb..2c3409a67a 100644 --- a/pkg/manager/impl/util/data_test.go +++ b/pkg/manager/impl/util/data_test.go @@ -123,8 +123,9 @@ func TestGetInputs(t *testing.T) { assert.Equal(t, inputsURI, uri) return expectedURLBlob, nil } - remoteDataConfig := interfaces.RemoteDataConfig{} - remoteDataConfig.MaxSizeInBytes = 2000 + remoteDataConfig := interfaces.RemoteDataConfig{ + MaxSizeInBytes: 2000, + } mockStorage := commonMocks.GetMockStorageClient() mockStorage.ComposedProtobufStore.(*commonMocks.TestDataStore).ReadProtobufCb = func( @@ -135,39 +136,58 @@ func TestGetInputs(t *testing.T) { return nil } - fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, inputsURI) - assert.NoError(t, err) - assert.True(t, proto.Equal(fullInputs, testLiteralMap)) - assert.True(t, proto.Equal(inputURLBlob, &expectedURLBlob)) + t.Run("should sign URL", func(t *testing.T) { + remoteDataConfig.SignedURL = interfaces.SignedURL{ + Enabled: true, + } + fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, inputsURI) + assert.NoError(t, err) + assert.True(t, proto.Equal(fullInputs, testLiteralMap)) + assert.True(t, proto.Equal(inputURLBlob, &expectedURLBlob)) + }) + t.Run("should not sign URL", func(t *testing.T) { + remoteDataConfig.SignedURL = interfaces.SignedURL{ + Enabled: false, + } + fullInputs, inputURLBlob, err := GetInputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, inputsURI) + assert.NoError(t, err) + assert.True(t, proto.Equal(fullInputs, testLiteralMap)) + assert.Empty(t, inputURLBlob) + }) + } func TestGetOutputs(t *testing.T) { - t.Run("offloaded outputs", func(t *testing.T) { - expectedURLBlob := admin.UrlBlob{ - Url: "s3://foo/signed/outputs.pb", - Bytes: 1000, - } + expectedURLBlob := admin.UrlBlob{ + Url: "s3://foo/signed/outputs.pb", + Bytes: 1000, + } - mockRemoteURL := urlMocks.NewMockRemoteURL() - mockRemoteURL.(*urlMocks.MockRemoteURL).GetCallback = func(ctx context.Context, uri string) (admin.UrlBlob, error) { - assert.Equal(t, testOutputsURI, uri) - return expectedURLBlob, nil - } - remoteDataConfig := interfaces.RemoteDataConfig{} - remoteDataConfig.MaxSizeInBytes = 2000 + mockRemoteURL := urlMocks.NewMockRemoteURL() + mockRemoteURL.(*urlMocks.MockRemoteURL).GetCallback = func(ctx context.Context, uri string) (admin.UrlBlob, error) { + assert.Equal(t, testOutputsURI, uri) + return expectedURLBlob, nil + } - mockStorage := commonMocks.GetMockStorageClient() - mockStorage.ComposedProtobufStore.(*commonMocks.TestDataStore).ReadProtobufCb = func( - ctx context.Context, reference storage.DataReference, msg proto.Message) error { - assert.Equal(t, testOutputsURI, reference.String()) - marshalled, _ := proto.Marshal(testLiteralMap) - _ = proto.Unmarshal(marshalled, msg) - return nil - } - closure := &admin.NodeExecutionClosure{ - OutputResult: &admin.NodeExecutionClosure_OutputUri{ - OutputUri: testOutputsURI, - }, + remoteDataConfig := interfaces.RemoteDataConfig{ + MaxSizeInBytes: 2000, + } + mockStorage := commonMocks.GetMockStorageClient() + mockStorage.ComposedProtobufStore.(*commonMocks.TestDataStore).ReadProtobufCb = func( + ctx context.Context, reference storage.DataReference, msg proto.Message) error { + assert.Equal(t, testOutputsURI, reference.String()) + marshalled, _ := proto.Marshal(testLiteralMap) + _ = proto.Unmarshal(marshalled, msg) + return nil + } + closure := &admin.NodeExecutionClosure{ + OutputResult: &admin.NodeExecutionClosure_OutputUri{ + OutputUri: testOutputsURI, + }, + } + t.Run("offloaded outputs with signed URL", func(t *testing.T) { + remoteDataConfig.SignedURL = interfaces.SignedURL{ + Enabled: true, } fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure) @@ -175,6 +195,16 @@ func TestGetOutputs(t *testing.T) { assert.True(t, proto.Equal(fullOutputs, testLiteralMap)) assert.True(t, proto.Equal(outputURLBlob, &expectedURLBlob)) }) + t.Run("offloaded outputs without signed URL", func(t *testing.T) { + remoteDataConfig.SignedURL = interfaces.SignedURL{ + Enabled: false, + } + + fullOutputs, outputURLBlob, err := GetOutputs(context.TODO(), mockRemoteURL, &remoteDataConfig, mockStorage, closure) + assert.NoError(t, err) + assert.True(t, proto.Equal(fullOutputs, testLiteralMap)) + assert.Empty(t, outputURLBlob) + }) t.Run("inline outputs", func(t *testing.T) { mockRemoteURL := urlMocks.NewMockRemoteURL() mockRemoteURL.(*urlMocks.MockRemoteURL).GetCallback = func(ctx context.Context, uri string) (admin.UrlBlob, error) { diff --git a/pkg/runtime/application_config_provider.go b/pkg/runtime/application_config_provider.go index c582d5c5fb..400a981009 100644 --- a/pkg/runtime/application_config_provider.go +++ b/pkg/runtime/application_config_provider.go @@ -60,6 +60,9 @@ var schedulerConfig = config.MustRegisterSection(scheduler, &interfaces.Schedule var remoteDataConfig = config.MustRegisterSection(remoteData, &interfaces.RemoteDataConfig{ Scheme: common.None, MaxSizeInBytes: 2 * MB, + SignedURL: interfaces.SignedURL{ + Enabled: false, + }, }) var notificationsConfig = config.MustRegisterSection(notifications, &interfaces.NotificationsConfig{ Type: common.Local, diff --git a/pkg/runtime/interfaces/application_configuration.go b/pkg/runtime/interfaces/application_configuration.go index 42c477dc63..926e3a106b 100644 --- a/pkg/runtime/interfaces/application_configuration.go +++ b/pkg/runtime/interfaces/application_configuration.go @@ -295,6 +295,9 @@ func (s *SchedulerConfig) GetReconnectDelaySeconds() int { // Configuration specific to setting up signed urls. type SignedURL struct { + // Whether signed urls should even be returned with GetExecutionData, GetNodeExecutionData and GetTaskExecutionData + // response objects. + Enabled bool `json:"enabled" pflag:",Whether signed urls should even be returned with GetExecutionData, GetNodeExecutionData and GetTaskExecutionData response objects."` // The amount of time for which a signed URL is valid. DurationMinutes int `json:"durationMinutes"` // The principal that signs the URL. This is only applicable to GCS URL. diff --git a/tests/task_execution_test.go b/tests/task_execution_test.go index d961dad3b4..b3955d4764 100644 --- a/tests/task_execution_test.go +++ b/tests/task_execution_test.go @@ -5,7 +5,6 @@ package tests import ( "context" "net/url" - "strings" "testing" "time" @@ -324,8 +323,24 @@ func TestGetTaskExecutionData(t *testing.T) { if err != nil { t.Fatalf("Failed to construct data reference [%s]. Error: %v", taskExecInputURI, err) } - dataToStore := "task execution input data" - err = store.WriteRaw(ctx, inputRef, int64(len(dataToStore)), storage.Options{}, strings.NewReader(dataToStore)) + taskInputs := core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: "foo", + }, + }, + }, + }, + }, + }, + }, + } + err = store.WriteProtobuf(ctx, inputRef, storage.Options{}, &taskInputs) if err != nil { t.Fatalf("Failed to write data. Error: %v", err) } @@ -335,8 +350,24 @@ func TestGetTaskExecutionData(t *testing.T) { t.Fatalf("Failed to construct data reference. Error: %v", err) } - dataToStore = "task execution output data" - err = store.WriteRaw(ctx, outputRef, int64(len(dataToStore)), storage.Options{}, strings.NewReader(dataToStore)) + taskOutputs := core.LiteralMap{ + Literals: map[string]*core.Literal{ + "bar": { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: "bar", + }, + }, + }, + }, + }, + }, + }, + } + err = store.WriteProtobuf(ctx, outputRef, storage.Options{}, &taskOutputs) if err != nil { t.Fatalf("Failed to write data. Error: %v", err) } @@ -375,8 +406,8 @@ func TestGetTaskExecutionData(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, resp) - assert.NotEmpty(t, resp.Inputs.Url) - assert.Equal(t, int64(25), resp.Inputs.Bytes) - assert.NotEmpty(t, resp.Outputs.Url) - assert.Equal(t, int64(26), resp.Outputs.Bytes) + assert.Empty(t, resp.Inputs) + assert.NotEmpty(t, resp.FullInputs) + assert.Empty(t, resp.Outputs) + assert.NotEmpty(t, resp.FullOutputs) }