diff --git a/cluster-autoscaler/cloudprovider/aws/aws_manager.go b/cluster-autoscaler/cloudprovider/aws/aws_manager.go index f5928b88029..044b5cf5cda 100644 --- a/cluster-autoscaler/cloudprovider/aws/aws_manager.go +++ b/cluster-autoscaler/cloudprovider/aws/aws_manager.go @@ -30,6 +30,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/ec2" @@ -67,6 +68,88 @@ type asgTemplate struct { Tags []*autoscaling.TagDescription } +// AutoscalerCloudConfig defines new type so new methods can be defined +type AutoscalerCloudConfig provider_aws.CloudConfig + +func (cfg *AutoscalerCloudConfig) validateOverrides() error { + if len(cfg.ServiceOverride) == 0 { + return nil + } + set := make(map[string]bool) + for onum, ovrd := range cfg.ServiceOverride { + // Note: gcfg does not space trim, so we have to when comparing to empty string "" + name := strings.TrimSpace(ovrd.Service) + if name == "" { + return fmt.Errorf("service name is missing [Service is \"\"] in override %s", onum) + } + // insure the map service name is space trimmed + ovrd.Service = name + + region := strings.TrimSpace(ovrd.Region) + if region == "" { + return fmt.Errorf("service region is missing [Region is \"\"] in override %s", onum) + } + // insure the map region is space trimmed + ovrd.Region = region + + url := strings.TrimSpace(ovrd.URL) + if url == "" { + return fmt.Errorf("url is missing [URL is \"\"] in override %s", onum) + } + signingRegion := strings.TrimSpace(ovrd.SigningRegion) + if signingRegion == "" { + return fmt.Errorf("signingRegion is missing [SigningRegion is \"\"] in override %s", onum) + } + signature := name + "_" + region + if set[signature] { + return fmt.Errorf("duplicate entry found for service override [%s] (%s in %s)", onum, name, region) + } + set[signature] = true + } + return nil +} + +func (cfg *AutoscalerCloudConfig) getResolver() endpoints.ResolverFunc { + defaultResolver := endpoints.DefaultResolver() + defaultResolverFn := func(service, region string, + optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + return defaultResolver.EndpointFor(service, region, optFns...) + } + if len(cfg.ServiceOverride) == 0 { + return defaultResolverFn + } + + return func(service, region string, + optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + for _, override := range cfg.ServiceOverride { + if override.Service == service && override.Region == region { + return endpoints.ResolvedEndpoint{ + URL: override.URL, + SigningRegion: override.SigningRegion, + SigningMethod: override.SigningMethod, + SigningName: override.SigningName, + }, nil + } + } + return defaultResolver.EndpointFor(service, region, optFns...) + } +} + +// Interface to make the CloudConfig immutable for awsSDKProvider +type awsCloudConfigProvider interface { + getResolver() endpoints.ResolverFunc +} + +type awsSDKProvider struct { + cfg awsCloudConfigProvider +} + +func newAWSSDKProvider(cfg *AutoscalerCloudConfig) *awsSDKProvider { + return &awsSDKProvider{ + cfg: cfg, + } +} + // getRegion deduces the current AWS Region. func getRegion(cfg ...*aws.Config) string { region, present := os.LookupEnv("AWS_REGION") @@ -93,16 +176,22 @@ func createAWSManagerInternal( autoScalingService *autoScalingWrapper, ec2Service *ec2Wrapper, ) (*AwsManager, error) { - if configReader != nil { - var cfg provider_aws.CloudConfig - if err := gcfg.ReadInto(&cfg, configReader); err != nil { - klog.Errorf("Couldn't read config: %v", err) - return nil, err - } + + cfg, err := readAWSCloudConfig(configReader) + if err != nil { + klog.Errorf("Couldn't read config: %v", err) + return nil, err + } + + if err = cfg.validateOverrides(); err != nil { + klog.Errorf("Unable to validate custom endpoint overrides: %v", err) + return nil, err } if autoScalingService == nil || ec2Service == nil { - sess := session.New(aws.NewConfig().WithRegion(getRegion())) + awsSdkProvider := newAWSSDKProvider(cfg) + sess := session.New(aws.NewConfig().WithRegion(getRegion()). + WithEndpointResolver(awsSdkProvider.cfg.getResolver())) if autoScalingService == nil { autoScalingService = &autoScalingWrapper{autoscaling.New(sess)} @@ -136,6 +225,21 @@ func createAWSManagerInternal( return manager, nil } +// readAWSCloudConfig reads an instance of AWSCloudConfig from config reader. +func readAWSCloudConfig(config io.Reader) (*AutoscalerCloudConfig, error) { + var cfg AutoscalerCloudConfig + var err error + + if config != nil { + err = gcfg.ReadInto(&cfg, config) + if err != nil { + return nil, err + } + } + + return &cfg, nil +} + // CreateAwsManager constructs awsManager object. func CreateAwsManager(configReader io.Reader, discoveryOpts cloudprovider.NodeGroupDiscoveryOptions) (*AwsManager, error) { return createAWSManagerInternal(configReader, discoveryOpts, nil, nil) diff --git a/cluster-autoscaler/cloudprovider/aws/aws_manager_test.go b/cluster-autoscaler/cloudprovider/aws/aws_manager_test.go index 0f3f775a145..85c18ed9ddb 100644 --- a/cluster-autoscaler/cloudprovider/aws/aws_manager_test.go +++ b/cluster-autoscaler/cloudprovider/aws/aws_manager_test.go @@ -18,7 +18,7 @@ package aws import ( "fmt" - + "io" "net/http" "net/http/httptest" "os" @@ -35,6 +35,7 @@ import ( apiv1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider" + provider_aws "k8s.io/kubernetes/pkg/cloudprovider/providers/aws" kubeletapis "k8s.io/kubernetes/pkg/kubelet/apis" ) @@ -416,6 +417,276 @@ func TestFetchAutoAsgs(t *testing.T) { assert.Empty(t, m.asgCache.Get()) } +type ServiceDescriptor struct { + name string + region string + signingRegion, signingMethod string + signingName string +} + +func TestOverridesActiveConfig(t *testing.T) { + tests := []struct { + name string + + reader io.Reader + aws provider_aws.Services + + expectError bool + active bool + servicesOverridden []ServiceDescriptor + }{ + { + "No overrides", + strings.NewReader(` + [global] + `), + nil, + false, false, + []ServiceDescriptor{}, + }, + { + "Missing Service Name", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Region=sregion + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing Service Region", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service=s3 + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing URL", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service="s3" + Region=sregion + SigningRegion=sregion + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing Signing Region", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service=s3 + Region=sregion + URL=https://s3.foo.bar + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Active Overrides", + strings.NewReader(` + [Global] + [ServiceOverride "1"] + Service = "s3 " + Region = sregion + URL = https://s3.foo.bar + SigningRegion = sregion + SigningMethod = v4 + `), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "sregion", signingRegion: "sregion", signingMethod: "v4"}}, + }, + { + "Multiple Overridden Services", + strings.NewReader(` + [Global] + vpc = vpc-abc1234567 + [ServiceOverride "1"] + Service=s3 + Region=sregion1 + URL=https://s3.foo.bar + SigningRegion=sregion1 + SigningMethod = v4 + [ServiceOverride "2"] + Service=ec2 + Region=sregion2 + URL=https://ec2.foo.bar + SigningRegion=sregion2 + SigningMethod = v4`), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "sregion1", signingRegion: "sregion1", signingMethod: "v4"}, + {name: "ec2", region: "sregion2", signingRegion: "sregion2", signingMethod: "v4"}}, + }, + { + "Duplicate Services", + strings.NewReader(` + [Global] + vpc = vpc-abc1234567 + [ServiceOverride "1"] + Service=s3 + Region=sregion1 + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign + [ServiceOverride "2"] + Service=s3 + Region=sregion1 + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign`), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Multiple Overridden Services in Multiple regions", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service=s3 + Region=region1 + URL=https://s3.foo.bar + SigningRegion=sregion1 + [ServiceOverride "2"] + Service=ec2 + Region=region2 + URL=https://ec2.foo.bar + SigningRegion=sregion + SigningMethod = v4 + `), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: ""}, + {name: "ec2", region: "region2", signingRegion: "sregion", signingMethod: "v4"}}, + }, + { + "Multiple regions, Same Service", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service=s3 + Region=region1 + URL=https://s3.foo.bar + SigningRegion=sregion1 + SigningMethod = v3 + [ServiceOverride "2"] + Service=s3 + Region=region2 + URL=https://s3.foo.bar + SigningRegion=sregion1 + SigningMethod = v4 + SigningName = "name" + `), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: "v3"}, + {name: "s3", region: "region2", signingRegion: "sregion1", signingMethod: "v4", signingName: "name"}}, + }, + } + + for _, test := range tests { + t.Logf("Running test case %s", test.name) + cfg, err := readAWSCloudConfig(test.reader) + if err == nil { + err = cfg.validateOverrides() + } + if test.expectError { + if err == nil { + t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg) + } + } else { + if err != nil { + t.Errorf("Should succeed for case: %s, got %v", test.name, err) + } + + if len(cfg.ServiceOverride) != len(test.servicesOverridden) { + t.Errorf("Expected %d overridden services, received %d for case %s", + len(test.servicesOverridden), len(cfg.ServiceOverride), test.name) + } else { + for _, sd := range test.servicesOverridden { + var found *struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + } + for _, v := range cfg.ServiceOverride { + if v.Service == sd.name && v.Region == sd.region { + found = v + break + } + } + if found == nil { + t.Errorf("Missing override for service %s in case %s", + sd.name, test.name) + } else { + if found.SigningRegion != sd.signingRegion { + t.Errorf("Expected signing region '%s', received '%s' for case %s", + sd.signingRegion, found.SigningRegion, test.name) + } + if found.SigningMethod != sd.signingMethod { + t.Errorf("Expected signing method '%s', received '%s' for case %s", + sd.signingMethod, found.SigningRegion, test.name) + } + targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) + if found.URL != targetName { + t.Errorf("Expected Endpoint '%s', received '%s' for case %s", + targetName, found.URL, test.name) + } + if found.SigningName != sd.signingName { + t.Errorf("Expected signing name '%s', received '%s' for case %s", + sd.signingName, found.SigningName, test.name) + } + + fn := cfg.getResolver() + ep1, e := fn(sd.name, sd.region, nil) + if e != nil { + t.Errorf("Expected a valid endpoint for %s in case %s", + sd.name, test.name) + } else { + targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) + if ep1.URL != targetName { + t.Errorf("Expected endpoint url: %s, received %s in case %s", + targetName, ep1.URL, test.name) + } + if ep1.SigningRegion != sd.signingRegion { + t.Errorf("Expected signing region '%s', received '%s' in case %s", + sd.signingRegion, ep1.SigningRegion, test.name) + } + if ep1.SigningMethod != sd.signingMethod { + t.Errorf("Expected signing method '%s', received '%s' in case %s", + sd.signingMethod, ep1.SigningRegion, test.name) + } + } + } + } + } + } + } +} + func tagsMatcher(expected *autoscaling.DescribeTagsInput) func(*autoscaling.DescribeTagsInput) bool { return func(actual *autoscaling.DescribeTagsInput) bool { expectedTags := flatTagSlice(expected.Filters)