Skip to content

Commit

Permalink
[v16] [access-graph] extract S3 bucket tags (#45551)
Browse files Browse the repository at this point in the history
* [access-graph] extract S3 bucket tags

This PR extracts the AWS bucket tags by calling `GetBucketTagging`.

Part of gravitational/access-graph#898

Signed-off-by: Tiago Silva <[email protected]>

* add tests

* handle comments

---------

Signed-off-by: Tiago Silva <[email protected]>
  • Loading branch information
tigrato authored Aug 26, 2024
1 parent ac46592 commit bdbc8c6
Show file tree
Hide file tree
Showing 8 changed files with 654 additions and 316 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ The IAM policy includes the following directives:
"s3:GetBucketPolicy",
"s3:ListBucket",
"s3:GetBucketLocation",
"s3:GetBucketTagging",

"iam:ListUsers",
"iam:GetUser",
Expand Down
637 changes: 325 additions & 312 deletions gen/proto/go/accessgraph/v1alpha/aws.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions lib/cloud/aws/policy_statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ func StatementAccessGraphAWSSync() *Statement {
"s3:GetBucketPolicy",
"s3:ListBucket",
"s3:GetBucketLocation",
"s3:GetBucketTagging",
// IAM IAM
"iam:ListUsers",
"iam:GetUser",
Expand Down
96 changes: 96 additions & 0 deletions lib/cloud/mocks/aws_s3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package mocks

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/gravitational/trace"
)

// S3Mock mocks AWS S3 API.
type S3Mock struct {
s3iface.S3API
Buckets []*s3.Bucket
BucketPolicy map[string]string
BucketPolicyStatus map[string]*s3.PolicyStatus
BucketACL map[string][]*s3.Grant
BucketTags map[string][]*s3.Tag
}

func (m *S3Mock) ListBucketsWithContext(_ aws.Context, _ *s3.ListBucketsInput, _ ...request.Option) (*s3.ListBucketsOutput, error) {
return &s3.ListBucketsOutput{
Buckets: m.Buckets,
}, nil
}

func (m *S3Mock) GetBucketPolicyWithContext(_ aws.Context, input *s3.GetBucketPolicyInput, _ ...request.Option) (*s3.GetBucketPolicyOutput, error) {
if aws.StringValue(input.Bucket) == "" {
return nil, trace.BadParameter("incorrect bucket name")
}
policy, ok := m.BucketPolicy[aws.StringValue(input.Bucket)]
if !ok {
return nil, trace.NotFound("bucket %v not found", aws.StringValue(input.Bucket))
}
return &s3.GetBucketPolicyOutput{
Policy: aws.String(policy),
}, nil
}

func (m *S3Mock) GetBucketPolicyStatusWithContext(_ aws.Context, input *s3.GetBucketPolicyStatusInput, _ ...request.Option) (*s3.GetBucketPolicyStatusOutput, error) {
if aws.StringValue(input.Bucket) == "" {
return nil, trace.BadParameter("incorrect bucket name")
}
policyStatus, ok := m.BucketPolicyStatus[aws.StringValue(input.Bucket)]
if !ok {
return nil, trace.NotFound("bucket %v not found", aws.StringValue(input.Bucket))
}
return &s3.GetBucketPolicyStatusOutput{
PolicyStatus: policyStatus,
}, nil
}

func (m *S3Mock) GetBucketAclWithContext(_ aws.Context, input *s3.GetBucketAclInput, _ ...request.Option) (*s3.GetBucketAclOutput, error) {
if aws.StringValue(input.Bucket) == "" {
return nil, trace.BadParameter("incorrect bucket name")
}
grants, ok := m.BucketACL[aws.StringValue(input.Bucket)]
if !ok {
return nil, trace.NotFound("bucket %v not found", aws.StringValue(input.Bucket))
}
return &s3.GetBucketAclOutput{
Grants: grants,
}, nil
}

func (m *S3Mock) GetBucketTaggingWithContext(_ aws.Context, input *s3.GetBucketTaggingInput, _ ...request.Option) (*s3.GetBucketTaggingOutput, error) {
if aws.StringValue(input.Bucket) == "" {
return nil, trace.BadParameter("incorrect bucket name")
}
tags, ok := m.BucketTags[aws.StringValue(input.Bucket)]
if !ok {
return nil, awserr.New("NoSuchTagSet", "The specified bucket does not have a tag set", nil)
}
return &s3.GetBucketTaggingOutput{
TagSet: tags,
}, nil
}
1 change: 0 additions & 1 deletion lib/srv/discovery/access_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ func (s *Server) initializeAndWatchAccessGraph(ctx context.Context, reloadCh <-c
SemaphoreKind: types.KindAccessGraph,
SemaphoreName: semaphoreName,
MaxLeases: 1,
Expires: s.clock.Now().Add(semaphoreExpiration),
Holder: s.Config.ServerID,
},
Expiry: semaphoreExpiration,
Expand Down
32 changes: 29 additions & 3 deletions lib/srv/discovery/fetchers/aws-sync/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ package aws_sync

import (
"context"
"errors"
"sync"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/gravitational/trace"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -83,7 +85,9 @@ func (a *awsFetcher) fetchS3Buckets(ctx context.Context) ([]*accessgraphv1alpha.
ctx,
&s3.ListBucketsInput{},
)

if err != nil {
return nil, trace.Wrap(err)
}
for _, bucket := range rsp.Buckets {
bucket := bucket
eG.Go(func() error {
Expand All @@ -107,8 +111,22 @@ func (a *awsFetcher) fetchS3Buckets(ctx context.Context) ([]*accessgraphv1alpha.
if err != nil {
collect(nil, trace.Wrap(err, "failed to fetch bucket %q acls policies", aws.ToString(bucket.Name)))
}

tagsOutput, err := s3Client.GetBucketTaggingWithContext(ctx, &s3.GetBucketTaggingInput{
Bucket: bucket.Name,
})
var awsErr awserr.Error
const noSuchTagSet = "NoSuchTagSet" // error code when there are no tags or the bucket does not support them
if errors.As(err, &awsErr) && awsErr.Code() == noSuchTagSet {
// If there are no tags, set the error to nil.
err = nil
}
if err != nil {
collect(nil, trace.Wrap(err, "failed to fetch bucket %q tags", aws.ToString(bucket.Name)))
}

collect(
awsS3Bucket(aws.ToString(bucket.Name), policy, policyStatus, acls, a.AccountID),
awsS3Bucket(aws.ToString(bucket.Name), policy, policyStatus, acls, tagsOutput, a.AccountID),
nil)
return nil
})
Expand All @@ -119,7 +137,7 @@ func (a *awsFetcher) fetchS3Buckets(ctx context.Context) ([]*accessgraphv1alpha.
return s3s, trace.Wrap(err)
}

func awsS3Bucket(name string, policy *s3.GetBucketPolicyOutput, policyStatus *s3.GetBucketPolicyStatusOutput, acls *s3.GetBucketAclOutput, accountID string) *accessgraphv1alpha.AWSS3BucketV1 {
func awsS3Bucket(name string, policy *s3.GetBucketPolicyOutput, policyStatus *s3.GetBucketPolicyStatusOutput, acls *s3.GetBucketAclOutput, tags *s3.GetBucketTaggingOutput, accountID string) *accessgraphv1alpha.AWSS3BucketV1 {
s3 := &accessgraphv1alpha.AWSS3BucketV1{
Name: name,
AccountId: accountID,
Expand All @@ -133,6 +151,14 @@ func awsS3Bucket(name string, policy *s3.GetBucketPolicyOutput, policyStatus *s3
if acls != nil {
s3.Acls = awsACLsToProtoACLs(acls.Grants)
}
if tags != nil {
for _, tag := range tags.TagSet {
s3.Tags = append(s3.Tags, &accessgraphv1alpha.AWSTag{
Key: aws.ToString(tag.Key),
Value: strPtrToWrapper(tag.Value),
})
}
}
return s3
}

Expand Down
200 changes: 200 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/s3_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package aws_sync

import (
"context"
"sort"
"sync"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"

accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
)

func TestPollAWSS3(t *testing.T) {
sortSlice := func(buckets []*accessgraphv1alpha.AWSS3BucketV1) {
sort.Slice(buckets, func(i, j int) bool {
return buckets[i].Name < buckets[j].Name
})
}

const (
accountID = "12345678"
bucketName = "bucket1"
otherBucketName = "bucket2"
)
var (
regions = []string{"eu-west-1"}
mockedClients = &cloud.TestCloudClients{
S3: &mocks.S3Mock{
Buckets: s3Buckets(bucketName, otherBucketName),
BucketPolicy: map[string]string{
bucketName: "policy",
otherBucketName: "otherPolicy",
},
BucketPolicyStatus: map[string]*s3.PolicyStatus{
bucketName: {
IsPublic: aws.Bool(true),
},
otherBucketName: {
IsPublic: aws.Bool(false),
},
},
BucketACL: map[string][]*s3.Grant{
bucketName: {
{
Grantee: &s3.Grantee{
ID: aws.String("id"),
},
Permission: aws.String("READ"),
},
},
otherBucketName: {
{
Grantee: &s3.Grantee{
ID: aws.String("id"),
},
Permission: aws.String("READ"),
},
},
},
BucketTags: map[string][]*s3.Tag{
bucketName: {
{
Key: aws.String("tag"),
Value: aws.String("val"),
},
},
},
},
}
)

tests := []struct {
name string
want *Resources
}{
{
name: "poll s3",
want: &Resources{
S3Buckets: []*accessgraphv1alpha.AWSS3BucketV1{
{
Name: bucketName,
AccountId: accountID,
PolicyDocument: []byte("policy"),
IsPublic: true,
Acls: []*accessgraphv1alpha.AWSS3BucketACL{
{
Grantee: &accessgraphv1alpha.AWSS3BucketACLGrantee{
Id: "id",
},
Permission: "READ",
},
},
Tags: []*accessgraphv1alpha.AWSTag{
{
Key: "tag",
Value: strPtrToWrapper(aws.String("val")),
},
},
},
{
Name: otherBucketName,
AccountId: accountID,
PolicyDocument: []byte("otherPolicy"),
IsPublic: false,
Acls: []*accessgraphv1alpha.AWSS3BucketACL{
{
Grantee: &accessgraphv1alpha.AWSS3BucketACLGrantee{
Id: "id",
},
Permission: "READ",
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

var (
errs []error
mu sync.Mutex
)

collectErr := func(err error) {
mu.Lock()
defer mu.Unlock()
errs = append(errs, err)
}
a := &awsFetcher{
Config: Config{
AccountID: accountID,
CloudClients: mockedClients,
Regions: regions,
Integration: accountID,
},
}
result := &Resources{}
execFunc := a.pollAWSS3Buckets(context.Background(), result, collectErr)
require.NoError(t, execFunc())
require.NoError(t, trace.NewAggregate(errs...))

sortSlice(tt.want.S3Buckets)
sortSlice(result.S3Buckets)
require.Empty(t, cmp.Diff(
tt.want,
result,
protocmp.Transform(),
// tags originate from a map so we must sort them before comparing.
protocmp.SortRepeated(
func(a, b *accessgraphv1alpha.AWSTag) bool {
return a.Key < b.Key
},
),
),
)

})
}
}

func s3Buckets(bucketNames ...string) []*s3.Bucket {
var output []*s3.Bucket
for _, name := range bucketNames {
output = append(output, &s3.Bucket{
Name: aws.String(name),
CreationDate: aws.Time(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)),
})

}
return output
}
Loading

0 comments on commit bdbc8c6

Please sign in to comment.