diff --git a/domain/domain.go b/domain/domain.go index 564ce8088d751..66fcf3ca0e3b3 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -17,7 +17,9 @@ package domain import ( "context" "fmt" + "math" "math/rand" + "sort" "strconv" "strings" "sync" @@ -1106,8 +1108,12 @@ func (do *Domain) closestReplicaReadCheckLoop(ctx context.Context, pdClient pd.C } } +// Periodically check and update the replica-read status when `tidb_replica_read` is set to "closest-adaptive" +// We disable "closest-adaptive" in following conditions to ensure the read traffic is evenly distributed across +// all AZs: +// - There are no TiKV servers in the AZ of this tidb instance +// - The AZ if this tidb contains more tidb than other AZ and this tidb's id is the bigger one. func (do *Domain) checkReplicaRead(ctx context.Context, pdClient pd.Client) error { - // fast path do.sysVarCache.RLock() replicaRead := do.sysVarCache.global[variable.TiDBReplicaRead] do.sysVarCache.RUnlock() @@ -1116,6 +1122,24 @@ func (do *Domain) checkReplicaRead(ctx context.Context, pdClient pd.Client) erro logutil.BgLogger().Debug("closest replica read is not enabled, skip check!", zap.String("mode", replicaRead)) return nil } + + serverInfo, err := infosync.GetServerInfo() + if err != nil { + return err + } + zone := "" + for k, v := range serverInfo.Labels { + if k == placement.DCLabelKey && v != "" { + zone = v + break + } + } + if zone == "" { + logutil.BgLogger().Debug("server contains no 'zone' label, disable closest replica read", zap.Any("labels", serverInfo.Labels)) + variable.SetEnableAdaptiveReplicaRead(false) + return nil + } + stores, err := pdClient.GetAllStores(ctx, pd.WithExcludeTombstone()) if err != nil { return err @@ -1135,32 +1159,48 @@ func (do *Domain) checkReplicaRead(ctx context.Context, pdClient pd.Client) erro } } - enabled := false - // if stores don't have zone labels or are distribued in 1 zone, just disable cloeset replica read. - if len(storeZones) > 1 { - enabled = true - servers, err := infosync.GetAllServerInfo(ctx) - if err != nil { - return err - } - for _, s := range servers { - if v, ok := s.Labels[placement.DCLabelKey]; ok && v != "" { - if _, ok := storeZones[v]; !ok { - enabled = false - break - } + // no stores in this AZ + if _, ok := storeZones[zone]; !ok { + variable.SetEnableAdaptiveReplicaRead(false) + return nil + } + + servers, err := infosync.GetAllServerInfo(ctx) + if err != nil { + return err + } + svrIdsInThisZone := make([]string, 0) + for _, s := range servers { + if v, ok := s.Labels[placement.DCLabelKey]; ok && v != "" { + if _, ok := storeZones[v]; ok { storeZones[v] += 1 - } - } - if enabled { - for _, count := range storeZones { - if count == 0 { - enabled = false - break + if v == zone { + svrIdsInThisZone = append(svrIdsInThisZone, s.ID) } } } } + enabledCount := math.MaxInt + for _, count := range storeZones { + if count < enabledCount { + enabledCount = count + } + } + // sort tidb in the same AZ by ID and disable the tidb with bigger ID + // because ID is unchangeable, so this is a simple and stable algorithm to select + // some instances across all tidb servers. + if enabledCount < len(svrIdsInThisZone) { + sort.Slice(svrIdsInThisZone, func(i, j int) bool { + return strings.Compare(svrIdsInThisZone[i], svrIdsInThisZone[j]) < 0 + }) + } + enabled := true + for _, s := range svrIdsInThisZone[enabledCount:] { + if s == serverInfo.ID { + enabled = false + break + } + } if variable.SetEnableAdaptiveReplicaRead(enabled) { logutil.BgLogger().Info("tidb server adaptive closest replica read is changed", zap.Bool("enable", enabled)) diff --git a/domain/domain_test.go b/domain/domain_test.go index 621f0fb2c431f..c117ac2244b2e 100644 --- a/domain/domain_test.go +++ b/domain/domain_test.go @@ -17,6 +17,8 @@ package domain import ( "context" "crypto/tls" + "encoding/json" + "fmt" "net" "runtime" "testing" @@ -247,7 +249,29 @@ func TestClosestReplicaReadChecker(t *testing.T) { } dom.sysVarCache.Unlock() - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo", `return("")`)) + makeFailpointRes := func(v interface{}) string { + bytes, err := json.Marshal(v) + require.NoError(t, err) + return fmt.Sprintf("return(`%s`)", string(bytes)) + } + + mockedAllServerInfos := map[string]*infosync.ServerInfo{ + "s1": { + ID: "s1", + Labels: map[string]string{ + "zone": "zone1", + }, + }, + "s2": { + ID: "s2", + Labels: map[string]string{ + "zone": "zone2", + }, + }, + } + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo", makeFailpointRes(mockedAllServerInfos))) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/infosync/mockGetServerInfo", makeFailpointRes(mockedAllServerInfos["s2"]))) stores := []*metapb.Store{ { @@ -304,8 +328,77 @@ func TestClosestReplicaReadChecker(t *testing.T) { require.False(t, variable.IsAdaptiveReplicaReadEnabled()) } + // partial matches + mockedAllServerInfos = map[string]*infosync.ServerInfo{ + "s1": { + ID: "s1", + Labels: map[string]string{ + "zone": "zone1", + }, + }, + "s2": { + ID: "s2", + Labels: map[string]string{ + "zone": "zone2", + }, + }, + "s22": { + ID: "s22", + Labels: map[string]string{ + "zone": "zone2", + }, + }, + "s3": { + ID: "s3", + Labels: map[string]string{ + "zone": "zone3", + }, + }, + "s4": { + ID: "s4", + Labels: map[string]string{ + "zone": "zone4", + }, + }, + } + pdClient.stores = stores + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo", makeFailpointRes(mockedAllServerInfos))) + cases := []struct { + id string + matches bool + }{ + { + id: "s1", + matches: true, + }, + { + id: "s2", + matches: true, + }, + { + id: "s22", + matches: false, + }, + { + id: "s3", + matches: true, + }, + { + id: "s4", + matches: false, + }, + } + for _, c := range cases { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/infosync/mockGetServerInfo", makeFailpointRes(mockedAllServerInfos[c.id]))) + variable.SetEnableAdaptiveReplicaRead(!c.matches) + err = dom.checkReplicaRead(ctx, pdClient) + require.Nil(t, err) + require.Equal(t, c.matches, variable.IsAdaptiveReplicaReadEnabled()) + } + variable.SetEnableAdaptiveReplicaRead(true) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/domain/infosync/mockGetServerInfo")) } type mockInfoPdClient struct { diff --git a/domain/infosync/info.go b/domain/infosync/info.go index a2af6c5dfa58f..c501d7f16d695 100644 --- a/domain/infosync/info.go +++ b/domain/infosync/info.go @@ -282,6 +282,11 @@ func SetMockTiFlash(tiflash *MockTiFlash) { // GetServerInfo gets self server static information. func GetServerInfo() (*ServerInfo, error) { + failpoint.Inject("mockGetServerInfo", func(v failpoint.Value) { + var res ServerInfo + err := json.Unmarshal([]byte(v.(string)), &res) + failpoint.Return(&res, err) + }) is, err := getGlobalInfoSyncer() if err != nil { return nil, err @@ -316,20 +321,10 @@ func (is *InfoSyncer) getServerInfoByID(ctx context.Context, id string) (*Server // GetAllServerInfo gets all servers static information from etcd. func GetAllServerInfo(ctx context.Context) (map[string]*ServerInfo, error) { - failpoint.Inject("mockGetAllServerInfo", func() { - res := map[string]*ServerInfo{ - "fa598405-a08e-4e74-83ff-75c30b1daedc": { - Labels: map[string]string{ - "zone": "zone1", - }, - }, - "ad84dbbd-5a50-4742-a73c-4f674d41d4bd": { - Labels: map[string]string{ - "zone": "zone2", - }, - }, - } - failpoint.Return(res, nil) + failpoint.Inject("mockGetAllServerInfo", func(val failpoint.Value) { + res := make(map[string]*ServerInfo) + err := json.Unmarshal([]byte(val.(string)), &res) + failpoint.Return(res, err) }) is, err := getGlobalInfoSyncer() if err != nil {