diff --git a/client/circuitbreaker/circuit_breaker.go b/client/circuitbreaker/circuit_breaker.go new file mode 100644 index 00000000000..b5a4c53ebb5 --- /dev/null +++ b/client/circuitbreaker/circuit_breaker.go @@ -0,0 +1,302 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package circuitbreaker + +import ( + "fmt" + "strings" + "sync" + "time" + + "github.com/tikv/pd/client/errs" + + "github.com/prometheus/client_golang/prometheus" + m "github.com/tikv/pd/client/metrics" + "go.uber.org/zap" + + "github.com/pingcap/log" +) + +// Overloading is a type describing service return value +type Overloading bool + +const ( + // No means the service is not overloaded + No = false + // Yes means the service is overloaded + Yes = true +) + +// Settings describes configuration for Circuit Breaker +type Settings struct { + // Defines the error rate threshold to trip the circuit breaker. + ErrorRateThresholdPct uint32 + // Defines the average qps over the `error_rate_window` that must be met before evaluating the error rate threshold. + MinQPSForOpen uint32 + // Defines how long to track errors before evaluating error_rate_threshold. + ErrorRateWindow time.Duration + // Defines how long to wait after circuit breaker is open before go to half-open state to send a probe request. + CoolDownInterval time.Duration + // Defines how many subsequent requests to test after cooldown period before fully close the circuit. + HalfOpenSuccessCount uint32 +} + +// AlwaysClosedSettings is a configuration that never trips the circuit breaker. +var AlwaysClosedSettings = Settings{ + ErrorRateThresholdPct: 0, // never trips + ErrorRateWindow: 10 * time.Second, // effectively results in testing for new settings every 10 seconds + MinQPSForOpen: 10, + CoolDownInterval: 10 * time.Second, + HalfOpenSuccessCount: 1, +} + +// CircuitBreaker is a state machine to prevent sending requests that are likely to fail. +type CircuitBreaker[T any] struct { + config *Settings + name string + + mutex sync.Mutex + state *State[T] + + successCounter prometheus.Counter + errorCounter prometheus.Counter + overloadCounter prometheus.Counter + fastFailCounter prometheus.Counter +} + +// StateType is a type that represents a state of CircuitBreaker. +type StateType int + +// States of CircuitBreaker. +const ( + StateClosed StateType = iota + StateOpen + StateHalfOpen +) + +// String implements stringer interface. +func (s StateType) String() string { + switch s { + case StateClosed: + return "closed" + case StateOpen: + return "open" + case StateHalfOpen: + return "half-open" + default: + return fmt.Sprintf("unknown state: %d", s) + } +} + +var replacer = strings.NewReplacer(" ", "_", "-", "_") + +// NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. +func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] { + cb := new(CircuitBreaker[T]) + cb.name = name + cb.config = &st + cb.state = cb.newState(time.Now(), StateClosed) + + metricName := replacer.Replace(name) + cb.successCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "success") + cb.errorCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "error") + cb.overloadCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "overload") + cb.fastFailCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "fast_fail") + return cb +} + +// ChangeSettings changes the CircuitBreaker settings. +// The changes will be reflected only in the next evaluation window. +func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + apply(cb.config) +} + +// Execute calls the given function if the CircuitBreaker is closed and returns the result of execution. +// Execute returns an error instantly if the CircuitBreaker is open. +// https://github.com/tikv/rfcs/blob/master/text/0115-circuit-breaker.md +func (cb *CircuitBreaker[T]) Execute(call func() (T, Overloading, error)) (T, error) { + state, err := cb.onRequest() + if err != nil { + cb.fastFailCounter.Inc() + var defaultValue T + return defaultValue, err + } + + defer func() { + e := recover() + if e != nil { + cb.emitMetric(Yes, err) + cb.onResult(state, Yes) + panic(e) + } + }() + + result, overloaded, err := call() + cb.emitMetric(overloaded, err) + cb.onResult(state, overloaded) + return result, err +} + +func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + state, err := cb.state.onRequest(cb) + cb.state = state + return state, err +} + +func (cb *CircuitBreaker[T]) onResult(state *State[T], overloaded Overloading) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + // even if the circuit breaker already moved to a new state while the request was in progress, + // it is still ok to update the old state, but it is not relevant anymore + state.onResult(overloaded) +} + +func (cb *CircuitBreaker[T]) emitMetric(overloaded Overloading, err error) { + switch overloaded { + case No: + cb.successCounter.Inc() + case Yes: + cb.overloadCounter.Inc() + default: + panic("unknown state") + } + if err != nil { + cb.errorCounter.Inc() + } +} + +// State represents the state of CircuitBreaker. +type State[T any] struct { + stateType StateType + cb *CircuitBreaker[T] + end time.Time + + pendingCount uint32 + successCount uint32 + failureCount uint32 +} + +// newState creates a new State with the given configuration and reset all success/failure counters. +func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State[T] { + var end time.Time + var pendingCount uint32 + switch stateType { + case StateClosed: + end = now.Add(cb.config.ErrorRateWindow) + case StateOpen: + end = now.Add(cb.config.CoolDownInterval) + case StateHalfOpen: + // we transition to HalfOpen state on the first request after the cooldown period, + // so we start with 1 pending request + pendingCount = 1 + default: + panic("unknown state") + } + return &State[T]{ + cb: cb, + stateType: stateType, + pendingCount: pendingCount, + end: end, + } +} + +// onRequest transitions the state to the next state based on the current state and the previous requests results +// The implementation represents a state machine for CircuitBreaker +// All state transitions happens at the request evaluation time only +// Circuit breaker start with a closed state, allows all requests to pass through and always lasts for a fixed duration of `Settings.ErrorRateWindow`. +// If `Settings.ErrorRateThresholdPct` is breached at the end of the window, then it moves to Open state, otherwise it moves to a new Closed state with a new window. +// Open state fails all request, it has a fixed duration of `Settings.CoolDownInterval` and always moves to HalfOpen state at the end of the interval. +// HalfOpen state does not have a fixed duration and lasts till `Settings.HalfOpenSuccessCount` are evaluated. +// If any of `Settings.HalfOpenSuccessCount` fails then it moves back to Open state, otherwise it moves to Closed state. +func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { + var now = time.Now() + switch s.stateType { + case StateClosed: + if now.After(s.end) { + // ErrorRateWindow is over, let's evaluate the error rate + if s.cb.config.ErrorRateThresholdPct > 0 { // otherwise circuit breaker is disabled + total := s.failureCount + s.successCount + if total > 0 { + observedErrorRatePct := s.failureCount * 100 / total + if total >= uint32(s.cb.config.ErrorRateWindow.Seconds())*s.cb.config.MinQPSForOpen && observedErrorRatePct >= s.cb.config.ErrorRateThresholdPct { + // the error threshold is breached, let's move to open state and start failing all requests + log.Error("Circuit breaker tripped. Starting to fail all requests", + zap.String("name", cb.name), + zap.Uint32("observedErrorRatePct", observedErrorRatePct), + zap.String("config", fmt.Sprintf("%+v", cb.config))) + return cb.newState(now, StateOpen), errs.ErrCircuitBreakerOpen + } + } + } + // the error threshold is not breached or there were not enough requests to evaluate it, + // continue in the closed state and allow all requests + return cb.newState(now, StateClosed), nil + } + // continue in closed state till ErrorRateWindow is over + return s, nil + case StateOpen: + if now.After(s.end) { + // CoolDownInterval is over, it is time to transition to half-open state + log.Info("Circuit breaker cooldown period is over. Transitioning to half-open state to test the service", + zap.String("name", cb.name), + zap.String("config", fmt.Sprintf("%+v", cb.config))) + return cb.newState(now, StateHalfOpen), nil + } else { + // continue in the open state till CoolDownInterval is over + return s, errs.ErrCircuitBreakerOpen + } + case StateHalfOpen: + // do we need some expire time here in case of one of pending requests is stuck forever? + if s.failureCount > 0 { + // there were some failures during half-open state, let's go back to open state to wait a bit longer + log.Error("Circuit breaker goes from half-open to open again as errors persist and continue to fail all requests", + zap.String("name", cb.name), + zap.String("config", fmt.Sprintf("%+v", cb.config))) + return cb.newState(now, StateOpen), errs.ErrCircuitBreakerOpen + } else if s.successCount == s.cb.config.HalfOpenSuccessCount { + // all probe requests are succeeded, we can move to closed state and allow all requests + log.Info("Circuit breaker is closed. Start allowing all requests", + zap.String("name", cb.name), + zap.String("config", fmt.Sprintf("%+v", cb.config))) + return cb.newState(now, StateClosed), nil + } else if s.pendingCount < s.cb.config.HalfOpenSuccessCount { + // allow more probe requests and continue in half-open state + s.pendingCount++ + return s, nil + } else { + // continue in half-open state till all probe requests are done and fail all other requests for now + return s, errs.ErrCircuitBreakerOpen + } + default: + panic("unknown state") + } +} + +func (s *State[T]) onResult(overloaded Overloading) { + switch overloaded { + case No: + s.successCount++ + case Yes: + s.failureCount++ + default: + panic("unknown state") + } +} diff --git a/client/circuitbreaker/circuit_breaker_test.go b/client/circuitbreaker/circuit_breaker_test.go new file mode 100644 index 00000000000..ca77b7f9f99 --- /dev/null +++ b/client/circuitbreaker/circuit_breaker_test.go @@ -0,0 +1,270 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package circuitbreaker + +import ( + "errors" + "testing" + "time" + + "github.com/tikv/pd/client/errs" + + "github.com/stretchr/testify/require" +) + +// advance emulate the state machine clock moves forward by the given duration +func (cb *CircuitBreaker[T]) advance(duration time.Duration) { + cb.state.end = cb.state.end.Add(-duration - 1) +} + +var settings = Settings{ + ErrorRateThresholdPct: 50, + MinQPSForOpen: 10, + ErrorRateWindow: 30 * time.Second, + CoolDownInterval: 10 * time.Second, + HalfOpenSuccessCount: 2, +} + +var minCountToOpen = int(settings.MinQPSForOpen * uint32(settings.ErrorRateWindow.Seconds())) + +func TestCircuitBreaker_Execute_Wrapper_Return_Values(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + originalError := errors.New("circuit breaker is open") + + result, err := cb.Execute(func() (int, Overloading, error) { + return 42, No, originalError + }) + re.Equal(err, originalError) + re.Equal(42, result) + + // same by interpret the result as overloading error + result, err = cb.Execute(func() (int, Overloading, error) { + return 42, Yes, originalError + }) + re.Equal(err, originalError) + re.Equal(42, result) +} + +func TestCircuitBreaker_OpenState(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + driveQPS(cb, minCountToOpen, Yes, re) + re.Equal(StateClosed, cb.state.stateType) + assertSucceeds(cb, re) // no error till ErrorRateWindow is finished + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) +} + +func TestCircuitBreaker_CloseState_Not_Enough_QPS(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen/2, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) +} + +func TestCircuitBreaker_CloseState_Not_Enough_Error_Rate(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen/4, Yes, re) + driveQPS(cb, minCountToOpen, No, re) + cb.advance(settings.ErrorRateWindow) + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) +} + +func TestCircuitBreaker_Half_Open_To_Closed(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) + cb.advance(settings.CoolDownInterval) + assertSucceeds(cb, re) + re.Equal(StateHalfOpen, cb.state.stateType) + assertSucceeds(cb, re) + re.Equal(StateHalfOpen, cb.state.stateType) + // state always transferred on the incoming request + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) +} + +func TestCircuitBreaker_Half_Open_To_Open(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) + cb.advance(settings.CoolDownInterval) + assertSucceeds(cb, re) + re.Equal(StateHalfOpen, cb.state.stateType) + _, err := cb.Execute(func() (int, Overloading, error) { + return 42, Yes, nil // this trip circuit breaker again + }) + re.NoError(err) + re.Equal(StateHalfOpen, cb.state.stateType) + // state always transferred on the incoming request + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) +} + +// in half open state, circuit breaker will allow only HalfOpenSuccessCount pending and should fast fail all other request till HalfOpenSuccessCount requests is completed +// this test moves circuit breaker to the half open state and verifies that requests above HalfOpenSuccessCount are failing +func TestCircuitBreaker_Half_Open_Fail_Over_Pending_Count(t *testing.T) { + re := require.New(t) + cb := newCircuitBreakerMovedToHalfOpenState(re) + + // the next request will move circuit breaker into the half open state + var started []chan bool + var waited []chan bool + var ended []chan bool + for range settings.HalfOpenSuccessCount { + start := make(chan bool) + wait := make(chan bool) + end := make(chan bool) + started = append(started, start) + waited = append(waited, wait) + ended = append(ended, end) + go func() { + defer func() { + end <- true + }() + _, err := cb.Execute(func() (int, Overloading, error) { + start <- true + <-wait + return 42, No, nil + }) + re.NoError(err) + }() + } + // make sure all requests are started + for i := range started { + <-started[i] + } + // validate that requests beyond HalfOpenSuccessCount are failing + assertFastFail(cb, re) + re.Equal(StateHalfOpen, cb.state.stateType) + // unblock pending requests and wait till they are completed + for i := range ended { + waited[i] <- true + <-ended[i] + } + // validate that circuit breaker moves to closed state + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) + // make sure that after moving to open state all counters are reset + re.Equal(uint32(1), cb.state.successCount) +} + +func TestCircuitBreaker_Count_Only_Requests_In_Same_Window(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + + start := make(chan bool) + wait := make(chan bool) + end := make(chan bool) + go func() { + defer func() { + end <- true + }() + _, err := cb.Execute(func() (int, Overloading, error) { + start <- true + <-wait + return 42, No, nil + }) + re.NoError(err) + }() + <-start // make sure the request is started + // assert running request is not counted + re.Equal(uint32(0), cb.state.successCount) + + // advance request to the next window + cb.advance(settings.ErrorRateWindow) + assertSucceeds(cb, re) + re.Equal(uint32(1), cb.state.successCount) + + // complete the request from the previous window + wait <- true // resume + <-end // wait for the request to complete + // assert request from last window is not counted + re.Equal(uint32(1), cb.state.successCount) +} + +func TestCircuitBreaker_ChangeSettings(t *testing.T) { + re := require.New(t) + + cb := NewCircuitBreaker[int]("test_cb", AlwaysClosedSettings) + driveQPS(cb, int(AlwaysClosedSettings.MinQPSForOpen*uint32(AlwaysClosedSettings.ErrorRateWindow.Seconds())), Yes, re) + cb.advance(AlwaysClosedSettings.ErrorRateWindow) + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) + + cb.ChangeSettings(func(config *Settings) { + config.ErrorRateThresholdPct = settings.ErrorRateThresholdPct + }) + re.Equal(settings.ErrorRateThresholdPct, cb.config.ErrorRateThresholdPct) + + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) +} + +func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreaker[int] { + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) + cb.advance(settings.CoolDownInterval) + return cb +} + +func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *require.Assertions) { + for range count { + _, err := cb.Execute(func() (int, Overloading, error) { + return 42, overload, nil + }) + re.NoError(err) + } +} + +func assertFastFail(cb *CircuitBreaker[int], re *require.Assertions) { + var executed = false + _, err := cb.Execute(func() (int, Overloading, error) { + executed = true + return 42, No, nil + }) + re.Equal(err, errs.ErrCircuitBreakerOpen) + re.False(executed) +} + +func assertSucceeds(cb *CircuitBreaker[int], re *require.Assertions) { + result, err := cb.Execute(func() (int, Overloading, error) { + return 42, No, nil + }) + re.NoError(err) + re.Equal(42, result) +} diff --git a/client/client.go b/client/client.go index 49ce73bf9fb..c271f10591d 100644 --- a/client/client.go +++ b/client/client.go @@ -22,6 +22,8 @@ import ( "sync" "time" + cb "github.com/tikv/pd/client/circuitbreaker" + "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -456,6 +458,12 @@ func (c *client) UpdateOption(option opt.DynamicOption, value any) error { return errors.New("[pd] invalid value type for TSOClientRPCConcurrency option, it should be int") } c.inner.option.SetTSOClientRPCConcurrency(value) + case opt.RegionMetadataCircuitBreakerSettings: + applySettingsChange, ok := value.(func(config *cb.Settings)) + if !ok { + return errors.New("[pd] invalid value type for RegionMetadataCircuitBreakerSettings option, it should be pd.Settings") + } + c.inner.regionMetaCircuitBreaker.ChangeSettings(applySettingsChange) default: return errors.New("[pd] unsupported client option") } @@ -650,7 +658,10 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req) + resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { + region, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req) + return region, isOverloaded(err), err + }) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { @@ -690,7 +701,10 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req) + resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req) + return resp, isOverloaded(err), err + }) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { @@ -730,7 +744,10 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req) + resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req) + return resp, isOverloaded(err), err + }) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { diff --git a/client/errs/errno.go b/client/errs/errno.go index df8b677525a..25665f01017 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -70,6 +70,7 @@ var ( ErrClientGetServingEndpoint = errors.Normalize("get serving endpoint failed", errors.RFCCodeText("PD:client:ErrClientGetServingEndpoint")) ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID")) ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) + ErrCircuitBreakerOpen = errors.Normalize("circuit breaker is open", errors.RFCCodeText("PD:client:ErrCircuitBreakerOpen")) ) // grpcutil errors diff --git a/client/inner_client.go b/client/inner_client.go index 7be35e9a3b9..ae15c763854 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -6,9 +6,12 @@ import ( "sync" "time" + "google.golang.org/grpc/codes" + "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + cb "github.com/tikv/pd/client/circuitbreaker" "github.com/tikv/pd/client/clients/tso" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" @@ -16,6 +19,7 @@ import ( sd "github.com/tikv/pd/client/servicediscovery" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/status" ) const ( @@ -24,10 +28,11 @@ const ( ) type innerClient struct { - keyspaceID uint32 - svrUrls []string - pdSvcDiscovery sd.ServiceDiscovery - tokenDispatcher *tokenDispatcher + keyspaceID uint32 + svrUrls []string + pdSvcDiscovery sd.ServiceDiscovery + tokenDispatcher *tokenDispatcher + regionMetaCircuitBreaker *cb.CircuitBreaker[*pdpb.GetRegionResponse] // For service mode switching. serviceModeKeeper @@ -53,6 +58,7 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error { } return err } + c.regionMetaCircuitBreaker = cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", c.option.RegionMetaCircuitBreakerSettings) return nil } @@ -245,3 +251,12 @@ func (c *innerClient) dispatchTSORequestWithRetry(ctx context.Context) tso.TSFut } return req } + +func isOverloaded(err error) cb.Overloading { + switch status.Code(errors.Cause(err)) { + case codes.DeadlineExceeded, codes.Unavailable, codes.ResourceExhausted: + return cb.Yes + default: + return cb.No + } +} diff --git a/client/metrics/metrics.go b/client/metrics/metrics.go index da36217eb34..3a3199c74a6 100644 --- a/client/metrics/metrics.go +++ b/client/metrics/metrics.go @@ -56,6 +56,8 @@ var ( OngoingRequestCountGauge *prometheus.GaugeVec // EstimateTSOLatencyGauge is the gauge to indicate the estimated latency of TSO requests. EstimateTSOLatencyGauge *prometheus.GaugeVec + // CircuitBreakerCounters is a vector for different circuit breaker counters + CircuitBreakerCounters *prometheus.CounterVec ) func initMetrics(constLabels prometheus.Labels) { @@ -144,6 +146,15 @@ func initMetrics(constLabels prometheus.Labels) { Help: "Estimated latency of an RTT of getting TSO", ConstLabels: constLabels, }, []string{"stream"}) + + CircuitBreakerCounters = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "pd_client", + Subsystem: "request", + Name: "circuit_breaker_count", + Help: "Circuit Breaker counters", + ConstLabels: constLabels, + }, []string{"name", "success"}) } // CmdDurationXXX and CmdFailedDurationXXX are the durations of the client commands. @@ -259,4 +270,5 @@ func registerMetrics() { prometheus.MustRegister(TSOBatchSendLatency) prometheus.MustRegister(RequestForwarded) prometheus.MustRegister(EstimateTSOLatencyGauge) + prometheus.MustRegister(CircuitBreakerCounters) } diff --git a/client/opt/option.go b/client/opt/option.go index faeb232195f..9a80a895cc0 100644 --- a/client/opt/option.go +++ b/client/opt/option.go @@ -18,6 +18,8 @@ import ( "sync/atomic" "time" + cb "github.com/tikv/pd/client/circuitbreaker" + "github.com/pingcap/errors" "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/client/pkg/retry" @@ -47,6 +49,8 @@ const ( EnableFollowerHandle // TSOClientRPCConcurrency controls the amount of ongoing TSO RPC requests at the same time in a single TSO client. TSOClientRPCConcurrency + // RegionMetadataCircuitBreakerSettings controls settings for circuit breaker for region metadata requests. + RegionMetadataCircuitBreakerSettings dynamicOptionCount ) @@ -67,16 +71,18 @@ type Option struct { // Dynamic options. dynamicOptions [dynamicOptionCount]atomic.Value - EnableTSOFollowerProxyCh chan struct{} + EnableTSOFollowerProxyCh chan struct{} + RegionMetaCircuitBreakerSettings cb.Settings } // NewOption creates a new PD client option with the default values set. func NewOption() *Option { co := &Option{ - Timeout: defaultPDTimeout, - MaxRetryTimes: maxInitClusterRetries, - EnableTSOFollowerProxyCh: make(chan struct{}, 1), - InitMetrics: true, + Timeout: defaultPDTimeout, + MaxRetryTimes: maxInitClusterRetries, + EnableTSOFollowerProxyCh: make(chan struct{}, 1), + InitMetrics: true, + RegionMetaCircuitBreakerSettings: cb.AlwaysClosedSettings, } co.dynamicOptions[MaxTSOBatchWaitInterval].Store(defaultMaxTSOBatchWaitInterval) @@ -147,6 +153,11 @@ func (o *Option) GetTSOClientRPCConcurrency() int { return o.dynamicOptions[TSOClientRPCConcurrency].Load().(int) } +// GetRegionMetadataCircuitBreakerSettings gets circuit breaker settings for PD region metadata calls. +func (o *Option) GetRegionMetadataCircuitBreakerSettings() cb.Settings { + return o.dynamicOptions[RegionMetadataCircuitBreakerSettings].Load().(cb.Settings) +} + // ClientOption configures client. type ClientOption func(*Option) @@ -201,6 +212,13 @@ func WithInitMetricsOption(initMetrics bool) ClientOption { } } +// WithRegionMetaCircuitBreaker configures the client with circuit breaker for region meta calls +func WithRegionMetaCircuitBreaker(config cb.Settings) ClientOption { + return func(op *Option) { + op.RegionMetaCircuitBreakerSettings = config + } +} + // WithBackoffer configures the client with backoffer. func WithBackoffer(bo *retry.Backoffer) ClientOption { return func(op *Option) {