Skip to content

Commit

Permalink
add shutdown_delay option to executor & gRPC GracefulStop (SeldonIO#3711
Browse files Browse the repository at this point in the history
)

* shutdown delay & gRPC graceful shutdown

* fix maxprocs logger

* getPredictorFromEnv missing key error & tests
  • Loading branch information
asobrien authored Nov 2, 2021
1 parent c39ed39 commit d6edb8d
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 34 deletions.
116 changes: 83 additions & 33 deletions executor/cmd/executor/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ package main
import (
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"log"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path"
"sync"
"syscall"
"time"

Expand All @@ -37,8 +40,8 @@ import (
"go.uber.org/automaxprocs/maxprocs"
"go.uber.org/zap"
"google.golang.org/grpc/reflection"
zapf "sigs.k8s.io/controller-runtime/pkg/log/zap"
logf "sigs.k8s.io/controller-runtime/pkg/log"
zapf "sigs.k8s.io/controller-runtime/pkg/log/zap"
)

const (
Expand All @@ -62,6 +65,7 @@ var (
httpPort = flag.Int("http_port", 8080, "Executor http port")
grpcPort = flag.Int("grpc_port", 5000, "Executor grpc port")
wait = flag.Duration("graceful_timeout", time.Second*15, "Graceful shutdown secs")
delay = flag.Duration("shutdown_delay", 0, "Shutdown delay secs")
protocol = flag.String("protocol", "seldon", "The payload protocol")
transport = flag.String("transport", "rest", "The network transport mechanism rest, grpc")
filename = flag.String("file", "", "Load graph from file")
Expand All @@ -73,7 +77,7 @@ var (
kafkaTopicOut = flag.String("kafka_output_topic", "", "The kafka output topic")
kafkaFullGraph = flag.Bool("kafka_full_graph", false, "Use kafka for internal graph processing")
kafkaWorkers = flag.Int("kafka_workers", 4, "Number of kafka workers")
logKafkaBroker = flag.String("log_kafka_broker", "", "The kafka log broker")
logKafkaBroker = flag.String("log_kafka_broker", "", "The kafka log broker")
logKafkaTopic = flag.String("log_kafka_topic", "", "The kafka log topic")
debug = flag.Bool(
"debug",
Expand All @@ -95,46 +99,42 @@ func getServerUrl(hostname string, port int) (*url.URL, error) {
return url.Parse(fmt.Sprintf("http://%s:%d/", hostname, port))
}

func runHttpServer(lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, port int,
probesOnly bool, serverUrl *url.URL, namespace string, protocol string, deploymentName string, prometheusPath string) {
func runHttpServer(wg *sync.WaitGroup, shutdown chan bool, lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, port int, probesOnly bool, serverUrl *url.URL, namespace string, protocol string, deploymentName string, prometheusPath string) {
wg.Add(1)
defer wg.Done()
defer lis.Close()

// Create REST API
seldonRest := rest.NewServerRestApi(predictor, client, probesOnly, serverUrl, namespace, protocol, deploymentName, prometheusPath)
seldonRest.Initialise()
srv := seldonRest.CreateHttpServer(port)

var isShutdown bool
go func() {
logger.Info("http server started")
if err := srv.Serve(lis); err != nil {
logger.Error(err, "Server error")
if !(isShutdown && errors.Is(err, http.ErrServerClosed)) {
logger.Error(err, "http server error")
}
}
logger.Info("server started")
}()

c := make(chan os.Signal, 1)
// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C) and SIGTERM
// SIGKILL, SIGQUIT will not be caught.
signal.Notify(c, syscall.SIGINT)
signal.Notify(c, syscall.SIGTERM)

// Block until we receive our signal.
<-c

// Create a deadline to wait for.
ctx, cancel := context.WithTimeout(context.Background(), *wait)
defer cancel()
// Doesn't block if no connections, but will otherwise wait
// until the timeout deadline.
srv.Shutdown(ctx)
// Optionally, you could run srv.Shutdown in a goroutine and block on
// <-ctx.Done() if your application should wait for other services
// to finalize based on context cancellation.
logger.Info("shutting down")
os.Exit(0)

isShutdown = <-shutdown

// Doesn't block if no connections, otherwise this waits for context
// deadline. As shutdown is coordinated across multiple server by the
// caller, we just use a background context and wait indefinitely if there
// are active connections.
if err := srv.Shutdown(context.Background()); err != nil {
logger.Error(err, "http server shutdown error")
}
logger.Info("http server shutdown")
}

func runGrpcServer(lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, serverUrl *url.URL, namespace string, protocol string, deploymentName string, annotations map[string]string) {
func runGrpcServer(wg *sync.WaitGroup, shutdown chan bool, lis net.Listener, logger logr.Logger, predictor *v1.PredictorSpec, client seldonclient.SeldonApiClient, serverUrl *url.URL, namespace string, protocol string, deploymentName string, annotations map[string]string) {
wg.Add(1)
defer wg.Done()
defer lis.Close()
grpcServer, err := grpc.CreateGrpcServer(predictor, deploymentName, annotations, logger)
if err != nil {
Expand All @@ -154,9 +154,53 @@ func runGrpcServer(lis net.Listener, logger logr.Logger, predictor *v1.Predictor
kfservingGrpcServer := kfserving.NewGrpcKFServingServer(predictor, client, serverUrl, namespace)
kfproto.RegisterGRPCInferenceServiceServer(grpcServer, kfservingGrpcServer)
}
err = grpcServer.Serve(lis)
if err != nil {
logger.Error(err, "gRPC server error")

go func() {
logger.Info("gRPC server started")
if err := grpcServer.Serve(lis); err != nil {
logger.Error(err, "gRPC server error")
}
}()

<-shutdown // wait for signal
grpcServer.GracefulStop()
logger.Info("gRPC server shutdown")
}

func waitForShutdown(logger logr.Logger, wg *sync.WaitGroup, chs ...chan bool) {
// wait for and then signal servers have exited
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()

c := make(chan os.Signal, 1)
// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C) and SIGTERM
// SIGKILL, SIGQUIT will not be caught.
signal.Notify(c, syscall.SIGINT)
signal.Notify(c, syscall.SIGTERM)

// Block until we receive our signal.
sig := <-c
logger.Info("shutdown signal received", "signal", sig, "shutdown_delay", *delay)
time.Sleep(*delay) // shutdown_delay

// Create a deadline to wait for graceful shutdown.
ctx, cancel := context.WithTimeout(context.Background(), *wait)
defer cancel()

// send signals to server channels to initiate shutdown.
for _, ch := range chs {
ch <- true
}

select {
case <-done:
return // servers shutdown
case <-ctx.Done():
logger.Error(ctx.Err(), "graceful_timeout exceeded, exiting immediately", "graceful_timeout", *wait)
os.Exit(1)
}
}

Expand Down Expand Up @@ -263,7 +307,9 @@ func main() {
logger := logf.Log.WithName("entrypoint")

// Set runtime.GOMAXPROCS to respect container limits if the env var GOMAXPROCS is not set or is invalid, preventing CPU throttling.
undo, err := maxprocs.Set(maxprocs.Logger(logger.Info))
undo, err := maxprocs.Set(maxprocs.Logger(func(format string, a ...interface{}) {
logger.WithName("maxprocs").Info(fmt.Sprintf(format, a...))
}))
defer undo()
if err != nil {
logger.Error(err, "failed to set GOMAXPROCS")
Expand Down Expand Up @@ -342,11 +388,15 @@ func main() {
log.Fatalf("Failed to create grpc client. Unknown protocol %s: %v", *protocol, err)
}

wg := sync.WaitGroup{}
logger.Info("Running http server ", "port", *httpPort)
go runHttpServer(createListener(*httpPort, logger), logger, predictor, clientRest, *httpPort, false, serverUrl, *namespace, *protocol, *sdepName, *prometheusPath)
httpStop := make(chan bool, 1)
go runHttpServer(&wg, httpStop, createListener(*httpPort, logger), logger, predictor, clientRest, *httpPort, false, serverUrl, *namespace, *protocol, *sdepName, *prometheusPath)

logger.Info("Running grpc server ", "port", *grpcPort)
runGrpcServer(createListener(*grpcPort, logger), logger, predictor, clientGrpc, serverUrl, *namespace, *protocol, *sdepName, annotations)
grpcStop := make(chan bool, 1)
go runGrpcServer(&wg, grpcStop, createListener(*grpcPort, logger), logger, predictor, clientGrpc, serverUrl, *namespace, *protocol, *sdepName, annotations)
waitForShutdown(logger, &wg, httpStop, grpcStop)
}

func createListener(port int, logger logr.Logger) net.Listener {
Expand Down
7 changes: 6 additions & 1 deletion executor/predictor/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"strings"
)

const EnvKeyEnginePredictor = "ENGINE_PREDICTOR"

func GetPredictor(predictorName, filename, sdepName, namespace string, configPath *string) (*v1.PredictorSpec, error) {
if filename != "" {
predictor, err := getPredictorFromFile(predictorName, filename)
Expand All @@ -25,7 +27,10 @@ func GetPredictor(predictorName, filename, sdepName, namespace string, configPat
}

func getPredictorFromEnv() (*v1.PredictorSpec, error) {
b64Predictor := os.Getenv("ENGINE_PREDICTOR")
b64Predictor, ok := os.LookupEnv(EnvKeyEnginePredictor)
if !ok {
return nil, fmt.Errorf("Predictor not found, enviroment variable %s not set", EnvKeyEnginePredictor)
}
if b64Predictor != "" {
bytes, err := base64.StdEncoding.DecodeString(b64Predictor)
if err != nil {
Expand Down
73 changes: 73 additions & 0 deletions executor/predictor/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package predictor

import (
"encoding/base64"
"fmt"
"os"
"testing"

. "github.com/onsi/gomega"
v1 "github.com/seldonio/seldon-core/operator/apis/machinelearning.seldon.io/v1"
)

func TestGetPredictorFromEnv(t *testing.T) {
t.Logf("Started")
g := NewGomegaWithT(t)
var b64Error base64.CorruptInputError

tests := []struct {
name string
key string
val string
predictor *v1.PredictorSpec
err error
}{
{
name: "missing env var",
err: fmt.Errorf("Predictor not found, enviroment variable %s not set", EnvKeyEnginePredictor),
},
{
name: "empty value",
key: EnvKeyEnginePredictor,
},
{
name: "non-b64 value",
key: EnvKeyEnginePredictor,
val: ":;,",
err: b64Error,
},
}

// unset existing env var & reset at end of test
val, ok := os.LookupEnv(EnvKeyEnginePredictor)
if ok {
if err := os.Unsetenv(EnvKeyEnginePredictor); err != nil {
t.Fatalf("failed to unset env var %v: %v", EnvKeyEnginePredictor, err)
}
defer func() {
os.Setenv(EnvKeyEnginePredictor, val)
}()
}

setenv := func(key, val string) {
if key == "" {
return
}
if err := os.Setenv(key, val); err != nil {
t.Fatalf("failed to set env var: %v", err)
}
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setenv(tt.key, tt.val)
got, err := getPredictorFromEnv()
g.Expect(got).Should(Equal(tt.predictor))
if tt.err == nil {
g.Expect(err).Should(BeNil())
} else {
g.Expect(err).Should(Equal(tt.err))
}
})
}
}

0 comments on commit d6edb8d

Please sign in to comment.