diff --git a/README.md b/README.md index aaea5bd84..71c82476a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ This operator coordinates the rollout of pods between different StatefulSets wit ## How updates work -The operator coordinates the rollout of pods belonging to `StatefulSets` with the `rollout-group` label and updates strategy set to `OnDelete`. The label value should identify the group of StatefulSets to which the StatefulSet belongs to. Make sure the statefulset has a label `name` in its `spec.template`, the operator uses it to find pods belonging to it. +The operator coordinates the rollout of pods belonging to `StatefulSets` with the `rollout-group` label and updates strategy set to `OnDelete`. The label value should identify the group of StatefulSets to which the StatefulSet belongs to. Make sure the StatefulSet has a label `name` in its `spec.template`, as the operator uses it to find pods belonging to it. For example, given the following StatefulSets in a namespace: - `ingester-zone-a` with `rollout-group: ingester` @@ -25,9 +25,9 @@ For each **rollout group**, the operator **guarantees**: 1. Pods in a StatefulSet are rolled out if and only if all pods in all other StatefulSets of the same group are `Ready` (otherwise it will start or continue the rollout once this check is satisfied) 1. Pods are rolled out if and only if all StatefulSets in the same group have `OnDelete` update strategy (otherwise the operator will skip the group and log an error) 1. The maximum number of not-Ready pods in a StatefulSet doesn't exceed the value configured in the `rollout-max-unavailable` annotation (if not set, it defaults to `1`). Values: - - `<= 0`: invalid (will default to `1` and log a warning) - - `1`: pods are rolled out sequentially - - `> 1`: pods are rolled out in parallel (honoring the configured number of max unavailable pods) + - `<= 0`: invalid (will default to `1` and log a warning) + - `1`: pods are rolled out sequentially + - `> 1`: pods are rolled out in parallel (honoring the configured number of max unavailable pods) ## How scaling up and down works diff --git a/pkg/admission/prep_downscale.go b/pkg/admission/prep_downscale.go index 2318436ff..140055688 100644 --- a/pkg/admission/prep_downscale.go +++ b/pkg/admission/prep_downscale.go @@ -90,6 +90,7 @@ func prepareDownscale(ctx context.Context, l log.Logger, ar v1.AdmissionReview, return allowWarn(logger, fmt.Sprintf("%s, allowing the change", err)) } + // Since it's a downscale, check if the resource has the label that indicates it needs to be prepared to be downscaled. if lbls[config.PrepareDownscaleLabelKey] != config.PrepareDownscaleLabelValue { // Not labeled, nothing to do. return &v1.AdmissionResponse{Allowed: true} @@ -152,7 +153,6 @@ func prepareDownscale(ctx context.Context, l log.Logger, ar v1.AdmissionReview, } } - // Since it's a downscale, check if the resource has the label that indicates it needs to be prepared to be downscaled. // Create a slice of endpoint addresses for pods to send HTTP POST requests to and to fail if any don't return 200 eps := createEndpoints(ar, oldInfo, newInfo, port, path) @@ -460,6 +460,47 @@ func createEndpoints(ar v1.AdmissionReview, oldInfo, newInfo *objectInfo, port, return eps } +func invokePrepareShutdown(ctx context.Context, method string, parentLogger log.Logger, client httpClient, ep endpoint) error { + span := "admission.PreparePodForShutdown" + if method == http.MethodDelete { + span = "admission.UnpreparePodForShutdown" + } + + logger, ctx := spanlogger.New(ctx, parentLogger, span, tenantResolver) + defer logger.Span.Finish() + + logger.SetSpanAndLogTag("url", ep.url) + logger.SetSpanAndLogTag("index", ep.index) + logger.SetSpanAndLogTag("method", method) + + req, err := http.NewRequestWithContext(ctx, method, "http://"+ep.url, nil) + if err != nil { + level.Error(logger).Log("msg", fmt.Sprintf("error creating HTTP %s request", method), "err", err) + return err + } + + req.Header.Set("Content-Type", "application/json") + req, ht := nethttp.TraceRequest(opentracing.GlobalTracer(), req) + defer ht.Finish() + + resp, err := client.Do(req) + if err != nil { + level.Error(logger).Log("msg", fmt.Sprintf("error sending HTTP %s request", method), "err", err) + return err + } + + defer resp.Body.Close() + + if resp.StatusCode/100 != 2 { + err := fmt.Errorf("HTTP %s request returned non-2xx status code", method) + body, readError := io.ReadAll(resp.Body) + level.Error(logger).Log("msg", "error received from shutdown endpoint", "err", err, "status", resp.StatusCode, "response_body", string(body)) + return errors.Join(err, readError) + } + level.Debug(logger).Log("msg", "pod prepare-shutdown handler called", "method", method, "url", ep.url) + return nil +} + func sendPrepareShutdownRequests(ctx context.Context, logger log.Logger, client httpClient, eps []endpoint) error { if len(eps) == 0 { return nil @@ -468,44 +509,42 @@ func sendPrepareShutdownRequests(ctx context.Context, logger log.Logger, client span, ctx := opentracing.StartSpanFromContext(ctx, "admission.sendPrepareShutdownRequests()") defer span.Finish() - g, _ := errgroup.WithContext(ctx) - for _, ep := range eps { - ep := ep // https://golang.org/doc/faq#closures_and_goroutines - g.Go(func() error { - logger, ctx := spanlogger.New(ctx, logger, "admission.PreparePodForShutdown", tenantResolver) - defer logger.Span.Finish() + // Attempt to POST to every prepare-shutdown endpoint. If any fail, we'll + // undo them all with a DELETE. - logger.SetSpanAndLogTag("url", ep.url) - logger.SetSpanAndLogTag("index", ep.index) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://"+ep.url, nil) - if err != nil { - level.Error(logger).Log("msg", "error creating HTTP POST request", "err", err) - } + const maxGoroutines = 32 - req.Header.Set("Content-Type", "application/json") - req, ht := nethttp.TraceRequest(opentracing.GlobalTracer(), req) - defer ht.Finish() - - resp, err := client.Do(req) - if err != nil { - level.Error(logger).Log("msg", "error sending HTTP POST request", "err", err) + g, ectx := errgroup.WithContext(ctx) + g.SetLimit(maxGoroutines) + for _, ep := range eps { + ep := ep + g.Go(func() error { + if err := ectx.Err(); err != nil { return err } - - defer resp.Body.Close() - - if resp.StatusCode/100 != 2 { - err := errors.New("HTTP POST request returned non-2xx status code") - body, readError := io.ReadAll(resp.Body) - level.Error(logger).Log("msg", "error received from shutdown endpoint", "err", err, "status", resp.StatusCode, "response_body", string(body)) - return errors.Join(err, readError) - } - level.Debug(logger).Log("msg", "pod prepared for shutdown") - return nil + return invokePrepareShutdown(ectx, http.MethodPost, logger, client, ep) }) } - return g.Wait() + + err := g.Wait() + if err != nil { + // At least one failed. Undo them all. + level.Warn(logger).Log("msg", "failed to prepare hosts for shutdown. unpreparing...", "err", err) + undoGroup, _ := errgroup.WithContext(ctx) + undoGroup.SetLimit(maxGoroutines) + for _, ep := range eps { + ep := ep + undoGroup.Go(func() error { + if err := invokePrepareShutdown(ctx, http.MethodDelete, logger, client, ep); err != nil { + level.Warn(logger).Log("msg", "failed to undo prepare shutdown request", "url", ep.url, "err", err) + } + return nil + }) + } + _ = undoGroup.Wait() + } + + return err } var tenantResolver spanlogger.TenantResolver = noTenantResolver{} diff --git a/pkg/admission/prep_downscale_test.go b/pkg/admission/prep_downscale_test.go index b266f0b0b..45251fe6c 100644 --- a/pkg/admission/prep_downscale_test.go +++ b/pkg/admission/prep_downscale_test.go @@ -3,12 +3,15 @@ package admission import ( "bytes" "context" + "fmt" "io" "net/http" "net/http/httptest" "net/url" "os" "reflect" + "strings" + "sync/atomic" "testing" "text/template" "time" @@ -123,9 +126,14 @@ type templateParams struct { type fakeHttpClient struct { statusCode int + mockDo func(*http.Request) (*http.Response, error) } func (f *fakeHttpClient) Do(req *http.Request) (resp *http.Response, err error) { + if f.mockDo != nil { + return f.mockDo(req) + } + return &http.Response{ StatusCode: f.statusCode, Body: io.NopCloser(bytes.NewBuffer([]byte(""))), @@ -269,6 +277,107 @@ func testPrepDownscaleWebhook(t *testing.T, oldReplicas, newReplicas int, option } } +func TestSendPrepareShutdown(t *testing.T) { + cases := map[string]struct { + numEndpoints int + lastPostsFail int + expectDeletes int + deletesFail bool + expectErr bool + }{ + "no endpoints": { + numEndpoints: 0, + lastPostsFail: 0, + expectDeletes: 0, + expectErr: false, + }, + "all posts succeed": { + numEndpoints: 64, + lastPostsFail: 0, + expectDeletes: 0, + expectErr: false, + }, + "all posts fail": { + numEndpoints: 64, + lastPostsFail: 64, + expectDeletes: 64, + expectErr: true, + }, + "last post fails": { + numEndpoints: 64, + lastPostsFail: 1, + expectDeletes: 64, + expectErr: true, + }, + "delete failures should still call all deletes": { + numEndpoints: 64, + lastPostsFail: 64, + expectDeletes: 64, + deletesFail: true, + expectErr: true, + }, + } + + for n, c := range cases { + t.Run(n, func(t *testing.T) { + var postCalls atomic.Int32 + var deleteCalls atomic.Int32 + + errResponse := &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader("we've had a problem")), + } + successResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("good")), + } + + httpClient := &fakeHttpClient{ + mockDo: func(r *http.Request) (*http.Response, error) { + if r.Method == http.MethodPost { + calls := postCalls.Add(1) + if calls > int32(c.numEndpoints-c.lastPostsFail) { + return errResponse, nil + } else { + return successResponse, nil + } + } else if r.Method == http.MethodDelete { + deleteCalls.Add(1) + if c.deletesFail { + return errResponse, nil + } else { + return successResponse, nil + } + } + panic("unexpected method") + }, + } + + endpoints := make([]endpoint, 0, c.numEndpoints) + for i := range cap(endpoints) { + endpoints = append(endpoints, endpoint{ + url: fmt.Sprintf("url-something.foo.%d.example.biz", i), + index: i, + }) + } + + err := sendPrepareShutdownRequests(context.Background(), log.NewNopLogger(), httpClient, endpoints) + if c.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + if c.lastPostsFail > 0 { + // (It's >= because the goroutines may already have in-flight POSTs before realizing the context is canceled.) + assert.GreaterOrEqual(t, postCalls.Load(), int32(c.numEndpoints-c.lastPostsFail), "at least |e|-|fails| should have been sent a POST") + } else { + assert.Equal(t, int32(c.numEndpoints), postCalls.Load(), "all endpoints should have been sent a POST") + } + assert.Equal(t, int32(c.expectDeletes), deleteCalls.Load()) + }) + } +} + func TestFindStatefulSetWithNonUpdatedReplicas(t *testing.T) { namespace := "test" rolloutGroup := "ingester" diff --git a/pkg/controller/replicas.go b/pkg/controller/replicas.go index cbe66546b..0340d4da0 100644 --- a/pkg/controller/replicas.go +++ b/pkg/controller/replicas.go @@ -106,7 +106,6 @@ func minimumTimeHasElapsed(follower *v1.StatefulSet, all []*v1.StatefulSet, logg ) return timeSinceDownscale > minTimeSinceDownscale, nil - } // getMostRecentDownscale gets the time of the most recent downscale of any statefulset besides