diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index ee30ed5cc8b4a..be13c50dcede1 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -143,13 +143,13 @@ var ( // pdHTTPRequest defines the interface to send a request to pd and return the result in bytes. type pdHTTPRequest func(ctx context.Context, addr string, prefix string, - cli *http.Client, method string, body io.Reader) ([]byte, error) + cli *http.Client, method string, body []byte) ([]byte, error) // pdRequest is a func to send an HTTP to pd and return the result bytes. func pdRequest( ctx context.Context, addr string, prefix string, - cli *http.Client, method string, body io.Reader) ([]byte, error) { + cli *http.Client, method string, body []byte) ([]byte, error) { _, respBody, err := pdRequestWithCode(ctx, addr, prefix, cli, method, body) return respBody, err } @@ -157,7 +157,7 @@ func pdRequest( func pdRequestWithCode( ctx context.Context, addr string, prefix string, - cli *http.Client, method string, body io.Reader) (int, []byte, error) { + cli *http.Client, method string, body []byte) (int, []byte, error) { u, err := url.Parse(addr) if err != nil { return 0, nil, errors.Trace(err) @@ -167,10 +167,13 @@ func pdRequestWithCode( req *http.Request resp *http.Response ) + if body == nil { + body = []byte("") + } count := 0 // the total retry duration: 120*1 = 2min for { - req, err = http.NewRequestWithContext(ctx, method, reqURL, body) + req, err = http.NewRequestWithContext(ctx, method, reqURL, bytes.NewBuffer(body)) if err != nil { return 0, nil, errors.Trace(err) } @@ -197,6 +200,8 @@ func pdRequestWithCode( (err != nil && !common.IsRetryableError(err)) { break } + log.Warn("request failed, will retry later", + zap.String("url", reqURL), zap.Int("retry-count", count), zap.Error(err)) if resp != nil { _ = resp.Body.Close() } @@ -454,7 +459,11 @@ func (p *PdController) doPauseSchedulers(ctx context.Context, for _, scheduler := range schedulers { prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler) for _, addr := range p.getAllPDAddrs() { +<<<<<<< HEAD _, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body)) +======= + _, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, body) +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) if err == nil { removedSchedulers = append(removedSchedulers, scheduler) break @@ -537,7 +546,11 @@ func (p *PdController) resumeSchedulerWith(ctx context.Context, schedulers []str for _, scheduler := range schedulers { prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler) for _, addr := range p.getAllPDAddrs() { +<<<<<<< HEAD _, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body)) +======= + _, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, body) +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) if err == nil { break } @@ -626,7 +639,7 @@ func (p *PdController) doUpdatePDScheduleConfig( return errors.Trace(err) } _, e := post(ctx, addr, prefix, - p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + p.cli, http.MethodPost, reqData) if e == nil { return nil } @@ -883,7 +896,11 @@ func (p *PdController) RecoverBaseAllocID(ctx context.Context, id uint64) error }) var err error for _, addr := range p.getAllPDAddrs() { +<<<<<<< HEAD _, e := pdRequest(ctx, addr, baseAllocIDPrefix, p.cli, http.MethodPost, bytes.NewBuffer(reqData)) +======= + _, e := pdRequest(ctx, addr, pdapi.BaseAllocID, p.cli, http.MethodPost, reqData) +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) if e != nil { log.Warn("failed to recover base alloc id", zap.String("addr", addr), zap.Error(e)) err = e @@ -907,7 +924,11 @@ func (p *PdController) ResetTS(ctx context.Context, ts uint64) error { }) var err error for _, addr := range p.getAllPDAddrs() { +<<<<<<< HEAD code, _, e := pdRequestWithCode(ctx, addr, resetTSPrefix, p.cli, http.MethodPost, bytes.NewBuffer(reqData)) +======= + code, _, e := pdRequestWithCode(ctx, addr, pdapi.ResetTS, p.cli, http.MethodPost, reqData) +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) if e != nil { // for pd version <= 6.2, if the given ts < current ts of pd, pd returns StatusForbidden. // it's not an error for br @@ -983,8 +1004,13 @@ func (p *PdController) CreateOrUpdateRegionLabelRule(ctx context.Context, rule L var lastErr error addrs := p.getAllPDAddrs() for i, addr := range addrs { +<<<<<<< HEAD _, lastErr = pdRequest(ctx, addr, regionLabelPrefix, p.cli, http.MethodPost, bytes.NewBuffer(reqData)) +======= + _, lastErr = pdRequest(ctx, addr, pdapi.RegionLabelRule, + p.cli, http.MethodPost, reqData) +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) if lastErr == nil { return nil } diff --git a/br/pkg/pdutil/pd_serial_test.go b/br/pkg/pdutil/pd_serial_test.go index 39c2fae8dd014..271ca8ee2ebae 100644 --- a/br/pkg/pdutil/pd_serial_test.go +++ b/br/pkg/pdutil/pd_serial_test.go @@ -3,7 +3,6 @@ package pdutil import ( - "bytes" "context" "encoding/hex" "encoding/json" @@ -31,7 +30,7 @@ func TestScheduler(t *testing.T) { defer cancel() scheduler := "balance-leader-scheduler" - mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return nil, errors.New("failed") } schedulerPauseCh := make(chan struct{}) @@ -66,7 +65,7 @@ func TestScheduler(t *testing.T) { _, err = pdController.listSchedulersWith(ctx, mock) require.EqualError(t, err, "failed") - mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return []byte(`["` + scheduler + `"]`), nil } @@ -86,7 +85,7 @@ func TestScheduler(t *testing.T) { func TestGetClusterVersion(t *testing.T) { pdController := &PdController{addrs: []string{"", ""}} // two endpoints counter := 0 - mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { counter++ if counter <= 1 { return nil, errors.New("mock error") @@ -99,7 +98,7 @@ func TestGetClusterVersion(t *testing.T) { require.NoError(t, err) require.Equal(t, "test", respString) - mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return nil, errors.New("mock error") } _, err = pdController.getClusterVersionWith(ctx, mock) @@ -129,7 +128,7 @@ func TestRegionCount(t *testing.T) { require.Equal(t, 3, len(regions.Regions)) mock := func( - _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, + _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) u, e := url.Parse(query) @@ -180,6 +179,9 @@ func TestPDRequestRetry(t *testing.T) { count := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { count++ + bytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.Equal(t, "test", string(bytes)) if count <= pdRequestRetryTime-1 { w.WriteHeader(http.StatusGatewayTimeout) return @@ -195,8 +197,7 @@ func TestPDRequestRetry(t *testing.T) { cli.Transport.(*http.Transport).DisableKeepAlives = true taddr := ts.URL - body := bytes.NewBuffer([]byte("test")) - _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, body) + _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, []byte("test")) require.NoError(t, reqErr) ts.Close() count = 0 @@ -268,7 +269,7 @@ func TestStoreInfo(t *testing.T) { }, } mock := func( - _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, + _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) require.Equal(t, "http://mock/pd/api/v1/store/1", query)