Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[access-graph] extract S3 bucket tags #45364

Merged
merged 3 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ The IAM policy includes the following directives:
"s3:GetBucketPolicy",
"s3:ListBucket",
"s3:GetBucketLocation",
"s3:GetBucketTagging",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changelog should probably mention something about this new requirement, otherwise self-hosted users won't know they need to update their permissions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is yet being used but added a note to changelog.


"iam:ListUsers",
"iam:GetUser",
Expand Down
637 changes: 325 additions & 312 deletions gen/proto/go/accessgraph/v1alpha/aws.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions lib/cloud/aws/policy_statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ func StatementAccessGraphAWSSync() *Statement {
"s3:GetBucketPolicy",
"s3:ListBucket",
"s3:GetBucketLocation",
"s3:GetBucketTagging",
// IAM IAM
"iam:ListUsers",
"iam:GetUser",
Expand Down
96 changes: 96 additions & 0 deletions lib/cloud/mocks/aws_s3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package mocks

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/gravitational/trace"
)

// S3Mock mocks AWS S3 API.
type S3Mock struct {
s3iface.S3API
Buckets []*s3.Bucket
BucketPolicy map[string]string
BucketPolicyStatus map[string]*s3.PolicyStatus
BucketACL map[string][]*s3.Grant
BucketTags map[string][]*s3.Tag
}

func (m *S3Mock) ListBucketsWithContext(_ aws.Context, _ *s3.ListBucketsInput, _ ...request.Option) (*s3.ListBucketsOutput, error) {
return &s3.ListBucketsOutput{
Buckets: m.Buckets,
}, nil
}

func (m *S3Mock) GetBucketPolicyWithContext(_ aws.Context, input *s3.GetBucketPolicyInput, _ ...request.Option) (*s3.GetBucketPolicyOutput, error) {
if aws.StringValue(input.Bucket) == "" {
return nil, trace.BadParameter("incorrect bucket name")
}
policy, ok := m.BucketPolicy[aws.StringValue(input.Bucket)]
if !ok {
return nil, trace.NotFound("bucket %v not found", aws.StringValue(input.Bucket))
}
return &s3.GetBucketPolicyOutput{
Policy: aws.String(policy),
}, nil
}

func (m *S3Mock) GetBucketPolicyStatusWithContext(_ aws.Context, input *s3.GetBucketPolicyStatusInput, _ ...request.Option) (*s3.GetBucketPolicyStatusOutput, error) {
if aws.StringValue(input.Bucket) == "" {
return nil, trace.BadParameter("incorrect bucket name")
}
policyStatus, ok := m.BucketPolicyStatus[aws.StringValue(input.Bucket)]
if !ok {
return nil, trace.NotFound("bucket %v not found", aws.StringValue(input.Bucket))
}
return &s3.GetBucketPolicyStatusOutput{
PolicyStatus: policyStatus,
}, nil
}

func (m *S3Mock) GetBucketAclWithContext(_ aws.Context, input *s3.GetBucketAclInput, _ ...request.Option) (*s3.GetBucketAclOutput, error) {
if aws.StringValue(input.Bucket) == "" {
return nil, trace.BadParameter("incorrect bucket name")
}
grants, ok := m.BucketACL[aws.StringValue(input.Bucket)]
if !ok {
return nil, trace.NotFound("bucket %v not found", aws.StringValue(input.Bucket))
}
return &s3.GetBucketAclOutput{
Grants: grants,
}, nil
}

func (m *S3Mock) GetBucketTaggingWithContext(_ aws.Context, input *s3.GetBucketTaggingInput, _ ...request.Option) (*s3.GetBucketTaggingOutput, error) {
if aws.StringValue(input.Bucket) == "" {
return nil, trace.BadParameter("incorrect bucket name")
}
tags, ok := m.BucketTags[aws.StringValue(input.Bucket)]
if !ok {
return nil, awserr.New("NoSuchTagSet", "The specified bucket does not have a tag set", nil)
}
return &s3.GetBucketTaggingOutput{
TagSet: tags,
}, nil
}
32 changes: 29 additions & 3 deletions lib/srv/discovery/fetchers/aws-sync/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ package aws_sync

import (
"context"
"errors"
"sync"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/gravitational/trace"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -83,7 +85,9 @@ func (a *awsFetcher) fetchS3Buckets(ctx context.Context) ([]*accessgraphv1alpha.
ctx,
&s3.ListBucketsInput{},
)

if err != nil {
return nil, trace.Wrap(err)
}
for _, bucket := range rsp.Buckets {
bucket := bucket
eG.Go(func() error {
Expand All @@ -107,8 +111,22 @@ func (a *awsFetcher) fetchS3Buckets(ctx context.Context) ([]*accessgraphv1alpha.
if err != nil {
collect(nil, trace.Wrap(err, "failed to fetch bucket %q acls policies", aws.ToString(bucket.Name)))
}

tagsOutput, err := s3Client.GetBucketTaggingWithContext(ctx, &s3.GetBucketTaggingInput{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in f0dcf3c

Bucket: bucket.Name,
})
var awsErr awserr.Error
const noSuchTagSet = "NoSuchTagSet" // error code when there are no tags or the bucket does not support them
if errors.As(err, &awsErr) && awsErr.Code() == noSuchTagSet {
// If there are no tags, set the error to nil.
err = nil
}
if err != nil {
collect(nil, trace.Wrap(err, "failed to fetch bucket %q tags", aws.ToString(bucket.Name)))
}

collect(
awsS3Bucket(aws.ToString(bucket.Name), policy, policyStatus, acls, a.AccountID),
awsS3Bucket(aws.ToString(bucket.Name), policy, policyStatus, acls, tagsOutput, a.AccountID),
nil)
return nil
})
Expand All @@ -119,7 +137,7 @@ func (a *awsFetcher) fetchS3Buckets(ctx context.Context) ([]*accessgraphv1alpha.
return s3s, trace.Wrap(err)
}

func awsS3Bucket(name string, policy *s3.GetBucketPolicyOutput, policyStatus *s3.GetBucketPolicyStatusOutput, acls *s3.GetBucketAclOutput, accountID string) *accessgraphv1alpha.AWSS3BucketV1 {
func awsS3Bucket(name string, policy *s3.GetBucketPolicyOutput, policyStatus *s3.GetBucketPolicyStatusOutput, acls *s3.GetBucketAclOutput, tags *s3.GetBucketTaggingOutput, accountID string) *accessgraphv1alpha.AWSS3BucketV1 {
s3 := &accessgraphv1alpha.AWSS3BucketV1{
Name: name,
AccountId: accountID,
Expand All @@ -133,6 +151,14 @@ func awsS3Bucket(name string, policy *s3.GetBucketPolicyOutput, policyStatus *s3
if acls != nil {
s3.Acls = awsACLsToProtoACLs(acls.Grants)
}
if tags != nil {
for _, tag := range tags.TagSet {
s3.Tags = append(s3.Tags, &accessgraphv1alpha.AWSTag{
Key: aws.ToString(tag.Key),
Value: strPtrToWrapper(tag.Value),
})
}
}
return s3
}

Expand Down
198 changes: 198 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/s3_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package aws_sync

import (
"context"
"sort"
"sync"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"

accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
)

func TestPollAWSS3(t *testing.T) {
const (
accountID = "12345678"
bucketName = "bucket1"
otherBucketName = "bucket2"
)
var (
regions = []string{"eu-west-1"}
)

tests := []struct {
name string
want *Resources
}{
{
name: "poll s3",
want: &Resources{
S3Buckets: []*accessgraphv1alpha.AWSS3BucketV1{
{
Name: bucketName,
AccountId: accountID,
PolicyDocument: []byte("policy"),
IsPublic: true,
Acls: []*accessgraphv1alpha.AWSS3BucketACL{
{
Grantee: &accessgraphv1alpha.AWSS3BucketACLGrantee{
Id: "id",
},
Permission: "READ",
},
},
Tags: []*accessgraphv1alpha.AWSTag{
{
Key: "tag",
Value: strPtrToWrapper(aws.String("val")),
},
},
},
{
Name: otherBucketName,
AccountId: accountID,
PolicyDocument: []byte("otherPolicy"),
IsPublic: false,
Acls: []*accessgraphv1alpha.AWSS3BucketACL{
{
Grantee: &accessgraphv1alpha.AWSS3BucketACLGrantee{
Id: "id",
},
Permission: "READ",
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockedClients := &cloud.TestCloudClients{
S3: &mocks.S3Mock{
Buckets: s3Buckets(bucketName, otherBucketName),
BucketPolicy: map[string]string{
bucketName: "policy",
otherBucketName: "otherPolicy",
},
BucketPolicyStatus: map[string]*s3.PolicyStatus{
bucketName: {
IsPublic: aws.Bool(true),
},
otherBucketName: {
IsPublic: aws.Bool(false),
},
},
BucketACL: map[string][]*s3.Grant{
bucketName: {
{
Grantee: &s3.Grantee{
ID: aws.String("id"),
},
Permission: aws.String("READ"),
},
},
otherBucketName: {
{
Grantee: &s3.Grantee{
ID: aws.String("id"),
},
Permission: aws.String("READ"),
},
},
},
BucketTags: map[string][]*s3.Tag{
bucketName: {
{
Key: aws.String("tag"),
Value: aws.String("val"),
},
},
},
},
}
tigrato marked this conversation as resolved.
Show resolved Hide resolved

var (
errs []error
mu sync.Mutex
)

collectErr := func(err error) {
mu.Lock()
defer mu.Unlock()
errs = append(errs, err)
}
a := &awsFetcher{
Config: Config{
AccountID: accountID,
CloudClients: mockedClients,
Regions: regions,
Integration: accountID,
},
}
result := &Resources{}
execFunc := a.pollAWSS3Buckets(context.Background(), result, collectErr)
require.NoError(t, execFunc())
require.NoError(t, trace.NewAggregate(errs...))

sort.Slice(result.S3Buckets, func(i, j int) bool {
return result.S3Buckets[i].Name < result.S3Buckets[j].Name
})
sort.Slice(tt.want.S3Buckets, func(i, j int) bool {
return result.S3Buckets[i].Name < result.S3Buckets[j].Name
tigrato marked this conversation as resolved.
Show resolved Hide resolved
})
require.Empty(t, cmp.Diff(
tt.want,
result,
protocmp.Transform(),
// tags originate from a map so we must sort them before comparing.
protocmp.SortRepeated(
func(a, b *accessgraphv1alpha.AWSTag) bool {
return a.Key < b.Key
},
),
),
)

})
}
}

func s3Buckets(bucketNames ...string) []*s3.Bucket {
var output []*s3.Bucket
for _, name := range bucketNames {
output = append(output, &s3.Bucket{
Name: aws.String(name),
CreationDate: aws.Time(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)),
})

}
return output
}
2 changes: 2 additions & 0 deletions proto/accessgraph/v1alpha/aws.proto
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ message AWSS3BucketV1 {
bool is_public = 4;
// acl is the ACL of the S3 bucket.
repeated AWSS3BucketACL acls = 5;
// tags is the list of tags that are attached to the S3 bucket.
repeated AWSTag tags = 6;
}

// AWSS3BucketACL is the ACL of an AWS S3 bucket.
Expand Down
Loading