Skip to content

Commit

Permalink
s
Browse files Browse the repository at this point in the history
  • Loading branch information
louisruch committed Nov 23, 2024
1 parent cc9b2e9 commit 17dae9d
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 56 deletions.
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/hashicorp/boundary-plugin-aws

go 1.21
go 1.23.3

require (
github.com/aws/aws-sdk-go-v2 v1.32.5
Expand All @@ -13,7 +13,7 @@ require (
github.com/google/uuid v1.4.0
github.com/hashicorp/boundary/sdk v0.0.43-0.20240717182311-a20aae98794a
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.0
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.1-0.20241123041515-c99d3d7bba9a
github.com/hashicorp/terraform-json v0.14.0
github.com/mitchellh/mapstructure v1.5.0
github.com/stretchr/testify v1.8.4
Expand Down Expand Up @@ -43,7 +43,7 @@ require (
github.com/hashicorp/eventlogger v0.2.6-0.20231025104552-802587e608f0 // indirect
github.com/hashicorp/eventlogger/filters/encrypt v0.1.8-0.20231025104552-802587e608f0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-hclog v1.5.0 // indirect
github.com/hashicorp/go-hclog v1.6.3 // indirect
github.com/hashicorp/go-kms-wrapping/v2 v2.0.14 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/hashicorp/go-version v1.5.0 // indirect
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ github.com/hashicorp/eventlogger/filters/encrypt v0.1.8-0.20231025104552-802587e
github.com/hashicorp/eventlogger/filters/encrypt v0.1.8-0.20231025104552-802587e608f0/go.mod h1:tMywUTIvdB/FXhwm6HMTt61C8/eODY6gitCHhXtyojg=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c=
github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-kms-wrapping/plugin/v2 v2.0.5 h1:jrnDfQm2hCQ0/hEselgqzV4fK16gpZoY0OWGZpVPNHM=
github.com/hashicorp/go-kms-wrapping/plugin/v2 v2.0.5/go.mod h1:psh1qKep5ukvuNobFY/hCybuudlkkACpmazOsCgX5Rg=
github.com/hashicorp/go-kms-wrapping/v2 v2.0.14 h1:1ZuhfnZgRnLK8S0KovJkoTCRIQId5pv3sDR7pG5VQBw=
Expand All @@ -94,8 +94,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hashicorp/go-plugin v1.5.2 h1:aWv8eimFqWlsEiMrYZdPYl+FdHaBJSN4AWwGWfT1G2Y=
github.com/hashicorp/go-plugin v1.5.2/go.mod h1:w1sAEES3g3PuV/RzUrgow20W2uErMly84hhD3um1WL4=
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.0 h1:ca5TSI4AgaOncPpyzLDtCGjVEtKukONpeM95vFxXCOQ=
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.0/go.mod h1:7CUvZtfTp2U0CYQCLzMtS2ngckjAZePSfwrE2aeDP1M=
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.1-0.20241123041515-c99d3d7bba9a h1:/CCTVDc9Q8GGIB/cCImf+x3XtL7qySfogGAPY8XVseA=
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.1-0.20241123041515-c99d3d7bba9a/go.mod h1:OeRwM2eWNW62L1Z+8GvoZM5nQJMRWBewHSoo77qmb4Y=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
Expand Down
115 changes: 79 additions & 36 deletions plugin/service/storage/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,17 @@ func New() *StoragePlugin {
}
}

func (p *StoragePlugin) deleteClient(storageBucketId string) {
p.clients.Lock()
defer p.clients.Unlock()
delete(p.clients.cache, storageBucketId)
}

// GetClient returns an S3API client for the given storage bucket id.
func (p *StoragePlugin) GetClient(ctx context.Context, storageBucketId string, storageState *awsStoragePersistedState, opts ...s3Option) (S3API, error) {
if storageBucketId == "" {
// No storage bucket ID to key cache on, create and return client
func (p *StoragePlugin) getClient(ctx context.Context, storageBucketId string, storageState *awsStoragePersistedState, opts ...s3Option) (S3API, error) {
roleArn := storageState.CredentialsConfig.RoleARN
if storageBucketId == "" || roleArn == "" {
// No storage bucket ID to key cache on or not a roleArn cred type
client, err := storageState.S3Client(ctx, opts...)
if err != nil {
return nil, errors.BadRequestStatus("error creating S3 client: %s", err)
Expand All @@ -78,11 +85,13 @@ func (p *StoragePlugin) GetClient(ctx context.Context, storageBucketId string, s
p.clients.RLock()
client, ok := p.clients.cache[storageBucketId]
p.clients.RUnlock()
if !ok {

if !ok || client.Credentials().Expired() || client.RoleArn() != roleArn {
p.clients.Lock()
// Check again in case another caller created it since we got lock

// Check again in case another caller updated it since we got the lock
client, ok = p.clients.cache[storageBucketId]
if ok {
if ok && !client.Credentials().Expired() && client.RoleArn() == roleArn {
p.clients.Unlock()
return client, nil
}
Expand All @@ -98,6 +107,39 @@ func (p *StoragePlugin) GetClient(ctx context.Context, storageBucketId string, s
return client, nil
}

type s3Caller func(client S3API) (any, error)

func (p *StoragePlugin) call(
ctx context.Context,
fn s3Caller,
storageBucketId string,
storageState *awsStoragePersistedState,
opts ...s3Option) (any, error) {

var attempt int
for {
s3Client, err := p.getClient(ctx, storageBucketId, storageState, opts...)
if err != nil {
return nil, errors.BadRequestStatus("error getting S3 client: %s", err)
}
resp, err := fn(s3Client)
if err != nil {
st, _ := errors.ParseAWSError("", err)
if st.Code() == codes.PermissionDenied {
// We got a permission error, if this is the first time re-create client otherwise return error
if attempt != 0 || storageState.CredentialsConfig.RoleARN == "" {
return nil, err
}
attempt++
p.deleteClient(storageBucketId)
continue
}
return nil, err
}
return resp, nil
}
}

// OnCreateStorageBucket is called when a storage bucket is created.
func (p *StoragePlugin) OnCreateStorageBucket(ctx context.Context, req *pb.OnCreateStorageBucketRequest) (*pb.OnCreateStorageBucketResponse, error) {
bucket := req.GetBucket()
Expand Down Expand Up @@ -400,19 +442,20 @@ func (p *StoragePlugin) HeadObject(ctx context.Context, req *pb.HeadObjectReques
if storageAttributes.EndpointUrl != "" {
opts = append(opts, WithEndpoint(storageAttributes.EndpointUrl))
}
s3Client, err := storageState.S3Client(ctx, opts...)
if err != nil {
return nil, errors.BadRequestStatus("error getting S3 client: %s", err)
}

objectKey := path.Join(bucket.GetBucketPrefix(), req.GetKey())
resp, err := s3Client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(bucket.GetBucketName()),
Key: aws.String(objectKey),
})
headCall := func(s3Client S3API) (any, error) {
return s3Client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(bucket.GetBucketName()),
Key: aws.String(objectKey),
})
}
headResp, err := p.call(ctx, headCall, bucket.GetId(), storageState, opts...)
if err != nil {
return nil, parseS3Error("head object", err, req).Err()
}
resp := headResp.(*s3.HeadObjectOutput)

return &pb.HeadObjectResponse{
ContentLength: aws.ToInt64(resp.ContentLength),
LastModified: timestamppb.New(*resp.LastModified),
Expand Down Expand Up @@ -526,19 +569,19 @@ func (p *StoragePlugin) GetObject(req *pb.GetObjectRequest, stream pb.StoragePlu
if storageAttributes.EndpointUrl != "" {
opts = append(opts, WithEndpoint(storageAttributes.EndpointUrl))
}
s3Client, err := storageState.S3Client(stream.Context(), opts...)
if err != nil {
return errors.BadRequestStatus("error getting S3 client: %s", err)
}

objectKey := path.Join(bucket.GetBucketPrefix(), req.GetKey())
resp, err := s3Client.GetObject(stream.Context(), &s3.GetObjectInput{
Bucket: aws.String(bucket.GetBucketName()),
Key: aws.String(objectKey),
})
getCall := func(s3Client S3API) (any, error) {
return s3Client.GetObject(stream.Context(), &s3.GetObjectInput{
Bucket: aws.String(bucket.GetBucketName()),
Key: aws.String(objectKey),
})
}
getResp, err := p.call(stream.Context(), getCall, bucket.GetId(), storageState, opts...)
if err != nil {
return parseS3Error("get object", err, req).Err()
}
resp := getResp.(*s3.GetObjectOutput)

defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
Expand Down Expand Up @@ -638,11 +681,6 @@ func (p *StoragePlugin) PutObject(ctx context.Context, req *pb.PutObjectRequest)
opts = append(opts, WithEndpoint(storageAttributes.EndpointUrl))
}

s3Client, err := p.GetClient(ctx, bucket.GetId(), storageState, opts...)
if err != nil {
return nil, errors.BadRequestStatus("error getting S3 client: %s", err)
}

hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return nil, errors.UnknownStatus("failed to calcualte hash")
Expand All @@ -653,18 +691,23 @@ func (p *StoragePlugin) PutObject(ctx context.Context, req *pb.PutObjectRequest)
if err != nil {
return nil, errors.UnknownStatus("failed to rewind file pointer")
}

objectKey := path.Join(bucket.GetBucketPrefix(), req.GetKey())
resp, err := s3Client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucket.GetBucketName()),
Key: aws.String(objectKey),
Body: file,
ChecksumAlgorithm: types.ChecksumAlgorithmSha256,
ChecksumSHA256: aws.String(checksum),
})

putCall := func(s3Client S3API) (any, error) {
return s3Client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucket.GetBucketName()),
Key: aws.String(objectKey),
Body: file,
ChecksumAlgorithm: types.ChecksumAlgorithmSha256,
ChecksumSHA256: aws.String(checksum),
})
}
putResp, err := p.call(ctx, putCall, bucket.GetId(), storageState, opts...)
if err != nil {
return nil, parseS3Error("put object", err, req).Err()
}
resp := putResp.(*s3.PutObjectOutput)

if resp.ChecksumSHA256 == nil {
return nil, errors.UnknownStatus("missing checksum response from aws")
}
Expand Down Expand Up @@ -733,7 +776,7 @@ func (p *StoragePlugin) DeleteObjects(ctx context.Context, req *pb.DeleteObjects
if storageAttributes.EndpointUrl != "" {
opts = append(opts, WithEndpoint(storageAttributes.EndpointUrl))
}
client, err := storageState.S3Client(ctx, opts...)
client, err := p.getClient(ctx, bucket.GetId(), storageState, opts...)
if err != nil {
return nil, errors.BadRequestStatus("error getting S3 client: %s", err)
}
Expand Down
37 changes: 24 additions & 13 deletions plugin/service/storage/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
Expand All @@ -27,17 +24,22 @@ type S3API interface {
PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error)
DeleteObjects(ctx context.Context, params *s3.DeleteObjectsInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectsOutput, error)
ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error)
Credentials() aws.Credentials
RoleArn() string
}

var customClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
},
type s3Client struct {
*s3.Client
creds aws.Credentials
arn string
}

func (c *s3Client) Credentials() aws.Credentials {
return c.creds
}

func (c *s3Client) RoleArn() string {
return c.arn
}

type awsStoragePersistedState struct {
Expand Down Expand Up @@ -114,7 +116,16 @@ func (s *awsStoragePersistedState) S3Client(ctx context.Context, opt ...s3Option
if s.testS3APIFunc != nil {
return s.testS3APIFunc(*awsCfg)
}
return s3.NewFromConfig(*awsCfg, s3Opts...), nil

// Retrieve the credentials provider from the client
credsProvider := awsCfg.Credentials
creds, err := credsProvider.Retrieve(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve credentials: %v\n", err)
}

c := s3.NewFromConfig(*awsCfg, s3Opts...)
return &s3Client{Client: c, creds: creds, arn: s.CredentialsConfig.RoleARN}, nil
}

// getOpts iterates the inbound s3Options and returns a struct
Expand Down

0 comments on commit 17dae9d

Please sign in to comment.