Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add shutdown_delay option to executor & gRPC GracefulStop #3711

Merged
merged 3 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 83 additions & 33 deletions executor/cmd/executor/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ package main
import (
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"log"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path"
"sync"
"syscall"
"time"

Expand All @@ -37,8 +40,8 @@ import (
"go.uber.org/automaxprocs/maxprocs"
"go.uber.org/zap"
"google.golang.org/grpc/reflection"
zapf "sigs.k8s.io/controller-runtime/pkg/log/zap"
logf "sigs.k8s.io/controller-runtime/pkg/log"
zapf "sigs.k8s.io/controller-runtime/pkg/log/zap"
)

const (
Expand All @@ -62,6 +65,7 @@ var (
httpPort = flag.Int("http_port", 8080, "Executor http port")
grpcPort = flag.Int("grpc_port", 5000, "Executor grpc port")
wait = flag.Duration("graceful_timeout", time.Second*15, "Graceful shutdown secs")
delay = flag.Duration("shutdown_delay", 0, "Shutdown delay secs")
protocol = flag.String("protocol", "seldon", "The payload protocol")
transport = flag.String("transport", "rest", "The network transport mechanism rest, grpc")
filename = flag.String("file", "", "Load graph from file")
Expand All @@ -73,7 +77,7 @@ var (
kafkaTopicOut = flag.String("kafka_output_topic", "", "The kafka output topic")
kafkaFullGraph = flag.Bool("kafka_full_graph", false, "Use kafka for internal graph processing")
kafkaWorkers = flag.Int("kafka_workers", 4, "Number of kafka workers")
logKafkaBroker = flag.String("log_kafka_broker", "", "The kafka log broker")
logKafkaBroker = flag.String("log_kafka_broker", "", "The kafka log broker")
logKafkaTopic = flag.String("log_kafka_topic", "", "The kafka log topic")
debug = flag.Bool(
"debug",
Expand All @@ -95,46 +99,42 @@ func getServerUrl(hostname string, port int) (*url.URL, error) {
return url.Parse(fmt.Sprintf("http://%s:%d/", hostname, port))
}

func runHttpServer(lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, port int,
probesOnly bool, serverUrl *url.URL, namespace string, protocol string, deploymentName string, prometheusPath string) {
func runHttpServer(wg *sync.WaitGroup, shutdown chan bool, lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, port int, probesOnly bool, serverUrl *url.URL, namespace string, protocol string, deploymentName string, prometheusPath string) {
wg.Add(1)
defer wg.Done()
defer lis.Close()

// Create REST API
seldonRest := rest.NewServerRestApi(predictor, client, probesOnly, serverUrl, namespace, protocol, deploymentName, prometheusPath)
seldonRest.Initialise()
srv := seldonRest.CreateHttpServer(port)

var isShutdown bool
go func() {
logger.Info("http server started")
if err := srv.Serve(lis); err != nil {
logger.Error(err, "Server error")
if !(isShutdown && errors.Is(err, http.ErrServerClosed)) {
logger.Error(err, "http server error")
}
}
logger.Info("server started")
}()

c := make(chan os.Signal, 1)
// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C) and SIGTERM
// SIGKILL, SIGQUIT will not be caught.
signal.Notify(c, syscall.SIGINT)
signal.Notify(c, syscall.SIGTERM)

// Block until we receive our signal.
<-c

// Create a deadline to wait for.
ctx, cancel := context.WithTimeout(context.Background(), *wait)
defer cancel()
// Doesn't block if no connections, but will otherwise wait
// until the timeout deadline.
srv.Shutdown(ctx)
// Optionally, you could run srv.Shutdown in a goroutine and block on
// <-ctx.Done() if your application should wait for other services
// to finalize based on context cancellation.
logger.Info("shutting down")
os.Exit(0)

isShutdown = <-shutdown

// Doesn't block if no connections, otherwise this waits for context
// deadline. As shutdown is coordinated across multiple server by the
// caller, we just use a background context and wait indefinitely if there
// are active connections.
if err := srv.Shutdown(context.Background()); err != nil {
logger.Error(err, "http server shutdown error")
}
logger.Info("http server shutdown")
}

func runGrpcServer(lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, serverUrl *url.URL, namespace string, protocol string, deploymentName string, annotations map[string]string) {
func runGrpcServer(wg *sync.WaitGroup, shutdown chan bool, lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, serverUrl *url.URL, namespace string, protocol string, deploymentName string, annotations map[string]string) {
wg.Add(1)
defer wg.Done()
defer lis.Close()
grpcServer, err := grpc.CreateGrpcServer(predictor, deploymentName, annotations, logger)
if err != nil {
Expand All @@ -154,9 +154,53 @@ func runGrpcServer(lis net.Listener, logger logr.Logger, predictor *v1.Predictor
kfservingGrpcServer := kfserving.NewGrpcKFServingServer(predictor, client, serverUrl, namespace)
kfproto.RegisterGRPCInferenceServiceServer(grpcServer, kfservingGrpcServer)
}
err = grpcServer.Serve(lis)
if err != nil {
logger.Error(err, "gRPC server error")

go func() {
logger.Info("gRPC server started")
if err := grpcServer.Serve(lis); err != nil {
logger.Error(err, "gRPC server error")
}
}()

<-shutdown // wait for signal
grpcServer.GracefulStop()
logger.Info("gRPC server shutdown")
}

func waitForShutdown(logger logr.Logger, wg *sync.WaitGroup, chs ...chan bool) {
// wait for and then signal servers have exited
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()

c := make(chan os.Signal, 1)
// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C) and SIGTERM
// SIGKILL, SIGQUIT will not be caught.
signal.Notify(c, syscall.SIGINT)
signal.Notify(c, syscall.SIGTERM)

// Block until we receive our signal.
sig := <-c
logger.Info("shutdown signal received", "signal", sig, "shutdown_delay", *delay)
time.Sleep(*delay) // shutdown_delay

// Create a deadline to wait for graceful shutdown.
ctx, cancel := context.WithTimeout(context.Background(), *wait)
defer cancel()

// send signals to server channels to initiate shutdown.
for _, ch := range chs {
ch <- true
}

select {
case <-done:
return // servers shutdown
case <-ctx.Done():
logger.Error(ctx.Err(), "graceful_timeout exceeded, exiting immediately", "graceful_timeout", *wait)
os.Exit(1)
}
}

Expand Down Expand Up @@ -263,7 +307,9 @@ func main() {
logger := logf.Log.WithName("entrypoint")

// Set runtime.GOMAXPROCS to respect container limits if the env var GOMAXPROCS is not set or is invalid, preventing CPU throttling.
undo, err := maxprocs.Set(maxprocs.Logger(logger.Info))
undo, err := maxprocs.Set(maxprocs.Logger(func(format string, a ...interface{}) {
logger.WithName("maxprocs").Info(fmt.Sprintf(format, a...))
}))
defer undo()
if err != nil {
logger.Error(err, "failed to set GOMAXPROCS")
Expand Down Expand Up @@ -342,11 +388,15 @@ func main() {
log.Fatalf("Failed to create grpc client. Unknown protocol %s: %v", *protocol, err)
}

wg := sync.WaitGroup{}
logger.Info("Running http server ", "port", *httpPort)
go runHttpServer(createListener(*httpPort, logger), logger, predictor, clientRest, *httpPort, false, serverUrl, *namespace, *protocol, *sdepName, *prometheusPath)
httpStop := make(chan bool, 1)
go runHttpServer(&wg, httpStop, createListener(*httpPort, logger), logger, predictor, clientRest, *httpPort, false, serverUrl, *namespace, *protocol, *sdepName, *prometheusPath)

logger.Info("Running grpc server ", "port", *grpcPort)
runGrpcServer(createListener(*grpcPort, logger), logger, predictor, clientGrpc, serverUrl, *namespace, *protocol, *sdepName, annotations)
grpcStop := make(chan bool, 1)
go runGrpcServer(&wg, grpcStop, createListener(*grpcPort, logger), logger, predictor, clientGrpc, serverUrl, *namespace, *protocol, *sdepName, annotations)
waitForShutdown(logger, &wg, httpStop, grpcStop)
}

func createListener(port int, logger logr.Logger) net.Listener {
Expand Down
7 changes: 6 additions & 1 deletion executor/predictor/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"strings"
)

const EnvKeyEnginePredictor = "ENGINE_PREDICTOR"

func GetPredictor(predictorName, filename, sdepName, namespace string, configPath *string) (*v1.PredictorSpec, error) {
if filename != "" {
predictor, err := getPredictorFromFile(predictorName, filename)
Expand All @@ -25,7 +27,10 @@ func GetPredictor(predictorName, filename, sdepName, namespace string, configPat
}

func getPredictorFromEnv() (*v1.PredictorSpec, error) {
b64Predictor := os.Getenv("ENGINE_PREDICTOR")
b64Predictor, ok := os.LookupEnv(EnvKeyEnginePredictor)
if !ok {
return nil, fmt.Errorf("Predictor not found, enviroment variable %s not set", EnvKeyEnginePredictor)
}
if b64Predictor != "" {
bytes, err := base64.StdEncoding.DecodeString(b64Predictor)
if err != nil {
Expand Down
73 changes: 73 additions & 0 deletions executor/predictor/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package predictor

import (
"encoding/base64"
"fmt"
"os"
"testing"

. "github.com/onsi/gomega"
v1 "github.com/seldonio/seldon-core/operator/apis/machinelearning.seldon.io/v1"
)

func TestGetPredictorFromEnv(t *testing.T) {
t.Logf("Started")
g := NewGomegaWithT(t)
var b64Error base64.CorruptInputError

tests := []struct {
name string
key string
val string
predictor *v1.PredictorSpec
err error
}{
{
name: "missing env var",
err: fmt.Errorf("Predictor not found, enviroment variable %s not set", EnvKeyEnginePredictor),
},
{
name: "empty value",
key: EnvKeyEnginePredictor,
},
{
name: "non-b64 value",
key: EnvKeyEnginePredictor,
val: ":;,",
err: b64Error,
},
}

// unset existing env var & reset at end of test
val, ok := os.LookupEnv(EnvKeyEnginePredictor)
if ok {
if err := os.Unsetenv(EnvKeyEnginePredictor); err != nil {
t.Fatalf("failed to unset env var %v: %v", EnvKeyEnginePredictor, err)
}
defer func() {
os.Setenv(EnvKeyEnginePredictor, val)
}()
}

setenv := func(key, val string) {
if key == "" {
return
}
if err := os.Setenv(key, val); err != nil {
t.Fatalf("failed to set env var: %v", err)
}
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setenv(tt.key, tt.val)
got, err := getPredictorFromEnv()
g.Expect(got).Should(Equal(tt.predictor))
if tt.err == nil {
g.Expect(err).Should(BeNil())
} else {
g.Expect(err).Should(Equal(tt.err))
}
})
}
}