Skip to content

Commit

Permalink
Invoke DELETE on pod prepare-downscale path if any POSTs failed (#146)
Browse files Browse the repository at this point in the history
This addresses a bug in rollout-operator where:

1. Kubernetes receives a request to downscale a statefulset by `X` hosts.
2. The prepare-downscale admission webhook attempts to prepare `X` pods for shutdown by sending an HTTP `POST` to their handler identified by the `grafana.com/prepare-downscale-http-path` and `-port` annotations.
3. At least one of these requests fails. The admission webhook returns an error to Kubernetes, so the downscale is not approved.
4. 💥 But some hosts may have been prepared for downscale. 💥 

This PR adds cleanup logic to issue `DELETE` requests on all involved pods if any of the `POST`s failed. Notes:
* `DELETE` calls are attempted once.
* `DELETE` failures are logged but otherwise ignored.
* For simplicity, we'll invoke `DELETE` on all of the pods involved in the scaledown operation, not just ones that received a POST.

This doesn't fix the similar issue where replica count changing from 10->9->10 leaves that one pod prepared for shutdown. (But that's in the works.)
  • Loading branch information
seizethedave authored May 15, 2024
1 parent ea17193 commit 33c4fcf
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 38 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This operator coordinates the rollout of pods between different StatefulSets wit

## How updates work

The operator coordinates the rollout of pods belonging to `StatefulSets` with the `rollout-group` label and updates strategy set to `OnDelete`. The label value should identify the group of StatefulSets to which the StatefulSet belongs to. Make sure the statefulset has a label `name` in its `spec.template`, the operator uses it to find pods belonging to it.
The operator coordinates the rollout of pods belonging to `StatefulSets` with the `rollout-group` label and updates strategy set to `OnDelete`. The label value should identify the group of StatefulSets to which the StatefulSet belongs to. Make sure the StatefulSet has a label `name` in its `spec.template`, as the operator uses it to find pods belonging to it.

For example, given the following StatefulSets in a namespace:
- `ingester-zone-a` with `rollout-group: ingester`
Expand All @@ -25,9 +25,9 @@ For each **rollout group**, the operator **guarantees**:
1. Pods in a StatefulSet are rolled out if and only if all pods in all other StatefulSets of the same group are `Ready` (otherwise it will start or continue the rollout once this check is satisfied)
1. Pods are rolled out if and only if all StatefulSets in the same group have `OnDelete` update strategy (otherwise the operator will skip the group and log an error)
1. The maximum number of not-Ready pods in a StatefulSet doesn't exceed the value configured in the `rollout-max-unavailable` annotation (if not set, it defaults to `1`). Values:
- `<= 0`: invalid (will default to `1` and log a warning)
- `1`: pods are rolled out sequentially
- `> 1`: pods are rolled out in parallel (honoring the configured number of max unavailable pods)
- `<= 0`: invalid (will default to `1` and log a warning)
- `1`: pods are rolled out sequentially
- `> 1`: pods are rolled out in parallel (honoring the configured number of max unavailable pods)

## How scaling up and down works

Expand Down
105 changes: 72 additions & 33 deletions pkg/admission/prep_downscale.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func prepareDownscale(ctx context.Context, l log.Logger, ar v1.AdmissionReview,
return allowWarn(logger, fmt.Sprintf("%s, allowing the change", err))
}

// Since it's a downscale, check if the resource has the label that indicates it needs to be prepared to be downscaled.
if lbls[config.PrepareDownscaleLabelKey] != config.PrepareDownscaleLabelValue {
// Not labeled, nothing to do.
return &v1.AdmissionResponse{Allowed: true}
Expand Down Expand Up @@ -152,7 +153,6 @@ func prepareDownscale(ctx context.Context, l log.Logger, ar v1.AdmissionReview,
}
}

// Since it's a downscale, check if the resource has the label that indicates it needs to be prepared to be downscaled.
// Create a slice of endpoint addresses for pods to send HTTP POST requests to and to fail if any don't return 200
eps := createEndpoints(ar, oldInfo, newInfo, port, path)

Expand Down Expand Up @@ -460,6 +460,47 @@ func createEndpoints(ar v1.AdmissionReview, oldInfo, newInfo *objectInfo, port,
return eps
}

func invokePrepareShutdown(ctx context.Context, method string, parentLogger log.Logger, client httpClient, ep endpoint) error {
span := "admission.PreparePodForShutdown"
if method == http.MethodDelete {
span = "admission.UnpreparePodForShutdown"
}

logger, ctx := spanlogger.New(ctx, parentLogger, span, tenantResolver)
defer logger.Span.Finish()

logger.SetSpanAndLogTag("url", ep.url)
logger.SetSpanAndLogTag("index", ep.index)
logger.SetSpanAndLogTag("method", method)

req, err := http.NewRequestWithContext(ctx, method, "http://"+ep.url, nil)
if err != nil {
level.Error(logger).Log("msg", fmt.Sprintf("error creating HTTP %s request", method), "err", err)
return err
}

req.Header.Set("Content-Type", "application/json")
req, ht := nethttp.TraceRequest(opentracing.GlobalTracer(), req)
defer ht.Finish()

resp, err := client.Do(req)
if err != nil {
level.Error(logger).Log("msg", fmt.Sprintf("error sending HTTP %s request", method), "err", err)
return err
}

defer resp.Body.Close()

if resp.StatusCode/100 != 2 {
err := fmt.Errorf("HTTP %s request returned non-2xx status code", method)
body, readError := io.ReadAll(resp.Body)
level.Error(logger).Log("msg", "error received from shutdown endpoint", "err", err, "status", resp.StatusCode, "response_body", string(body))
return errors.Join(err, readError)
}
level.Debug(logger).Log("msg", "pod prepare-shutdown handler called", "method", method, "url", ep.url)
return nil
}

func sendPrepareShutdownRequests(ctx context.Context, logger log.Logger, client httpClient, eps []endpoint) error {
if len(eps) == 0 {
return nil
Expand All @@ -468,44 +509,42 @@ func sendPrepareShutdownRequests(ctx context.Context, logger log.Logger, client
span, ctx := opentracing.StartSpanFromContext(ctx, "admission.sendPrepareShutdownRequests()")
defer span.Finish()

g, _ := errgroup.WithContext(ctx)
for _, ep := range eps {
ep := ep // https://golang.org/doc/faq#closures_and_goroutines
g.Go(func() error {
logger, ctx := spanlogger.New(ctx, logger, "admission.PreparePodForShutdown", tenantResolver)
defer logger.Span.Finish()
// Attempt to POST to every prepare-shutdown endpoint. If any fail, we'll
// undo them all with a DELETE.

logger.SetSpanAndLogTag("url", ep.url)
logger.SetSpanAndLogTag("index", ep.index)

req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://"+ep.url, nil)
if err != nil {
level.Error(logger).Log("msg", "error creating HTTP POST request", "err", err)
}
const maxGoroutines = 32

req.Header.Set("Content-Type", "application/json")
req, ht := nethttp.TraceRequest(opentracing.GlobalTracer(), req)
defer ht.Finish()

resp, err := client.Do(req)
if err != nil {
level.Error(logger).Log("msg", "error sending HTTP POST request", "err", err)
g, ectx := errgroup.WithContext(ctx)
g.SetLimit(maxGoroutines)
for _, ep := range eps {
ep := ep
g.Go(func() error {
if err := ectx.Err(); err != nil {
return err
}

defer resp.Body.Close()

if resp.StatusCode/100 != 2 {
err := errors.New("HTTP POST request returned non-2xx status code")
body, readError := io.ReadAll(resp.Body)
level.Error(logger).Log("msg", "error received from shutdown endpoint", "err", err, "status", resp.StatusCode, "response_body", string(body))
return errors.Join(err, readError)
}
level.Debug(logger).Log("msg", "pod prepared for shutdown")
return nil
return invokePrepareShutdown(ectx, http.MethodPost, logger, client, ep)
})
}
return g.Wait()

err := g.Wait()
if err != nil {
// At least one failed. Undo them all.
level.Warn(logger).Log("msg", "failed to prepare hosts for shutdown. unpreparing...", "err", err)
undoGroup, _ := errgroup.WithContext(ctx)
undoGroup.SetLimit(maxGoroutines)
for _, ep := range eps {
ep := ep
undoGroup.Go(func() error {
if err := invokePrepareShutdown(ctx, http.MethodDelete, logger, client, ep); err != nil {
level.Warn(logger).Log("msg", "failed to undo prepare shutdown request", "url", ep.url, "err", err)
}
return nil
})
}
_ = undoGroup.Wait()
}

return err
}

var tenantResolver spanlogger.TenantResolver = noTenantResolver{}
Expand Down
109 changes: 109 additions & 0 deletions pkg/admission/prep_downscale_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package admission
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"reflect"
"strings"
"sync/atomic"
"testing"
"text/template"
"time"
Expand Down Expand Up @@ -123,9 +126,14 @@ type templateParams struct {

type fakeHttpClient struct {
statusCode int
mockDo func(*http.Request) (*http.Response, error)
}

func (f *fakeHttpClient) Do(req *http.Request) (resp *http.Response, err error) {
if f.mockDo != nil {
return f.mockDo(req)
}

return &http.Response{
StatusCode: f.statusCode,
Body: io.NopCloser(bytes.NewBuffer([]byte(""))),
Expand Down Expand Up @@ -269,6 +277,107 @@ func testPrepDownscaleWebhook(t *testing.T, oldReplicas, newReplicas int, option
}
}

func TestSendPrepareShutdown(t *testing.T) {
cases := map[string]struct {
numEndpoints int
lastPostsFail int
expectDeletes int
deletesFail bool
expectErr bool
}{
"no endpoints": {
numEndpoints: 0,
lastPostsFail: 0,
expectDeletes: 0,
expectErr: false,
},
"all posts succeed": {
numEndpoints: 64,
lastPostsFail: 0,
expectDeletes: 0,
expectErr: false,
},
"all posts fail": {
numEndpoints: 64,
lastPostsFail: 64,
expectDeletes: 64,
expectErr: true,
},
"last post fails": {
numEndpoints: 64,
lastPostsFail: 1,
expectDeletes: 64,
expectErr: true,
},
"delete failures should still call all deletes": {
numEndpoints: 64,
lastPostsFail: 64,
expectDeletes: 64,
deletesFail: true,
expectErr: true,
},
}

for n, c := range cases {
t.Run(n, func(t *testing.T) {
var postCalls atomic.Int32
var deleteCalls atomic.Int32

errResponse := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader("we've had a problem")),
}
successResponse := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("good")),
}

httpClient := &fakeHttpClient{
mockDo: func(r *http.Request) (*http.Response, error) {
if r.Method == http.MethodPost {
calls := postCalls.Add(1)
if calls > int32(c.numEndpoints-c.lastPostsFail) {
return errResponse, nil
} else {
return successResponse, nil
}
} else if r.Method == http.MethodDelete {
deleteCalls.Add(1)
if c.deletesFail {
return errResponse, nil
} else {
return successResponse, nil
}
}
panic("unexpected method")
},
}

endpoints := make([]endpoint, 0, c.numEndpoints)
for i := range cap(endpoints) {
endpoints = append(endpoints, endpoint{
url: fmt.Sprintf("url-something.foo.%d.example.biz", i),
index: i,
})
}

err := sendPrepareShutdownRequests(context.Background(), log.NewNopLogger(), httpClient, endpoints)
if c.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
if c.lastPostsFail > 0 {
// (It's >= because the goroutines may already have in-flight POSTs before realizing the context is canceled.)
assert.GreaterOrEqual(t, postCalls.Load(), int32(c.numEndpoints-c.lastPostsFail), "at least |e|-|fails| should have been sent a POST")
} else {
assert.Equal(t, int32(c.numEndpoints), postCalls.Load(), "all endpoints should have been sent a POST")
}
assert.Equal(t, int32(c.expectDeletes), deleteCalls.Load())
})
}
}

func TestFindStatefulSetWithNonUpdatedReplicas(t *testing.T) {
namespace := "test"
rolloutGroup := "ingester"
Expand Down
1 change: 0 additions & 1 deletion pkg/controller/replicas.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ func minimumTimeHasElapsed(follower *v1.StatefulSet, all []*v1.StatefulSet, logg
)

return timeSinceDownscale > minTimeSinceDownscale, nil

}

// getMostRecentDownscale gets the time of the most recent downscale of any statefulset besides
Expand Down

0 comments on commit 33c4fcf

Please sign in to comment.