From d6edb8d34a0d51e59106fc9b71eaaaae326efdf6 Mon Sep 17 00:00:00 2001 From: asobrien Date: Tue, 2 Nov 2021 09:42:58 -0400 Subject: [PATCH] add shutdown_delay option to executor & gRPC GracefulStop (#3711) * shutdown delay & gRPC graceful shutdown * fix maxprocs logger * getPredictorFromEnv missing key error & tests --- executor/cmd/executor/main.go | 116 ++++++++++++++++++++++--------- executor/predictor/utils.go | 7 +- executor/predictor/utils_test.go | 73 +++++++++++++++++++ 3 files changed, 162 insertions(+), 34 deletions(-) create mode 100644 executor/predictor/utils_test.go diff --git a/executor/cmd/executor/main.go b/executor/cmd/executor/main.go index 3bd96a60c3..bd01041451 100644 --- a/executor/cmd/executor/main.go +++ b/executor/cmd/executor/main.go @@ -3,14 +3,17 @@ package main import ( "context" "crypto/tls" + "errors" "flag" "fmt" "log" "net" + "net/http" "net/url" "os" "os/signal" "path" + "sync" "syscall" "time" @@ -37,8 +40,8 @@ import ( "go.uber.org/automaxprocs/maxprocs" "go.uber.org/zap" "google.golang.org/grpc/reflection" - zapf "sigs.k8s.io/controller-runtime/pkg/log/zap" logf "sigs.k8s.io/controller-runtime/pkg/log" + zapf "sigs.k8s.io/controller-runtime/pkg/log/zap" ) const ( @@ -62,6 +65,7 @@ var ( httpPort = flag.Int("http_port", 8080, "Executor http port") grpcPort = flag.Int("grpc_port", 5000, "Executor grpc port") wait = flag.Duration("graceful_timeout", time.Second*15, "Graceful shutdown secs") + delay = flag.Duration("shutdown_delay", 0, "Shutdown delay secs") protocol = flag.String("protocol", "seldon", "The payload protocol") transport = flag.String("transport", "rest", "The network transport mechanism rest, grpc") filename = flag.String("file", "", "Load graph from file") @@ -73,7 +77,7 @@ var ( kafkaTopicOut = flag.String("kafka_output_topic", "", "The kafka output topic") kafkaFullGraph = flag.Bool("kafka_full_graph", false, "Use kafka for internal graph processing") kafkaWorkers = flag.Int("kafka_workers", 4, "Number of kafka workers") - logKafkaBroker = flag.String("log_kafka_broker", "", "The kafka log broker") + logKafkaBroker = flag.String("log_kafka_broker", "", "The kafka log broker") logKafkaTopic = flag.String("log_kafka_topic", "", "The kafka log topic") debug = flag.Bool( "debug", @@ -95,8 +99,9 @@ func getServerUrl(hostname string, port int) (*url.URL, error) { return url.Parse(fmt.Sprintf("http://%s:%d/", hostname, port)) } -func runHttpServer(lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, port int, - probesOnly bool, serverUrl *url.URL, namespace string, protocol string, deploymentName string, prometheusPath string) { +func runHttpServer(wg *sync.WaitGroup, shutdown chan bool, lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, port int, probesOnly bool, serverUrl *url.URL, namespace string, protocol string, deploymentName string, prometheusPath string) { + wg.Add(1) + defer wg.Done() defer lis.Close() // Create REST API @@ -104,37 +109,32 @@ func runHttpServer(lis net.Listener, logger logr.Logger, predictor *v1.Predictor seldonRest.Initialise() srv := seldonRest.CreateHttpServer(port) + var isShutdown bool go func() { + logger.Info("http server started") if err := srv.Serve(lis); err != nil { - logger.Error(err, "Server error") + if !(isShutdown && errors.Is(err, http.ErrServerClosed)) { + logger.Error(err, "http server error") + } } - logger.Info("server started") }() - c := make(chan os.Signal, 1) - // We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C) and SIGTERM - // SIGKILL, SIGQUIT will not be caught. - signal.Notify(c, syscall.SIGINT) - signal.Notify(c, syscall.SIGTERM) - // Block until we receive our signal. - <-c - - // Create a deadline to wait for. - ctx, cancel := context.WithTimeout(context.Background(), *wait) - defer cancel() - // Doesn't block if no connections, but will otherwise wait - // until the timeout deadline. - srv.Shutdown(ctx) - // Optionally, you could run srv.Shutdown in a goroutine and block on - // <-ctx.Done() if your application should wait for other services - // to finalize based on context cancellation. - logger.Info("shutting down") - os.Exit(0) - + isShutdown = <-shutdown + + // Doesn't block if no connections, otherwise this waits for context + // deadline. As shutdown is coordinated across multiple server by the + // caller, we just use a background context and wait indefinitely if there + // are active connections. + if err := srv.Shutdown(context.Background()); err != nil { + logger.Error(err, "http server shutdown error") + } + logger.Info("http server shutdown") } -func runGrpcServer(lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, serverUrl *url.URL, namespace string, protocol string, deploymentName string, annotations map[string]string) { +func runGrpcServer(wg *sync.WaitGroup, shutdown chan bool, lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, serverUrl *url.URL, namespace string, protocol string, deploymentName string, annotations map[string]string) { + wg.Add(1) + defer wg.Done() defer lis.Close() grpcServer, err := grpc.CreateGrpcServer(predictor, deploymentName, annotations, logger) if err != nil { @@ -154,9 +154,53 @@ func runGrpcServer(lis net.Listener, logger logr.Logger, predictor *v1.Predictor kfservingGrpcServer := kfserving.NewGrpcKFServingServer(predictor, client, serverUrl, namespace) kfproto.RegisterGRPCInferenceServiceServer(grpcServer, kfservingGrpcServer) } - err = grpcServer.Serve(lis) - if err != nil { - logger.Error(err, "gRPC server error") + + go func() { + logger.Info("gRPC server started") + if err := grpcServer.Serve(lis); err != nil { + logger.Error(err, "gRPC server error") + } + }() + + <-shutdown // wait for signal + grpcServer.GracefulStop() + logger.Info("gRPC server shutdown") +} + +func waitForShutdown(logger logr.Logger, wg *sync.WaitGroup, chs ...chan bool) { + // wait for and then signal servers have exited + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + c := make(chan os.Signal, 1) + // We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C) and SIGTERM + // SIGKILL, SIGQUIT will not be caught. + signal.Notify(c, syscall.SIGINT) + signal.Notify(c, syscall.SIGTERM) + + // Block until we receive our signal. + sig := <-c + logger.Info("shutdown signal received", "signal", sig, "shutdown_delay", *delay) + time.Sleep(*delay) // shutdown_delay + + // Create a deadline to wait for graceful shutdown. + ctx, cancel := context.WithTimeout(context.Background(), *wait) + defer cancel() + + // send signals to server channels to initiate shutdown. + for _, ch := range chs { + ch <- true + } + + select { + case <-done: + return // servers shutdown + case <-ctx.Done(): + logger.Error(ctx.Err(), "graceful_timeout exceeded, exiting immediately", "graceful_timeout", *wait) + os.Exit(1) } } @@ -263,7 +307,9 @@ func main() { logger := logf.Log.WithName("entrypoint") // Set runtime.GOMAXPROCS to respect container limits if the env var GOMAXPROCS is not set or is invalid, preventing CPU throttling. - undo, err := maxprocs.Set(maxprocs.Logger(logger.Info)) + undo, err := maxprocs.Set(maxprocs.Logger(func(format string, a ...interface{}) { + logger.WithName("maxprocs").Info(fmt.Sprintf(format, a...)) + })) defer undo() if err != nil { logger.Error(err, "failed to set GOMAXPROCS") @@ -342,11 +388,15 @@ func main() { log.Fatalf("Failed to create grpc client. Unknown protocol %s: %v", *protocol, err) } + wg := sync.WaitGroup{} logger.Info("Running http server ", "port", *httpPort) - go runHttpServer(createListener(*httpPort, logger), logger, predictor, clientRest, *httpPort, false, serverUrl, *namespace, *protocol, *sdepName, *prometheusPath) + httpStop := make(chan bool, 1) + go runHttpServer(&wg, httpStop, createListener(*httpPort, logger), logger, predictor, clientRest, *httpPort, false, serverUrl, *namespace, *protocol, *sdepName, *prometheusPath) logger.Info("Running grpc server ", "port", *grpcPort) - runGrpcServer(createListener(*grpcPort, logger), logger, predictor, clientGrpc, serverUrl, *namespace, *protocol, *sdepName, annotations) + grpcStop := make(chan bool, 1) + go runGrpcServer(&wg, grpcStop, createListener(*grpcPort, logger), logger, predictor, clientGrpc, serverUrl, *namespace, *protocol, *sdepName, annotations) + waitForShutdown(logger, &wg, httpStop, grpcStop) } func createListener(port int, logger logr.Logger) net.Listener { diff --git a/executor/predictor/utils.go b/executor/predictor/utils.go index 82571ce0c1..5024b79de7 100644 --- a/executor/predictor/utils.go +++ b/executor/predictor/utils.go @@ -11,6 +11,8 @@ import ( "strings" ) +const EnvKeyEnginePredictor = "ENGINE_PREDICTOR" + func GetPredictor(predictorName, filename, sdepName, namespace string, configPath *string) (*v1.PredictorSpec, error) { if filename != "" { predictor, err := getPredictorFromFile(predictorName, filename) @@ -25,7 +27,10 @@ func GetPredictor(predictorName, filename, sdepName, namespace string, configPat } func getPredictorFromEnv() (*v1.PredictorSpec, error) { - b64Predictor := os.Getenv("ENGINE_PREDICTOR") + b64Predictor, ok := os.LookupEnv(EnvKeyEnginePredictor) + if !ok { + return nil, fmt.Errorf("Predictor not found, enviroment variable %s not set", EnvKeyEnginePredictor) + } if b64Predictor != "" { bytes, err := base64.StdEncoding.DecodeString(b64Predictor) if err != nil { diff --git a/executor/predictor/utils_test.go b/executor/predictor/utils_test.go new file mode 100644 index 0000000000..dbcaca8a98 --- /dev/null +++ b/executor/predictor/utils_test.go @@ -0,0 +1,73 @@ +package predictor + +import ( + "encoding/base64" + "fmt" + "os" + "testing" + + . "github.com/onsi/gomega" + v1 "github.com/seldonio/seldon-core/operator/apis/machinelearning.seldon.io/v1" +) + +func TestGetPredictorFromEnv(t *testing.T) { + t.Logf("Started") + g := NewGomegaWithT(t) + var b64Error base64.CorruptInputError + + tests := []struct { + name string + key string + val string + predictor *v1.PredictorSpec + err error + }{ + { + name: "missing env var", + err: fmt.Errorf("Predictor not found, enviroment variable %s not set", EnvKeyEnginePredictor), + }, + { + name: "empty value", + key: EnvKeyEnginePredictor, + }, + { + name: "non-b64 value", + key: EnvKeyEnginePredictor, + val: ":;,", + err: b64Error, + }, + } + + // unset existing env var & reset at end of test + val, ok := os.LookupEnv(EnvKeyEnginePredictor) + if ok { + if err := os.Unsetenv(EnvKeyEnginePredictor); err != nil { + t.Fatalf("failed to unset env var %v: %v", EnvKeyEnginePredictor, err) + } + defer func() { + os.Setenv(EnvKeyEnginePredictor, val) + }() + } + + setenv := func(key, val string) { + if key == "" { + return + } + if err := os.Setenv(key, val); err != nil { + t.Fatalf("failed to set env var: %v", err) + } + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setenv(tt.key, tt.val) + got, err := getPredictorFromEnv() + g.Expect(got).Should(Equal(tt.predictor)) + if tt.err == nil { + g.Expect(err).Should(BeNil()) + } else { + g.Expect(err).Should(Equal(tt.err)) + } + }) + } +}