Skip to content

Commit

Permalink
Merge pull request #2051 from aws/add100ContinueCustomization
Browse files Browse the repository at this point in the history
Port v1 SDK customization for s3 HTTP PUT request
  • Loading branch information
wty-Bryant authored Mar 21, 2023
2 parents 3497eac + c01aac6 commit c93b5cc
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .changelog/bbab7da0e2504bebb9d999dacb2de133.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"id": "bbab7da0-e250-4beb-b9d9-99dacb2de133",
"type": "feature",
"description": "port v1 sdk 100-continue http header customization for s3 PutObject/UploadPart request and enable user config",
"modules": [
"service/internal/s3shared",
"service/s3"
]
}
1 change: 1 addition & 0 deletions aws/signer/internal/v4/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ var IgnoredHeaders = Rules{
"Authorization": struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
"Expect": struct{}{},
},
},
}
Expand Down
28 changes: 28 additions & 0 deletions aws/signer/internal/v4/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,31 @@ func TestAllowedQueryHoisting(t *testing.T) {
})
}
}

func TestIgnoredHeaders(t *testing.T) {
cases := map[string]struct {
Header string
ExpectIgnored bool
}{
"expect": {
Header: "Expect",
ExpectIgnored: true,
},
"authorization": {
Header: "Authorization",
ExpectIgnored: true,
},
"X-AMZ header": {
Header: "X-Amz-Content-Sha256",
ExpectIgnored: false,
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
if e, a := c.ExpectIgnored, IgnoredHeaders.IsValid(c.Header); e == a {
t.Errorf("expect ignored %v, was %v", e, a)
}
})
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package software.amazon.smithy.aws.go.codegen.customization;

import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoDelegator;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.integration.ConfigField;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.traits.HttpTrait;
import software.amazon.smithy.utils.ListUtils;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* Add middleware, which adds {Expect: 100-continue} header for s3 client HTTP PUT request larger than 2MB
* or with unknown size streaming bodies, during operation builder step
*/
public class S3100Continue implements GoIntegration {
private static final String ADD_100Continue_Header = "add100Continue";
private static final String ADD_100Continue_Header_INTERNAL = "Add100Continue";
private static final String Continue_Client_Option = "ContinueHeaderThresholdBytes";
private static final Set<String> Put_Op_ShapeId_Set = new HashSet<>(Arrays.asList("com.amazonaws.s3#PutObject", "com.amazonaws.s3#UploadPart"));

/**
* Return true if service is Amazon S3.
*
* @param model is the generation model.
* @param service is the service shape being audited.
*/
private static boolean isS3Service(Model model, ServiceShape service) {
return S3ModelUtils.isServiceS3(model, service);
}

/**
* Gets the sort order of the customization from -128 to 127, with lowest
* executed first.
*
* @return Returns the sort order, defaults to -40.
*/
@Override
public byte getOrder() {
return 126;
}

@Override
public void writeAdditionalFiles(
GoSettings settings,
Model model,
SymbolProvider symbolProvider,
GoDelegator goDelegator
) {
ServiceShape service = settings.getService(model);
if (!isS3Service(model, service)) {
return;
}

goDelegator.useShapeWriter(service, this::writeMiddlewareHelper);
}

private void writeMiddlewareHelper(GoWriter writer) {
writer.openBlock("func $L(stack *middleware.Stack, options Options) error {", "}", ADD_100Continue_Header, () -> {
writer.write("return $T(stack, options.ContinueHeaderThresholdBytes)",
SymbolUtils.createValueSymbolBuilder(ADD_100Continue_Header_INTERNAL,
AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION).build()
);
});
writer.insertTrailingNewline();
}

@Override
public List<RuntimeClientPlugin> getClientPlugins() {
return ListUtils.of(
RuntimeClientPlugin.builder()
.operationPredicate((model, service, operation) ->
isS3Service(model, service) && Put_Op_ShapeId_Set.contains(operation.getId().toString())
)
.registerMiddleware(MiddlewareRegistrar.builder()
.resolvedFunction(SymbolUtils.createValueSymbolBuilder(ADD_100Continue_Header).build())
.useClientOptions()
.build()
)
.build(),
RuntimeClientPlugin.builder()
.servicePredicate(S3100Continue::isS3Service)
.configFields(ListUtils.of(
ConfigField.builder()
.name(Continue_Client_Option)
.type(SymbolUtils.createValueSymbolBuilder("int64")
.putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true)
.build())
.documentation("The threshold ContentLength in bytes for HTTP PUT request to receive {Expect: 100-continue} header. " +
"Setting to -1 will disable adding the Expect header to requests; setting to 0 will set the threshold " +
"to default 2MB")
.build()
))
.build()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ software.amazon.smithy.aws.go.codegen.customization.BackfillEc2UnboxedToBoxedSha
software.amazon.smithy.aws.go.codegen.customization.AdjustAwsRestJsonContentType
software.amazon.smithy.aws.go.codegen.customization.SQSValidateMessageChecksum
software.amazon.smithy.aws.go.codegen.EndpointDiscoveryGenerator
software.amazon.smithy.aws.go.codegen.customization.S3100Continue
54 changes: 54 additions & 0 deletions service/internal/s3shared/s3100continue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package s3shared

import (
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)

const s3100ContinueID = "S3100Continue"
const default100ContinueThresholdBytes int64 = 1024 * 1024 * 2

// Add100Continue add middleware, which adds {Expect: 100-continue} header for s3 client HTTP PUT request larger than 2MB
// or with unknown size streaming bodies, during operation builder step
func Add100Continue(stack *middleware.Stack, continueHeaderThresholdBytes int64) error {
return stack.Build.Add(&s3100Continue{
continueHeaderThresholdBytes: continueHeaderThresholdBytes,
}, middleware.After)
}

type s3100Continue struct {
continueHeaderThresholdBytes int64
}

// ID returns the middleware identifier
func (m *s3100Continue) ID() string {
return s3100ContinueID
}

func (m *s3100Continue) HandleBuild(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
sizeLimit := default100ContinueThresholdBytes
switch {
case m.continueHeaderThresholdBytes == -1:
return next.HandleBuild(ctx, in)
case m.continueHeaderThresholdBytes > 0:
sizeLimit = m.continueHeaderThresholdBytes
default:
}

req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}

if req.ContentLength == -1 || (req.ContentLength == 0 && req.Body != nil) || req.ContentLength >= sizeLimit {
req.Header.Set("Expect", "100-continue")
}

return next.HandleBuild(ctx, in)
}
96 changes: 96 additions & 0 deletions service/internal/s3shared/s3100continue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package s3shared

import (
"context"
"github.com/aws/aws-sdk-go-v2/internal/awstesting"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"testing"
)

// unit test for service/internal/s3shared/s3100continue.go
func TestAdd100ContinueHttpHeader(t *testing.T) {
const HeaderKey = "Expect"
HeaderValue := "100-continue"

cases := map[string]struct {
ContentLength int64
Body *awstesting.ReadCloser
ExpectValueFound string
ContinueHeaderThresholdBytes int64
}{
"http request smaller than default 2MB": {
ContentLength: 1,
Body: &awstesting.ReadCloser{Size: 1},
ExpectValueFound: "",
},
"http request smaller than configured threshold": {
ContentLength: 1024 * 1024 * 2,
Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 2},
ExpectValueFound: "",
ContinueHeaderThresholdBytes: 1024 * 1024 * 3,
},
"http request larger than default 2MB": {
ContentLength: 1024 * 1024 * 3,
Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 3},
ExpectValueFound: HeaderValue,
},
"http request larger than configured threshold": {
ContentLength: 1024 * 1024 * 4,
Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 4},
ExpectValueFound: HeaderValue,
ContinueHeaderThresholdBytes: 1024 * 1024 * 3,
},
"http put request with unknown -1 ContentLength": {
ContentLength: -1,
Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 10},
ExpectValueFound: HeaderValue,
},
"http put request with 0 ContentLength but unknown non-nil body": {
ContentLength: 0,
Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 3},
ExpectValueFound: HeaderValue,
},
"http put request with unknown -1 ContentLength and configured threshold": {
ContentLength: -1,
Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 3},
ExpectValueFound: HeaderValue,
ContinueHeaderThresholdBytes: 1024 * 1024 * 10,
},
"http put request with continue header disabled": {
ContentLength: 1024 * 1024 * 3,
Body: &awstesting.ReadCloser{Size: 1024 * 1024 * 3},
ExpectValueFound: "",
ContinueHeaderThresholdBytes: -1,
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
var err error
req := smithyhttp.NewStackRequest().(*smithyhttp.Request)

req.ContentLength = c.ContentLength
req.Body = c.Body
var updatedRequest *smithyhttp.Request
m := s3100Continue{
continueHeaderThresholdBytes: c.ContinueHeaderThresholdBytes,
}
_, _, err = m.HandleBuild(context.Background(),
middleware.BuildInput{Request: req},
middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) (
out middleware.BuildOutput, metadata middleware.Metadata, err error) {
updatedRequest = input.Request.(*smithyhttp.Request)
return out, metadata, nil
}),
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := c.ExpectValueFound, updatedRequest.Header.Get(HeaderKey); e != a {
t.Errorf("expect header value %v found, got %v", e, a)
}
})
}
}
9 changes: 9 additions & 0 deletions service/s3/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/s3/api_op_PutObject.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions service/s3/api_op_UploadPart.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c93b5cc

Please sign in to comment.