From f356b1b8510dc8b0c8c6648c76fa6fb3557d889f Mon Sep 17 00:00:00 2001 From: Kevin Conner Date: Sun, 1 Sep 2024 20:27:42 -0700 Subject: [PATCH] fix(aws): handle ECR repositories in different regions (#6217) Signed-off-by: Kevin Conner --- pkg/fanal/image/registry/azure/azure.go | 30 ++++---- pkg/fanal/image/registry/azure/azure_test.go | 2 +- pkg/fanal/image/registry/ecr/ecr.go | 54 +++++++++++---- pkg/fanal/image/registry/ecr/ecr_test.go | 68 +++++++++++++++++-- pkg/fanal/image/registry/google/google.go | 18 +++-- .../image/registry/google/google_test.go | 12 ++-- pkg/fanal/image/registry/intf/registry.go | 15 ++++ pkg/fanal/image/registry/token.go | 14 ++-- 8 files changed, 159 insertions(+), 54 deletions(-) create mode 100644 pkg/fanal/image/registry/intf/registry.go diff --git a/pkg/fanal/image/registry/azure/azure.go b/pkg/fanal/image/registry/azure/azure.go index 3203829f3d3d..67368c8bb3c8 100644 --- a/pkg/fanal/image/registry/azure/azure.go +++ b/pkg/fanal/image/registry/azure/azure.go @@ -14,15 +14,19 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "golang.org/x/xerrors" + "github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf" "github.com/aquasecurity/trivy/pkg/fanal/types" ) -type Registry struct { +type RegistryClient struct { domain string scope string cloud cloud.Configuration } +type Registry struct { +} + const ( azureURL = ".azurecr.io" chinaAzureURL = ".azurecr.cn" @@ -31,23 +35,25 @@ const ( scheme = "https" ) -func (r *Registry) CheckOptions(domain string, _ types.RegistryOptions) error { +func (r *Registry) CheckOptions(domain string, _ types.RegistryOptions) (intf.RegistryClient, error) { if strings.HasSuffix(domain, azureURL) { - r.domain = domain - r.scope = scope - r.cloud = cloud.AzurePublic - return nil + return &RegistryClient{ + domain: domain, + scope: scope, + cloud: cloud.AzurePublic, + }, nil } else if strings.HasSuffix(domain, chinaAzureURL) { - r.domain = domain - r.scope = chinaScope - r.cloud = cloud.AzureChina - return nil + return &RegistryClient{ + domain: domain, + scope: scope, + cloud: cloud.AzureChina, + }, nil } - return xerrors.Errorf("Azure registry: %w", types.InvalidURLPattern) + return nil, xerrors.Errorf("Azure registry: %w", types.InvalidURLPattern) } -func (r *Registry) GetCredential(ctx context.Context) (string, string, error) { +func (r *RegistryClient) GetCredential(ctx context.Context) (string, string, error) { opts := azcore.ClientOptions{Cloud: r.cloud} cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{ClientOptions: opts}) if err != nil { diff --git a/pkg/fanal/image/registry/azure/azure_test.go b/pkg/fanal/image/registry/azure/azure_test.go index 0fb4839e8fee..5be56f777a51 100644 --- a/pkg/fanal/image/registry/azure/azure_test.go +++ b/pkg/fanal/image/registry/azure/azure_test.go @@ -38,7 +38,7 @@ func TestRegistry_CheckOptions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := azure.Registry{} - err := r.CheckOptions(tt.domain, types.RegistryOptions{}) + _, err := r.CheckOptions(tt.domain, types.RegistryOptions{}) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) } else { diff --git a/pkg/fanal/image/registry/ecr/ecr.go b/pkg/fanal/image/registry/ecr/ecr.go index 261030d1a584..c9bb0dea5330 100644 --- a/pkg/fanal/image/registry/ecr/ecr.go +++ b/pkg/fanal/image/registry/ecr/ecr.go @@ -3,6 +3,7 @@ package ecr import ( "context" "encoding/base64" + "regexp" "strings" "github.com/aws/aws-sdk-go-v2/aws" @@ -11,48 +12,73 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecr" "golang.org/x/xerrors" + "github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf" "github.com/aquasecurity/trivy/pkg/fanal/types" + "github.com/aquasecurity/trivy/pkg/log" ) -const ecrURLSuffix = ".amazonaws.com" -const ecrURLPartial = ".dkr.ecr" - type ecrAPI interface { GetAuthorizationToken(ctx context.Context, params *ecr.GetAuthorizationTokenInput, optFns ...func(*ecr.Options)) (*ecr.GetAuthorizationTokenOutput, error) } type ECR struct { +} + +type ECRClient struct { Client ecrAPI } -func getSession(option types.RegistryOptions) (aws.Config, error) { +func getSession(domain, region string, option types.RegistryOptions) (aws.Config, error) { // create custom credential information if option is valid if option.AWSSecretKey != "" && option.AWSAccessKey != "" && option.AWSRegion != "" { + if region != option.AWSRegion { + log.Warnf("The region from AWS_REGION (%s) is being overridden. The region from domain (%s) was used.", option.AWSRegion, domain) + } return config.LoadDefaultConfig( context.TODO(), - config.WithRegion(option.AWSRegion), + config.WithRegion(region), config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(option.AWSAccessKey, option.AWSSecretKey, option.AWSSessionToken)), ) } - return config.LoadDefaultConfig(context.TODO()) + return config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) } -func (e *ECR) CheckOptions(domain string, option types.RegistryOptions) error { - if !strings.HasSuffix(domain, ecrURLSuffix) && !strings.Contains(domain, ecrURLPartial) { - return xerrors.Errorf("ECR : %w", types.InvalidURLPattern) +func (e *ECR) CheckOptions(domain string, option types.RegistryOptions) (intf.RegistryClient, error) { + region := determineRegion(domain) + if region == "" { + return nil, xerrors.Errorf("ECR : %w", types.InvalidURLPattern) } - cfg, err := getSession(option) + cfg, err := getSession(domain, region, option) if err != nil { - return err + return nil, err } svc := ecr.NewFromConfig(cfg) - e.Client = svc - return nil + return &ECRClient{Client: svc}, nil +} + +// Endpoints take the form +// .dkr.ecr..amazonaws.com +// .dkr.ecr-fips..amazonaws.com +// .dkr.ecr..amazonaws.com.cn +// .dkr.ecr..sc2s.sgov.gov +// .dkr.ecr..c2s.ic.gov +// see +// - https://docs.aws.amazon.com/general/latest/gr/ecr.html +// - https://docs.amazonaws.cn/en_us/aws/latest/userguide/endpoints-arns.html +// - https://github.com/boto/botocore/blob/1.34.51/botocore/data/endpoints.json +var ecrEndpointMatch = regexp.MustCompile(`^[^.]+\.dkr\.ecr(?:-fips)?\.([^.]+)\.(?:amazonaws\.com(?:\.cn)?|sc2s\.sgov\.gov|c2s\.ic\.gov)$`) + +func determineRegion(domain string) string { + matches := ecrEndpointMatch.FindStringSubmatch(domain) + if matches != nil { + return matches[1] + } + return "" } -func (e *ECR) GetCredential(ctx context.Context) (username, password string, err error) { +func (e *ECRClient) GetCredential(ctx context.Context) (username, password string, err error) { input := &ecr.GetAuthorizationTokenInput{} result, err := e.Client.GetAuthorizationToken(ctx, input) if err != nil { diff --git a/pkg/fanal/image/registry/ecr/ecr_test.go b/pkg/fanal/image/registry/ecr/ecr_test.go index 68b1870f07da..321de7f8639d 100644 --- a/pkg/fanal/image/registry/ecr/ecr_test.go +++ b/pkg/fanal/image/registry/ecr/ecr_test.go @@ -8,14 +8,20 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ecr" awstypes "github.com/aws/aws-sdk-go-v2/service/ecr/types" + "github.com/stretchr/testify/require" "github.com/aquasecurity/trivy/pkg/fanal/types" ) +type testECRClient interface { + Options() ecr.Options +} + func TestCheckOptions(t *testing.T) { var tests = map[string]struct { - domain string - wantErr error + domain string + expectedRegion string + wantErr error }{ "InvalidURL": { domain: "alpine:3.9", @@ -30,19 +36,71 @@ func TestCheckOptions(t *testing.T) { wantErr: types.InvalidURLPattern, }, "NoOption": { - domain: "xxx.ecr.ap-northeast-1.amazonaws.com", + domain: "xxx.dkr.ecr.ap-northeast-1.amazonaws.com", + expectedRegion: "ap-northeast-1", + }, + "region-1": { + domain: "xxx.dkr.ecr.region-1.amazonaws.com", + expectedRegion: "region-1", + }, + "region-2": { + domain: "xxx.dkr.ecr.region-2.amazonaws.com", + expectedRegion: "region-2", + }, + "fips-region-1": { + domain: "xxx.dkr.ecr-fips.fips-region.amazonaws.com", + expectedRegion: "fips-region", + }, + "cn-region-1": { + domain: "xxx.dkr.ecr.region-1.amazonaws.com.cn", + expectedRegion: "region-1", + }, + "cn-region-2": { + domain: "xxx.dkr.ecr.region-2.amazonaws.com.cn", + expectedRegion: "region-2", + }, + "sc2s-region-1": { + domain: "xxx.dkr.ecr.sc2s-region.sc2s.sgov.gov", + expectedRegion: "sc2s-region", + }, + "c2s-region-1": { + domain: "xxx.dkr.ecr.c2s-region.c2s.ic.gov", + expectedRegion: "c2s-region", + }, + "invalid-ecr": { + domain: "xxx.dkrecr.region-1.amazonaws.com", + wantErr: types.InvalidURLPattern, + }, + "invalid-fips": { + domain: "xxx.dkr.ecrfips.fips-region.amazonaws.com", + wantErr: types.InvalidURLPattern, + }, + "invalid-cn": { + domain: "xxx.dkr.ecr.region-2.amazonaws.cn", + wantErr: types.InvalidURLPattern, + }, + "invalid-sc2s": { + domain: "xxx.dkr.ecr.sc2s-region.sc2s.sgov", + wantErr: types.InvalidURLPattern, + }, + "invalid-cs2": { + domain: "xxx.dkr.ecr.c2s-region.c2s.ic", + wantErr: types.InvalidURLPattern, }, } for testname, v := range tests { a := &ECR{} - err := a.CheckOptions(v.domain, types.RegistryOptions{}) + ecrClient, err := a.CheckOptions(v.domain, types.RegistryOptions{}) if err != nil { if !errors.Is(err, v.wantErr) { t.Errorf("[%s]\nexpected error based on %v\nactual : %v", testname, v.wantErr, err) } continue } + + client := (ecrClient.(*ECRClient)).Client.(testECRClient) + require.Equal(t, v.expectedRegion, client.Options().Region) } } @@ -90,7 +148,7 @@ func TestECRGetCredential(t *testing.T) { } for i, c := range cases { - e := ECR{ + e := ECRClient{ Client: mockedECR{Resp: c.Resp}, } username, password, err := e.GetCredential(context.Background()) diff --git a/pkg/fanal/image/registry/google/google.go b/pkg/fanal/image/registry/google/google.go index fe52c85f7493..3c58f0de005f 100644 --- a/pkg/fanal/image/registry/google/google.go +++ b/pkg/fanal/image/registry/google/google.go @@ -9,14 +9,18 @@ import ( "github.com/GoogleCloudPlatform/docker-credential-gcr/store" "golang.org/x/xerrors" + "github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf" "github.com/aquasecurity/trivy/pkg/fanal/types" ) -type Registry struct { +type GoogleRegistryClient struct { Store store.GCRCredStore domain string } +type Registry struct { +} + // Google container registry const gcrURLDomain = "gcr.io" const gcrURLSuffix = ".gcr.io" @@ -24,18 +28,18 @@ const gcrURLSuffix = ".gcr.io" // Google artifact registry const garURLSuffix = "-docker.pkg.dev" -func (g *Registry) CheckOptions(domain string, option types.RegistryOptions) error { +func (g *Registry) CheckOptions(domain string, option types.RegistryOptions) (intf.RegistryClient, error) { if domain != gcrURLDomain && !strings.HasSuffix(domain, gcrURLSuffix) && !strings.HasSuffix(domain, garURLSuffix) { - return xerrors.Errorf("Google registry: %w", types.InvalidURLPattern) + return nil, xerrors.Errorf("Google registry: %w", types.InvalidURLPattern) } - g.domain = domain + client := GoogleRegistryClient{domain: domain} if option.GCPCredPath != "" { - g.Store = store.NewGCRCredStore(option.GCPCredPath) + client.Store = store.NewGCRCredStore(option.GCPCredPath) } - return nil + return &client, nil } -func (g *Registry) GetCredential(_ context.Context) (username, password string, err error) { +func (g *GoogleRegistryClient) GetCredential(_ context.Context) (username, password string, err error) { var credStore store.GCRCredStore if g.Store == nil { credStore, err = store.DefaultGCRCredStore() diff --git a/pkg/fanal/image/registry/google/google_test.go b/pkg/fanal/image/registry/google/google_test.go index 2e4ba64d663e..1bce72897d46 100644 --- a/pkg/fanal/image/registry/google/google_test.go +++ b/pkg/fanal/image/registry/google/google_test.go @@ -14,7 +14,7 @@ func TestCheckOptions(t *testing.T) { var tests = map[string]struct { domain string opt types.RegistryOptions - gcr *Registry + grc *GoogleRegistryClient wantErr error }{ "InvalidURL": { @@ -27,12 +27,12 @@ func TestCheckOptions(t *testing.T) { }, "NoOption": { domain: "gcr.io", - gcr: &Registry{domain: "gcr.io"}, + grc: &GoogleRegistryClient{domain: "gcr.io"}, }, "CredOption": { domain: "gcr.io", opt: types.RegistryOptions{GCPCredPath: "/path/to/file.json"}, - gcr: &Registry{ + grc: &GoogleRegistryClient{ domain: "gcr.io", Store: store.NewGCRCredStore("/path/to/file.json"), }, @@ -41,7 +41,7 @@ func TestCheckOptions(t *testing.T) { for testname, v := range tests { g := &Registry{} - err := g.CheckOptions(v.domain, v.opt) + grc, err := g.CheckOptions(v.domain, v.opt) if v.wantErr != nil { if err == nil { t.Errorf("%s : expected error but no error", testname) @@ -52,8 +52,8 @@ func TestCheckOptions(t *testing.T) { } continue } - if !reflect.DeepEqual(v.gcr, g) { - t.Errorf("[%s]\nexpected : %v\nactual : %v", testname, v.gcr, g) + if !reflect.DeepEqual(v.grc, grc) { + t.Errorf("[%s]\nexpected : %v\nactual : %v", testname, v.grc, grc) } } } diff --git a/pkg/fanal/image/registry/intf/registry.go b/pkg/fanal/image/registry/intf/registry.go new file mode 100644 index 000000000000..da8c5d3c1789 --- /dev/null +++ b/pkg/fanal/image/registry/intf/registry.go @@ -0,0 +1,15 @@ +package intf + +import ( + "context" + + "github.com/aquasecurity/trivy/pkg/fanal/types" +) + +type RegistryClient interface { + GetCredential(ctx context.Context) (string, string, error) +} + +type Registry interface { + CheckOptions(domain string, option types.RegistryOptions) (RegistryClient, error) +} diff --git a/pkg/fanal/image/registry/token.go b/pkg/fanal/image/registry/token.go index b959c6cc7bbc..72c569890196 100644 --- a/pkg/fanal/image/registry/token.go +++ b/pkg/fanal/image/registry/token.go @@ -8,12 +8,13 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/image/registry/azure" "github.com/aquasecurity/trivy/pkg/fanal/image/registry/ecr" "github.com/aquasecurity/trivy/pkg/fanal/image/registry/google" + "github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf" "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" ) var ( - registries []Registry + registries []intf.Registry ) func init() { @@ -22,23 +23,18 @@ func init() { RegisterRegistry(&azure.Registry{}) } -type Registry interface { - CheckOptions(domain string, option types.RegistryOptions) error - GetCredential(ctx context.Context) (string, string, error) -} - -func RegisterRegistry(registry Registry) { +func RegisterRegistry(registry intf.Registry) { registries = append(registries, registry) } func GetToken(ctx context.Context, domain string, opt types.RegistryOptions) (auth authn.Basic) { // check registry which particular to get credential for _, registry := range registries { - err := registry.CheckOptions(domain, opt) + client, err := registry.CheckOptions(domain, opt) if err != nil { continue } - username, password, err := registry.GetCredential(ctx) + username, password, err := client.GetCredential(ctx) if err != nil { // only skip check registry if error occurred log.Debug("Credential error", log.Err(err))