Skip to content

Commit

Permalink
Merge pull request #2490 from aws/feat-aid-endpoints
Browse files Browse the repository at this point in the history
Support aws account id in endpoint2.0 routing
  • Loading branch information
wty-Bryant authored Jun 17, 2024
2 parents aa796dc + 3133994 commit 374440d
Show file tree
Hide file tree
Showing 1,593 changed files with 21,682 additions and 5,800 deletions.
8 changes: 8 additions & 0 deletions .changelog/783a73b97a9843d991c055e06537c43c.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "783a73b9-7a98-43d9-91c0-55e06537c43c",
"type": "feature",
"description": "Support accountID-based endpoint routing.",
"modules": [
"."
]
}
18 changes: 18 additions & 0 deletions aws/accountid_endpoint_mode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package aws

// AccountIDEndpointMode controls how a resolved AWS account ID is handled for endpoint routing.
type AccountIDEndpointMode string

const (
// AccountIDEndpointModeUnset indicates the AWS account ID will not be used for endpoint routing
AccountIDEndpointModeUnset AccountIDEndpointMode = ""

// AccountIDEndpointModePreferred indicates the AWS account ID will be used for endpoint routing if present
AccountIDEndpointModePreferred = "preferred"

// AccountIDEndpointModeRequired indicates an error will be returned if the AWS account ID is not resolved from identity
AccountIDEndpointModeRequired = "required"

// AccountIDEndpointModeDisabled indicates the AWS account ID will be ignored during endpoint routing
AccountIDEndpointModeDisabled = "disabled"
)
3 changes: 3 additions & 0 deletions aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ type Config struct {
// This variable is sourced from environment variable AWS_REQUEST_MIN_COMPRESSION_SIZE_BYTES or
// the shared config profile attribute request_min_compression_size_bytes
RequestMinCompressSizeBytes int64

// Controls how a resolved AWS account ID is handled for endpoint routing.
AccountIDEndpointMode AccountIDEndpointMode
}

// NewConfig returns a new Config pointer that can be chained with builder
Expand Down
3 changes: 3 additions & 0 deletions aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ type Credentials struct {
// The time the credentials will expire at. Should be ignored if CanExpire
// is false.
Expires time.Time

// The ID of the account for the credentials.
AccountID string
}

// Expired returns if the credentials have expired.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ public class AddAwsConfigFields implements GoIntegration {

private static final String REQUEST_MIN_COMPRESSION_SIZE_BYTES = "RequestMinCompressSizeBytes";

private static final String SDK_ACCOUNTID_ENDPOINT_MODE = "AccountIDEndpointMode";

private static final List<AwsConfigField> AWS_CONFIG_FIELDS = ListUtils.of(
AwsConfigField.builder()
.name(REGION_CONFIG_NAME)
Expand Down Expand Up @@ -235,6 +237,11 @@ public class AddAwsConfigFields implements GoIntegration {
.documentation("The inclusive min request body size to be compressed.")
.servicePredicate(RequestCompression::isRequestCompressionService)
.generatedOnClient(false)
.build(),
AwsConfigField.builder()
.name(SDK_ACCOUNTID_ENDPOINT_MODE)
.type(SdkGoTypes.Aws.AccountIDEndpointMode)
.documentation("Indicates how aws account ID is applied in endpoint2.0 routing")
.build()
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ public static final class Aws {

public static final Symbol IsCredentialsProvider = AwsGoDependency.AWS_CORE.valueSymbol("IsCredentialsProvider");
public static final Symbol AnonymousCredentials = AwsGoDependency.AWS_CORE.pointableSymbol("AnonymousCredentials");
public static final Symbol AccountIDEndpointMode = AwsGoDependency.AWS_CORE.valueSymbol("AccountIDEndpointMode");
public static final Symbol AccountIDEndpointModeUnset = AwsGoDependency.AWS_CORE.valueSymbol("AccountIDEndpointModeUnset");
public static final Symbol AccountIDEndpointModePreferred = AwsGoDependency.AWS_CORE.valueSymbol("AccountIDEndpointModePreferred");
public static final Symbol AccountIDEndpointModeRequired = AwsGoDependency.AWS_CORE.valueSymbol("AccountIDEndpointModeRequired");
public static final Symbol AccountIDEndpointModeDisabled = AwsGoDependency.AWS_CORE.valueSymbol("AccountIDEndpointModeDisabled");


public static final class Middleware {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package software.amazon.smithy.aws.go.codegen.customization;

import software.amazon.smithy.aws.go.codegen.SdkGoTypes;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoDelegator;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.GoStdlibTypes;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait;
import software.amazon.smithy.utils.MapUtils;

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;

public class AccountIDEndpointRouting implements GoIntegration {
@Override
public void renderPreEndpointResolutionHook(GoSettings settings, GoWriter writer, Model model) {
writer.write("""
if err := checkAccountID(getIdentity(ctx), m.options.AccountIDEndpointMode); err != nil {
return out, metadata, $T("invalid accountID set: %w", err)
}
""",
GoStdlibTypes.Fmt.Errorf);
}

@Override
public void writeAdditionalFiles(
GoSettings settings,
Model model,
SymbolProvider symbolProvider,
GoDelegator goDelegator
) {
if (!settings.getService(model).hasTrait(EndpointRuleSetTrait.class)) {
return;
}
goDelegator.useShapeWriter(settings.getService(model), goTemplate("""
func checkAccountID(identity $auth:T, mode $accountIDEndpointMode:T) error {
switch mode {
case $aidModeUnset:T:
case $aidModePreferred:T:
case $aidModeDisabled:T:
case $aidModeRequired:T:
if ca, ok := identity.(*$credentialsAdapter:T); !ok {
return $errorf:T("accountID is required but not set")
} else if ca.Credentials.AccountID == "" {
return $errorf:T("accountID is required but not set")
}
// default check in case invalid mode is configured through request config
default:
return $errorf:T("invalid accountID endpoint mode %s, must be preferred/required/disabled", mode)
}
return nil
}
""",
MapUtils.of(
"auth", SmithyGoTypes.Auth.Identity,
"accountIDEndpointMode", SdkGoTypes.Aws.AccountIDEndpointMode,
"credentialsAdapter", SdkGoTypes.Internal.Auth.Smithy.CredentialsAdapter,
"aidModePreferred", SdkGoTypes.Aws.AccountIDEndpointModePreferred,
"aidModeRequired", SdkGoTypes.Aws.AccountIDEndpointModeRequired,
"aidModeUnset", SdkGoTypes.Aws.AccountIDEndpointModeUnset,
"aidModeDisabled", SdkGoTypes.Aws.AccountIDEndpointModeDisabled,
"errorf", GoStdlibTypes.Fmt.Errorf
)
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import software.amazon.smithy.aws.traits.auth.SigV4Trait;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoDelegator;
import software.amazon.smithy.go.codegen.GoStdlibTypes;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SymbolUtils;
Expand Down Expand Up @@ -94,17 +95,17 @@ private boolean isSigV4Service(Model model, ServiceShape service) {

private GoWriter.Writable writeRegionResolver() {
return goTemplate("""
func bindAuthParamsRegion(params $P, _ interface{}, options Options) {
func bindAuthParamsRegion( _ interface{}, params $P, _ interface{}, options Options) {
params.Region = options.Region
}
""", AuthParametersGenerator.STRUCT_SYMBOL);
}

private GoWriter.Writable writeEndpointParamResolver() {
return goTemplate("""
func bindAuthEndpointParams(params $P, input interface{}, options Options) {
params.endpointParams = bindEndpointParams(input, options)
func bindAuthEndpointParams(ctx $P, params $P, input interface{}, options Options) {
params.endpointParams = bindEndpointParams(ctx, input, options)
}
""", AuthParametersGenerator.STRUCT_SYMBOL);
""", GoStdlibTypes.Context.Context, AuthParametersGenerator.STRUCT_SYMBOL);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import software.amazon.smithy.go.codegen.GoDelegator;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait;
import software.amazon.smithy.utils.ListUtils;
import software.amazon.smithy.utils.MapUtils;

import java.util.List;

Expand Down Expand Up @@ -36,6 +39,8 @@ public class AwsEndpointBuiltins implements GoIntegration {
goTemplate("$T(options.UseARNRegion)", SdkGoTypes.Aws.Bool);
private static final GoWriter.Writable BindAwsS3DisableMultiRegionAccessPoints =
goTemplate("$T(options.DisableMultiRegionAccessPoints)", SdkGoTypes.Aws.Bool);
private static final GoWriter.Writable BindAccountID =
goTemplate("resolveAccountID(getIdentity(ctx), options.AccountIDEndpointMode)");

@Override
public List<RuntimeClientPlugin> getClientPlugins() {
Expand All @@ -49,12 +54,38 @@ public List<RuntimeClientPlugin> getClientPlugins() {
.addEndpointBuiltinBinding("AWS::S3::UseArnRegion", BindAwsS3UseArnRegion)
.addEndpointBuiltinBinding("AWS::S3::DisableMultiRegionAccessPoints", BindAwsS3DisableMultiRegionAccessPoints)
.addEndpointBuiltinBinding("AWS::S3Control::UseArnRegion", BindAwsS3UseArnRegion)
.addEndpointBuiltinBinding("AWS::Auth::AccountId", BindAccountID)
.build());
}

@Override
public void writeAdditionalFiles(GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator) {
goDelegator.useFileWriter("endpoints.go", settings.getModuleName(), builtinBindingSource());
if (!settings.getService(model).hasTrait(EndpointRuleSetTrait.class)) {
return;
}
goDelegator.useShapeWriter(settings.getService(model), goTemplate("""
func resolveAccountID(identity $auth:T, mode $accountIDEndpointMode:T) *string {
if mode == $aidModeDisabled:T {
return nil
}
if ca, ok := identity.(*$credentialsAdapter:T); ok && ca.Credentials.AccountID != "" {
return $string:T(ca.Credentials.AccountID)
}
return nil
}
""",
MapUtils.of(
"auth", SmithyGoTypes.Auth.Identity,
"accountIDEndpointMode", SdkGoTypes.Aws.AccountIDEndpointMode,
"aidModeUnset", SdkGoTypes.Aws.AccountIDEndpointModeUnset,
"aidModeDisabled", SdkGoTypes.Aws.AccountIDEndpointModeDisabled,
"credentialsAdapter", SdkGoTypes.Internal.Auth.Smithy.CredentialsAdapter,
"string", SdkGoTypes.Aws.String
)
));
}

private GoWriter.Writable builtinBindingSource() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,5 @@ software.amazon.smithy.aws.go.codegen.customization.CloudFrontKVSSigV4a
software.amazon.smithy.aws.go.codegen.customization.BackfillProtocolTestServiceTrait
software.amazon.smithy.go.codegen.integration.MiddlewareStackSnapshotTests
software.amazon.smithy.aws.go.codegen.customization.s3.S3ExpiresShapeCustomization
software.amazon.smithy.aws.go.codegen.ClockSkewGenerator
software.amazon.smithy.aws.go.codegen.ClockSkewGenerator
software.amazon.smithy.aws.go.codegen.customization.AccountIDEndpointRouting
3 changes: 3 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ var defaultAWSConfigResolvers = []awsConfigResolver{

// Sets the RequestMinCompressSizeBytes if present in env var or shared config profile
resolveRequestMinCompressSizeBytes,

// Sets the AccountIDEndpointMode if present in env var or shared config profile
resolveAccountIDEndpointMode,
}

// A Config represents a generic configuration value or set of values. This type
Expand Down
37 changes: 37 additions & 0 deletions config/env_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ const (
awsRequestMinCompressionSizeBytes = "AWS_REQUEST_MIN_COMPRESSION_SIZE_BYTES"

awsS3DisableExpressSessionAuthEnv = "AWS_S3_DISABLE_EXPRESS_SESSION_AUTH"

awsAccountIDEnv = "AWS_ACCOUNT_ID"
awsAccountIDEndpointModeEnv = "AWS_ACCOUNT_ID_ENDPOINT_MODE"
)

var (
Expand Down Expand Up @@ -290,6 +293,9 @@ type EnvConfig struct {
// will only bypass the modified endpoint routing and signing behaviors
// associated with the feature.
S3DisableExpressAuth *bool

// Indicates whether account ID will be required/ignored in endpoint2.0 routing
AccountIDEndpointMode aws.AccountIDEndpointMode
}

// loadEnvConfig reads configuration values from the OS's environment variables.
Expand All @@ -309,6 +315,7 @@ func NewEnvConfig() (EnvConfig, error) {
setStringFromEnvVal(&creds.AccessKeyID, credAccessEnvKeys)
setStringFromEnvVal(&creds.SecretAccessKey, credSecretEnvKeys)
if creds.HasKeys() {
creds.AccountID = os.Getenv(awsAccountIDEnv)
creds.SessionToken = os.Getenv(awsSessionTokenEnvVar)
cfg.Credentials = creds
}
Expand Down Expand Up @@ -389,6 +396,10 @@ func NewEnvConfig() (EnvConfig, error) {
return cfg, err
}

if err := setAIDEndPointModeFromEnvVal(&cfg.AccountIDEndpointMode, []string{awsAccountIDEndpointModeEnv}); err != nil {
return cfg, err
}

return cfg, nil
}

Expand Down Expand Up @@ -417,6 +428,10 @@ func (c EnvConfig) getRequestMinCompressSizeBytes(context.Context) (int64, bool,
return *c.RequestMinCompressSizeBytes, true, nil
}

func (c EnvConfig) getAccountIDEndpointMode(context.Context) (aws.AccountIDEndpointMode, bool, error) {
return c.AccountIDEndpointMode, len(c.AccountIDEndpointMode) > 0, nil
}

// GetRetryMaxAttempts returns the value of AWS_MAX_ATTEMPTS if was specified,
// and not 0.
func (c EnvConfig) GetRetryMaxAttempts(ctx context.Context) (int, bool, error) {
Expand Down Expand Up @@ -491,6 +506,28 @@ func setEC2IMDSEndpointMode(mode *imds.EndpointModeState, keys []string) error {
return nil
}

func setAIDEndPointModeFromEnvVal(m *aws.AccountIDEndpointMode, keys []string) error {
for _, k := range keys {
value := os.Getenv(k)
if len(value) == 0 {
continue
}

switch value {
case "preferred":
*m = aws.AccountIDEndpointModePreferred
case "required":
*m = aws.AccountIDEndpointModeRequired
case "disabled":
*m = aws.AccountIDEndpointModeDisabled
default:
return fmt.Errorf("invalid value for environment variable, %s=%s, must be preferred/required/disabled", k, value)
}
break
}
return nil
}

// GetRegion returns the AWS Region if set in the environment. Returns an empty
// string if not set.
func (c EnvConfig) getRegion(ctx context.Context) (string, bool, error) {
Expand Down
27 changes: 27 additions & 0 deletions config/env_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ func TestNewEnvConfig_Creds(t *testing.T) {
Source: CredentialsSourceName,
},
},
{
Env: map[string]string{
"AWS_ACCESS_KEY_ID": "AKID",
"AWS_SECRET_ACCESS_KEY": "SECRET",
"AWS_ACCOUNT_ID": "012345678901",
},
Val: aws.Credentials{
AccessKeyID: "AKID", SecretAccessKey: "SECRET", AccountID: "012345678901",
Source: CredentialsSourceName,
},
},
}

for i, c := range cases {
Expand Down Expand Up @@ -496,6 +507,22 @@ func TestNewEnvConfig(t *testing.T) {
},
WantErr: true,
},
46: {
Env: map[string]string{
"AWS_ACCOUNT_ID_ENDPOINT_MODE": "required",
},
Config: EnvConfig{
AccountIDEndpointMode: aws.AccountIDEndpointModeRequired,
},
WantErr: false,
},
47: {
Env: map[string]string{
"AWS_ACCOUNT_ID_ENDPOINT_MODE": "blabla",
},
Config: EnvConfig{},
WantErr: true,
},
}

for i, c := range cases {
Expand Down
Loading

0 comments on commit 374440d

Please sign in to comment.