Skip to content

Commit

Permalink
Shutdown Subcommands with SIGTERM and SIGINT (#389)
Browse files Browse the repository at this point in the history
* Shutdown Subcommands with SIGTERM
  • Loading branch information
lawliet89 authored Nov 14, 2020
1 parent 5c9020a commit 239a2db
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 48 deletions.
16 changes: 11 additions & 5 deletions subcommand/inject-connect/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

connectinject "github.com/hashicorp/consul-k8s/connect-inject"
Expand Down Expand Up @@ -163,11 +164,11 @@ func (c *Command) init() {
flags.Merge(c.flagSet, c.http.Flags())
c.help = flags.Usage(help, c.flagSet)

// Wait on an interrupt for exit, be sure to init it before running
// Wait on an interrupt or terminate for exit, be sure to init it before running
// the controller so that we don't receive an interrupt before it's ready.
if c.sigCh == nil {
c.sigCh = make(chan os.Signal, 1)
signal.Notify(c.sigCh, os.Interrupt)
signal.Notify(c.sigCh, syscall.SIGINT, syscall.SIGTERM)
}
}

Expand Down Expand Up @@ -390,8 +391,9 @@ func (c *Command) Run(args []string) int {
}()

select {
// Interrupted, gracefully exit.
case <-c.sigCh:
// Interrupted/terminated, gracefully exit.
case sig := <-c.sigCh:
c.UI.Info(fmt.Sprintf("%s received, shutting down", sig))
if err := server.Close(); err != nil {
c.UI.Error(fmt.Sprintf("shutting down server: %v", err))
return 1
Expand All @@ -417,7 +419,11 @@ func (c *Command) Run(args []string) int {
}

func (c *Command) interrupt() {
c.sigCh <- os.Interrupt
c.sendSignal(syscall.SIGINT)
}

func (c *Command) sendSignal(sig os.Signal) {
c.sigCh <- sig
}

func (c *Command) handleReady(rw http.ResponseWriter, req *http.Request) {
Expand Down
63 changes: 36 additions & 27 deletions subcommand/inject-connect/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package connectinject
import (
"fmt"
"os"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -206,38 +207,46 @@ func TestRun_CommandFailsWithInvalidListener(t *testing.T) {
require.Contains(t, ui.ErrorWriter.String(), "Error listening: listen tcp: address 999999: missing port in address")
}

// Test that when healthchecks are enabled that SIGINT exits the
// Test that when healthchecks are enabled that SIGINT/SIGTERM exits the
// command cleanly.
func TestRun_CommandExitsCleanlyAfterSigInt(t *testing.T) {
k8sClient := fake.NewSimpleClientset()
ui := cli.NewMockUi()
cmd := Command{
UI: ui,
clientset: k8sClient,
}
ports := freeport.MustTake(1)
func TestRun_CommandExitsCleanlyAfterSignal(t *testing.T) {

// NOTE: This url doesn't matter because Consul is never called.
os.Setenv(api.HTTPAddrEnvName, "http://0.0.0.0:9999")
defer os.Unsetenv(api.HTTPAddrEnvName)
t.Run("SIGINT", testSignalHandling(syscall.SIGINT))
t.Run("SIGTERM", testSignalHandling(syscall.SIGTERM))
}

// Start the command asynchronously and then we'll send an interrupt.
exitChan := runCommandAsynchronously(&cmd, []string{
"-consul-k8s-image", "hashicorp/consul-k8s",
"-enable-health-checks-controller=true",
"-listen", fmt.Sprintf(":%d", ports[0]),
})
func testSignalHandling(sig os.Signal) func(*testing.T) {
return func(t *testing.T) {
k8sClient := fake.NewSimpleClientset()
ui := cli.NewMockUi()
cmd := Command{
UI: ui,
clientset: k8sClient,
}
ports := freeport.MustTake(1)

// NOTE: This url doesn't matter because Consul is never called.
os.Setenv(api.HTTPAddrEnvName, "http://0.0.0.0:9999")
defer os.Unsetenv(api.HTTPAddrEnvName)

// Start the command asynchronously and then we'll send an interrupt.
exitChan := runCommandAsynchronously(&cmd, []string{
"-consul-k8s-image", "hashicorp/consul-k8s",
"-enable-health-checks-controller=true",
"-listen", fmt.Sprintf(":%d", ports[0]),
})

// Send the interrupt.
cmd.interrupt()
// Send the signal
cmd.sendSignal(sig)

// Assert that it exits cleanly or timeout.
select {
case exitCode := <-exitChan:
require.Equal(t, 0, exitCode, ui.ErrorWriter.String())
case <-time.After(time.Second * 1):
// Fail if the stopCh was not caught.
require.Fail(t, "timeout waiting for command to exit")
// Assert that it exits cleanly or timeout.
select {
case exitCode := <-exitChan:
require.Equal(t, 0, exitCode, ui.ErrorWriter.String())
case <-time.After(time.Second * 1):
// Fail if the stopCh was not caught.
require.Fail(t, "timeout waiting for command to exit")
}
}
}

Expand Down
17 changes: 11 additions & 6 deletions subcommand/lifecycle-sidecar/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os/signal"
"strings"
"sync"
"syscall"
"time"

"github.com/hashicorp/consul-k8s/subcommand/flags"
Expand Down Expand Up @@ -47,12 +48,12 @@ func (c *Command) init() {
flags.Merge(c.flagSet, c.http.Flags())
c.help = flags.Usage(help, c.flagSet)

// Wait on an interrupt to exit. This channel must be initialized before
// Wait on an interrupt or terminate to exit. This channel must be initialized before
// Run() is called so that there are no race conditions where the channel
// is not defined.
if c.sigCh == nil {
c.sigCh = make(chan os.Signal, 1)
signal.Notify(c.sigCh, os.Interrupt)
signal.Notify(c.sigCh, syscall.SIGINT, syscall.SIGTERM)
}
}

Expand Down Expand Up @@ -106,12 +107,12 @@ func (c *Command) Run(args []string) int {
logger.Info("successfully synced service", "output", strings.TrimSpace(string(output)))
}

// Re-loop after syncPeriod or exit if we receive an interrupt.
// Re-loop after syncPeriod or exit if we receive interrupt or terminate signals.
select {
case <-time.After(c.flagSyncPeriod):
continue
case <-c.sigCh:
logger.Info("SIGINT received, shutting down")
case sig := <-c.sigCh:
logger.Info(fmt.Sprintf("%s received, shutting down", sig))
return 0
}
}
Expand Down Expand Up @@ -164,7 +165,11 @@ func (c *Command) parseConsulFlags() []string {
// interrupt sends os.Interrupt signal to the command
// so it can exit gracefully. This function is needed for tests
func (c *Command) interrupt() {
c.sigCh <- os.Interrupt
c.sendSignal(syscall.SIGINT)
}

func (c *Command) sendSignal(sig os.Signal) {
c.sigCh <- sig
}

func (c *Command) Synopsis() string { return synopsis }
Expand Down
32 changes: 32 additions & 0 deletions subcommand/lifecycle-sidecar/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io/ioutil"
"os"
"path/filepath"
"syscall"
"testing"
"time"

Expand All @@ -25,6 +26,37 @@ func TestRun_Defaults(t *testing.T) {
require.Equal(t, "consul", cmd.flagConsulBinary)
}

func TestRun_ExitsCleanlyonSignals(t *testing.T) {
t.Run("SIGINT", testRunSignalHandling(syscall.SIGINT))
t.Run("SIGTERM", testRunSignalHandling(syscall.SIGTERM))
}

func testRunSignalHandling(sig os.Signal) func(*testing.T) {
return func(t *testing.T) {
tmpDir, configFile := createServicesTmpFile(t, servicesRegistration)
defer os.RemoveAll(tmpDir)

ui := cli.NewMockUi()
cmd := Command{
UI: ui,
}
// Run async because we need to kill it when the test is over.
exitChan := runCommandAsynchronously(&cmd, []string{
"-service-config", configFile,
})
cmd.sendSignal(sig)

// Assert that it exits cleanly or timeout.
select {
case exitCode := <-exitChan:
require.Equal(t, 0, exitCode, ui.ErrorWriter.String())
case <-time.After(time.Second * 1):
// Fail if the signal was not caught.
require.Fail(t, "timeout waiting for command to exit")
}
}
}

func TestRun_FlagValidation(t *testing.T) {
t.Parallel()
cases := []struct {
Expand Down
16 changes: 11 additions & 5 deletions subcommand/sync-catalog/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os/signal"
"regexp"
"sync"
"syscall"
"time"

"github.com/deckarep/golang-set"
Expand Down Expand Up @@ -145,12 +146,12 @@ func (c *Command) init() {

c.help = flags.Usage(help, c.flags)

// Wait on an interrupt to exit. This channel must be initialized before
// Wait on an interrupt or terminate to exit. This channel must be initialized before
// Run() is called so that there are no race conditions where the channel
// is not defined.
if c.sigCh == nil {
c.sigCh = make(chan os.Signal, 1)
signal.Notify(c.sigCh, os.Interrupt)
signal.Notify(c.sigCh, syscall.SIGINT, syscall.SIGTERM)
}
}

Expand Down Expand Up @@ -345,8 +346,9 @@ func (c *Command) Run(args []string) int {
}
return 1

// Interrupted, gracefully exit
case <-c.sigCh:
// Interrupted/terminated, gracefully exit
case sig := <-c.sigCh:
c.logger.Info(fmt.Sprintf("%s received, shutting down", sig))
cancelF()
if toConsulCh != nil {
<-toConsulCh
Expand Down Expand Up @@ -379,7 +381,11 @@ func (c *Command) Help() string {
// interrupt sends os.Interrupt signal to the command
// so it can exit gracefully. This function is needed for tests
func (c *Command) interrupt() {
c.sigCh <- os.Interrupt
c.sendSignal(syscall.SIGINT)
}

func (c *Command) sendSignal(sig os.Signal) {
c.sigCh <- sig
}

func (c *Command) validateFlags() error {
Expand Down
44 changes: 44 additions & 0 deletions subcommand/sync-catalog/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package synccatalog

import (
"context"
"os"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -81,6 +83,48 @@ func TestRun_Defaults_SyncsConsulServiceToK8s(t *testing.T) {
})
}

// Test that the command exits cleanly on signals
func TestRun_ExitCleanlyOnSignals(t *testing.T) {
t.Run("SIGINT", testSignalHandling(syscall.SIGINT))
t.Run("SIGTERM", testSignalHandling(syscall.SIGTERM))
}

func testSignalHandling(sig os.Signal) func(*testing.T) {
return func(t *testing.T) {
k8s, testServer := completeSetup(t)
defer testServer.Stop()

// Run the command.
ui := cli.NewMockUi()
cmd := Command{
UI: ui,
clientset: k8s,
logger: hclog.New(&hclog.LoggerOptions{
Name: t.Name(),
Level: hclog.Debug,
}),
}

exitChan := runCommandAsynchronously(&cmd, []string{
"-http-addr", testServer.HTTPAddr,
})
cmd.sendSignal(sig)

// Assert that it exits cleanly or timeout.
select {
case exitCode := <-exitChan:
require.Equal(t, 0, exitCode, ui.ErrorWriter.String())

// For some reason, this command cannot exit within 1s,
// so it's set higher than other tests in other commands
// to allow it to exit properly
case <-time.After(time.Second * 5):
// Fail if the signal was not caught.
require.Fail(t, "timeout waiting for command to exit")
}
}
}

// Test that when -add-k8s-namespace-suffix flag is used
// k8s namespaces are appended to the service names synced to Consul
func TestRun_ToConsulWithAddK8SNamespaceSuffix(t *testing.T) {
Expand Down
16 changes: 11 additions & 5 deletions subcommand/webhook-cert-manager/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"os/signal"
"strings"
"sync"
"syscall"
"time"

"github.com/hashicorp/consul-k8s/helper/cert"
Expand Down Expand Up @@ -65,12 +66,12 @@ func (c *Command) init() {
flags.Merge(c.flagSet, c.k8s.Flags())
c.help = flags.Usage(help, c.flagSet)

// Wait on an interrupt to exit. This channel must be initialized before
// Wait on an interrupt or terminate to exit. This channel must be initialized before
// Run() is called so that there are no race conditions where the channel
// is not defined.
if c.sigCh == nil {
c.sigCh = make(chan os.Signal, 1)
signal.Notify(c.sigCh, os.Interrupt)
signal.Notify(c.sigCh, syscall.SIGINT, syscall.SIGTERM)
}
}

Expand Down Expand Up @@ -167,11 +168,12 @@ func (c *Command) Run(args []string) int {

go c.certWatcher(ctx, certCh, c.clientset, c.logger)

// We define a signal handler for OS interrupts, and when an SIGINT is received,
// We define a signal handler for OS interrupts, and when an SIGINT or SIGTERM is received,
// we gracefully shut down, by first stopping our cert notifiers and then cancelling
// all the contexts that have been created by the process.
select {
case <-c.sigCh:
case sig := <-c.sigCh:
c.logger.Info(fmt.Sprintf("%s received, shutting down", sig))
cancelFunc()
for _, notifier := range notifiers {
notifier.Stop()
Expand Down Expand Up @@ -367,7 +369,11 @@ func (c *Command) Synopsis() string {
// interrupt sends os.Interrupt signal to the command
// so it can exit gracefully. This function is needed for tests
func (c *Command) interrupt() {
c.sigCh <- os.Interrupt
c.sendSignal(syscall.SIGINT)
}

func (c *Command) sendSignal(sig os.Signal) {
c.sigCh <- sig
}

const synopsis = "Starts the Consul Kubernetes webhook-cert-manager"
Expand Down
Loading

0 comments on commit 239a2db

Please sign in to comment.