Skip to content

Commit

Permalink
Fix EC2 Presigned URL customization (#4808)
Browse files Browse the repository at this point in the history
In some cases, the Presigned URL was NOT being sent when calling EC2 CopySnapshot.

This PR fixes the customization responsible for generating the Presigned URL.
  • Loading branch information
Steven Yuan authored Apr 20, 2023
1 parent b1a6e5f commit 42260bb
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 5 deletions.
22 changes: 17 additions & 5 deletions service/ec2/customizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
)

const (
// ec2CopySnapshotPresignedUrlCustomization handler name
ec2CopySnapshotPresignedUrlCustomization = "ec2CopySnapshotPresignedUrl"

// customRetryerMinRetryDelay sets min retry delay
customRetryerMinRetryDelay = 1 * time.Second

Expand All @@ -21,7 +24,10 @@ const (
func init() {
initRequest = func(r *request.Request) {
if r.Operation.Name == opCopySnapshot { // fill the PresignedURL parameter
r.Handlers.Build.PushFront(fillPresignedURL)
r.Handlers.Build.PushFrontNamed(request.NamedHandler{
Name: ec2CopySnapshotPresignedUrlCustomization,
Fn: fillPresignedURL,
})
}

// only set the retryer on request if config doesn't have a retryer
Expand All @@ -48,13 +54,15 @@ func fillPresignedURL(r *request.Request) {

origParams := r.Params.(*CopySnapshotInput)

// Stop if PresignedURL/DestinationRegion is set
if origParams.PresignedUrl != nil || origParams.DestinationRegion != nil {
// Stop if PresignedURL is set
if origParams.PresignedUrl != nil {
return
}

// Always use config region as destination region for SDKs
origParams.DestinationRegion = r.Config.Region
newParams := awsutil.CopyOf(r.Params).(*CopySnapshotInput)

newParams := awsutil.CopyOf(origParams).(*CopySnapshotInput)

// Create a new request based on the existing request. We will use this to
// presign the CopySnapshot request against the source region.
Expand Down Expand Up @@ -82,8 +90,12 @@ func fillPresignedURL(r *request.Request) {
clientInfo.Endpoint = resolved.URL
clientInfo.SigningRegion = resolved.SigningRegion

// Copy handlers without Presigned URL customization to avoid an infinite loop
handlersWithoutPresignCustomization := r.Handlers.Copy()
handlersWithoutPresignCustomization.Build.RemoveByName(ec2CopySnapshotPresignedUrlCustomization)

// Presign a CopySnapshot request with modified params
req := request.New(*cfg, clientInfo, r.Handlers, r.Retryer, r.Operation, newParams, r.Data)
req := request.New(*cfg, clientInfo, handlersWithoutPresignCustomization, r.Retryer, r.Operation, newParams, r.Data)
url, err := req.Presign(5 * time.Minute) // 5 minutes should be enough.
if err != nil { // bubble error back up to original request
r.Error = err
Expand Down
186 changes: 186 additions & 0 deletions service/ec2/customizations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ package ec2_test
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strconv"
"testing"

"github.com/aws/aws-sdk-go/aws"
sdkclient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/ec2"
Expand Down Expand Up @@ -55,6 +58,189 @@ func TestCopySnapshotPresignedURL(t *testing.T) {
}
}

func TestCopySnapshotPresignedURLConfig(t *testing.T) {
const (
inputKmsKeyId = "KMS_KEY_ID"
inputSnapshotId = "SNAPSHOT_ID"
clientRegion = endpoints.UsEast1RegionID
inputSourceRegion = endpoints.UsWest2RegionID
)
cases := map[string]struct {
Encrypted bool
DestinationRegion string
KmsKeyId string
}{
// Not Encrypted
"Not Encrypted": {},
// Not Encrypted with KmsKeyId
"Not Encrypted with KmsKeyId": {
KmsKeyId: inputKmsKeyId,
},
// Not Encrypted with DestinationRegion
"Not Encrypted with DestinationRegion": {
DestinationRegion: endpoints.UsEast2RegionID,
},
// Not Encrypted with KmsKeyId and DestinationRegion
"Not Encrypted with KmsKeyId and DestinationRegion": {
KmsKeyId: inputKmsKeyId,
DestinationRegion: endpoints.UsEast2RegionID,
},
// Encrypted
"Encrypted": {
Encrypted: true,
},
// Encrypted with KmsKeyId
"Encrypted with KmsKeyId": {
Encrypted: true,
KmsKeyId: inputKmsKeyId,
},
// Encrypted with DestinationRegion
"Encrypted with DestinationRegion": {
Encrypted: true,
DestinationRegion: endpoints.UsEast2RegionID,
},
// Encrypted with KmsKeyId and DestinationRegion
"Encrypted with KmsKeyId and DestinationRegion": {
Encrypted: true,
KmsKeyId: inputKmsKeyId,
DestinationRegion: endpoints.UsEast2RegionID,
},
}

for name, config := range cases {
t.Run(name, func(t *testing.T) {
t.Log(name)

// Set up new client
svc := ec2.New(unit.Session, &aws.Config{
Region: aws.String(clientRegion),
})

// Base input
input := ec2.CopySnapshotInput{
SourceRegion: aws.String(inputSourceRegion),
SourceSnapshotId: aws.String(inputSnapshotId),
}

// Add input from test case config
if config.Encrypted != false {
input.Encrypted = &config.Encrypted
}
if config.DestinationRegion != "" {
input.DestinationRegion = &config.DestinationRegion
}
if config.KmsKeyId != "" {
input.KmsKeyId = &config.KmsKeyId
}

// Execute request
req, _ := svc.CopySnapshotRequest(&input)
req.Sign()

// Parse request
body, _ := ioutil.ReadAll(req.HTTPRequest.Body)
query, _ := url.ParseQuery(string(body))

// Test Body SourceRegion
sourceRegion := query.Get("SourceRegion")
if sourceRegion == "" {
t.Errorf("SourceRegion should always be sent in the request")
}
if sourceRegion != inputSourceRegion {
t.Errorf("SourceRegion should be `%v`, but found `%v`", inputSourceRegion, sourceRegion)
}
// Test Body SourceSnapshotId
sourceSnapshotId := query.Get("SourceSnapshotId")
if sourceSnapshotId == "" {
t.Errorf("SourceSnapshotId should always be sent in the request")
}
if sourceSnapshotId != inputSnapshotId {
t.Errorf("SourceSnapshotId should be `%v`, but found `%v`", inputSnapshotId, sourceSnapshotId)
}
// Test Body Encrypted
encrypted := query.Get("Encrypted")
if config.Encrypted && strconv.FormatBool(config.Encrypted) != encrypted {
t.Errorf("Encrypted should be `%v`, but found `%v`", config.Encrypted, encrypted)
}
if !config.Encrypted && encrypted != "" {
t.Errorf("Encrypted should be empty, but found `%v`", encrypted)
}
// Test Body DestinationRegion
destinationRegion := query.Get("DestinationRegion")
if destinationRegion != clientRegion {
t.Errorf("DestinationRegion should always be equal to the client region `%v`, but found `%v`", clientRegion, destinationRegion)
}
if destinationRegion == "" {
t.Errorf("DestinationRegion should never empty")
}
// Test Body KmsKeyId
kmsKeyId := query.Get("KmsKeyId")
if config.KmsKeyId != "" && config.KmsKeyId != kmsKeyId {
t.Errorf("KmsKeyId should be `%v`, but found `%v`", config.KmsKeyId, kmsKeyId)
}
if config.KmsKeyId == "" && kmsKeyId != "" {
t.Errorf("KmsKeyId should be empty, but found `%v`", kmsKeyId)
}

// Assert PresignedUrl
presignedUrl, _ := url.QueryUnescape(query.Get("PresignedUrl"))
if presignedUrl == "" {
t.Errorf("PresignedUrl should always be sent in the request")
}
// Test PresignedUrl EC2 URL
baseEc2UrlRegex := regexp.MustCompile(fmt.Sprintf(`^https://ec2\.%s\.amazonaws\.com/`, inputSourceRegion))
if !baseEc2UrlRegex.MatchString(presignedUrl) {
t.Errorf("Expected PresignedUrl to match `%v`, but found `%v`", baseEc2UrlRegex.String(), presignedUrl)
}

presignedUrlQuery, _ := url.ParseQuery(presignedUrl)
// Test PresignedUrl SourceRegion
presignedUrlSourceRegion := presignedUrlQuery.Get("SourceRegion")
if presignedUrlSourceRegion == "" {
t.Errorf("PresignedUrl SourceRegion should always be sent in the request")
}
if presignedUrlSourceRegion != inputSourceRegion {
t.Errorf("PresignedUrl SourceRegion should be `%v`, but found `%v`", inputSourceRegion, presignedUrlSourceRegion)
}
// Test PresignedUrl SourceSnapshotId
presignedUrlSourceSnapshotId := presignedUrlQuery.Get("SourceSnapshotId")
if presignedUrlSourceSnapshotId == "" {
t.Errorf("PresignedUrl SourceSnapshotId should always be sent in the request")
}
if presignedUrlSourceSnapshotId != inputSnapshotId {
t.Errorf("PresignedUrl SourceSnapshotId should be `%v`, but found `%v`", inputSnapshotId, presignedUrlSourceSnapshotId)
}
// Test PresignedUrl Encrypted
presignedUrlEncrypted := query.Get("Encrypted")
if config.Encrypted && strconv.FormatBool(config.Encrypted) != presignedUrlEncrypted {
t.Errorf("PresignedUrl Encrypted should be `%v`, but found `%v`", config.Encrypted, presignedUrlEncrypted)
}
if !config.Encrypted && presignedUrlEncrypted != "" {
t.Errorf("PresignedUrl Encrypted should be empty, but found `%v`", presignedUrlEncrypted)
}
// Test PresignedUrl DestinationRegion
presignedUrlDestinationRegion := presignedUrlQuery.Get("DestinationRegion")
if presignedUrlDestinationRegion != clientRegion {
t.Errorf("PresignedUrl DestinationRegion should always be equal to the client region `%v`, but found `%v`", clientRegion, presignedUrlDestinationRegion)
}
// Test PresignedUrl KmsKeyId
presignedUrlKmsKeyId := query.Get("KmsKeyId")
if config.KmsKeyId != "" && config.KmsKeyId != presignedUrlKmsKeyId {
t.Errorf("PresignedUrl KmsKeyId should be `%v`, but found `%v`", config.KmsKeyId, presignedUrlKmsKeyId)
}
if config.KmsKeyId == "" && presignedUrlKmsKeyId != "" {
t.Errorf("PresignedUrl KmsKeyId should be empty, but found `%v`", presignedUrlKmsKeyId)
}
// Test PresignedUrl X-Amz-Credential
presignedUrlAmzCredential := presignedUrlQuery.Get("X-Amz-Credential")
amzCredentialRegex := regexp.MustCompile(fmt.Sprintf(`^\w{4}/\d{8}/%s/ec2/aws4_request$`, inputSourceRegion))
if !amzCredentialRegex.MatchString(presignedUrlAmzCredential) {
t.Errorf("Expected PresignedUrl X-Amz-Credential to match `%v`, but found `%v`", amzCredentialRegex.String(), presignedUrlAmzCredential)
}
})
}
}

func TestNoCustomRetryerWithMaxRetries(t *testing.T) {
cases := map[string]struct {
Config aws.Config
Expand Down

0 comments on commit 42260bb

Please sign in to comment.