Skip to content

Commit

Permalink
moves host prefix addition to a serialize step middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
skotambkar committed Nov 8, 2020
1 parent 43a13c1 commit 62c2a93
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.EndpointTrait;

/**
* EndpointHostPrefixMiddleware adds middlewares to identify
* a host prefix and mutate the request URL host if permitted.
**/
public class EndpointHostPrefixMiddleware implements GoIntegration {

private static final MiddlewareIdentifier MIDDLEWARE_ID = MiddlewareIdentifier.string("EndpointHostPrefix");
Expand Down Expand Up @@ -84,10 +88,9 @@ public void writeAdditionalFiles(
middlewareHelperName,
() -> {
writer.write(
"return stack.Serialize.Insert(&$L{}, `OperationSerializer`, middleware.Before)",
"return stack.Serialize.Insert(&$L{}, `OperationSerializer`, middleware.After)",
middlewareName);
});

});
});
}
Expand All @@ -109,7 +112,8 @@ private static void writeMiddleware(
writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_TRANSPORT);
writer.addUseImports(SmithyGoDependency.FMT);

w.openBlock("if smithyhttp.GetHostnameImmutable(ctx) {", "}", () -> {
w.openBlock("if smithyhttp.GetHostnameImmutable(ctx) || "
+ "smithyhttp.IsEndpointHostPrefixDisabled(ctx) {", "}", () -> {
w.write("return next.$L(ctx, in)", generator.getHandleMethodName());
}).write("");

Expand All @@ -119,8 +123,7 @@ private static void writeMiddleware(
}).write("");

if (pattern.getLabels().isEmpty()) {
w.write("req.HostPrefix = $S", pattern.toString());

w.write("req.URL.Host = $S + req.URL.Host", pattern.toString());
} else {
// If the pattern has labels, we need to build up the host prefix using a string builder.
writer.addUseImports(SmithyGoDependency.STRINGS);
Expand Down Expand Up @@ -156,8 +159,7 @@ private static void writeMiddleware(
});
}
}

w.write("req.HostPrefix = prefix.String()");
w.write("req.URL.Host = prefix.String() + req.URL.Host");
}
w.write("");

Expand Down
15 changes: 15 additions & 0 deletions transport/http/middleware_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import "context"

type (
hostnameImmutableKey struct{}
hostPrefixDisableKey struct{}
)

// GetHostnameImmutable retrieves if the endpoint hostname should be considered
Expand All @@ -18,3 +19,17 @@ func GetHostnameImmutable(ctx context.Context) (v bool) {
func SetHostnameImmutable(ctx context.Context, value bool) context.Context {
return context.WithValue(ctx, hostnameImmutableKey{}, value)
}

// DisableEndpointHostPrefix sets or modifies if the request's endpoint host
// prefixing to be disabled. If value is set to true, endpoint host prefixing
// will be disabled.
func DisableEndpointHostPrefix(ctx context.Context, value bool) context.Context {
return context.WithValue(ctx, hostPrefixDisableKey{}, value)
}

// IsEndpointHostPrefixDisabled retrieves if the hostname prefixing
// is disabled.
func IsEndpointHostPrefixDisabled(ctx context.Context) (v bool) {
v, _ = ctx.Value(hostPrefixDisableKey{}).(bool)
return v
}
8 changes: 1 addition & 7 deletions transport/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ type Request struct {
stream io.Reader
isStreamSeekable bool
streamStartPos int64
HostPrefix string
}

// NewStackRequest returns an initialized request ready to populated with the
Expand Down Expand Up @@ -127,7 +126,7 @@ func (r *Request) SetStream(reader io.Reader) (rc *Request, err error) {

// Build returns a build standard HTTP request value from the Smithy request.
// The request's stream is wrapped in a safe container that allows it to be
// reused for subsiquent attempts.
// reused for subsequent attempts.
func (r *Request) Build(ctx context.Context) *http.Request {
req := r.Request.Clone(ctx)

Expand All @@ -139,11 +138,6 @@ func (r *Request) Build(ctx context.Context) *http.Request {
req.ContentLength = 0
}

// Add the host prefix
if len(r.HostPrefix) != 0 {
req.URL.Host = r.HostPrefix + req.URL.Host
}

return req
}

Expand Down

0 comments on commit 62c2a93

Please sign in to comment.