diff --git a/cmd/status/root_test.go b/cmd/status/root_test.go index 01d794c65..35edf545d 100644 --- a/cmd/status/root_test.go +++ b/cmd/status/root_test.go @@ -46,7 +46,7 @@ func TestStatusCmd(t *testing.T) { }) t.Run("case=block", func(t *testing.T) { - ctx := context.WithValue(context.Background(), client.ContextKeyTimeout, time.Millisecond) + ctx := context.WithValue(context.Background(), client.ContextKeyTimeout, 100*time.Millisecond) l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/internal/driver/daemon.go b/internal/driver/daemon.go index 2f2d020fb..89fce833a 100644 --- a/internal/driver/daemon.go +++ b/internal/driver/daemon.go @@ -107,7 +107,7 @@ func (r *RegistryDefault) ServeAll(ctx context.Context) error { innerCtx, cancel := context.WithCancel(ctx) defer cancel() - serveFuncs := []func(context.Context, chan<- struct{}) error{ + serveFuncs := []func(context.Context, chan<- struct{}) func() error{ r.serveRead, r.serveWrite, r.serveOPLSyntax, @@ -149,88 +149,94 @@ func (r *RegistryDefault) ServeAll(ctx context.Context) error { // We need to separate the setup (invoking the functions that return the serve functions) from running the serve // functions to mitigate race conditions in the HTTP router. for _, serve := range serveFuncs { - eg.Go(func() error { - return serve(innerCtx, doneShutdown) - }) + eg.Go(serve(innerCtx, doneShutdown)) } return eg.Wait() } -func (r *RegistryDefault) serveRead(ctx context.Context, done chan<- struct{}) error { +func (r *RegistryDefault) serveRead(ctx context.Context, done chan<- struct{}) func() error { rt, s := r.ReadRouter(ctx), r.ReadGRPCServer(ctx) if tracer := r.Tracer(ctx); tracer.IsLoaded() { rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - addr, listenFile := r.Config(ctx).ReadAPIListenOn() - return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), addr, listenFile, rt, s, done) + return func() error { + addr, listenFile := r.Config(ctx).ReadAPIListenOn() + return multiplexPort(ctx, r.Logger().WithField("endpoint", "read"), addr, listenFile, rt, s, done) + } } -func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) error { +func (r *RegistryDefault) serveWrite(ctx context.Context, done chan<- struct{}) func() error { rt, s := r.WriteRouter(ctx), r.WriteGRPCServer(ctx) if tracer := r.Tracer(ctx); tracer.IsLoaded() { rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - addr, listenFile := r.Config(ctx).WriteAPIListenOn() - return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), addr, listenFile, rt, s, done) + return func() error { + addr, listenFile := r.Config(ctx).WriteAPIListenOn() + return multiplexPort(ctx, r.Logger().WithField("endpoint", "write"), addr, listenFile, rt, s, done) + } } -func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct{}) error { +func (r *RegistryDefault) serveOPLSyntax(ctx context.Context, done chan<- struct{}) func() error { rt, s := r.OPLSyntaxRouter(ctx), r.OplGRPCServer(ctx) if tracer := r.Tracer(ctx); tracer.IsLoaded() { rt = otelx.TraceHandler(rt, otelhttp.WithTracerProvider(tracer.Provider())) } - addr, listenFile := r.Config(ctx).OPLSyntaxAPIListenOn() - return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), addr, listenFile, rt, s, done) + return func() error { + addr, listenFile := r.Config(ctx).OPLSyntaxAPIListenOn() + return multiplexPort(ctx, r.Logger().WithField("endpoint", "opl"), addr, listenFile, rt, s, done) + } } -func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{}) error { +func (r *RegistryDefault) serveMetrics(ctx context.Context, done chan<- struct{}) func() error { ctx, cancel := context.WithCancel(ctx) - defer cancel() - - addr, listenFile := r.Config(ctx).MetricsListenOn() - l, err := listenAndWriteFile(ctx, addr, listenFile) - if err != nil { - return err - } //nolint:gosec // graceful.WithDefaults already sets a timeout s := graceful.WithDefaults(&http.Server{ Handler: r.metricsRouter(ctx), }) - eg := &errgroup.Group{} + return func() error { + defer cancel() + eg := &errgroup.Group{} - eg.Go(func() error { - if err := s.Serve(l); !errors.Is(err, http.ErrServerClosed) { - return errors.WithStack(err) + addr, listenFile := r.Config(ctx).MetricsListenOn() + l, err := listenAndWriteFile(ctx, addr, listenFile) + if err != nil { + return err } - return nil - }) - eg.Go(func() (err error) { - defer func() { - l := r.Logger().WithField("endpoint", "metrics") - if err != nil { - l.WithError(err).Error("graceful shutdown failed") - } else { - l.Info("gracefully shutdown server") + + eg.Go(func() error { + if err := s.Serve(l); !errors.Is(err, http.ErrServerClosed) { + return errors.WithStack(err) } - done <- struct{}{} - }() + return nil + }) + eg.Go(func() (err error) { + defer func() { + l := r.Logger().WithField("endpoint", "metrics") + if err != nil { + l.WithError(err).Error("graceful shutdown failed") + } else { + l.Info("gracefully shutdown server") + } + done <- struct{}{} + }() - <-ctx.Done() - ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout) - defer cancel() - return s.Shutdown(ctx) - }) + <-ctx.Done() + ctx, cancel := context.WithTimeout(context.Background(), graceful.DefaultShutdownTimeout) + defer cancel() + return s.Shutdown(ctx) + }) - return eg.Wait() + return eg.Wait() + } } func multiplexPort(ctx context.Context, log *logrusx.Logger, addr, listenFile string, router http.Handler, grpcS *grpc.Server, done chan<- struct{}) error { diff --git a/internal/driver/daemon_test.go b/internal/driver/daemon_test.go index e84fb3031..0b4666d63 100644 --- a/internal/driver/daemon_test.go +++ b/internal/driver/daemon_test.go @@ -36,12 +36,8 @@ func TestScrapingEndpoint(t *testing.T) { eg := errgroup.Group{} doneShutdown := make(chan struct{}) - eg.Go(func() error { - return r.serveWrite(ctx, doneShutdown) - }) - eg.Go(func() error { - return r.serveMetrics(ctx, doneShutdown) - }) + eg.Go(r.serveWrite(ctx, doneShutdown)) + eg.Go(r.serveMetrics(ctx, doneShutdown)) _, writePort, _ := getAddr(t, "write") _, metricsPort, _ := getAddr(t, "metrics") @@ -104,9 +100,7 @@ func TestPanicRecovery(t *testing.T) { eg := errgroup.Group{} doneShutdown := make(chan struct{}) - eg.Go(func() error { - return r.serveWrite(ctx, doneShutdown) - }) + eg.Go(r.serveWrite(ctx, doneShutdown)) _, port, _ := getAddr(t, "write") diff --git a/internal/e2e/full_suit_test.go b/internal/e2e/full_suit_test.go index ab381729c..d16eeac49 100644 --- a/internal/e2e/full_suit_test.go +++ b/internal/e2e/full_suit_test.go @@ -53,7 +53,6 @@ const ( func Test(t *testing.T) { t.Parallel() for _, dsn := range dbx.GetDSNs(t, false) { - dsn := dsn t.Run(fmt.Sprintf("dsn=%s", dsn.Name), func(t *testing.T) { t.Parallel() @@ -98,7 +97,6 @@ func Test(t *testing.T) { syntaxRemote: oplAddr, }, } { - cl := cl t.Run(fmt.Sprintf("client=%T", cl), runCases(cl, namespaceTestMgr)) if tc, ok := cl.(transactClient); ok {