diff --git a/agent/app/agent.go b/agent/app/agent.go index 757a8e14d97..9b481df8482 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -616,26 +616,26 @@ func (agent *ecsAgent) startAsyncRoutines( // Start automatic spot instance draining poller routine if agent.cfg.SpotInstanceDrainingEnabled { - go agent.startSpotInstanceDrainingPoller(client) + go agent.startSpotInstanceDrainingPoller(agent.ctx, client) } go agent.terminationHandler(stateManager, taskEngine) // Agent introspection api - go handlers.ServeIntrospectionHTTPEndpoint(&agent.containerInstanceARN, taskEngine, agent.cfg) + go handlers.ServeIntrospectionHTTPEndpoint(agent.ctx, &agent.containerInstanceARN, taskEngine, agent.cfg) statsEngine := stats.NewDockerStatsEngine(agent.cfg, agent.dockerClient, containerChangeEventStream) // Start serving the endpoint to fetch IAM Role credentials and other task metadata if agent.cfg.TaskMetadataAZDisabled { // send empty availability zone - go handlers.ServeTaskHTTPEndpoint(credentialsManager, state, client, agent.containerInstanceARN, agent.cfg, statsEngine, "") + go handlers.ServeTaskHTTPEndpoint(agent.ctx, credentialsManager, state, client, agent.containerInstanceARN, agent.cfg, statsEngine, "") } else { - go handlers.ServeTaskHTTPEndpoint(credentialsManager, state, client, agent.containerInstanceARN, agent.cfg, statsEngine, agent.availabilityZone) + go handlers.ServeTaskHTTPEndpoint(agent.ctx, credentialsManager, state, client, agent.containerInstanceARN, agent.cfg, statsEngine, agent.availabilityZone) } // Start sending events to the backend - go eventhandler.HandleEngineEvents(taskEngine, client, taskHandler, attachmentEventHandler) + go eventhandler.HandleEngineEvents(agent.ctx, taskEngine, client, taskHandler, attachmentEventHandler) telemetrySessionParams := tcshandler.TelemetrySessionParams{ Ctx: agent.ctx, @@ -652,9 +652,14 @@ func (agent *ecsAgent) startAsyncRoutines( go tcshandler.StartMetricsSession(&telemetrySessionParams) } -func (agent *ecsAgent) startSpotInstanceDrainingPoller(client api.ECSClient) { +func (agent *ecsAgent) startSpotInstanceDrainingPoller(ctx context.Context, client api.ECSClient) { for !agent.spotInstanceDrainingPoller(client) { - time.Sleep(time.Second) + select { + case <-ctx.Done(): + return + default: + time.Sleep(time.Second) + } } } diff --git a/agent/eventhandler/handler.go b/agent/eventhandler/handler.go index 4af5af1d32b..6d5ec5a79b6 100644 --- a/agent/eventhandler/handler.go +++ b/agent/eventhandler/handler.go @@ -14,6 +14,7 @@ package eventhandler import ( + "context" "fmt" "github.com/aws/amazon-ecs-agent/agent/api" @@ -24,13 +25,20 @@ import ( // HandleEngineEvents handles state change events from the state change event channel by sending it to // responsible event handler -func HandleEngineEvents(taskEngine engine.TaskEngine, client api.ECSClient, taskHandler *TaskHandler, +func HandleEngineEvents( + ctx context.Context, + taskEngine engine.TaskEngine, + client api.ECSClient, + taskHandler *TaskHandler, attachmentEventHandler *AttachmentEventHandler) { for { stateChangeEvents := taskEngine.StateChangeEvents() for stateChangeEvents != nil { select { + case <-ctx.Done(): + seelog.Infof("Exiting the engine event handler.") + return case event, ok := <-stateChangeEvents: if !ok { stateChangeEvents = nil diff --git a/agent/handlers/introspection_server_setup.go b/agent/handlers/introspection_server_setup.go index 42c7b0c6aaf..07f138a2237 100644 --- a/agent/handlers/introspection_server_setup.go +++ b/agent/handlers/introspection_server_setup.go @@ -15,10 +15,10 @@ package handlers import ( + "context" "encoding/json" "net/http" "strconv" - "sync" "time" "github.com/aws/amazon-ecs-agent/agent/config" @@ -78,22 +78,29 @@ func v1HandlersSetup(serverMux *http.ServeMux, // ServeIntrospectionHTTPEndpoint serves information about this agent/containerInstance and tasks // running on it. "V1" here indicates the hostname version of this server instead // of the handler versions, i.e. "V1" server can include "V1" and "V2" handlers. -func ServeIntrospectionHTTPEndpoint(containerInstanceArn *string, taskEngine engine.TaskEngine, cfg *config.Config) { +func ServeIntrospectionHTTPEndpoint(ctx context.Context, containerInstanceArn *string, taskEngine engine.TaskEngine, cfg *config.Config) { // Is this the right level to type assert, assuming we'd abstract multiple taskengines here? // Revisit if we ever add another type.. dockerTaskEngine := taskEngine.(*engine.DockerTaskEngine) server := introspectionServerSetup(containerInstanceArn, dockerTaskEngine, cfg) + + go func() { + <-ctx.Done() + if err := server.Shutdown(context.Background()); err != nil { + // Error from closing listeners, or context timeout: + seelog.Infof("HTTP server Shutdown: %v", err) + } + }() + for { - once := sync.Once{} retry.RetryWithBackoff(retry.NewExponentialBackoff(time.Second, time.Minute, 0.2, 2), func() error { - // TODO, make this cancellable and use the passed in context; for - // now, not critical if this gets interrupted - err := server.ListenAndServe() - once.Do(func() { - seelog.Error("Error running http api", "err", err) - }) - return err + if err := server.ListenAndServe(); err != http.ErrServerClosed { + seelog.Errorf("Error running introspection endpoint: %v", err) + return err + } + // server was cleanly closed via context + return nil }) } } diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index f0beca5b22e..f4bba9953c0 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -14,6 +14,7 @@ package handlers import ( + "context" "net/http" "strconv" "time" @@ -153,7 +154,9 @@ func v4HandlersSetup(muxRouter *mux.Router, // ServeTaskHTTPEndpoint serves task/container metadata, task/container stats, and IAM Role Credentials // for tasks being managed by the agent. -func ServeTaskHTTPEndpoint(credentialsManager credentials.Manager, +func ServeTaskHTTPEndpoint( + ctx context.Context, + credentialsManager credentials.Manager, state dockerstate.TaskEngineState, ecsClient api.ECSClient, containerInstanceArn string, @@ -161,7 +164,6 @@ func ServeTaskHTTPEndpoint(credentialsManager credentials.Manager, statsEngine stats.Engine, availabilityZone string) { // Create and initialize the audit log - // TODO Use seelog's programmatic configuration instead of xml. logger, err := seelog.LoggerFromConfigAsString(audit.AuditLoggerConfig(cfg)) if err != nil { seelog.Errorf("Error initializing the audit log: %v", err) @@ -174,14 +176,22 @@ func ServeTaskHTTPEndpoint(credentialsManager credentials.Manager, server := taskServerSetup(credentialsManager, auditLogger, state, ecsClient, cfg.Cluster, statsEngine, cfg.TaskMetadataSteadyStateRate, cfg.TaskMetadataBurstRate, availabilityZone, containerInstanceArn) + go func() { + <-ctx.Done() + if err := server.Shutdown(context.Background()); err != nil { + // Error from closing listeners, or context timeout: + seelog.Infof("HTTP server Shutdown: %v", err) + } + }() + for { retry.RetryWithBackoff(retry.NewExponentialBackoff(time.Second, time.Minute, 0.2, 2), func() error { - // TODO, make this cancellable and use the passed in context; - err := server.ListenAndServe() - if err != nil { - seelog.Errorf("Error running http api: %v", err) + if err := server.ListenAndServe(); err != http.ErrServerClosed { + seelog.Errorf("Error running task api: %v", err) + return err } - return err + // server was cleanly closed via context + return nil }) } } diff --git a/agent/tcs/handler/handler.go b/agent/tcs/handler/handler.go index 81ab189681e..a0924cd4187 100644 --- a/agent/tcs/handler/handler.go +++ b/agent/tcs/handler/handler.go @@ -14,6 +14,7 @@ package tcshandler import ( + "context" "io" "net/url" "strings" @@ -85,6 +86,7 @@ func StartSession(params *TelemetrySessionParams, statsEngine stats.Engine) erro } select { case <-params.Ctx.Done(): + seelog.Info("TCS session exited cleanly.") return nil default: } @@ -98,12 +100,14 @@ func startTelemetrySession(params *TelemetrySessionParams, statsEngine stats.Eng return err } url := formatURL(tcsEndpoint, params.Cfg.Cluster, params.ContainerInstanceArn, params.TaskEngine) - return startSession(url, params.Cfg, params.CredentialProvider, statsEngine, + return startSession(params.Ctx, url, params.Cfg, params.CredentialProvider, statsEngine, defaultHeartbeatTimeout, defaultHeartbeatJitter, config.DefaultContainerMetricsPublishInterval, params.DeregisterInstanceEventStream) } -func startSession(url string, +func startSession( + ctx context.Context, + url string, cfg *config.Config, credentialProvider *credentials.Credentials, statsEngine stats.Engine, @@ -129,18 +133,27 @@ func startSession(url string, // start a timer and listens for tcs heartbeats/acks. The timer is reset when // we receive a heartbeat from the server or when a publish metrics message // is acked. - timer := time.AfterFunc(retry.AddJitter(heartbeatTimeout, heartbeatJitter), func() { - // Close the connection if there haven't been any messages received from backend - // for a long time. - seelog.Info("TCS Connection hasn't had any activity for too long; disconnecting") - client.Disconnect() - }) + timer := time.NewTimer(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) defer timer.Stop() client.AddRequestHandler(heartbeatHandler(timer)) client.AddRequestHandler(ackPublishMetricHandler(timer)) client.AddRequestHandler(ackPublishHealthMetricHandler(timer)) client.SetAnyRequestHandler(anyMessageHandler(client)) - return client.Serve() + serveC := make(chan error) + go func() { + serveC <- client.Serve() + }() + select { + case <-ctx.Done(): + // outer context done, agent is exiting + client.Disconnect() + case <-timer.C: + seelog.Info("TCS Connection hasn't had any activity for too long; disconnecting") + client.Disconnect() + case err := <-serveC: + return err + } + return nil } // heartbeatHandler resets the heartbeat timer when HeartbeatMessage message is received from tcs. diff --git a/agent/tcs/handler/handler_test.go b/agent/tcs/handler/handler_test.go index fcf2f001145..52d67c25d74 100644 --- a/agent/tcs/handler/handler_test.go +++ b/agent/tcs/handler/handler_test.go @@ -133,7 +133,7 @@ func TestStartSession(t *testing.T) { deregisterInstanceEventStream := eventstream.NewEventStream("Deregister_Instance", context.Background()) // Start a session with the test server. - go startSession(server.URL, testCfg, testCreds, &mockStatsEngine{}, + go startSession(ctx, server.URL, testCfg, testCreds, &mockStatsEngine{}, defaultHeartbeatTimeout, defaultHeartbeatJitter, testPublishMetricsInterval, deregisterInstanceEventStream) @@ -197,7 +197,7 @@ func TestSessionConnectionClosedByRemote(t *testing.T) { defer cancel() // Start a session with the test server. - err = startSession(server.URL, testCfg, testCreds, &mockStatsEngine{}, + err = startSession(ctx, server.URL, testCfg, testCreds, &mockStatsEngine{}, defaultHeartbeatTimeout, defaultHeartbeatJitter, testPublishMetricsInterval, deregisterInstanceEventStream) @@ -234,11 +234,11 @@ func TestConnectionInactiveTimeout(t *testing.T) { deregisterInstanceEventStream.StartListening() defer cancel() // Start a session with the test server. - err = startSession(server.URL, testCfg, testCreds, &mockStatsEngine{}, + err = startSession(ctx, server.URL, testCfg, testCreds, &mockStatsEngine{}, 50*time.Millisecond, 100*time.Millisecond, testPublishMetricsInterval, deregisterInstanceEventStream) // if we are not blocked here, then the test pass as it will reconnect in StartSession - assert.Error(t, err, "Close the connection should cause the tcs client return error") + assert.NoError(t, err, "Close the connection should cause the tcs client return error") assert.True(t, websocket.IsCloseError(<-serverErr, websocket.CloseAbnormalClosure), "Read from closed connection should produce an io.EOF error")