Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass agent context to subroutines #2463

Merged
merged 1 commit into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions agent/app/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,26 +616,26 @@ func (agent *ecsAgent) startAsyncRoutines(

// Start automatic spot instance draining poller routine
if agent.cfg.SpotInstanceDrainingEnabled {
go agent.startSpotInstanceDrainingPoller(client)
go agent.startSpotInstanceDrainingPoller(agent.ctx, client)
}

go agent.terminationHandler(stateManager, taskEngine)

// Agent introspection api
go handlers.ServeIntrospectionHTTPEndpoint(&agent.containerInstanceARN, taskEngine, agent.cfg)
go handlers.ServeIntrospectionHTTPEndpoint(agent.ctx, &agent.containerInstanceARN, taskEngine, agent.cfg)

statsEngine := stats.NewDockerStatsEngine(agent.cfg, agent.dockerClient, containerChangeEventStream)

// Start serving the endpoint to fetch IAM Role credentials and other task metadata
if agent.cfg.TaskMetadataAZDisabled {
// send empty availability zone
go handlers.ServeTaskHTTPEndpoint(credentialsManager, state, client, agent.containerInstanceARN, agent.cfg, statsEngine, "")
go handlers.ServeTaskHTTPEndpoint(agent.ctx, credentialsManager, state, client, agent.containerInstanceARN, agent.cfg, statsEngine, "")
} else {
go handlers.ServeTaskHTTPEndpoint(credentialsManager, state, client, agent.containerInstanceARN, agent.cfg, statsEngine, agent.availabilityZone)
go handlers.ServeTaskHTTPEndpoint(agent.ctx, credentialsManager, state, client, agent.containerInstanceARN, agent.cfg, statsEngine, agent.availabilityZone)
}

// Start sending events to the backend
go eventhandler.HandleEngineEvents(taskEngine, client, taskHandler, attachmentEventHandler)
go eventhandler.HandleEngineEvents(agent.ctx, taskEngine, client, taskHandler, attachmentEventHandler)

telemetrySessionParams := tcshandler.TelemetrySessionParams{
Ctx: agent.ctx,
Expand All @@ -652,9 +652,14 @@ func (agent *ecsAgent) startAsyncRoutines(
go tcshandler.StartMetricsSession(&telemetrySessionParams)
}

func (agent *ecsAgent) startSpotInstanceDrainingPoller(client api.ECSClient) {
func (agent *ecsAgent) startSpotInstanceDrainingPoller(ctx context.Context, client api.ECSClient) {
for !agent.spotInstanceDrainingPoller(client) {
time.Sleep(time.Second)
select {
case <-ctx.Done():
return
default:
time.Sleep(time.Second)
}
}
}

Expand Down
10 changes: 9 additions & 1 deletion agent/eventhandler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package eventhandler

import (
"context"
"fmt"

"github.com/aws/amazon-ecs-agent/agent/api"
Expand All @@ -24,13 +25,20 @@ import (

// HandleEngineEvents handles state change events from the state change event channel by sending it to
// responsible event handler
func HandleEngineEvents(taskEngine engine.TaskEngine, client api.ECSClient, taskHandler *TaskHandler,
func HandleEngineEvents(
ctx context.Context,
taskEngine engine.TaskEngine,
client api.ECSClient,
taskHandler *TaskHandler,
attachmentEventHandler *AttachmentEventHandler) {
for {
stateChangeEvents := taskEngine.StateChangeEvents()

for stateChangeEvents != nil {
select {
case <-ctx.Done():
seelog.Infof("Exiting the engine event handler.")
return
case event, ok := <-stateChangeEvents:
if !ok {
stateChangeEvents = nil
Expand Down
27 changes: 17 additions & 10 deletions agent/handlers/introspection_server_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
package handlers

import (
"context"
"encoding/json"
"net/http"
"strconv"
"sync"
"time"

"github.com/aws/amazon-ecs-agent/agent/config"
Expand Down Expand Up @@ -78,22 +78,29 @@ func v1HandlersSetup(serverMux *http.ServeMux,
// ServeIntrospectionHTTPEndpoint serves information about this agent/containerInstance and tasks
// running on it. "V1" here indicates the hostname version of this server instead
// of the handler versions, i.e. "V1" server can include "V1" and "V2" handlers.
func ServeIntrospectionHTTPEndpoint(containerInstanceArn *string, taskEngine engine.TaskEngine, cfg *config.Config) {
func ServeIntrospectionHTTPEndpoint(ctx context.Context, containerInstanceArn *string, taskEngine engine.TaskEngine, cfg *config.Config) {
// Is this the right level to type assert, assuming we'd abstract multiple taskengines here?
// Revisit if we ever add another type..
dockerTaskEngine := taskEngine.(*engine.DockerTaskEngine)

server := introspectionServerSetup(containerInstanceArn, dockerTaskEngine, cfg)

go func() {
<-ctx.Done()
if err := server.Shutdown(context.Background()); err != nil {
// Error from closing listeners, or context timeout:
seelog.Infof("HTTP server Shutdown: %v", err)
}
}()

for {
once := sync.Once{}
retry.RetryWithBackoff(retry.NewExponentialBackoff(time.Second, time.Minute, 0.2, 2), func() error {
// TODO, make this cancellable and use the passed in context; for
// now, not critical if this gets interrupted
err := server.ListenAndServe()
once.Do(func() {
seelog.Error("Error running http api", "err", err)
})
return err
if err := server.ListenAndServe(); err != http.ErrServerClosed {
seelog.Errorf("Error running introspection endpoint: %v", err)
return err
}
// server was cleanly closed via context
return nil
})
}
}
24 changes: 17 additions & 7 deletions agent/handlers/task_server_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package handlers

import (
"context"
"net/http"
"strconv"
"time"
Expand Down Expand Up @@ -153,15 +154,16 @@ func v4HandlersSetup(muxRouter *mux.Router,

// ServeTaskHTTPEndpoint serves task/container metadata, task/container stats, and IAM Role Credentials
// for tasks being managed by the agent.
func ServeTaskHTTPEndpoint(credentialsManager credentials.Manager,
func ServeTaskHTTPEndpoint(
ctx context.Context,
credentialsManager credentials.Manager,
state dockerstate.TaskEngineState,
ecsClient api.ECSClient,
containerInstanceArn string,
cfg *config.Config,
statsEngine stats.Engine,
availabilityZone string) {
// Create and initialize the audit log
// TODO Use seelog's programmatic configuration instead of xml.
logger, err := seelog.LoggerFromConfigAsString(audit.AuditLoggerConfig(cfg))
if err != nil {
seelog.Errorf("Error initializing the audit log: %v", err)
Expand All @@ -174,14 +176,22 @@ func ServeTaskHTTPEndpoint(credentialsManager credentials.Manager,
server := taskServerSetup(credentialsManager, auditLogger, state, ecsClient, cfg.Cluster, statsEngine,
cfg.TaskMetadataSteadyStateRate, cfg.TaskMetadataBurstRate, availabilityZone, containerInstanceArn)

go func() {
<-ctx.Done()
if err := server.Shutdown(context.Background()); err != nil {
// Error from closing listeners, or context timeout:
seelog.Infof("HTTP server Shutdown: %v", err)
}
}()

for {
retry.RetryWithBackoff(retry.NewExponentialBackoff(time.Second, time.Minute, 0.2, 2), func() error {
// TODO, make this cancellable and use the passed in context;
err := server.ListenAndServe()
if err != nil {
seelog.Errorf("Error running http api: %v", err)
if err := server.ListenAndServe(); err != http.ErrServerClosed {
fierlion marked this conversation as resolved.
Show resolved Hide resolved
seelog.Errorf("Error running task api: %v", err)
return err
}
return err
// server was cleanly closed via context
return nil
})
}
}
31 changes: 22 additions & 9 deletions agent/tcs/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package tcshandler

import (
"context"
"io"
"net/url"
"strings"
Expand Down Expand Up @@ -85,6 +86,7 @@ func StartSession(params *TelemetrySessionParams, statsEngine stats.Engine) erro
}
select {
case <-params.Ctx.Done():
seelog.Info("TCS session exited cleanly.")
return nil
default:
}
Expand All @@ -98,12 +100,14 @@ func startTelemetrySession(params *TelemetrySessionParams, statsEngine stats.Eng
return err
}
url := formatURL(tcsEndpoint, params.Cfg.Cluster, params.ContainerInstanceArn, params.TaskEngine)
return startSession(url, params.Cfg, params.CredentialProvider, statsEngine,
return startSession(params.Ctx, url, params.Cfg, params.CredentialProvider, statsEngine,
defaultHeartbeatTimeout, defaultHeartbeatJitter, config.DefaultContainerMetricsPublishInterval,
params.DeregisterInstanceEventStream)
}

func startSession(url string,
func startSession(
ctx context.Context,
url string,
cfg *config.Config,
credentialProvider *credentials.Credentials,
statsEngine stats.Engine,
Expand All @@ -129,18 +133,27 @@ func startSession(url string,
// start a timer and listens for tcs heartbeats/acks. The timer is reset when
// we receive a heartbeat from the server or when a publish metrics message
// is acked.
timer := time.AfterFunc(retry.AddJitter(heartbeatTimeout, heartbeatJitter), func() {
// Close the connection if there haven't been any messages received from backend
// for a long time.
seelog.Info("TCS Connection hasn't had any activity for too long; disconnecting")
client.Disconnect()
})
timer := time.NewTimer(retry.AddJitter(heartbeatTimeout, heartbeatJitter))
defer timer.Stop()
client.AddRequestHandler(heartbeatHandler(timer))
client.AddRequestHandler(ackPublishMetricHandler(timer))
client.AddRequestHandler(ackPublishHealthMetricHandler(timer))
client.SetAnyRequestHandler(anyMessageHandler(client))
return client.Serve()
serveC := make(chan error)
go func() {
serveC <- client.Serve()
}()
select {
case <-ctx.Done():
// outer context done, agent is exiting
client.Disconnect()
case <-timer.C:
seelog.Info("TCS Connection hasn't had any activity for too long; disconnecting")
client.Disconnect()
case err := <-serveC:
return err
}
return nil
}

// heartbeatHandler resets the heartbeat timer when HeartbeatMessage message is received from tcs.
Expand Down
8 changes: 4 additions & 4 deletions agent/tcs/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestStartSession(t *testing.T) {

deregisterInstanceEventStream := eventstream.NewEventStream("Deregister_Instance", context.Background())
// Start a session with the test server.
go startSession(server.URL, testCfg, testCreds, &mockStatsEngine{},
go startSession(ctx, server.URL, testCfg, testCreds, &mockStatsEngine{},
defaultHeartbeatTimeout, defaultHeartbeatJitter,
testPublishMetricsInterval, deregisterInstanceEventStream)

Expand Down Expand Up @@ -197,7 +197,7 @@ func TestSessionConnectionClosedByRemote(t *testing.T) {
defer cancel()

// Start a session with the test server.
err = startSession(server.URL, testCfg, testCreds, &mockStatsEngine{},
err = startSession(ctx, server.URL, testCfg, testCreds, &mockStatsEngine{},
defaultHeartbeatTimeout, defaultHeartbeatJitter,
testPublishMetricsInterval, deregisterInstanceEventStream)

Expand Down Expand Up @@ -234,11 +234,11 @@ func TestConnectionInactiveTimeout(t *testing.T) {
deregisterInstanceEventStream.StartListening()
defer cancel()
// Start a session with the test server.
err = startSession(server.URL, testCfg, testCreds, &mockStatsEngine{},
err = startSession(ctx, server.URL, testCfg, testCreds, &mockStatsEngine{},
50*time.Millisecond, 100*time.Millisecond,
testPublishMetricsInterval, deregisterInstanceEventStream)
// if we are not blocked here, then the test pass as it will reconnect in StartSession
assert.Error(t, err, "Close the connection should cause the tcs client return error")
assert.NoError(t, err, "Close the connection should cause the tcs client return error")

assert.True(t, websocket.IsCloseError(<-serverErr, websocket.CloseAbnormalClosure),
"Read from closed connection should produce an io.EOF error")
Expand Down