Skip to content

Commit

Permalink
[chore] - minor cleanup S3 source (#3554)
Browse files Browse the repository at this point in the history
* cleanup s3 source

* revert
  • Loading branch information
ahrav authored Nov 4, 2024
1 parent 741146f commit 6923843
Showing 1 changed file with 77 additions and 55 deletions.
132 changes: 77 additions & 55 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ import (
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/go-errors/errors"
"github.com/go-logr/logr"
"github.com/trufflesecurity/trufflehog/v3/pkg/log"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/handlers"
"github.com/trufflesecurity/trufflehog/v3/pkg/log"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sanitizer"
Expand All @@ -40,16 +39,17 @@ const (

type Source struct {
name string
sourceId sources.SourceID
jobId sources.JobID
sourceID sources.SourceID
jobID sources.JobID
verify bool
concurrency int
log logr.Logger
sources.Progress
conn *sourcespb.S3

errorCount *sync.Map
conn *sourcespb.S3
jobPool *errgroup.Group
maxObjectSize int64

sources.CommonSourceUnitUnmarshaller
}

Expand All @@ -59,43 +59,41 @@ var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.Validator = (*Source)(nil)

// Type returns the type of source
func (s *Source) Type() sourcespb.SourceType {
return SourceType
}
func (s *Source) Type() sourcespb.SourceType { return SourceType }

func (s *Source) SourceID() sources.SourceID {
return s.sourceId
}
func (s *Source) SourceID() sources.SourceID { return s.sourceID }

func (s *Source) JobID() sources.JobID {
return s.jobId
}
func (s *Source) JobID() sources.JobID { return s.jobID }

// Init returns an initialized AWS source
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
s.log = context.WithValues(aCtx, "source", s.Type(), "name", name).Logger()

func (s *Source) Init(
_ context.Context,
name string,
jobID sources.JobID,
sourceID sources.SourceID,
verify bool,
connection *anypb.Any,
concurrency int,
) error {
s.name = name
s.sourceId = sourceId
s.jobId = jobId
s.sourceID = sourceID
s.jobID = jobID
s.verify = verify
s.concurrency = concurrency
s.errorCount = &sync.Map{}
s.log = aCtx.Logger()
s.jobPool = &errgroup.Group{}
s.jobPool.SetLimit(concurrency)

var conn sourcespb.S3
err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
if err != nil {
return errors.WrapPrefix(err, "error unmarshalling connection", 0)
if err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}); err != nil {
return fmt.Errorf("error unmarshalling connection: %w", err)
}
s.conn = &conn

s.setMaxObjectSize(conn.GetMaxObjectSize())

if len(conn.Buckets) > 0 && len(conn.IgnoreBuckets) > 0 {
return fmt.Errorf("either a bucket include list or a bucket ignore list can be specified, but not both")
if len(conn.GetBuckets()) > 0 && len(conn.GetIgnoreBuckets()) > 0 {
return errors.New("either a bucket include list or a bucket ignore list can be specified, but not both")
}

return nil
Expand All @@ -110,8 +108,7 @@ func (s *Source) Validate(ctx context.Context) []error {
}
}

err := s.visitRoles(ctx, visitor)
if err != nil {
if err := s.visitRoles(ctx, visitor); err != nil {
errs = append(errs, err)
}

Expand All @@ -136,11 +133,15 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) {

switch cred := s.conn.GetCredential().(type) {
case *sourcespb.S3_SessionToken:
cfg.Credentials = credentials.NewStaticCredentials(cred.SessionToken.Key, cred.SessionToken.Secret, cred.SessionToken.SessionToken)
cfg.Credentials = credentials.NewStaticCredentials(
cred.SessionToken.GetKey(),
cred.SessionToken.GetSecret(),
cred.SessionToken.GetSessionToken(),
)
log.RedactGlobally(cred.SessionToken.GetSecret())
log.RedactGlobally(cred.SessionToken.GetSessionToken())
case *sourcespb.S3_AccessKey:
cfg.Credentials = credentials.NewStaticCredentials(cred.AccessKey.Key, cred.AccessKey.Secret, "")
cfg.Credentials = credentials.NewStaticCredentials(cred.AccessKey.GetKey(), cred.AccessKey.GetSecret(), "")
log.RedactGlobally(cred.AccessKey.GetSecret())
case *sourcespb.S3_Unauthenticated:
cfg.Credentials = credentials.AnonymousCredentials
Expand Down Expand Up @@ -174,12 +175,12 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) {

// IAM identity needs s3:ListBuckets permission
func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
if len(s.conn.Buckets) > 0 {
return s.conn.Buckets, nil
if buckets := s.conn.GetBuckets(); len(buckets) > 0 {
return buckets, nil
}

ignore := make(map[string]struct{}, len(s.conn.IgnoreBuckets))
for _, bucket := range s.conn.IgnoreBuckets {
ignore := make(map[string]struct{}, len(s.conn.GetIgnoreBuckets()))
for _, bucket := range s.conn.GetIgnoreBuckets() {
ignore[bucket] = struct{}{}
}

Expand All @@ -198,53 +199,62 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
return bucketsToScan, nil
}

func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk) {
objectCount := uint64(0)
func (s *Source) scanBuckets(
ctx context.Context,
client *s3.S3,
role string,
bucketsToScan []string,
chunksChan chan *sources.Chunk,
) {
var objectCount uint64

logger := s.log
if role != "" {
logger = logger.WithValues("roleArn", role)
ctx = context.WithValue(ctx, "role", role)
}

for i, bucket := range bucketsToScan {
logger := logger.WithValues("bucket", bucket)
ctx := context.WithValue(ctx, "bucket", bucket)

if common.IsDone(ctx) {
return
}

s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
logger.V(3).Info("Scanning bucket")
ctx.Logger().V(3).Info("Scanning bucket")

regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
logger.Error(err, "could not get regional client for bucket")
ctx.Logger().Error(err, "could not get regional client for bucket")
continue
}

errorCount := sync.Map{}

err = regionalClient.ListObjectsV2PagesWithContext(
ctx, &s3.ListObjectsV2Input{Bucket: &bucket},
func(page *s3.ListObjectsV2Output, last bool) bool {
func(page *s3.ListObjectsV2Output, _ bool) bool {
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount)
return true
})

if err != nil {
if role == "" {
logger.Error(err, "could not list objects in bucket")
ctx.Logger().Error(err, "could not list objects in bucket")
} else {
// Our documentation blesses specifying a role to assume without specifying buckets to scan, which will
// often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the
// account, but the role probably doesn't have access to all of them). This makes it expected behavior
// and therefore not an error.
logger.V(3).Info("could not list objects in bucket",
"err", err)
ctx.Logger().V(3).Info("could not list objects in bucket", "err", err)
}
}
}
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "")
s.SetProgressComplete(
len(bucketsToScan),
len(bucketsToScan),
fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount),
"",
)
}

// Chunks emits chunks of bytes over a channel.
Expand All @@ -256,10 +266,15 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
return s.visitRoles(ctx, visitor)
}

func (s *Source) getRegionalClientForBucket(ctx context.Context, defaultRegionClient *s3.S3, role, bucket string) (*s3.S3, error) {
func (s *Source) getRegionalClientForBucket(
ctx context.Context,
defaultRegionClient *s3.S3,
role string,
bucket string,
) (*s3.S3, error) {
region, err := s3manager.GetBucketRegionWithClient(ctx, defaultRegionClient, bucket)
if err != nil {
return nil, errors.WrapPrefix(err, "could not get s3 region for bucket", 0)
return nil, fmt.Errorf("could not get s3 region for bucket: %s", bucket)
}

if region == defaultAWSRegion {
Expand All @@ -268,7 +283,7 @@ func (s *Source) getRegionalClientForBucket(ctx context.Context, defaultRegionCl

regionalClient, err := s.newClient(region, role)
if err != nil {
return nil, errors.WrapPrefix(err, "could not create regional s3 client", 0)
return nil, fmt.Errorf("could not create regional s3 client for bucket %s: %w", bucket, err)
}

return regionalClient, nil
Expand Down Expand Up @@ -448,7 +463,6 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
}

_, err = regionalClient.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: &bucket})

if err == nil {
wasAbleToListAnyBucket = true
} else if shouldHaveAccessToAllBuckets {
Expand All @@ -458,7 +472,7 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr

if !wasAbleToListAnyBucket {
if roleArn == "" {
errs = append(errs, fmt.Errorf("could not list objects in any bucket"))
errs = append(errs, errors.New("could not list objects in any bucket"))
} else {
errs = append(errs, fmt.Errorf("role %q could not list objects in any bucket", roleArn))
}
Expand All @@ -467,16 +481,24 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
return errs
}

func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string)) error {
roles := s.conn.Roles
// visitRoles iterates over the configured AWS roles and calls the provided function
// for each role, passing in the default S3 client, the role ARN, and the list of
// buckets to scan.
//
// If no roles are configured, it will call the function with an empty role ARN.
func (s *Source) visitRoles(
ctx context.Context,
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string),
) error {
roles := s.conn.GetRoles()
if len(roles) == 0 {
roles = []string{""}
}

for _, role := range roles {
client, err := s.newClient(defaultAWSRegion, role)
if err != nil {
return errors.WrapPrefix(err, "could not create s3 client", 0)
return fmt.Errorf("could not create s3 client: %w", err)
}

bucketsToScan, err := s.getBucketsToScan(client)
Expand All @@ -493,7 +515,7 @@ func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defau
// S3 links currently have the general format of:
// https://[bucket].s3[.region unless us-east-1].amazonaws.com/[key]
func makeS3Link(bucket, region, key string) string {
if region == "us-east-1" {
if region == defaultAWSRegion {
region = ""
} else {
region = "." + region
Expand Down

0 comments on commit 6923843

Please sign in to comment.