diff --git a/cluster-autoscaler/cloudprovider/aws/aws_manager.go b/cluster-autoscaler/cloudprovider/aws/aws_manager.go index f5928b88029..32210f5440a 100644 --- a/cluster-autoscaler/cloudprovider/aws/aws_manager.go +++ b/cluster-autoscaler/cloudprovider/aws/aws_manager.go @@ -67,6 +67,16 @@ type asgTemplate struct { Tags []*autoscaling.TagDescription } +type awsSDKProvider struct { + cfg provider_aws.AwsCloudConfigProvider +} + +func newAWSSDKProvider(cfg *provider_aws.CloudConfig) *awsSDKProvider { + return &awsSDKProvider{ + cfg: cfg, + } +} + // getRegion deduces the current AWS Region. func getRegion(cfg ...*aws.Config) string { region, present := os.LookupEnv("AWS_REGION") @@ -93,16 +103,22 @@ func createAWSManagerInternal( autoScalingService *autoScalingWrapper, ec2Service *ec2Wrapper, ) (*AwsManager, error) { - if configReader != nil { - var cfg provider_aws.CloudConfig - if err := gcfg.ReadInto(&cfg, configReader); err != nil { - klog.Errorf("Couldn't read config: %v", err) - return nil, err - } + + cfg, err := readAWSCloudConfig(configReader) + if err != nil { + klog.Errorf("Couldn't read config: %v", err) + return nil, err + } + + if err = cfg.ValidateOverrides(); err != nil { + klog.Errorf("Unable to validate custom endpoint overrides: %v", err) + return nil, err } if autoScalingService == nil || ec2Service == nil { - sess := session.New(aws.NewConfig().WithRegion(getRegion())) + awsSdkProvider := newAWSSDKProvider(cfg) + sess := session.New(aws.NewConfig().WithRegion(getRegion()). + WithEndpointResolver(awsSdkProvider.cfg.GetResolver())) if autoScalingService == nil { autoScalingService = &autoScalingWrapper{autoscaling.New(sess)} @@ -136,6 +152,21 @@ func createAWSManagerInternal( return manager, nil } +// readAWSCloudConfig reads an instance of AWSCloudConfig from config reader. +func readAWSCloudConfig(config io.Reader) (*provider_aws.CloudConfig, error) { + var cfg provider_aws.CloudConfig + var err error + + if config != nil { + err = gcfg.ReadInto(&cfg, config) + if err != nil { + return nil, err + } + } + + return &cfg, nil +} + // CreateAwsManager constructs awsManager object. func CreateAwsManager(configReader io.Reader, discoveryOpts cloudprovider.NodeGroupDiscoveryOptions) (*AwsManager, error) { return createAWSManagerInternal(configReader, discoveryOpts, nil, nil) diff --git a/cluster-autoscaler/cloudprovider/aws/aws_manager_test.go b/cluster-autoscaler/cloudprovider/aws/aws_manager_test.go index 0f3f775a145..a13a67341c7 100644 --- a/cluster-autoscaler/cloudprovider/aws/aws_manager_test.go +++ b/cluster-autoscaler/cloudprovider/aws/aws_manager_test.go @@ -18,7 +18,7 @@ package aws import ( "fmt" - + "io" "net/http" "net/http/httptest" "os" @@ -35,6 +35,7 @@ import ( apiv1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider" + provider_aws "k8s.io/kubernetes/pkg/cloudprovider/providers/aws" kubeletapis "k8s.io/kubernetes/pkg/kubelet/apis" ) @@ -416,6 +417,276 @@ func TestFetchAutoAsgs(t *testing.T) { assert.Empty(t, m.asgCache.Get()) } +type ServiceDescriptor struct { + name string + region string + signingRegion, signingMethod string + signingName string +} + +func TestOverridesActiveConfig(t *testing.T) { + tests := []struct { + name string + + reader io.Reader + aws provider_aws.Services + + expectError bool + active bool + servicesOverridden []ServiceDescriptor + }{ + { + "No overrides", + strings.NewReader(` + [global] + `), + nil, + false, false, + []ServiceDescriptor{}, + }, + { + "Missing Service Name", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Region=sregion + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing Service Region", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service=s3 + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing URL", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service="s3" + Region=sregion + SigningRegion=sregion + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing Signing Region", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service=s3 + Region=sregion + URL=https://s3.foo.bar + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Active Overrides", + strings.NewReader(` + [Global] + [ServiceOverride "1"] + Service = "s3 " + Region = sregion + URL = https://s3.foo.bar + SigningRegion = sregion + SigningMethod = v4 + `), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "sregion", signingRegion: "sregion", signingMethod: "v4"}}, + }, + { + "Multiple Overridden Services", + strings.NewReader(` + [Global] + vpc = vpc-abc1234567 + [ServiceOverride "1"] + Service=s3 + Region=sregion1 + URL=https://s3.foo.bar + SigningRegion=sregion1 + SigningMethod = v4 + [ServiceOverride "2"] + Service=ec2 + Region=sregion2 + URL=https://ec2.foo.bar + SigningRegion=sregion2 + SigningMethod = v4`), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "sregion1", signingRegion: "sregion1", signingMethod: "v4"}, + {name: "ec2", region: "sregion2", signingRegion: "sregion2", signingMethod: "v4"}}, + }, + { + "Duplicate Services", + strings.NewReader(` + [Global] + vpc = vpc-abc1234567 + [ServiceOverride "1"] + Service=s3 + Region=sregion1 + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign + [ServiceOverride "2"] + Service=s3 + Region=sregion1 + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign`), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Multiple Overridden Services in Multiple regions", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service=s3 + Region=region1 + URL=https://s3.foo.bar + SigningRegion=sregion1 + [ServiceOverride "2"] + Service=ec2 + Region=region2 + URL=https://ec2.foo.bar + SigningRegion=sregion + SigningMethod = v4 + `), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: ""}, + {name: "ec2", region: "region2", signingRegion: "sregion", signingMethod: "v4"}}, + }, + { + "Multiple regions, Same Service", + strings.NewReader(` + [global] + [ServiceOverride "1"] + Service=s3 + Region=region1 + URL=https://s3.foo.bar + SigningRegion=sregion1 + SigningMethod = v3 + [ServiceOverride "2"] + Service=s3 + Region=region2 + URL=https://s3.foo.bar + SigningRegion=sregion1 + SigningMethod = v4 + SigningName = "name" + `), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: "v3"}, + {name: "s3", region: "region2", signingRegion: "sregion1", signingMethod: "v4", signingName: "name"}}, + }, + } + + for _, test := range tests { + t.Logf("Running test case %s", test.name) + cfg, err := readAWSCloudConfig(test.reader) + if err == nil { + err = cfg.ValidateOverrides() + } + if test.expectError { + if err == nil { + t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg) + } + } else { + if err != nil { + t.Errorf("Should succeed for case: %s, got %v", test.name, err) + } + + if len(cfg.ServiceOverride) != len(test.servicesOverridden) { + t.Errorf("Expected %d overridden services, received %d for case %s", + len(test.servicesOverridden), len(cfg.ServiceOverride), test.name) + } else { + for _, sd := range test.servicesOverridden { + var found *struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + } + for _, v := range cfg.ServiceOverride { + if v.Service == sd.name && v.Region == sd.region { + found = v + break + } + } + if found == nil { + t.Errorf("Missing override for service %s in case %s", + sd.name, test.name) + } else { + if found.SigningRegion != sd.signingRegion { + t.Errorf("Expected signing region '%s', received '%s' for case %s", + sd.signingRegion, found.SigningRegion, test.name) + } + if found.SigningMethod != sd.signingMethod { + t.Errorf("Expected signing method '%s', received '%s' for case %s", + sd.signingMethod, found.SigningRegion, test.name) + } + targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) + if found.URL != targetName { + t.Errorf("Expected Endpoint '%s', received '%s' for case %s", + targetName, found.URL, test.name) + } + if found.SigningName != sd.signingName { + t.Errorf("Expected signing name '%s', received '%s' for case %s", + sd.signingName, found.SigningName, test.name) + } + + fn := cfg.GetResolver() + ep1, e := fn(sd.name, sd.region, nil) + if e != nil { + t.Errorf("Expected a valid endpoint for %s in case %s", + sd.name, test.name) + } else { + targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) + if ep1.URL != targetName { + t.Errorf("Expected endpoint url: %s, received %s in case %s", + targetName, ep1.URL, test.name) + } + if ep1.SigningRegion != sd.signingRegion { + t.Errorf("Expected signing region '%s', received '%s' in case %s", + sd.signingRegion, ep1.SigningRegion, test.name) + } + if ep1.SigningMethod != sd.signingMethod { + t.Errorf("Expected signing method '%s', received '%s' in case %s", + sd.signingMethod, ep1.SigningRegion, test.name) + } + } + } + } + } + } + } +} + func tagsMatcher(expected *autoscaling.DescribeTagsInput) func(*autoscaling.DescribeTagsInput) bool { return func(actual *autoscaling.DescribeTagsInput) bool { expectedTags := flatTagSlice(expected.Filters) diff --git a/cluster-autoscaler/vendor/k8s.io/kubernetes/pkg/cloudprovider/providers/aws/aws.go b/cluster-autoscaler/vendor/k8s.io/kubernetes/pkg/cloudprovider/providers/aws/aws.go index da7787285a5..a915caba1b9 100644 --- a/cluster-autoscaler/vendor/k8s.io/kubernetes/pkg/cloudprovider/providers/aws/aws.go +++ b/cluster-autoscaler/vendor/k8s.io/kubernetes/pkg/cloudprovider/providers/aws/aws.go @@ -595,7 +595,7 @@ type CloudConfig struct { } } -func (cfg *CloudConfig) validateOverrides() error { +func (cfg *CloudConfig) ValidateOverrides() error { if len(cfg.ServiceOverride) == 0 { return nil } @@ -633,7 +633,7 @@ func (cfg *CloudConfig) validateOverrides() error { return nil } -func (cfg *CloudConfig) getResolver() endpoints.ResolverFunc { +func (cfg *CloudConfig) GetResolver() endpoints.ResolverFunc { defaultResolver := endpoints.DefaultResolver() defaultResolverFn := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { @@ -665,13 +665,13 @@ type awsSdkEC2 struct { } // Interface to make the CloudConfig immutable for awsSDKProvider -type awsCloudConfigProvider interface { - getResolver() endpoints.ResolverFunc +type AwsCloudConfigProvider interface { + GetResolver() endpoints.ResolverFunc } type awsSDKProvider struct { creds *credentials.Credentials - cfg awsCloudConfigProvider + cfg AwsCloudConfigProvider mutex sync.Mutex regionDelayers map[string]*CrossRequestRetryDelay @@ -751,7 +751,7 @@ func (p *awsSDKProvider) Compute(regionName string) (EC2, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) + WithEndpointResolver(p.cfg.GetResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -773,7 +773,7 @@ func (p *awsSDKProvider) LoadBalancing(regionName string) (ELB, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) + WithEndpointResolver(p.cfg.GetResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -791,7 +791,7 @@ func (p *awsSDKProvider) LoadBalancingV2(regionName string) (ELBV2, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) + WithEndpointResolver(p.cfg.GetResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -810,7 +810,7 @@ func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) + WithEndpointResolver(p.cfg.GetResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -825,7 +825,7 @@ func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { func (p *awsSDKProvider) Metadata() (EC2Metadata, error) { sess, err := session.NewSession(&aws.Config{ - EndpointResolver: p.cfg.getResolver(), + EndpointResolver: p.cfg.GetResolver(), }) if err != nil { return nil, fmt.Errorf("unable to initialize AWS session: %v", err) @@ -841,7 +841,7 @@ func (p *awsSDKProvider) KeyManagement(regionName string) (KMS, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(p.cfg.getResolver()) + WithEndpointResolver(p.cfg.GetResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -1066,7 +1066,7 @@ func init() { return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) } - if err = cfg.validateOverrides(); err != nil { + if err = cfg.ValidateOverrides(); err != nil { return nil, fmt.Errorf("unable to validate custom endpoint overrides: %v", err) }