Skip to content

Commit

Permalink
Fix race condition in endpoint discovery (#4180)
Browse files Browse the repository at this point in the history
  • Loading branch information
skmcgrail authored Nov 23, 2021
1 parent cd7ead6 commit 35b0995
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
### SDK Enhancements

### SDK Bugs
* `aws/crr`: Fixed a race condition that caused concurrent calls relying on endpoint discovery to share the same `url.URL` reference in their operation's `http.Request`.
5 changes: 4 additions & 1 deletion aws/crr/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ func (c *EndpointCache) get(endpointKey string) (Endpoint, bool) {
return Endpoint{}, false
}

c.endpoints.Store(endpointKey, endpoint)
ev := endpoint.(Endpoint)
ev.Prune()

c.endpoints.Store(endpointKey, ev)
return endpoint.(Endpoint), true
}

Expand Down
40 changes: 40 additions & 0 deletions aws/crr/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net/url"
"reflect"
"testing"
"time"
)

func urlParse(uri string) *url.URL {
Expand Down Expand Up @@ -450,3 +451,42 @@ func TestCacheGet(t *testing.T) {
}
}
}

func TestEndpointCache_Get_prune(t *testing.T) {
c := NewEndpointCache(2)
c.Add(Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: &url.URL{
Host: "foo.amazonaws.com",
},
Expired: time.Now().Add(5 * time.Minute),
},
{
URL: &url.URL{
Host: "bar.amazonaws.com",
},
Expired: time.Now().Add(5 * -time.Minute),
},
},
})

load, _ := c.endpoints.Load("foo")
if ev := load.(Endpoint); len(ev.Addresses) != 2 {
t.Errorf("expected two weighted addresses")
}

weightedAddress, err := c.Get(nil, "foo", false)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "foo.amazonaws.com", weightedAddress.URL.Host; e != a {
t.Errorf("expect %v, got %v", e, a)
}

load, _ = c.endpoints.Load("foo")
if ev := load.(Endpoint); len(ev.Addresses) != 1 {
t.Errorf("expected one weighted address")
}
}
33 changes: 33 additions & 0 deletions aws/crr/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,32 @@ func (e *Endpoint) GetValidAddress() (WeightedAddress, bool) {
continue
}

we.URL = cloneURL(we.URL)

return we, true
}

return WeightedAddress{}, false
}

// Prune will prune the expired addresses from the endpoint by allocating a new []WeightAddress.
// This is not concurrent safe, and should be called from a single owning thread.
func (e *Endpoint) Prune() bool {
validLen := e.Len()
if validLen == len(e.Addresses) {
return false
}
wa := make([]WeightedAddress, 0, validLen)
for i := range e.Addresses {
if e.Addresses[i].HasExpired() {
continue
}
wa = append(wa, e.Addresses[i])
}
e.Addresses = wa
return true
}

// Discoverer is an interface used to discovery which endpoint hit. This
// allows for specifics about what parameters need to be used to be contained
// in the Discoverer implementor.
Expand Down Expand Up @@ -97,3 +117,16 @@ func BuildEndpointKey(params map[string]*string) string {

return strings.Join(values, ".")
}

func cloneURL(u *url.URL) (clone *url.URL) {
clone = &url.URL{}

*clone = *u

if u.User != nil {
user := *u.User
clone.User = &user
}

return clone
}
126 changes: 126 additions & 0 deletions aws/crr/endpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
//go:build go1.16
// +build go1.16

package crr

import (
"net/url"
"reflect"
"strconv"
"testing"
"time"
)

func Test_cloneURL(t *testing.T) {
tests := []struct {
value *url.URL
wantClone *url.URL
}{
{
value: &url.URL{
Scheme: "https",
Opaque: "foo",
User: nil,
Host: "amazonaws.com",
Path: "/",
RawPath: "/",
ForceQuery: true,
RawQuery: "thing=value",
Fragment: "1234",
RawFragment: "1234",
},
wantClone: &url.URL{
Scheme: "https",
Opaque: "foo",
User: nil,
Host: "amazonaws.com",
Path: "/",
RawPath: "/",
ForceQuery: true,
RawQuery: "thing=value",
Fragment: "1234",
RawFragment: "1234",
},
},
{
value: &url.URL{
Scheme: "https",
Opaque: "foo",
User: url.UserPassword("NOT", "VALID"),
Host: "amazonaws.com",
Path: "/",
RawPath: "/",
ForceQuery: true,
RawQuery: "thing=value",
Fragment: "1234",
RawFragment: "1234",
},
wantClone: &url.URL{
Scheme: "https",
Opaque: "foo",
User: url.UserPassword("NOT", "VALID"),
Host: "amazonaws.com",
Path: "/",
RawPath: "/",
ForceQuery: true,
RawQuery: "thing=value",
Fragment: "1234",
RawFragment: "1234",
},
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
gotClone := cloneURL(tt.value)
if gotClone == tt.value {
t.Errorf("expct clone URL to not be same pointer address")
}
if tt.value.User != nil {
if tt.value.User == gotClone.User {
t.Errorf("expct cloned Userinfo to not be same pointer address")
}
}
if !reflect.DeepEqual(gotClone, tt.wantClone) {
t.Errorf("cloneURL() = %v, want %v", gotClone, tt.wantClone)
}
})
}
}

func TestEndpoint_Prune(t *testing.T) {
endpoint := Endpoint{}

endpoint.Add(WeightedAddress{
URL: &url.URL{},
Expired: time.Now().Add(5 * time.Minute),
})

initial := endpoint.Addresses

if e, a := false, endpoint.Prune(); e != a {
t.Errorf("expect prune %v, got %v", e, a)
}

if e, a := &initial[0], &endpoint.Addresses[0]; e != a {
t.Errorf("expect slice address to be same")
}

endpoint.Add(WeightedAddress{
URL: &url.URL{},
Expired: time.Now().Add(5 * -time.Minute),
})

initial = endpoint.Addresses

if e, a := true, endpoint.Prune(); e != a {
t.Errorf("expect prune %v, got %v", e, a)
}

if e, a := &initial[0], &endpoint.Addresses[0]; e == a {
t.Errorf("expect slice address to be different")
}

if e, a := 1, endpoint.Len(); e != a {
t.Errorf("expect slice length %v, got %v", e, a)
}
}

0 comments on commit 35b0995

Please sign in to comment.