Skip to content

Commit

Permalink
Enabling race detector and fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
petderek committed Jul 13, 2017
1 parent 4f1b919 commit fee2053
Show file tree
Hide file tree
Showing 29 changed files with 472 additions and 502 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ misc/certs/ca-certificates.crt:
docker run "amazon/amazon-ecs-agent-cert-source:make" cat /etc/ssl/certs/ca-certificates.crt > misc/certs/ca-certificates.crt

test:
. ./scripts/shared_env && go test -timeout=25s -v -cover $(shell go list ./agent/... | grep -v /vendor/)
. ./scripts/shared_env && go test -race -timeout=25s -v -cover $(shell go list ./agent/... | grep -v /vendor/)

benchmark-test:
. ./scripts/shared_env && go test -run=XX -bench=. $(shell go list ./agent/... | grep -v /vendor/)
Expand Down
298 changes: 122 additions & 176 deletions agent/acs/client/acs_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@ package acsclient
import (
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"reflect"
"sync"
"testing"
"time"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
"github.com/aws/amazon-ecs-agent/agent/config"
"github.com/aws/amazon-ecs-agent/agent/wsclient"
"github.com/aws/amazon-ecs-agent/agent/wsclient/mock"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/golang/mock/gomock"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)

const sampleCredentialsMessage = `
Expand All @@ -50,146 +52,102 @@ const sampleCredentialsMessage = `
}
`

type messageLogger struct {
writes [][]byte
reads [][]byte
closed bool
}
const (
TestClusterArn = "arn:aws:ec2:123:container/cluster:123456"
TestInstanceArn = "arn:aws:ec2:123:container/containerInstance/12345678"
)

var testCfg = &config.Config{
AcceptInsecureCert: true,
AWSRegion: "us-east-1",
}

func (ml *messageLogger) WriteMessage(_ int, data []byte) error {
if ml.closed {
return errors.New("can't write to closed ws")
}
ml.writes = append(ml.writes, data)
return nil
}

func (ml *messageLogger) Close() error {
ml.closed = true
return nil
}

func (ml *messageLogger) ReadMessage() (int, []byte, error) {
for len(ml.reads) == 0 && !ml.closed {
time.Sleep(1 * time.Millisecond)
}
if ml.closed {
return 0, []byte{}, errors.New("can't read from a closed websocket")
}
read := ml.reads[len(ml.reads)-1]
ml.reads = ml.reads[0 : len(ml.reads)-1]
return websocket.TextMessage, read, nil
}
func TestMakeUnrecognizedRequest(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

func testCS() (wsclient.ClientServer, *messageLogger) {
testCreds := credentials.AnonymousCredentials
conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().Close()

cs := New("localhost:443", testCfg, testCreds).(*clientServer)
ml := &messageLogger{make([][]byte, 0), make([][]byte, 0), false}
cs.SetConnection(ml)
return cs, ml
}

func TestMakeUnrecognizedRequest(t *testing.T) {
cs, _ := testCS()
cs := testCS(conn)
defer cs.Close()
// 'testing.T' should not be a known type ;)
err := cs.MakeRequest(t)
if _, ok := err.(*wsclient.UnrecognizedWSRequestType); !ok {
t.Fatal("Expected unrecognized request type")
}
_ = err.Error() // This is one of those times when 100% coverage is silly
cs.Close()
}

func strptr(s string) *string {
return &s
}

func TestWriteAckRequest(t *testing.T) {
cs, ml := testCS()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().Close()
cs := testCS(conn)
defer cs.Close()

// capture bytes written
var writes []byte
conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Do(func(_ int, data []byte) {
writes = data
})

req := ecsacs.AckRequest{Cluster: strptr("default"), ContainerInstance: strptr("testCI"), MessageId: strptr("messageID")}
err := cs.MakeRequest(&req)
if err != nil {
t.Fatal(err)
}
// send request
err := cs.MakeRequest(&ecsacs.AckRequest{})
assert.NoError(t, err)

write := ml.writes[0]
writtenReq := struct {
Type string
Message ecsacs.AckRequest
}{}
err = json.Unmarshal(write, &writtenReq)
if err != nil {
t.Fatal("Unable to unmarshal written", err)
}
msg := writtenReq.Message
if *msg.Cluster != "default" || *msg.ContainerInstance != "testCI" || *msg.MessageId != "messageID" {
t.Error("Did not write what we expected")
}
cs.Close()
// unmarshal bytes written to the socket
msg := &wsclient.RequestMessage{}
err = json.Unmarshal(writes, msg)
assert.NoError(t, err)
assert.Equal(t, "AckRequest", msg.Type)
}

func TestPayloadHandlerCalled(t *testing.T) {
cs, ml := testCS()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

var handledPayload *ecsacs.PayloadMessage
conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(`{"type":"PayloadMessage","message":{"tasks":[{"arn":"arn"}]}}`), nil)
conn.EXPECT().Close()
cs := testCS(conn)
defer cs.Close()

messageChannel := make(chan *ecsacs.PayloadMessage)
reqHandler := func(payload *ecsacs.PayloadMessage) {
handledPayload = payload
messageChannel <- payload
}
cs.AddRequestHandler(reqHandler)
go cs.Serve()

ml.reads = [][]byte{[]byte(`{"type":"PayloadMessage","message":{"tasks":[{"arn":"arn"}]}}`)}

var isClosed bool
go func() {
err := cs.Serve()
if !isClosed && err != nil {
t.Fatal("Premature end of serving", err)
}
}()

time.Sleep(1 * time.Millisecond)
if handledPayload == nil {
t.Fatal("Handler was not called")
expectedMessage := &ecsacs.PayloadMessage{
Tasks: []*ecsacs.Task{{
Arn: aws.String("arn"),
}},
}

if len(handledPayload.Tasks) != 1 || *handledPayload.Tasks[0].Arn != "arn" {
t.Error("Unmarshalled data did not contain expected values")
}

isClosed = true
cs.Close()
assert.Equal(t, expectedMessage, <-messageChannel)
}

func TestRefreshCredentialsHandlerCalled(t *testing.T) {
cs, ml := testCS()
ctrl := gomock.NewController(t)
defer ctrl.Finish()


wait := sync.WaitGroup{}
wait.Add(1)
var handledMessage *ecsacs.IAMRoleCredentialsMessage
conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(sampleCredentialsMessage), nil)
conn.EXPECT().Close()
cs := testCS(conn)
defer cs.Close()

messageChannel := make(chan *ecsacs.IAMRoleCredentialsMessage)
reqHandler := func(message *ecsacs.IAMRoleCredentialsMessage) {
wait.Done()
handledMessage = message
messageChannel <- message
}
cs.AddRequestHandler(reqHandler)

ml.reads = [][]byte{[]byte(sampleCredentialsMessage)}

var isClosed bool
go func() {
err := cs.Serve()
if !isClosed && err != nil {
t.Fatal("Premature end of serving", err)
}
}()

wait.Wait()
go cs.Serve()

expectedMessage := &ecsacs.IAMRoleCredentialsMessage{
MessageId: aws.String("123"),
Expand All @@ -203,79 +161,24 @@ func TestRefreshCredentialsHandlerCalled(t *testing.T) {
SessionToken: aws.String("token"),
},
}

if !reflect.DeepEqual(handledMessage, expectedMessage) {
t.Error("Unmarshalled credential message did not contain expected values")
}

isClosed = true
cs.Close()
assert.Equal(t, <-messageChannel, expectedMessage)
}

func TestClosingConnection(t *testing.T) {
cs, ml := testCS()
closedChan := make(chan error)
var expectedClosed bool
go func() {
err := cs.Serve()
if !expectedClosed {
t.Fatal("Serve ended early")
}
closedChan <- err
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

expectedClosed = true
ml.Close()
err := <-closedChan
if err == nil {
t.Error("Closing was expected to result in an error")
}

req := ecsacs.AckRequest{Cluster: strptr("default"), ContainerInstance: strptr("testCI"), MessageId: strptr("messageID")}
err = cs.MakeRequest(&req)
if err == nil {
t.Error("Cannot request over closed connection")
}
}

const (
TestClusterArn = "arn:aws:ec2:123:container/cluster:123456"
TestInstanceArn = "arn:aws:ec2:123:container/containerInstance/12345678"
)
// Returning EOF tells the ClientServer that the connection is closed
conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().ReadMessage().Return(0, nil, io.EOF)
conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(io.EOF)
cs := testCS(conn)

func startMockAcsServer(t *testing.T, closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) {
serverChan := make(chan string)
requestsChan := make(chan string)
errChan := make(chan error)
serveErr := cs.Serve()
assert.Error(t, serveErr)

upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
go func() {
<-closeWS
ws.Close()
}()
if err != nil {
errChan <- err
}
go func() {
_, msg, err := ws.ReadMessage()
if err != nil {
errChan <- err
} else {
requestsChan <- string(msg)
}
}()
for str := range serverChan {
err := ws.WriteMessage(websocket.TextMessage, []byte(str))
if err != nil {
errChan <- err
}
}
})

server := httptest.NewTLSServer(handler)
return server, serverChan, requestsChan, errChan, nil
err := cs.MakeRequest(&ecsacs.AckRequest{})
assert.Error(t, err)
}

func TestConnect(t *testing.T) {
Expand Down Expand Up @@ -362,7 +265,50 @@ func TestConnectClientError(t *testing.T) {

cs := New(testServer.URL, testCfg, credentials.AnonymousCredentials)
err := cs.Connect()
if _, ok := err.(*wsclient.WSError); !ok || err.Error() != "InvalidClusterException: Invalid cluster" {
t.Error("Did not get correctly typed error: " + err.Error())
}
_, ok := err.(*wsclient.WSError)
assert.True(t, ok)
assert.EqualError(t, err, "InvalidClusterException: Invalid cluster")
}

func testCS(conn *mock_wsclient.MockWebsocketConn) wsclient.ClientServer {
testCreds := credentials.AnonymousCredentials
cs := New("localhost:443", testCfg, testCreds).(*clientServer)
cs.SetConnection(conn)
return cs
}

// TODO: replace with gomock
func startMockAcsServer(t *testing.T, closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) {
serverChan := make(chan string)
requestsChan := make(chan string)
errChan := make(chan error)

upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
go func() {
<-closeWS
ws.Close()
}()
if err != nil {
errChan <- err
}
go func() {
_, msg, err := ws.ReadMessage()
if err != nil {
errChan <- err
} else {
requestsChan <- string(msg)
}
}()
for str := range serverChan {
err := ws.WriteMessage(websocket.TextMessage, []byte(str))
if err != nil {
errChan <- err
}
}
})

server := httptest.NewTLSServer(handler)
return server, serverChan, requestsChan, errChan, nil
}
Loading

0 comments on commit fee2053

Please sign in to comment.