Skip to content

Commit

Permalink
Cancel agent's top-level context on exit (#2462)
Browse files Browse the repository at this point in the history
  • Loading branch information
sparrc authored Jun 1, 2020
1 parent 602f8ef commit 3e3a675
Show file tree
Hide file tree
Showing 12 changed files with 72 additions and 52 deletions.
12 changes: 9 additions & 3 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ func (acsSession *session) Start() error {
seelog.Debugf("Received connect to ACS message")
// Start a session with ACS
acsError := acsSession.startSessionOnce()
select {
case <-acsSession.ctx.Done():
// agent is shutting down, exiting cleanly
return nil
default:
}
// Session with ACS was stopped with some error, start processing the error
isInactiveInstance := isInactiveInstanceError(acsError)
if isInactiveInstance {
Expand Down Expand Up @@ -231,8 +237,8 @@ func (acsSession *session) Start() error {
}
}
case <-acsSession.ctx.Done():
seelog.Debugf("ACS session context cancelled")
return acsSession.ctx.Err()
// agent is shutting down, exiting cleanly
return nil
}

}
Expand Down Expand Up @@ -366,7 +372,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
case <-acsSession.ctx.Done():
// Stop receiving and sending messages from and to ACS when
// the context received from the main function is canceled
seelog.Infof("ACS session context cancelled.")
seelog.Infof("ACS session exited cleanly.")
return acsSession.ctx.Err()
case err := <-serveErr:
// Stop receiving and sending messages from and to ACS when
Expand Down
16 changes: 8 additions & 8 deletions agent/app/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ type agent interface {
// the newAgent() method
type ecsAgent struct {
ctx context.Context
cancel context.CancelFunc
ec2MetadataClient ec2.EC2MetadataClient
ec2Client ec2.Client
cfg *config.Config
Expand All @@ -115,11 +116,8 @@ type ecsAgent struct {
}

// newAgent returns a new ecsAgent object, but does not start anything
func newAgent(
ctx context.Context,
blackholeEC2Metadata bool,
acceptInsecureCert *bool) (agent, error) {

func newAgent(blackholeEC2Metadata bool, acceptInsecureCert *bool) (agent, error) {
ctx, cancel := context.WithCancel(context.Background())
ec2MetadataClient := ec2.NewEC2MetadataClient(nil)
if blackholeEC2Metadata {
ec2MetadataClient = ec2.NewBlackholeEC2MetadataClient()
Expand All @@ -131,6 +129,7 @@ func newAgent(
// All required config values can be inferred from EC2 Metadata,
// so this error could be transient.
seelog.Criticalf("Error loading config: %v", err)
cancel()
return nil, err
}
cfg.AcceptInsecureCert = aws.BoolValue(acceptInsecureCert)
Expand All @@ -146,6 +145,7 @@ func newAgent(
if err != nil {
// This is also non terminal in the current config
seelog.Criticalf("Error creating Docker client: %v", err)
cancel()
return nil, err
}

Expand All @@ -160,6 +160,7 @@ func newAgent(
initialSeqNumber := int64(-1)
return &ecsAgent{
ctx: ctx,
cancel: cancel,
ec2MetadataClient: ec2MetadataClient,
ec2Client: ec2Client,
cfg: cfg,
Expand Down Expand Up @@ -616,7 +617,7 @@ func (agent *ecsAgent) startAsyncRoutines(
go agent.startSpotInstanceDrainingPoller(agent.ctx, client)
}

go agent.terminationHandler(stateManager, taskEngine)
go agent.terminationHandler(stateManager, taskEngine, agent.cancel)

// Agent introspection api
go handlers.ServeIntrospectionHTTPEndpoint(agent.ctx, &agent.containerInstanceARN, taskEngine, agent.cfg)
Expand Down Expand Up @@ -727,8 +728,7 @@ func (agent *ecsAgent) startACSSession(
seelog.Criticalf("Unretriable error starting communicating with ACS: %v", err)
return exitcodes.ExitTerminal
}
seelog.Critical("ACS Session handler should never exit")
return exitcodes.ExitError
return exitcodes.ExitSuccess
}

// validateRequiredVersion validates docker version.
Expand Down
4 changes: 1 addition & 3 deletions agent/app/agent_integ_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package app

import (
"context"
"os"
"testing"

Expand All @@ -25,11 +24,10 @@ import (
)

func TestNewAgent(t *testing.T) {
ctx := context.TODO()
os.Setenv("AWS_DEFAULT_REGION", "us-west-2")
defer os.Unsetenv("AWS_DEFAULT_REGION")

agent, err := newAgent(ctx, true, aws.Bool(true))
agent, err := newAgent(true, aws.Bool(true))

assert.NoError(t, err)
// printECSAttributes should ensure that agent's cfg is set with
Expand Down
2 changes: 1 addition & 1 deletion agent/app/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func TestDoStartRegisterAvailabilityZone(t *testing.T) {
credentialProvider: aws_credentials.NewCredentials(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
metadataManager: containermetadata,
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine) {},
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {},
ec2MetadataClient: ec2MetadataClient,
}

Expand Down
14 changes: 7 additions & 7 deletions agent/app/agent_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func TestDoStartHappyPath(t *testing.T) {
credentialProvider: credentials.NewCredentials(mockCredentialsProvider),
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine) {},
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {},
mobyPlugins: mockMobyPlugins,
ec2MetadataClient: ec2MetadataClient,
}
Expand Down Expand Up @@ -235,7 +235,7 @@ func TestDoStartTaskENIHappyPath(t *testing.T) {
pauseLoader: mockPauseLoader,
cniClient: cniClient,
ec2MetadataClient: mockMetadata,
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine) {},
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {},
mobyPlugins: mockMobyPlugins,
}

Expand Down Expand Up @@ -550,7 +550,7 @@ func TestDoStartCgroupInitHappyPath(t *testing.T) {
credentialProvider: credentials.NewCredentials(mockCredentialsProvider),
pauseLoader: mockPauseLoader,
dockerClient: dockerClient,
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine) {},
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {},
mobyPlugins: mockMobyPlugins,
ec2MetadataClient: ec2MetadataClient,
resourceFields: &taskresource.ResourceFields{
Expand Down Expand Up @@ -607,7 +607,7 @@ func TestDoStartCgroupInitErrorPath(t *testing.T) {
credentialProvider: credentials.NewCredentials(mockCredentialsProvider),
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine) {},
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {},
resourceFields: &taskresource.ResourceFields{
Control: mockControl,
},
Expand Down Expand Up @@ -697,7 +697,7 @@ func TestDoStartGPUManagerHappyPath(t *testing.T) {
credentialProvider: credentials.NewCredentials(mockCredentialsProvider),
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine) {},
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {},
mobyPlugins: mockMobyPlugins,
ec2MetadataClient: ec2MetadataClient,
resourceFields: &taskresource.ResourceFields{
Expand Down Expand Up @@ -752,7 +752,7 @@ func TestDoStartGPUManagerInitError(t *testing.T) {
credentialProvider: credentials.NewCredentials(mockCredentialsProvider),
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine) {},
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {},
resourceFields: &taskresource.ResourceFields{
NvidiaGPUManager: mockGPUManager,
},
Expand Down Expand Up @@ -800,7 +800,7 @@ func TestDoStartTaskENIPauseError(t *testing.T) {
pauseLoader: mockPauseLoader,
cniClient: cniClient,
ec2MetadataClient: mockMetadata,
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine) {},
terminationHandler: func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {},
mobyPlugins: mockMobyPlugins,
}

Expand Down
2 changes: 1 addition & 1 deletion agent/app/agent_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (h *handler) runAgent(ctx context.Context) uint32 {
agentCtx, cancel := context.WithCancel(ctx)
indicator := newTermHandlerIndicator()

terminationHandler := func(saver statemanager.Saver, taskEngine engine.TaskEngine) {
terminationHandler := func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {
// We're using a custom indicator to record that the handler is scheduled to be executed (has been invoked) and
// to determine whether it should run (we skip when the agent engine has already exited). After recording to
// the indicator that the handler has been invoked, we wait on the context. When we wake up, we determine
Expand Down
7 changes: 4 additions & 3 deletions agent/app/agent_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@ func TestHandler_RunAgent_ForceSaveWithTerminationHandler(t *testing.T) {

agent := &mockAgent{}

ctx, cancel := context.WithCancel(context.TODO())
done := make(chan struct{})
defer func() { done <- struct{}{} }()
startFunc := func() int {
go agent.terminationHandler(stateManager, taskEngine)
go agent.terminationHandler(stateManager, taskEngine, cancel)
<-done // block until after the test ends so that we can test that runAgent returns when cancelled
return 0
}
agent.startFunc = startFunc
handler := &handler{agent}
ctx, cancel := context.WithCancel(context.TODO())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
Expand Down Expand Up @@ -200,10 +200,11 @@ func TestHandler_Execute_WindowsStops(t *testing.T) {

agent := &mockAgent{}

_, cancel := context.WithCancel(context.TODO())
done := make(chan struct{})
defer func() { done <- struct{}{} }()
startFunc := func() int {
go agent.terminationHandler(stateManager, taskEngine)
go agent.terminationHandler(stateManager, taskEngine, cancel)
<-done // block until after the test ends so that we can test that Execute returns when Stopped
return 0
}
Expand Down
5 changes: 1 addition & 4 deletions agent/app/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package app

import (
"context"
"time"

"github.com/aws/amazon-ecs-agent/agent/app/args"
Expand Down Expand Up @@ -50,9 +49,7 @@ func Run(arguments []string) int {
logger.SetLevel(*parsedArgs.LogLevel)

// Create an Agent object
agent, err := newAgent(context.Background(),
aws.BoolValue(parsedArgs.BlackholeEC2Metadata),
parsedArgs.AcceptInsecureCert)
agent, err := newAgent(aws.BoolValue(parsedArgs.BlackholeEC2Metadata), parsedArgs.AcceptInsecureCert)
if err != nil {
// Failure to initialize either the docker client or the EC2 metadata
// service client are non terminal errors as they could be transient
Expand Down
2 changes: 0 additions & 2 deletions agent/engine/docker_task_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,8 @@ func (engine *DockerTaskEngine) MarshalJSON() ([]byte, error) {
// and operate normally.
// This function must be called before any other function, except serializing and deserializing, can succeed without error.
func (engine *DockerTaskEngine) Init(ctx context.Context) error {
// TODO, pass in a a context from main from background so that other things can stop us, not just the tests
derivedCtx, cancel := context.WithCancel(ctx)
engine.stopEngine = cancel

engine.ctx = derivedCtx

// Open the event stream before we sync state so that e.g. if a container
Expand Down
27 changes: 20 additions & 7 deletions agent/engine/task_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,19 @@ func (mtask *managedTask) overseeTask() {

// Main infinite loop. This is where we receive messages and dispatch work.
for {
select {
case <-mtask.ctx.Done():
seelog.Infof("Managed task [%s]: parent context cancelled, exit", mtask.Arn)
if mtask.shouldExit() {
return
default:
}

// If it's steadyState, just spin until we need to do work
for mtask.steadyState() {
mtask.waitSteady()
}

if mtask.shouldExit() {
return
}

if !mtask.GetKnownStatus().Terminal() {
// If we aren't terminal and we aren't steady state, we should be
// able to move some containers along.
Expand Down Expand Up @@ -258,6 +259,16 @@ func (mtask *managedTask) overseeTask() {
mtask.cleanupTask(mtask.cfg.TaskCleanupWaitDuration)
}

// shouldExit checks if the task manager should exit, as the agent is exiting.
func (mtask *managedTask) shouldExit() bool {
select {
case <-mtask.ctx.Done():
return true
default:
return false
}
}

// emitCurrentStatus emits a container event for every container and a task
// event for the task
func (mtask *managedTask) emitCurrentStatus() {
Expand Down Expand Up @@ -311,9 +322,12 @@ func (mtask *managedTask) waitSteady() {
timeoutCtx, cancel := context.WithTimeout(mtask.ctx, retry.AddJitter(mtask.steadyStatePollInterval, mtask.steadyStatePollIntervalJitter))
defer cancel()
timedOut := mtask.waitEvent(timeoutCtx.Done())
if mtask.shouldExit() {
return
}

if timedOut {
seelog.Debugf("Managed task [%s]: checking to make sure it's still at steadystate", mtask.Arn)
seelog.Infof("Managed task [%s]: checking to verify it's still at steady state.", mtask.Arn)
go mtask.engine.checkTaskState(mtask.Task)
}
}
Expand All @@ -323,7 +337,7 @@ func (mtask *managedTask) waitSteady() {
func (mtask *managedTask) steadyState() bool {
select {
case <-mtask.ctx.Done():
seelog.Info("Context expired. No longer steady.")
seelog.Infof("Managed task [%s]: agent task manager exiting.", mtask.Arn)
return false
default:
taskKnownStatus := mtask.GetKnownStatus()
Expand Down Expand Up @@ -360,7 +374,6 @@ func (mtask *managedTask) waitEvent(stopWaiting <-chan struct{}) bool {
mtask.handleResourceStateChange(resChange)
return false
case <-stopWaiting:
seelog.Infof("Managed task [%s]: no longer waiting", mtask.Arn)
return true
}
}
Expand Down
17 changes: 10 additions & 7 deletions agent/sighandlers/termination_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package sighandlers

import (
"context"
"errors"
"os"
"os/signal"
Expand All @@ -39,23 +40,25 @@ const (
)

// TerminationHandler defines a handler used for terminating the agent
type TerminationHandler func(saver statemanager.Saver, taskEngine engine.TaskEngine)
type TerminationHandler func(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc)

// StartDefaultTerminationHandler defines a default termination handler suitable for running in a process
func StartDefaultTerminationHandler(saver statemanager.Saver, taskEngine engine.TaskEngine) {
signalChannel := make(chan os.Signal, 2)
signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM)
func StartDefaultTerminationHandler(saver statemanager.Saver, taskEngine engine.TaskEngine, cancel context.CancelFunc) {
// when we receive a termination signal, first save the state, then
// cancel the agent's context so other goroutines can exit cleanly.
signalC := make(chan os.Signal, 2)
signal.Notify(signalC, os.Interrupt, syscall.SIGTERM)

sig := <-signalChannel
seelog.Debugf("Termination handler received termination signal: %s", sig.String())
sig := <-signalC
seelog.Infof("Agent received termination signal: %s", sig.String())

err := FinalSave(saver, taskEngine)
if err != nil {
seelog.Criticalf("Error saving state before final shutdown: %v", err)
// Terminal because it's a sigterm; the user doesn't want it to restart
os.Exit(exitcodes.ExitTerminal)
}
os.Exit(exitcodes.ExitSuccess)
cancel()
}

// FinalSave should be called immediately before exiting, and only before
Expand Down
Loading

0 comments on commit 3e3a675

Please sign in to comment.