diff --git a/Makefile b/Makefile index ecd7083ee0b..b11578cea9c 100644 --- a/Makefile +++ b/Makefile @@ -63,7 +63,7 @@ misc/certs/ca-certificates.crt: docker run "amazon/amazon-ecs-agent-cert-source:make" cat /etc/ssl/certs/ca-certificates.crt > misc/certs/ca-certificates.crt test: - . ./scripts/shared_env && go test -timeout=25s -v -cover $(shell go list ./agent/... | grep -v /vendor/) + . ./scripts/shared_env && go test -race -timeout=25s -v -cover $(shell go list ./agent/... | grep -v /vendor/) benchmark-test: . ./scripts/shared_env && go test -run=XX -bench=. $(shell go list ./agent/... | grep -v /vendor/) diff --git a/agent/acs/client/acs_client_test.go b/agent/acs/client/acs_client_test.go index f763e21fce7..d2326451478 100644 --- a/agent/acs/client/acs_client_test.go +++ b/agent/acs/client/acs_client_test.go @@ -16,19 +16,21 @@ package acsclient import ( "encoding/json" "errors" + "io" "net/http" "net/http/httptest" - "reflect" - "sync" "testing" "time" "github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/wsclient" + "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/golang/mock/gomock" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" ) const sampleCredentialsMessage = ` @@ -50,146 +52,102 @@ const sampleCredentialsMessage = ` } ` -type messageLogger struct { - writes [][]byte - reads [][]byte - closed bool -} +const ( + TestClusterArn = "arn:aws:ec2:123:container/cluster:123456" + TestInstanceArn = "arn:aws:ec2:123:container/containerInstance/12345678" +) var testCfg = &config.Config{ AcceptInsecureCert: true, AWSRegion: "us-east-1", } -func (ml *messageLogger) WriteMessage(_ int, data []byte) error { - if ml.closed { - return errors.New("can't write to closed ws") - } - ml.writes = append(ml.writes, data) - return nil -} - -func (ml *messageLogger) Close() error { - ml.closed = true - return nil -} - -func (ml *messageLogger) ReadMessage() (int, []byte, error) { - for len(ml.reads) == 0 && !ml.closed { - time.Sleep(1 * time.Millisecond) - } - if ml.closed { - return 0, []byte{}, errors.New("can't read from a closed websocket") - } - read := ml.reads[len(ml.reads)-1] - ml.reads = ml.reads[0 : len(ml.reads)-1] - return websocket.TextMessage, read, nil -} +func TestMakeUnrecognizedRequest(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() -func testCS() (wsclient.ClientServer, *messageLogger) { - testCreds := credentials.AnonymousCredentials + conn := mock_wsclient.NewMockWebsocketConn(ctrl) + conn.EXPECT().Close() - cs := New("localhost:443", testCfg, testCreds).(*clientServer) - ml := &messageLogger{make([][]byte, 0), make([][]byte, 0), false} - cs.SetConnection(ml) - return cs, ml -} - -func TestMakeUnrecognizedRequest(t *testing.T) { - cs, _ := testCS() + cs := testCS(conn) + defer cs.Close() // 'testing.T' should not be a known type ;) err := cs.MakeRequest(t) if _, ok := err.(*wsclient.UnrecognizedWSRequestType); !ok { t.Fatal("Expected unrecognized request type") } - _ = err.Error() // This is one of those times when 100% coverage is silly - cs.Close() -} - -func strptr(s string) *string { - return &s } func TestWriteAckRequest(t *testing.T) { - cs, ml := testCS() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsclient.NewMockWebsocketConn(ctrl) + conn.EXPECT().Close() + cs := testCS(conn) + defer cs.Close() + + // capture bytes written + var writes []byte + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Do(func(_ int, data []byte) { + writes = data + }) - req := ecsacs.AckRequest{Cluster: strptr("default"), ContainerInstance: strptr("testCI"), MessageId: strptr("messageID")} - err := cs.MakeRequest(&req) - if err != nil { - t.Fatal(err) - } + // send request + err := cs.MakeRequest(&ecsacs.AckRequest{}) + assert.NoError(t, err) - write := ml.writes[0] - writtenReq := struct { - Type string - Message ecsacs.AckRequest - }{} - err = json.Unmarshal(write, &writtenReq) - if err != nil { - t.Fatal("Unable to unmarshal written", err) - } - msg := writtenReq.Message - if *msg.Cluster != "default" || *msg.ContainerInstance != "testCI" || *msg.MessageId != "messageID" { - t.Error("Did not write what we expected") - } - cs.Close() + // unmarshal bytes written to the socket + msg := &wsclient.RequestMessage{} + err = json.Unmarshal(writes, msg) + assert.NoError(t, err) + assert.Equal(t, "AckRequest", msg.Type) } func TestPayloadHandlerCalled(t *testing.T) { - cs, ml := testCS() + ctrl := gomock.NewController(t) + defer ctrl.Finish() - var handledPayload *ecsacs.PayloadMessage + conn := mock_wsclient.NewMockWebsocketConn(ctrl) + conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(`{"type":"PayloadMessage","message":{"tasks":[{"arn":"arn"}]}}`), nil) + conn.EXPECT().Close() + cs := testCS(conn) + defer cs.Close() + + messageChannel := make(chan *ecsacs.PayloadMessage) reqHandler := func(payload *ecsacs.PayloadMessage) { - handledPayload = payload + messageChannel <- payload } cs.AddRequestHandler(reqHandler) + go cs.Serve() - ml.reads = [][]byte{[]byte(`{"type":"PayloadMessage","message":{"tasks":[{"arn":"arn"}]}}`)} - - var isClosed bool - go func() { - err := cs.Serve() - if !isClosed && err != nil { - t.Fatal("Premature end of serving", err) - } - }() - - time.Sleep(1 * time.Millisecond) - if handledPayload == nil { - t.Fatal("Handler was not called") + expectedMessage := &ecsacs.PayloadMessage{ + Tasks: []*ecsacs.Task{{ + Arn: aws.String("arn"), + }}, } - if len(handledPayload.Tasks) != 1 || *handledPayload.Tasks[0].Arn != "arn" { - t.Error("Unmarshalled data did not contain expected values") - } - - isClosed = true - cs.Close() + assert.Equal(t, expectedMessage, <-messageChannel) } func TestRefreshCredentialsHandlerCalled(t *testing.T) { - cs, ml := testCS() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + - wait := sync.WaitGroup{} - wait.Add(1) - var handledMessage *ecsacs.IAMRoleCredentialsMessage + conn := mock_wsclient.NewMockWebsocketConn(ctrl) + conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(sampleCredentialsMessage), nil) + conn.EXPECT().Close() + cs := testCS(conn) + defer cs.Close() + + messageChannel := make(chan *ecsacs.IAMRoleCredentialsMessage) reqHandler := func(message *ecsacs.IAMRoleCredentialsMessage) { - wait.Done() - handledMessage = message + messageChannel <- message } cs.AddRequestHandler(reqHandler) - ml.reads = [][]byte{[]byte(sampleCredentialsMessage)} - - var isClosed bool - go func() { - err := cs.Serve() - if !isClosed && err != nil { - t.Fatal("Premature end of serving", err) - } - }() - - wait.Wait() + go cs.Serve() expectedMessage := &ecsacs.IAMRoleCredentialsMessage{ MessageId: aws.String("123"), @@ -203,79 +161,24 @@ func TestRefreshCredentialsHandlerCalled(t *testing.T) { SessionToken: aws.String("token"), }, } - - if !reflect.DeepEqual(handledMessage, expectedMessage) { - t.Error("Unmarshalled credential message did not contain expected values") - } - - isClosed = true - cs.Close() + assert.Equal(t, <-messageChannel, expectedMessage) } func TestClosingConnection(t *testing.T) { - cs, ml := testCS() - closedChan := make(chan error) - var expectedClosed bool - go func() { - err := cs.Serve() - if !expectedClosed { - t.Fatal("Serve ended early") - } - closedChan <- err - }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() - expectedClosed = true - ml.Close() - err := <-closedChan - if err == nil { - t.Error("Closing was expected to result in an error") - } - - req := ecsacs.AckRequest{Cluster: strptr("default"), ContainerInstance: strptr("testCI"), MessageId: strptr("messageID")} - err = cs.MakeRequest(&req) - if err == nil { - t.Error("Cannot request over closed connection") - } -} - -const ( - TestClusterArn = "arn:aws:ec2:123:container/cluster:123456" - TestInstanceArn = "arn:aws:ec2:123:container/containerInstance/12345678" -) + // Returning EOF tells the ClientServer that the connection is closed + conn := mock_wsclient.NewMockWebsocketConn(ctrl) + conn.EXPECT().ReadMessage().Return(0, nil, io.EOF) + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(io.EOF) + cs := testCS(conn) -func startMockAcsServer(t *testing.T, closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { - serverChan := make(chan string) - requestsChan := make(chan string) - errChan := make(chan error) + serveErr := cs.Serve() + assert.Error(t, serveErr) - upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ws, err := upgrader.Upgrade(w, r, nil) - go func() { - <-closeWS - ws.Close() - }() - if err != nil { - errChan <- err - } - go func() { - _, msg, err := ws.ReadMessage() - if err != nil { - errChan <- err - } else { - requestsChan <- string(msg) - } - }() - for str := range serverChan { - err := ws.WriteMessage(websocket.TextMessage, []byte(str)) - if err != nil { - errChan <- err - } - } - }) - - server := httptest.NewTLSServer(handler) - return server, serverChan, requestsChan, errChan, nil + err := cs.MakeRequest(&ecsacs.AckRequest{}) + assert.Error(t, err) } func TestConnect(t *testing.T) { @@ -362,7 +265,50 @@ func TestConnectClientError(t *testing.T) { cs := New(testServer.URL, testCfg, credentials.AnonymousCredentials) err := cs.Connect() - if _, ok := err.(*wsclient.WSError); !ok || err.Error() != "InvalidClusterException: Invalid cluster" { - t.Error("Did not get correctly typed error: " + err.Error()) - } + _, ok := err.(*wsclient.WSError) + assert.True(t, ok) + assert.EqualError(t, err, "InvalidClusterException: Invalid cluster") +} + +func testCS(conn *mock_wsclient.MockWebsocketConn) wsclient.ClientServer { + testCreds := credentials.AnonymousCredentials + cs := New("localhost:443", testCfg, testCreds).(*clientServer) + cs.SetConnection(conn) + return cs +} + +// TODO: replace with gomock +func startMockAcsServer(t *testing.T, closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { + serverChan := make(chan string) + requestsChan := make(chan string) + errChan := make(chan error) + + upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := upgrader.Upgrade(w, r, nil) + go func() { + <-closeWS + ws.Close() + }() + if err != nil { + errChan <- err + } + go func() { + _, msg, err := ws.ReadMessage() + if err != nil { + errChan <- err + } else { + requestsChan <- string(msg) + } + }() + for str := range serverChan { + err := ws.WriteMessage(websocket.TextMessage, []byte(str)) + if err != nil { + errChan <- err + } + } + }) + + server := httptest.NewTLSServer(handler) + return server, serverChan, requestsChan, errChan, nil } diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go index 29315674d7e..9a7e4632c05 100644 --- a/agent/acs/handler/acs_handler_test.go +++ b/agent/acs/handler/acs_handler_test.go @@ -976,7 +976,7 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { t.Errorf("Mismatch between expected and added credentials id for task, expected: %s, added: %s", credentialsIdInRefreshMessage, credentialsIdFromTask) } - go server.Close() + server.Close() // Cancel context should close the session <-ended } @@ -1060,6 +1060,7 @@ func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T) } } +// TODO: replace with gomock func startMockAcsServer(t *testing.T, closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { serverChan := make(chan string, 1) requestsChan := make(chan string, 1) diff --git a/agent/api/container.go b/agent/api/container.go index dfdf44526f0..99c18230c81 100644 --- a/agent/api/container.go +++ b/agent/api/container.go @@ -62,6 +62,9 @@ type Container struct { DockerConfig DockerConfig `json:"dockerConfig"` RegistryAuthentication *RegistryAuthenticationData `json:"registryAuthentication"` + // lock is used for fields that are accessed and updated concurrently + lock sync.RWMutex + // DesiredStatusUnsafe represents the state where the container should go. Generally, // the desired status is informed by the ECS backend as a result of either // API calls made to ECS or decisions made by the ECS service scheduler, @@ -74,7 +77,6 @@ type Container struct { // setter/getter. When this is done, we need to ensure that the UnmarshalJSON // is handled properly so that the state storage continues to work. DesiredStatusUnsafe ContainerStatus `json:"desiredStatus"` - desiredStatusLock sync.RWMutex // KnownStatusUnsafe represents the state where the container is. // NOTE: Do not access `KnownStatusUnsafe` directly. Instead, use `GetKnownStatus` @@ -83,7 +85,6 @@ type Container struct { // setter/getter. When this is done, we need to ensure that the UnmarshalJSON // is handled properly so that the state storage continues to work. KnownStatusUnsafe ContainerStatus `json:"KnownStatus"` - knownStatusLock sync.RWMutex // RunDependencies is a list of containers that must be run before // this one is created @@ -106,9 +107,8 @@ type Container struct { // setter/getter. When this is done, we need to ensure that the UnmarshalJSON is // handled properly so that the state storage continues to work. SentStatusUnsafe ContainerStatus `json:"SentStatus"` - sentStatusLock sync.RWMutex - KnownExitCode *int + knownExitCode *int KnownPortBindings []PortBinding } @@ -156,57 +156,69 @@ func (c *Container) DesiredTerminal() bool { // GetKnownStatus returns the known status of the container func (c *Container) GetKnownStatus() ContainerStatus { - c.knownStatusLock.RLock() - defer c.knownStatusLock.RUnlock() + c.lock.RLock() + defer c.lock.RUnlock() return c.KnownStatusUnsafe } // SetKnownStatus sets the known status of the container func (c *Container) SetKnownStatus(status ContainerStatus) { - c.knownStatusLock.Lock() - defer c.knownStatusLock.Unlock() + c.lock.Lock() + defer c.lock.Unlock() c.KnownStatusUnsafe = status } // GetDesiredStatus gets the desired status of the container func (c *Container) GetDesiredStatus() ContainerStatus { - c.desiredStatusLock.RLock() - defer c.desiredStatusLock.RUnlock() + c.lock.RLock() + defer c.lock.RUnlock() return c.DesiredStatusUnsafe } // SetDesiredStatus sets the desired status of the container func (c *Container) SetDesiredStatus(status ContainerStatus) { - c.desiredStatusLock.Lock() - defer c.desiredStatusLock.Unlock() + c.lock.Lock() + defer c.lock.Unlock() c.DesiredStatusUnsafe = status } // GetSentStatus safely returns the SentStatusUnsafe of the container func (c *Container) GetSentStatus() ContainerStatus { - c.sentStatusLock.RLock() - defer c.sentStatusLock.RUnlock() + c.lock.RLock() + defer c.lock.RUnlock() return c.SentStatusUnsafe } // SetSentStatus safely sets the SentStatusUnsafe of the container func (c *Container) SetSentStatus(status ContainerStatus) { - c.sentStatusLock.Lock() - defer c.sentStatusLock.Unlock() + c.lock.Lock() + defer c.lock.Unlock() c.SentStatusUnsafe = status } +func (c *Container) SetKnownExitCode(i *int) { + c.lock.Lock() + defer c.lock.Unlock() + c.knownExitCode = i +} + +func (c *Container) GetKnownExitCode() *int { + c.lock.RLock() + defer c.lock.RUnlock() + return c.knownExitCode +} + // String returns a human readable string representation of this object func (c *Container) String() string { ret := fmt.Sprintf("%s(%s) (%s->%s)", c.Name, c.Image, c.GetKnownStatus().String(), c.GetDesiredStatus().String()) - if c.KnownExitCode != nil { - ret += " - Exit: " + strconv.Itoa(*c.KnownExitCode) + if c.GetKnownExitCode() != nil { + ret += " - Exit: " + strconv.Itoa(*c.GetKnownExitCode()) } return ret } diff --git a/agent/api/testutils/container_equal.go b/agent/api/testutils/container_equal.go index e8be638426e..9241992d263 100644 --- a/agent/api/testutils/container_equal.go +++ b/agent/api/testutils/container_equal.go @@ -82,14 +82,8 @@ func ContainersEqual(lhs, rhs *api.Container) bool { if lhs.AppliedStatus != rhs.AppliedStatus { return false } - if lhs.KnownExitCode == nil || rhs.KnownExitCode == nil { - if lhs.KnownExitCode != nil || rhs.KnownExitCode != nil { - return false - } - } else { - if *lhs.KnownExitCode != *rhs.KnownExitCode { - return false - } + if !reflect.DeepEqual(lhs.GetKnownExitCode(), rhs.GetKnownExitCode()) { + return false } return true diff --git a/agent/api/testutils/container_equal_test.go b/agent/api/testutils/container_equal_test.go index bb943a89edc..9e1f9397e61 100644 --- a/agent/api/testutils/container_equal_test.go +++ b/agent/api/testutils/container_equal_test.go @@ -14,81 +14,77 @@ package testutils import ( + "fmt" "testing" . "github.com/aws/amazon-ecs-agent/agent/api" + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/assert" ) func TestContainerEqual(t *testing.T) { - one := 1 - onePtr := &one - anotherOne := 1 - anotherOnePtr := &anotherOne - two := 2 - twoPtr := &two - equalPairs := []Container{ - {Name: "name"}, {Name: "name"}, - {Image: "nginx"}, {Image: "nginx"}, - {Command: []string{"c"}}, {Command: []string{"c"}}, - {CPU: 1}, {CPU: 1}, - {Memory: 1}, {Memory: 1}, - {Links: []string{"1", "2"}}, {Links: []string{"1", "2"}}, - {Links: []string{"1", "2"}}, {Links: []string{"2", "1"}}, - {VolumesFrom: []VolumeFrom{{"1", false}, {"2", true}}}, {VolumesFrom: []VolumeFrom{{"1", false}, {"2", true}}}, - {VolumesFrom: []VolumeFrom{{"1", false}, {"2", true}}}, {VolumesFrom: []VolumeFrom{{"2", true}, {"1", false}}}, - {Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, {Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, - {Essential: true}, {Essential: true}, - {EntryPoint: nil}, {EntryPoint: nil}, - {EntryPoint: &[]string{"1", "2"}}, {EntryPoint: &[]string{"1", "2"}}, - {Environment: map[string]string{}}, {Environment: map[string]string{}}, - {Environment: map[string]string{"a": "b", "c": "d"}}, {Environment: map[string]string{"c": "d", "a": "b"}}, - {DesiredStatusUnsafe: ContainerRunning}, {DesiredStatusUnsafe: ContainerRunning}, - {AppliedStatus: ContainerRunning}, {AppliedStatus: ContainerRunning}, - {KnownStatusUnsafe: ContainerRunning}, {KnownStatusUnsafe: ContainerRunning}, - {KnownExitCode: nil}, {KnownExitCode: nil}, - {KnownExitCode: onePtr}, {KnownExitCode: anotherOnePtr}, - } - unequalPairs := []Container{ - {Name: "name"}, {Name: "名前"}, - {Image: "nginx"}, {Image: "えんじんえっくす"}, - {Command: []string{"c"}}, {Command: []string{"し"}}, - {Command: []string{"c", "b"}}, {Command: []string{"b", "c"}}, - {CPU: 1}, {CPU: 2e2}, - {Memory: 1}, {Memory: 2e2}, - {Links: []string{"1", "2"}}, {Links: []string{"1", "二"}}, - {VolumesFrom: []VolumeFrom{{"1", false}, {"2", true}}}, {VolumesFrom: []VolumeFrom{{"1", false}, {"二", false}}}, - {Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, {Ports: []PortBinding{{1, 2, "二", TransportProtocolTCP}}}, - {Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, {Ports: []PortBinding{{1, 22, "1", TransportProtocolTCP}}}, - {Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, {Ports: []PortBinding{{1, 2, "1", TransportProtocolUDP}}}, - {Essential: true}, {Essential: false}, - {EntryPoint: nil}, {EntryPoint: &[]string{"nonnil"}}, - {EntryPoint: &[]string{"1", "2"}}, {EntryPoint: &[]string{"2", "1"}}, - {EntryPoint: &[]string{"1", "2"}}, {EntryPoint: &[]string{"1", "二"}}, - {Environment: map[string]string{"a": "b", "c": "d"}}, {Environment: map[string]string{"し": "d", "a": "b"}}, - {DesiredStatusUnsafe: ContainerRunning}, {DesiredStatusUnsafe: ContainerStopped}, - {AppliedStatus: ContainerRunning}, {AppliedStatus: ContainerStopped}, - {KnownStatusUnsafe: ContainerRunning}, {KnownStatusUnsafe: ContainerStopped}, - {KnownExitCode: nil}, {KnownExitCode: onePtr}, - {KnownExitCode: onePtr}, {KnownExitCode: twoPtr}, + + exitCodeContainer := func(p *int) Container { + c := Container{} + c.SetKnownExitCode(p) + return c } - for i := 0; i < len(equalPairs); i += 2 { - if !ContainersEqual(&equalPairs[i], &equalPairs[i+1]) { - t.Error(i, equalPairs[i], " should equal ", equalPairs[i+1]) - } - // Should be symetric - if !ContainersEqual(&equalPairs[i+1], &equalPairs[i]) { - t.Error(i, "(symetric)", equalPairs[i+1], " should equal ", equalPairs[i]) - } + testCases := []struct { + lhs Container + rhs Container + shouldBeEqual bool + }{ + // Equal Pairs + {Container{Name: "name"}, Container{Name: "name"}, true}, + {Container{Image: "nginx"}, Container{Image: "nginx"}, true}, + {Container{Command: []string{"c"}}, Container{Command: []string{"c"}}, true}, + {Container{CPU: 1}, Container{CPU: 1}, true}, + {Container{Memory: 1}, Container{Memory: 1}, true}, + {Container{Links: []string{"1", "2"}}, Container{Links: []string{"1", "2"}}, true}, + {Container{Links: []string{"1", "2"}}, Container{Links: []string{"2", "1"}}, true}, + {Container{VolumesFrom: []VolumeFrom{{"1", false}, {"2", true}}}, Container{VolumesFrom: []VolumeFrom{{"1", false}, {"2", true}}}, true}, + {Container{VolumesFrom: []VolumeFrom{{"1", false}, {"2", true}}}, Container{VolumesFrom: []VolumeFrom{{"2", true}, {"1", false}}}, true}, + {Container{Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, Container{Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, true}, + {Container{Essential: true}, Container{Essential: true}, true}, + {Container{EntryPoint: nil}, Container{EntryPoint: nil}, true}, + {Container{EntryPoint: &[]string{"1", "2"}}, Container{EntryPoint: &[]string{"1", "2"}}, true}, + {Container{Environment: map[string]string{}}, Container{Environment: map[string]string{}}, true}, + {Container{Environment: map[string]string{"a": "b", "c": "d"}}, Container{Environment: map[string]string{"c": "d", "a": "b"}}, true}, + {Container{DesiredStatusUnsafe: ContainerRunning}, Container{DesiredStatusUnsafe: ContainerRunning}, true}, + {Container{AppliedStatus: ContainerRunning}, Container{AppliedStatus: ContainerRunning}, true}, + {Container{KnownStatusUnsafe: ContainerRunning}, Container{KnownStatusUnsafe: ContainerRunning}, true}, + {exitCodeContainer(aws.Int(1)), exitCodeContainer(aws.Int(1)), true}, + {exitCodeContainer(nil), exitCodeContainer(nil), true}, + // Unequal Pairs + {Container{Name: "name"}, Container{Name: "名前"}, false}, + {Container{Image: "nginx"}, Container{Image: "えんじんえっくす"}, false}, + {Container{Command: []string{"c"}}, Container{Command: []string{"し"}}, false}, + {Container{Command: []string{"c", "b"}}, Container{Command: []string{"b", "c"}}, false}, + {Container{CPU: 1}, Container{CPU: 2e2}, false}, + {Container{Memory: 1}, Container{Memory: 2e2}, false}, + {Container{Links: []string{"1", "2"}}, Container{Links: []string{"1", "二"}}, false}, + {Container{VolumesFrom: []VolumeFrom{{"1", false}, {"2", true}}}, Container{VolumesFrom: []VolumeFrom{{"1", false}, {"二", false}}}, false}, + {Container{Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, Container{Ports: []PortBinding{{1, 2, "二", TransportProtocolTCP}}}, false}, + {Container{Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, Container{Ports: []PortBinding{{1, 22, "1", TransportProtocolTCP}}}, false}, + {Container{Ports: []PortBinding{{1, 2, "1", TransportProtocolTCP}}}, Container{Ports: []PortBinding{{1, 2, "1", TransportProtocolUDP}}}, false}, + {Container{Essential: true}, Container{Essential: false}, false}, + {Container{EntryPoint: nil}, Container{EntryPoint: &[]string{"nonnil"}}, false}, + {Container{EntryPoint: &[]string{"1", "2"}}, Container{EntryPoint: &[]string{"2", "1"}}, false}, + {Container{EntryPoint: &[]string{"1", "2"}}, Container{EntryPoint: &[]string{"1", "二"}}, false}, + {Container{Environment: map[string]string{"a": "b", "c": "d"}}, Container{Environment: map[string]string{"し": "d", "a": "b"}}, false}, + {Container{DesiredStatusUnsafe: ContainerRunning}, Container{DesiredStatusUnsafe: ContainerStopped}, false}, + {Container{AppliedStatus: ContainerRunning}, Container{AppliedStatus: ContainerStopped}, false}, + {Container{KnownStatusUnsafe: ContainerRunning}, Container{KnownStatusUnsafe: ContainerStopped}, false}, + {exitCodeContainer(aws.Int(0)), exitCodeContainer(aws.Int(42)), false}, + {exitCodeContainer(nil), exitCodeContainer(aws.Int(12)), false}, } - for i := 0; i < len(unequalPairs); i += 2 { - if ContainersEqual(&unequalPairs[i], &unequalPairs[i+1]) { - t.Error(i, unequalPairs[i], " shouldn't equal ", unequalPairs[i+1]) - } - //symetric - if ContainersEqual(&unequalPairs[i+1], &unequalPairs[i]) { - t.Error(i, "(symetric)", unequalPairs[i+1], " shouldn't equal ", unequalPairs[i]) - } + for index, tc := range testCases { + t.Run(fmt.Sprintf("index %d expected %t", index, tc.shouldBeEqual), func(t *testing.T) { + assert.Equal(t, ContainersEqual(&tc.lhs, &tc.rhs), tc.shouldBeEqual, "ContainersEqual not working as expected. Check index failure.") + // Symetric + assert.Equal(t, ContainersEqual(&tc.rhs, &tc.lhs), tc.shouldBeEqual, "Symetric equality check failed. Check index failure.") + }) } } diff --git a/agent/api/testutils/task_equal_test.go b/agent/api/testutils/task_equal_test.go index 1228c687e82..632eaea23fb 100644 --- a/agent/api/testutils/task_equal_test.go +++ b/agent/api/testutils/task_equal_test.go @@ -14,47 +14,42 @@ package testutils import ( + "fmt" "testing" . "github.com/aws/amazon-ecs-agent/agent/api" + "github.com/stretchr/testify/assert" ) func TestTaskEqual(t *testing.T) { - equalPairs := []Task{ - {Arn: "a"}, {Arn: "a"}, - {Family: "a"}, {Family: "a"}, - {Version: "a"}, {Version: "a"}, - {Containers: []*Container{{Name: "a"}}}, {Containers: []*Container{{Name: "a"}}}, - {DesiredStatusUnsafe: TaskRunning}, {DesiredStatusUnsafe: TaskRunning}, - {KnownStatusUnsafe: TaskRunning}, {KnownStatusUnsafe: TaskRunning}, - } - unequalPairs := []Task{ - {Arn: "a"}, {Arn: "あ"}, - {Family: "a"}, {Family: "あ"}, - {Version: "a"}, {Version: "あ"}, - {Containers: []*Container{{Name: "a"}}}, {Containers: []*Container{{Name: "あ"}}}, - {DesiredStatusUnsafe: TaskRunning}, {DesiredStatusUnsafe: TaskStopped}, - {KnownStatusUnsafe: TaskRunning}, {KnownStatusUnsafe: TaskStopped}, - } + testCases := []struct { + rhs Task + lhs Task + shouldBeEqual bool + }{ + // Equal Pairs + {Task{Arn: "a"}, Task{Arn: "a"}, true}, + {Task{Family: "a"}, Task{Family: "a"}, true}, + {Task{Version: "a"}, Task{Version: "a"}, true}, + {Task{Containers: []*Container{{Name: "a"}}}, Task{Containers: []*Container{{Name: "a"}}}, true}, + {Task{DesiredStatusUnsafe: TaskRunning}, Task{DesiredStatusUnsafe: TaskRunning}, true}, + {Task{KnownStatusUnsafe: TaskRunning}, Task{KnownStatusUnsafe: TaskRunning}, true}, - for i := 0; i < len(equalPairs); i += 2 { - if !TasksEqual(&equalPairs[i], &equalPairs[i+1]) { - t.Error(i, equalPairs[i], " should equal ", equalPairs[i+1]) - } - // Should be symetric - if !TasksEqual(&equalPairs[i+1], &equalPairs[i]) { - t.Error(i, "(symetric)", equalPairs[i+1], " should equal ", equalPairs[i]) - } + // Unequal Pairs + {Task{Arn: "a"}, Task{Arn: "あ"}, false}, + {Task{Family: "a"}, Task{Family: "あ"}, false}, + {Task{Version: "a"}, Task{Version: "あ"}, false}, + {Task{Containers: []*Container{{Name: "a"}}}, Task{Containers: []*Container{{Name: "あ"}}}, false}, + {Task{DesiredStatusUnsafe: TaskRunning}, Task{DesiredStatusUnsafe: TaskStopped}, false}, + {Task{KnownStatusUnsafe: TaskRunning}, Task{KnownStatusUnsafe: TaskStopped}, false}, } - for i := 0; i < len(unequalPairs); i += 2 { - if TasksEqual(&unequalPairs[i], &unequalPairs[i+1]) { - t.Error(i, unequalPairs[i], " shouldn't equal ", unequalPairs[i+1]) - } - //symetric - if TasksEqual(&unequalPairs[i+1], &unequalPairs[i]) { - t.Error(i, "(symetric)", unequalPairs[i+1], " shouldn't equal ", unequalPairs[i]) - } + for index, tc := range testCases { + t.Run(fmt.Sprintf("index %d expected %t", index, tc.shouldBeEqual), func(t *testing.T) { + assert.Equal(t, TasksEqual(&tc.lhs, &tc.rhs), tc.shouldBeEqual, "TasksEqual not working as expected. Check index failure.") + // Symetric + assert.Equal(t, TasksEqual(&tc.rhs, &tc.lhs), tc.shouldBeEqual, "Symetric equality check failed. Check index failure.") + }) } } diff --git a/agent/engine/docker_container_engine_test.go b/agent/engine/docker_container_engine_test.go index fc7395977d0..ac1dd8c7f43 100644 --- a/agent/engine/docker_container_engine_test.go +++ b/agent/engine/docker_container_engine_test.go @@ -720,6 +720,7 @@ func TestListContainersTimeout(t *testing.T) { if response.Error.(api.NamedError).ErrorName() != "DockerTimeoutError" { t.Error("Wrong error type") } + <-warp wait.Done() } diff --git a/agent/engine/docker_image_manager.go b/agent/engine/docker_image_manager.go index 5521cd5622c..53e7c17f627 100644 --- a/agent/engine/docker_image_manager.go +++ b/agent/engine/docker_image_manager.go @@ -84,6 +84,12 @@ func (imageManager *dockerImageManager) AddAllImageStates(imageStates []*image.I } } +func (imageManager *dockerImageManager) GetImageStatesCount() int { + imageManager.updateLock.RLock() + defer imageManager.updateLock.RUnlock() + return len(imageManager.imageStates) +} + // RecordContainerReference adds container reference to the corresponding imageState object func (imageManager *dockerImageManager) RecordContainerReference(container *api.Container) error { // the image state has been updated, save the new state diff --git a/agent/engine/docker_image_manager_test.go b/agent/engine/docker_image_manager_test.go index 29c899306e5..d49f1a641ea 100644 --- a/agent/engine/docker_image_manager_test.go +++ b/agent/engine/docker_image_manager_test.go @@ -728,10 +728,10 @@ func TestImageCleanupHappyPath(t *testing.T) { go imageManager.performPeriodicImageCleanup(ctx, 2*time.Millisecond) time.Sleep(1 * time.Second) cancel() - if len(imageState.Image.Names) != 0 { + if imageState.GetImageNamesCount() != 0 { t.Error("Error removing image name from state after the image is removed") } - if len(imageManager.imageStates) != 0 { + if imageManager.GetImageStatesCount() != 0 { t.Error("Error removing image state after the image is removed") } } diff --git a/agent/engine/docker_task_engine.go b/agent/engine/docker_task_engine.go index f05583d02d8..89bbb576bfa 100644 --- a/agent/engine/docker_task_engine.go +++ b/agent/engine/docker_task_engine.go @@ -371,7 +371,7 @@ func (engine *DockerTaskEngine) emitContainerEvent(task *api.Task, cont *api.Con TaskArn: task.Arn, ContainerName: cont.Name, Status: contKnownStatus, - ExitCode: cont.KnownExitCode, + ExitCode: cont.GetKnownExitCode(), PortBindings: cont.KnownPortBindings, Reason: reason, Container: cont, diff --git a/agent/engine/engine_integ_test.go b/agent/engine/engine_integ_test.go index dc9c5f20d34..c166abd7769 100644 --- a/agent/engine/engine_integ_test.go +++ b/agent/engine/engine_integ_test.go @@ -141,8 +141,8 @@ func TestHostVolumeMount(t *testing.T) { event = <-stateChangeEvents assert.Equal(t, event.(api.TaskStateChange).Status, api.TaskStopped, "Expected task to be STOPPED") - assert.NotNil(t, testTask.Containers[0].KnownExitCode, "No exit code found") - assert.Equal(t, 42, *testTask.Containers[0].KnownExitCode, "Wrong exit code") + assert.NotNil(t, testTask.Containers[0].GetKnownExitCode(), "No exit code found") + assert.Equal(t, 42, *testTask.Containers[0].GetKnownExitCode(), "Wrong exit code") data, err := ioutil.ReadFile(filepath.Join(tmpPath, "hello-from-container")) assert.Nil(t, err, "Unexpected error") @@ -178,8 +178,8 @@ func TestEmptyHostVolumeMount(t *testing.T) { event = <-stateChangeEvents assert.Equal(t, event.(api.TaskStateChange).Status, api.TaskStopped, "Expected task to be STOPPED") - assert.NotNil(t, testTask.Containers[0].KnownExitCode, "No exit code found") - assert.Equal(t, 42, *testTask.Containers[0].KnownExitCode, "Wrong exit code, file probably wasn't present") + assert.NotNil(t, testTask.Containers[0].GetKnownExitCode(), "No exit code found") + assert.Equal(t, 42, *testTask.Containers[0].GetKnownExitCode(), "Wrong exit code, file probably wasn't present") } func TestSweepContainer(t *testing.T) { diff --git a/agent/engine/engine_unix_integ_test.go b/agent/engine/engine_unix_integ_test.go index 212b45c657f..8e55e125763 100644 --- a/agent/engine/engine_unix_integ_test.go +++ b/agent/engine/engine_unix_integ_test.go @@ -621,13 +621,13 @@ func TestVolumesFromRO(t *testing.T) { verifyTaskIsStopped(stateChangeEvents, testTask) - if testTask.Containers[1].KnownExitCode == nil || *testTask.Containers[1].KnownExitCode != 42 { - t.Error("Didn't exit due to failure to touch ro fs as expected: ", *testTask.Containers[1].KnownExitCode) + if testTask.Containers[1].GetKnownExitCode() == nil || *testTask.Containers[1].GetKnownExitCode() != 42 { + t.Error("Didn't exit due to failure to touch ro fs as expected: ", *testTask.Containers[1].GetKnownExitCode()) } - if testTask.Containers[2].KnownExitCode == nil || *testTask.Containers[2].KnownExitCode != 0 { + if testTask.Containers[2].GetKnownExitCode() == nil || *testTask.Containers[2].GetKnownExitCode() != 0 { t.Error("Couldn't touch with default of rw") } - if testTask.Containers[3].KnownExitCode == nil || *testTask.Containers[3].KnownExitCode != 0 { + if testTask.Containers[3].GetKnownExitCode() == nil || *testTask.Containers[3].GetKnownExitCode() != 0 { t.Error("Couldn't touch with explicit rw") } } @@ -782,7 +782,7 @@ check_events: event = <-stateChangeEvents assert.Equal(t, event.(api.TaskStateChange).Status, api.TaskStopped, "Expected task to be STOPPED") - if testTask.Containers[0].KnownExitCode == nil || *testTask.Containers[0].KnownExitCode != 42 { + if testTask.Containers[0].GetKnownExitCode() == nil || *testTask.Containers[0].GetKnownExitCode() != 42 { t.Error("Wrong exit code; file probably wasn't present") } } diff --git a/agent/engine/image/types.go b/agent/engine/image/types.go index 5083e01964e..246c70a5a23 100644 --- a/agent/engine/image/types.go +++ b/agent/engine/image/types.go @@ -60,6 +60,12 @@ func (imageState *ImageState) AddImageName(imageName string) { } } +func (imageState *ImageState) GetImageNamesCount() int { + imageState.updateLock.RLock() + defer imageState.updateLock.RUnlock() + return len(imageState.Image.Names) +} + func (imageState *ImageState) HasNoAssociatedContainers() bool { return len(imageState.Containers) == 0 } diff --git a/agent/engine/task_manager.go b/agent/engine/task_manager.go index 35d9616ee6b..174e0a63f80 100644 --- a/agent/engine/task_manager.go +++ b/agent/engine/task_manager.go @@ -291,8 +291,8 @@ func (mtask *managedTask) handleContainerChange(containerChange dockerContainerC seelog.Warnf("Failed to write container change event to event stream, err %v", err) } - if event.ExitCode != nil && event.ExitCode != container.KnownExitCode { - container.KnownExitCode = event.ExitCode + if event.ExitCode != nil && event.ExitCode != container.GetKnownExitCode() { + container.SetKnownExitCode(event.ExitCode) } if event.PortBindings != nil { container.KnownPortBindings = event.PortBindings diff --git a/agent/eventhandler/handler_test.go b/agent/eventhandler/handler_test.go index 78faba5bcc2..49b6344b966 100644 --- a/agent/eventhandler/handler_test.go +++ b/agent/eventhandler/handler_test.go @@ -18,14 +18,12 @@ import ( "strconv" "sync" "testing" - "time" "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/api/mocks" "github.com/aws/amazon-ecs-agent/agent/statechange" "github.com/aws/amazon-ecs-agent/agent/utils" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" ) func containerEvent(arn string) statechange.Event { @@ -100,41 +98,33 @@ func TestSendsEventsConcurrentLimit(t *testing.T) { handler := NewTaskHandler() - contCalled := make(chan struct{}, concurrentEventCalls+1) completeStateChange := make(chan bool, concurrentEventCalls+1) - count := 0 - countLock := &sync.Mutex{} + var wg sync.WaitGroup + client.EXPECT().SubmitContainerStateChange(gomock.Any()).Times(concurrentEventCalls + 1).Do(func(interface{}) { - countLock.Lock() - count++ - countLock.Unlock() + wg.Done() <-completeStateChange - contCalled <- struct{}{} }) + // Test concurrency; ensure it doesn't attempt to send more than // concurrentEventCalls at once + wg.Add(concurrentEventCalls) + // Put on N+1 events for i := 0; i < concurrentEventCalls+1; i++ { handler.AddStateChangeEvent(containerEvent("concurrent_"+strconv.Itoa(i)), client) } - time.Sleep(10 * time.Millisecond) + wg.Wait() - // N events should be waiting for potential errors since we havent started completing state changes - assert.Equal(t, concurrentEventCalls, count, "Too many event calls got through concurrently") - // Let one state change finish + //Let one change through + wg.Add(1) completeStateChange <- true - <-contCalled - time.Sleep(10 * time.Millisecond) - - assert.Equal(t, concurrentEventCalls+1, count, "Another concurrent call didn't start when expected") + wg.Wait() // ensure the remaining requests are completed for i := 0; i < concurrentEventCalls; i++ { completeStateChange <- true - <-contCalled } - time.Sleep(5 * time.Millisecond) - assert.Equal(t, concurrentEventCalls+1, count, "Extra concurrent calls appeared from nowhere") } func TestSendsEventsContainerDifferences(t *testing.T) { diff --git a/agent/eventhandler/task_handler.go b/agent/eventhandler/task_handler.go index e8f01c97cd8..498a1634b45 100644 --- a/agent/eventhandler/task_handler.go +++ b/agent/eventhandler/task_handler.go @@ -159,7 +159,7 @@ func (handler *TaskHandler) SubmitTaskEvents(taskEvents *eventList, client api.E err = client.SubmitContainerStateChange(event.containerChange) if err == nil { // submitted; ensure we don't retry it - event.containerSent = true + event.setSent() if event.containerChange.Container != nil { event.containerChange.Container.SetSentStatus(event.containerChange.Status) } @@ -175,7 +175,7 @@ func (handler *TaskHandler) SubmitTaskEvents(taskEvents *eventList, client api.E err = client.SubmitTaskStateChange(event.taskChange) if err == nil { // submitted or can't be retried; ensure we don't retry it - event.taskSent = true + event.setSent() if event.taskChange.Task != nil { event.taskChange.Task.SetSentStatus(event.taskChange.Status) } diff --git a/agent/eventhandler/task_handler_types.go b/agent/eventhandler/task_handler_types.go index 6afea56159f..a21796d8cb0 100644 --- a/agent/eventhandler/task_handler_types.go +++ b/agent/eventhandler/task_handler_types.go @@ -15,6 +15,7 @@ package eventhandler import ( "github.com/aws/amazon-ecs-agent/agent/api" + "sync" ) // a state change that may have a container and, optionally, a task event to @@ -28,9 +29,13 @@ type sendableEvent struct { taskSent bool taskChange api.TaskStateChange + + lock sync.RWMutex } -func (event sendableEvent) String() string { +func (event *sendableEvent) String() string { + event.lock.RLock() + defer event.lock.RUnlock() if event.isContainerEvent { return "ContainerChange: " + event.containerChange.String() } else { @@ -62,6 +67,8 @@ func (event *sendableEvent) taskArn() string { } func (event *sendableEvent) taskShouldBeSent() bool { + event.lock.RLock() + defer event.lock.RUnlock() if event.isContainerEvent { return false } @@ -76,6 +83,8 @@ func (event *sendableEvent) taskShouldBeSent() bool { } func (event *sendableEvent) containerShouldBeSent() bool { + event.lock.RLock() + defer event.lock.RUnlock() if !event.isContainerEvent { return false } @@ -85,3 +94,13 @@ func (event *sendableEvent) containerShouldBeSent() bool { } return true } + +func (event *sendableEvent) setSent() { + event.lock.Lock() + defer event.lock.Unlock() + if event.isContainerEvent { + event.containerSent = true + } else { + event.taskSent = true + } +} diff --git a/agent/eventstream/eventstream.go b/agent/eventstream/eventstream.go index 3660471fd66..3d7ba8c620d 100644 --- a/agent/eventstream/eventstream.go +++ b/agent/eventstream/eventstream.go @@ -16,11 +16,11 @@ package eventstream import ( + "context" "fmt" "sync" "github.com/cihub/seelog" - "golang.org/x/net/context" ) type eventHandler func(...interface{}) error diff --git a/agent/eventstream/eventstream_test.go b/agent/eventstream/eventstream_test.go index 90816c8b751..e55678e7769 100644 --- a/agent/eventstream/eventstream_test.go +++ b/agent/eventstream/eventstream_test.go @@ -13,92 +13,61 @@ package eventstream import ( + "context" + "sync" "testing" "time" - "golang.org/x/net/context" + "github.com/stretchr/testify/assert" ) -type eventListener struct { - called bool -} - -func (listener *eventListener) recordCall(...interface{}) error { - listener.called = true - return nil -} - // TestSubscribe tests the listener subscribed to the // event stream will be notified func TestSubscribe(t *testing.T) { - listener := eventListener{called: false} + waiter, listener := setupWaitGroupAndListener() ctx, cancel := context.WithCancel(context.Background()) - eventStream := NewEventStream("TestSubscribe", ctx) - eventStream.Subscribe("listener", listener.recordCall) + defer cancel() + waiter.Add(1) + eventStream := NewEventStream("TestSubscribe", ctx) + eventStream.Subscribe("listener", listener) eventStream.StartListening() err := eventStream.WriteToEventStream(struct{}{}) - if err != nil { - t.Errorf("Write to event stream failed, err: %v", err) - } - - time.Sleep(1 * time.Second) - - cancel() - if !listener.called { - t.Error("Listener was not invoked") - } + assert.NoError(t, err) + waiter.Wait() } // TestUnsubscribe tests the listener unsubscribed from the // event steam will not be notified func TestUnsubscribe(t *testing.T) { - listener1 := eventListener{called: false} - listener2 := eventListener{called: false} + waiter1, listener1 := setupWaitGroupAndListener() + waiter2, listener2 := setupWaitGroupAndListener() ctx, cancel := context.WithCancel(context.Background()) - eventStream := NewEventStream("TestUnsubscribe", ctx) - - eventStream.Subscribe("listener1", listener1.recordCall) - eventStream.Subscribe("listener2", listener2.recordCall) + defer cancel() + eventStream := NewEventStream("TestUnsubscribe", ctx) + eventStream.Subscribe("listener1", listener1) + eventStream.Subscribe("listener2", listener2) eventStream.StartListening() + waiter1.Add(1) + waiter2.Add(1) err := eventStream.WriteToEventStream(struct{}{}) - if err != nil { - t.Errorf("Write to event stream failed, err: %v", err) - } + assert.NoError(t, err) + waiter1.Wait() + waiter2.Wait() - time.Sleep(1 * time.Second) - if !listener1.called { - t.Error("Listener 1 was not invoked") - } - if !listener2.called { - t.Error("Listener 2 was not invoked") - } - - listener1.called = false - listener2.called = false eventStream.Unsubscribe("listener1") + waiter2.Add(1) err = eventStream.WriteToEventStream(struct{}{}) - if err != nil { - t.Errorf("Write to event stream failed, err: %v", err) - } - - // wait for the event stream to broadcast - time.Sleep(1 * time.Second) - - if listener1.called { - t.Error("Unsubscribed handler shouldn't be called") - } + assert.NoError(t, err) - if !listener2.called { - t.Error("Listener 2 was not invoked without unsubscribing") - } - cancel() + waiter1.Wait() + waiter2.Wait() } // TestCancelEventStream tests the event stream can @@ -106,22 +75,25 @@ func TestUnsubscribe(t *testing.T) { func TestCancelEventStream(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) eventStream := NewEventStream("TestCancelEventStream", ctx) + _, listener := setupWaitGroupAndListener() - listener := eventListener{called: false} - - eventStream.Subscribe("listener", listener.recordCall) - + eventStream.Subscribe("listener", listener) eventStream.StartListening() cancel() - // wait for the event stream to handle cancel time.Sleep(1 * time.Second) err := eventStream.WriteToEventStream(struct{}{}) - if err == nil { - t.Error("Write to closed event stream should return an error") - } - if listener.called { - t.Error("Cancelled events context, handler should not be called") + assert.Error(t, err) +} + +// setupWaitGroupAndListener creates a waitgroup and a function +// that decrements a WaitGroup when called +func setupWaitGroupAndListener() (*sync.WaitGroup, func(...interface{}) error) { + waiter := &sync.WaitGroup{} + listener := func(...interface{}) error { + waiter.Done() + return nil } + return waiter, listener } diff --git a/agent/stats/container.go b/agent/stats/container.go index c3c467d8e82..48136a078e6 100644 --- a/agent/stats/container.go +++ b/agent/stats/container.go @@ -14,6 +14,7 @@ package stats import ( + "errors" "time" ecsengine "github.com/aws/amazon-ecs-agent/agent/engine" @@ -96,6 +97,9 @@ func (container *StatsContainer) collect() { func (container *StatsContainer) processStatsStream() error { dockerID := container.containerMetadata.DockerID seelog.Debugf("Collecting stats for container %s", dockerID) + if container.client == nil { + return errors.New("container processStatsStream: Client is not set.") + } dockerStats, err := container.client.Stats(dockerID, container.ctx) if err != nil { return err diff --git a/agent/tcs/client/client_test.go b/agent/tcs/client/client_test.go index f4f4943aa3c..23b59899074 100644 --- a/agent/tcs/client/client_test.go +++ b/agent/tcs/client/client_test.go @@ -20,21 +20,19 @@ package tcsclient import ( - "errors" "fmt" "strconv" "testing" "time" "github.com/aws/amazon-ecs-agent/agent/config" - "github.com/aws/amazon-ecs-agent/agent/eventstream" "github.com/aws/amazon-ecs-agent/agent/tcs/model/ecstcs" "github.com/aws/amazon-ecs-agent/agent/wsclient" + "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/gorilla/websocket" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "golang.org/x/net/context" ) const ( @@ -44,37 +42,6 @@ const ( testContainerInstance = "containerInstance" ) -type messageLogger struct { - writes [][]byte - reads [][]byte - closed bool -} - -func (ml *messageLogger) WriteMessage(_ int, data []byte) error { - if ml.closed { - return errors.New("can't write to closed ws") - } - ml.writes = append(ml.writes, data) - return nil -} - -func (ml *messageLogger) Close() error { - ml.closed = true - return nil -} - -func (ml *messageLogger) ReadMessage() (int, []byte, error) { - for len(ml.reads) == 0 && !ml.closed { - time.Sleep(1 * time.Millisecond) - } - if ml.closed { - return 0, []byte{}, errors.New("can't read from a closed websocket") - } - read := ml.reads[len(ml.reads)-1] - ml.reads = ml.reads[0 : len(ml.reads)-1] - return websocket.TextMessage, read, nil -} - type mockStatsEngine struct{} func (engine *mockStatsEngine) GetInstanceMetrics() (*ecstcs.MetricsMetadata, []*ecstcs.TaskMetric, error) { @@ -124,7 +91,14 @@ func newNonIdleStatsEngine(numTasks int) *nonIdleStatsEngine { } func TestPayloadHandlerCalled(t *testing.T) { - cs, ml := testCS() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsclient.NewMockWebsocketConn(ctrl) + cs := testCS(conn) + + conn.EXPECT().ReadMessage().AnyTimes().Return(1, []byte(`{"type":"AckPublishMetric","message":{}}`), nil) + conn.EXPECT().Close() handledPayload := make(chan *ecstcs.AckPublishMetric) @@ -133,30 +107,30 @@ func TestPayloadHandlerCalled(t *testing.T) { } cs.AddRequestHandler(reqHandler) - ml.reads = [][]byte{[]byte(`{"type":"AckPublishMetric","message":{}}`)} - - var isClosed bool - go func() { - err := cs.Serve() - if !isClosed && err != nil { - t.Fatal("Premature end of serving", err) - } - }() + go cs.Serve() + defer cs.Close() t.Log("Waiting for handler to return payload.") <-handledPayload - isClosed = true - cs.Close() } func TestPublishMetricsRequest(t *testing.T) { - cs, _ := testCS() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsclient.NewMockWebsocketConn(ctrl) + conn.EXPECT().Close() + // TODO: should use explicit values + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()) + + cs := testCS(conn) + defer cs.Close() + err := cs.MakeRequest(&ecstcs.PublishMetricsRequest{}) if err != nil { t.Fatal(err) } - cs.Close() } func TestPublishMetricsOnceEmptyStatsError(t *testing.T) { cs := clientServer{ @@ -221,47 +195,34 @@ func TestPublishOnceNonIdleStatsEngine(t *testing.T) { } } -func testCS() (wsclient.ClientServer, *messageLogger) { +func testCS(conn *mock_wsclient.MockWebsocketConn) wsclient.ClientServer { testCreds := credentials.AnonymousCredentials cfg := &config.Config{ AWSRegion: "us-east-1", AcceptInsecureCert: true, } cs := New("localhost:443", cfg, testCreds, &mockStatsEngine{}, testPublishMetricsInterval).(*clientServer) - ml := &messageLogger{make([][]byte, 0), make([][]byte, 0), false} - cs.SetConnection(ml) - return cs, ml + cs.SetConnection(conn) + return cs } -// TestDeregisterInstanceStream tests the ws connection will be closed by tcs client when +// TestCloseClientServer tests the ws connection will be closed by tcs client when // received the deregisterInstanceStream -func TestDeregisterInstanceStream(t *testing.T) { - cs, ml := testCS() +func TestCloseClientServer(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - ctx, cancel := context.WithCancel(context.Background()) - deregisterInstanceEventStream := eventstream.NewEventStream("TestDeregisterInstanceStream", ctx) - deregisterInstanceEventStream.StartListening() - defer cancel() + conn := mock_wsclient.NewMockWebsocketConn(ctrl) + cs := testCS(conn) - err := deregisterInstanceEventStream.Subscribe("TestDeregisterContainerInstanceHandler", cs.Disconnect) - if err != nil { - t.Errorf("Error subscribing to event stream, err %v", err) - } + gomock.InOrder( + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()), + conn.EXPECT().Close(), + ) - err = cs.MakeRequest(&ecstcs.PublishMetricsRequest{}) - if err != nil { - t.Errorf("Error making client request: %v", err) - } - if ml.closed { - t.Error("Connection closed before send the deregister event") - } - err = deregisterInstanceEventStream.WriteToEventStream(struct{}{}) - if err != nil { - t.Errorf("Failed to write to event stream, err %v", err) - } - // wait for the handler to run - time.Sleep(1 * time.Second) - if !ml.closed { - t.Error("Client should be closed after receiving deregister event") - } + err := cs.MakeRequest(&ecstcs.PublishMetricsRequest{}) + assert.Nil(t, err) + + err = cs.Disconnect() + assert.Nil(t, err) } diff --git a/agent/tcs/handler/handler_test.go b/agent/tcs/handler/handler_test.go index 3354df499d6..aef175ae724 100644 --- a/agent/tcs/handler/handler_test.go +++ b/agent/tcs/handler/handler_test.go @@ -29,7 +29,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/tcs/client" "github.com/aws/amazon-ecs-agent/agent/tcs/model/ecstcs" "github.com/aws/amazon-ecs-agent/agent/wsclient" - "github.com/aws/amazon-ecs-agent/agent/wsclient/mock/utils" + wsmock "github.com/aws/amazon-ecs-agent/agent/wsclient/mock/utils" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/golang/mock/gomock" "github.com/gorilla/websocket" @@ -82,15 +82,15 @@ func TestFormatURL(t *testing.T) { func TestStartSession(t *testing.T) { // Start test server. closeWS := make(chan []byte) - server, serverChan, requestChan, serverErr, err := mockwsutils.StartMockServer(t, closeWS) + server, serverChan, requestChan, serverErr, err := wsmock.StartMockServer(t, closeWS) defer server.Close() if err != nil { t.Fatal(err) } wait := &sync.WaitGroup{} ctx, cancel := context.WithCancel(context.Background()) + wait.Add(1) go func() { - wait.Add(1) select { case sErr := <-serverErr: t.Error(sErr) @@ -142,7 +142,7 @@ func TestStartSession(t *testing.T) { func TestSessionConnectionClosedByRemote(t *testing.T) { // Start test server. closeWS := make(chan []byte) - server, serverChan, _, serverErr, err := mockwsutils.StartMockServer(t, closeWS) + server, serverChan, _, serverErr, err := wsmock.StartMockServer(t, closeWS) defer server.Close() if err != nil { t.Fatal(err) @@ -181,7 +181,7 @@ func TestSessionConnectionClosedByRemote(t *testing.T) { func TestConnectionInactiveTimeout(t *testing.T) { // Start test server. closeWS := make(chan []byte) - server, _, requestChan, serverErr, err := mockwsutils.StartMockServer(t, closeWS) + server, _, requestChan, serverErr, err := wsmock.StartMockServer(t, closeWS) defer server.Close() if err != nil { t.Fatal(err) diff --git a/agent/utils/backoff.go b/agent/utils/backoff.go index 104f65d94fd..5e14ba7dfa9 100644 --- a/agent/utils/backoff.go +++ b/agent/utils/backoff.go @@ -15,7 +15,8 @@ package utils import ( "math" - mathrand "math/rand" + "math/rand" + "sync" "time" ) @@ -30,6 +31,7 @@ type SimpleBackoff struct { max time.Duration jitterMultiple float64 multiple float64 + mu sync.Mutex } // NewSimpleBackoff creates a Backoff which ranges from min to max increasing by @@ -50,6 +52,8 @@ func NewSimpleBackoff(min, max time.Duration, jitterMultiple, multiple float64) } func (sb *SimpleBackoff) Duration() time.Duration { + sb.mu.Lock() + defer sb.mu.Unlock() ret := sb.current sb.current = time.Duration(math.Min(float64(sb.max.Nanoseconds()), float64(float64(sb.current.Nanoseconds())*sb.multiple))) @@ -57,6 +61,8 @@ func (sb *SimpleBackoff) Duration() time.Duration { } func (sb *SimpleBackoff) Reset() { + sb.mu.Lock() + defer sb.mu.Unlock() sb.current = sb.start } @@ -67,7 +73,7 @@ func AddJitter(duration time.Duration, jitter time.Duration) time.Duration { if jitter.Nanoseconds() == 0 { randJitter = 0 } else { - randJitter = mathrand.Int63n(jitter.Nanoseconds()) + randJitter = rand.Int63n(jitter.Nanoseconds()) } return time.Duration(duration.Nanoseconds() + randJitter) } diff --git a/agent/wsclient/client.go b/agent/wsclient/client.go index 145779bee03..47153fb7697 100644 --- a/agent/wsclient/client.go +++ b/agent/wsclient/client.go @@ -104,7 +104,7 @@ type ClientServer interface { io.Closer } -//go:generate go run ../../scripts/generate/mockgen.go github.com/aws/amazon-ecs-agent/agent/wsclient ClientServer mock/$GOFILE +//go:generate go run ../../scripts/generate/mockgen.go github.com/aws/amazon-ecs-agent/agent/wsclient ClientServer,WebsocketConn mock/$GOFILE // ClientServerImpl wraps commonly used methods defined in ClientServer interface. type ClientServerImpl struct { @@ -135,6 +135,8 @@ type ClientServerImpl struct { // 'MakeRequest' can be made after calling this, but responss will not be // receivable until 'Serve' is also called. func (cs *ClientServerImpl) Connect() error { + cs.writeLock.Lock() + defer cs.writeLock.Unlock() parsedURL, err := url.Parse(cs.URL) if err != nil { return err @@ -201,6 +203,8 @@ func (cs *ClientServerImpl) Connect() error { } func (cs *ClientServerImpl) IsReady() bool { + cs.writeLock.Lock() + defer cs.writeLock.Unlock() return cs.conn != nil } @@ -210,6 +214,9 @@ func (cs *ClientServerImpl) SetConnection(conn WebsocketConn) { // Disconnect disconnects the connection func (cs *ClientServerImpl) Disconnect(...interface{}) error { + cs.writeLock.Lock() + defer cs.writeLock.Unlock() + if cs.conn != nil { return cs.conn.Close() } diff --git a/agent/wsclient/client_test.go b/agent/wsclient/client_test.go index f9cc4241208..3773a8ab368 100644 --- a/agent/wsclient/client_test.go +++ b/agent/wsclient/client_test.go @@ -39,7 +39,7 @@ func TestConcurrentWritesDontPanic(t *testing.T) { closeWS := make(chan []byte) defer close(closeWS) - mockServer, _, requests, _, _ := mockwsutils.StartMockServer(t, closeWS) + mockServer, _, requests, _, _ := utils.StartMockServer(t, closeWS) defer mockServer.Close() req := ecsacs.AckRequest{Cluster: aws.String("test"), ContainerInstance: aws.String("test"), MessageId: aws.String("test")} @@ -70,7 +70,7 @@ func TestProxyVariableCustomValue(t *testing.T) { closeWS := make(chan []byte) defer close(closeWS) - mockServer, _, _, _, _ := mockwsutils.StartMockServer(t, closeWS) + mockServer, _, _, _, _ := utils.StartMockServer(t, closeWS) defer mockServer.Close() testString := "Custom no proxy string" @@ -86,7 +86,7 @@ func TestProxyVariableDefaultValue(t *testing.T) { closeWS := make(chan []byte) defer close(closeWS) - mockServer, _, _, _, _ := mockwsutils.StartMockServer(t, closeWS) + mockServer, _, _, _, _ := utils.StartMockServer(t, closeWS) defer mockServer.Close() os.Unsetenv("NO_PROXY") @@ -104,7 +104,7 @@ func TestHandleMessagePermissibleCloseCode(t *testing.T) { defer close(closeWS) messageError := make(chan error) - mockServer, _, _, _, _ := mockwsutils.StartMockServer(t, closeWS) + mockServer, _, _, _, _ := utils.StartMockServer(t, closeWS) cs := getClientServer(mockServer.URL) cs.Connect() @@ -123,7 +123,7 @@ func TestHandleMessageUnexpectedCloseCode(t *testing.T) { defer close(closeWS) messageError := make(chan error) - mockServer, _, _, _, _ := mockwsutils.StartMockServer(t, closeWS) + mockServer, _, _, _, _ := utils.StartMockServer(t, closeWS) cs := getClientServer(mockServer.URL) cs.Connect() diff --git a/agent/wsclient/mock/client.go b/agent/wsclient/mock/client.go index 1f5246d49fe..a3bc2f55b96 100644 --- a/agent/wsclient/mock/client.go +++ b/agent/wsclient/mock/client.go @@ -12,7 +12,7 @@ // permissions and limitations under the License. // Automatically generated by MockGen. DO NOT EDIT! -// Source: github.com/aws/amazon-ecs-agent/agent/wsclient (interfaces: ClientServer) +// Source: github.com/aws/amazon-ecs-agent/agent/wsclient (interfaces: ClientServer,WebsocketConn) package mock_wsclient @@ -139,3 +139,56 @@ func (_m *MockClientServer) WriteMessage(_param0 []byte) error { func (_mr *_MockClientServerRecorder) WriteMessage(arg0 interface{}) *gomock.Call { return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMessage", arg0) } + +// Mock of WebsocketConn interface +type MockWebsocketConn struct { + ctrl *gomock.Controller + recorder *_MockWebsocketConnRecorder +} + +// Recorder for MockWebsocketConn (not exported) +type _MockWebsocketConnRecorder struct { + mock *MockWebsocketConn +} + +func NewMockWebsocketConn(ctrl *gomock.Controller) *MockWebsocketConn { + mock := &MockWebsocketConn{ctrl: ctrl} + mock.recorder = &_MockWebsocketConnRecorder{mock} + return mock +} + +func (_m *MockWebsocketConn) EXPECT() *_MockWebsocketConnRecorder { + return _m.recorder +} + +func (_m *MockWebsocketConn) Close() error { + ret := _m.ctrl.Call(_m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockWebsocketConnRecorder) Close() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "Close") +} + +func (_m *MockWebsocketConn) ReadMessage() (int, []byte, error) { + ret := _m.ctrl.Call(_m, "ReadMessage") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +func (_mr *_MockWebsocketConnRecorder) ReadMessage() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMessage") +} + +func (_m *MockWebsocketConn) WriteMessage(_param0 int, _param1 []byte) error { + ret := _m.ctrl.Call(_m, "WriteMessage", _param0, _param1) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockWebsocketConnRecorder) WriteMessage(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "WriteMessage", arg0, arg1) +} diff --git a/agent/wsclient/mock/utils/utils.go b/agent/wsclient/mock/utils/utils.go index 0fa1dc4c839..2995cb3cef7 100644 --- a/agent/wsclient/mock/utils/utils.go +++ b/agent/wsclient/mock/utils/utils.go @@ -11,7 +11,7 @@ // express or implied. See the License for the specific language governing // permissions and limitations under the License. -package mockwsutils +package utils import ( "net/http" @@ -23,6 +23,7 @@ import ( ) // StartMockServer starts a mock websocket server. +// TODO replace with gomock func StartMockServer(t *testing.T, closeWS <-chan []byte) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { serverChan := make(chan string) requestsChan := make(chan string) diff --git a/appveyor.yml b/appveyor.yml index 1de157e9648..484bab969df 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -13,4 +13,4 @@ install: build_script: - go build ./agent test_script: - - go test -short -timeout=40s ./agent/... + - go test -short -race -timeout=40s ./agent/...