diff --git a/cmd/serve.go b/cmd/serve.go index b69135ccb..da14c80a6 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -22,7 +22,11 @@ import ( "io" "log/slog" "net/http" + "os" + "os/signal" "strings" + "syscall" + "time" "code.cloudfoundry.org/lager/v3" osbapiBroker "github.com/cloudfoundry/cloud-service-broker/v2/brokerapi/broker" @@ -55,6 +59,8 @@ const ( tlsKeyProp = "api.tlsKey" encryptionPasswords = "db.encryption.passwords" encryptionEnabled = "db.encryption.enabled" + + shutdownTimeout = time.Hour ) var cfCompatibilityToggle = toggles.Features.Toggle("enable-cf-sharing", false, `Set all services to have the Sharable flag so they can be shared @@ -225,24 +231,41 @@ func startServer(registry pakBroker.BrokerRegistry, db *sql.DB, brokerapi http.H Addr: fmt.Sprintf("%s:%s", host, port), Handler: router, } - var err error - if tlsCertCaBundleFilePath != "" && tlsKeyFilePath != "" { - err = httpServer.ListenAndServeTLS(tlsCertCaBundleFilePath, tlsKeyFilePath) - } else { - err = httpServer.ListenAndServe() - } - // when the server is receiving a signal, we probably do not want to panic. - if err != http.ErrServerClosed { - logger.Fatal("Failed to start broker", err) - } -} + go func() { + var err error + if tlsCertCaBundleFilePath != "" && tlsKeyFilePath != "" { + err = httpServer.ListenAndServeTLS(tlsCertCaBundleFilePath, tlsKeyFilePath) + } else { + err = httpServer.ListenAndServe() + } + if err == http.ErrServerClosed { + logger.Info("shutting down csb") + } else { + logger.Fatal("Failed to start broker", err) + } + }() -func labelName(label string) string { - switch label { - case "": - return "none" + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM) + + signalReceived := <-sigChan + + switch signalReceived { + + case syscall.SIGTERM: + shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), shutdownTimeout) + if err := httpServer.Shutdown(shutdownCtx); err != nil { + logger.Fatal("shutdown error: %v", err) + } + logger.Info("received SIGTERM, server is shutting down gracefully allowing for in flight work to finish") + defer shutdownRelease() + for store.LockFilesExist() { + logger.Info("draining csb in progress") + time.Sleep(time.Second * 1) + } + logger.Info("draining complete") default: - return label + logger.Info(fmt.Sprintf("csb does not handle the %s interrupt signal", signalReceived)) } } @@ -288,3 +311,12 @@ func importStateHandler(store *storage.Storage) http.Handler { } }) } + +func labelName(label string) string { + switch label { + case "": + return "none" + default: + return label + } +} diff --git a/integrationtest/import_state_test.go b/integrationtest/import_state_test.go index f7e453c4c..36ab237d7 100644 --- a/integrationtest/import_state_test.go +++ b/integrationtest/import_state_test.go @@ -31,7 +31,7 @@ var _ = Describe("Import State", func() { broker = must(testdrive.StartBroker(csb, brokerpak, database)) DeferCleanup(func() { - Expect(broker.Stop()).To(Succeed()) + Expect(broker.Terminate()).To(Succeed()) cleanup(brokerpak) }) }) diff --git a/internal/testdrive/broker.go b/internal/testdrive/broker.go index 416d70deb..f2259483b 100644 --- a/internal/testdrive/broker.go +++ b/internal/testdrive/broker.go @@ -23,6 +23,15 @@ func (b *Broker) Stop() error { case b == nil, b.runner == nil: return nil default: - return b.runner.stop() + return b.runner.gracefullStop() + } +} + +func (b *Broker) Terminate() error { + switch { + case b == nil, b.runner == nil: + return nil + default: + return b.runner.forceStop() } } diff --git a/internal/testdrive/broker_start.go b/internal/testdrive/broker_start.go index 7831b189f..d71986553 100644 --- a/internal/testdrive/broker_start.go +++ b/internal/testdrive/broker_start.go @@ -115,7 +115,7 @@ func StartBroker(csbPath, bpk, db string, opts ...StartBrokerOption) (*Broker, e case err == nil && response.StatusCode == http.StatusOK: return &broker, nil case time.Since(start) > time.Minute: - if err := broker.runner.stop(); err != nil { + if err := broker.runner.forceStop(); err != nil { return nil, err } return nil, fmt.Errorf("timed out after %s waiting for broker to start: %s\n%s", time.Since(start), stdout.String(), stderr.String()) diff --git a/internal/testdrive/runner.go b/internal/testdrive/runner.go index 1d4736e3b..68aa89481 100644 --- a/internal/testdrive/runner.go +++ b/internal/testdrive/runner.go @@ -2,6 +2,7 @@ package testdrive import ( "os/exec" + "syscall" "time" ) @@ -26,12 +27,29 @@ func (r *runner) error(err error) *runner { return r } -func (r *runner) stop() error { +func (r *runner) gracefullStop() error { if r.exited { return nil } if r.cmd != nil && r.cmd.Process != nil { - if err := r.cmd.Process.Kill(); err != nil { + if err := r.cmd.Process.Signal(syscall.SIGTERM); err != nil { + return err + } + } + + for !r.exited { + time.Sleep(time.Millisecond) + } + + return nil +} + +func (r *runner) forceStop() error { + if r.exited { + return nil + } + if r.cmd != nil && r.cmd.Process != nil { + if err := r.cmd.Process.Signal(syscall.SIGKILL); err != nil { return err } }