diff --git a/pkg/domain/infosync/info.go b/pkg/domain/infosync/info.go index d7c78fd852f6d..b20a0a18f1262 100644 --- a/pkg/domain/infosync/info.go +++ b/pkg/domain/infosync/info.go @@ -405,6 +405,39 @@ func GetAllServerInfo(ctx context.Context) (map[string]*ServerInfo, error) { return is.getAllServerInfo(ctx) } +// UpdateServerLabel updates the server label for global info syncer. +func UpdateServerLabel(ctx context.Context, labels map[string]string) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + // when etcdCli is nil, the server infos are generated from the latest config, no need to update. + if is.etcdCli == nil { + return nil + } + selfInfo, err := is.getServerInfoByID(ctx, is.info.ID) + if err != nil { + return err + } + changed := false + for k, v := range labels { + if selfInfo.Labels[k] != v { + changed = true + selfInfo.Labels[k] = v + } + } + if !changed { + return nil + } + infoBuf, err := selfInfo.Marshal() + if err != nil { + return errors.Trace(err) + } + str := string(hack.String(infoBuf)) + err = util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, is.serverInfoPath, str, clientv3.WithLease(is.session.Lease())) + return err +} + // DeleteTiFlashTableSyncProgress is used to delete the tiflash table replica sync progress. func DeleteTiFlashTableSyncProgress(tableInfo *model.TableInfo) error { is, err := getGlobalInfoSyncer() diff --git a/pkg/server/handler/tests/BUILD.bazel b/pkg/server/handler/tests/BUILD.bazel index f992b3170f853..e7ec8a26d0fbd 100644 --- a/pkg/server/handler/tests/BUILD.bazel +++ b/pkg/server/handler/tests/BUILD.bazel @@ -9,7 +9,7 @@ go_test( "main_test.go", ], flaky = True, - shard_count = 36, + shard_count = 37, deps = [ "//pkg/config", "//pkg/ddl", @@ -54,6 +54,7 @@ go_test( "@com_github_pingcap_log//:log", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//tikv", + "@io_etcd_go_etcd_tests_v3//integration", "@org_uber_go_goleak//:goleak", "@org_uber_go_zap//:zap", ], diff --git a/pkg/server/handler/tests/http_handler_test.go b/pkg/server/handler/tests/http_handler_test.go index b04f10f3379a0..217a988f410cd 100644 --- a/pkg/server/handler/tests/http_handler_test.go +++ b/pkg/server/handler/tests/http_handler_test.go @@ -16,6 +16,7 @@ package tests import ( "bytes" + "context" "crypto/tls" "crypto/x509" "crypto/x509/pkix" @@ -66,6 +67,7 @@ import ( "github.com/pingcap/tidb/pkg/util/rowcodec" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/tikv" + "go.etcd.io/etcd/tests/v3/integration" "go.uber.org/zap" ) @@ -1202,6 +1204,64 @@ func TestSetLabels(t *testing.T) { }) } +func TestSetLabelsWithEtcd(t *testing.T) { + ts := createBasicHTTPHandlerTestSuite() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ts.startServer(t) + defer ts.stopServer(t) + + integration.BeforeTestExternal(t) + cluster := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1}) + defer cluster.Terminate(t) + client := cluster.RandClient() + infosync.SetEtcdClient(client) + ts.domain.InfoSyncer().Restart(ctx) + + testUpdateLabels := func(labels, expected map[string]string) { + buffer := bytes.NewBuffer([]byte{}) + require.Nil(t, json.NewEncoder(buffer).Encode(labels)) + resp, err := ts.PostStatus("/labels", "application/json", buffer) + require.NoError(t, err) + require.NotNil(t, resp) + defer func() { + require.NoError(t, resp.Body.Close()) + }() + require.Equal(t, http.StatusOK, resp.StatusCode) + newLabels := config.GetGlobalConfig().Labels + require.Equal(t, newLabels, expected) + servers, err := infosync.GetAllServerInfo(ctx) + require.NoError(t, err) + for _, server := range servers { + for k, expectV := range expected { + v, ok := server.Labels[k] + require.True(t, ok) + require.Equal(t, expectV, v) + } + return + } + require.Fail(t, "no server found") + } + + labels := map[string]string{ + "zone": "us-west-1", + "test": "123", + } + testUpdateLabels(labels, labels) + + updated := map[string]string{ + "zone": "bj-1", + } + labels["zone"] = "bj-1" + testUpdateLabels(updated, labels) + + // reset the global variable + config.UpdateGlobal(func(conf *config.Config) { + conf.Labels = map[string]string{} + }) +} + func TestSetLabelsConcurrentWithGetLabel(t *testing.T) { ts := createBasicHTTPHandlerTestSuite() diff --git a/pkg/server/handler/tikvhandler/tikv_handler.go b/pkg/server/handler/tikvhandler/tikv_handler.go index 7dc3029254bfd..aa525016a5ce9 100644 --- a/pkg/server/handler/tikvhandler/tikv_handler.go +++ b/pkg/server/handler/tikvhandler/tikv_handler.go @@ -2000,6 +2000,11 @@ func (LabelHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + if err := infosync.UpdateServerLabel(ctx, labels); err != nil { + logutil.BgLogger().Error("update etcd labels failed", zap.Any("labels", cfg.Labels), zap.Error(err)) + } + cancel() cfg.Labels = labels config.StoreGlobalConfig(&cfg) logutil.BgLogger().Info("update server labels", zap.Any("labels", cfg.Labels))