diff --git a/domain/infosync/info.go b/domain/infosync/info.go index a1cbd79bf4991..22907239c8a80 100644 --- a/domain/infosync/info.go +++ b/domain/infosync/info.go @@ -334,6 +334,39 @@ func GetAllServerInfo(ctx context.Context) (map[string]*ServerInfo, error) { return is.getAllServerInfo(ctx) } +// UpdateServerLabel updates the server label for global info syncer. +func UpdateServerLabel(ctx context.Context, labels map[string]string) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + // when etcdCli is nil, the server infos are generated from the latest config, no need to update. + if is.etcdCli == nil { + return nil + } + selfInfo, err := is.getServerInfoByID(ctx, is.info.ID) + if err != nil { + return err + } + changed := false + for k, v := range labels { + if selfInfo.Labels[k] != v { + changed = true + selfInfo.Labels[k] = v + } + } + if !changed { + return nil + } + infoBuf, err := selfInfo.Marshal() + if err != nil { + return errors.Trace(err) + } + str := string(hack.String(infoBuf)) + err = util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, is.serverInfoPath, str, clientv3.WithLease(is.session.Lease())) + return err +} + // DeleteTiFlashTableSyncProgress is used to delete the tiflash table replica sync progress. func DeleteTiFlashTableSyncProgress(tableInfo *model.TableInfo) error { is, err := getGlobalInfoSyncer() diff --git a/server/BUILD.bazel b/server/BUILD.bazel index ba3096330dbaa..8dc671b01f58c 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -214,6 +214,7 @@ go_test( "@com_github_tikv_client_go_v2//testutils", "@com_github_tikv_client_go_v2//tikv", "@com_github_tikv_client_go_v2//tikvrpc", + "@io_etcd_go_etcd_tests_v3//integration", "@io_opencensus_go//stats/view", "@org_uber_go_goleak//:goleak", "@org_uber_go_zap//:zap", diff --git a/server/http_handler.go b/server/http_handler.go index 83aa10b48f28e..c5c186f37a9f8 100644 --- a/server/http_handler.go +++ b/server/http_handler.go @@ -2215,6 +2215,11 @@ func (h labelHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + if err := infosync.UpdateServerLabel(ctx, labels); err != nil { + logutil.BgLogger().Error("update etcd labels failed", zap.Any("labels", cfg.Labels), zap.Error(err)) + } + cancel() cfg.Labels = labels config.StoreGlobalConfig(&cfg) logutil.BgLogger().Info("update server labels", zap.Any("labels", cfg.Labels)) diff --git a/server/http_handler_test.go b/server/http_handler_test.go index 926a5d99b6997..628757aa4ac20 100644 --- a/server/http_handler_test.go +++ b/server/http_handler_test.go @@ -16,6 +16,7 @@ package server import ( "bytes" + "context" "crypto/tls" "crypto/x509" "crypto/x509/pkix" @@ -38,6 +39,7 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" @@ -59,6 +61,7 @@ import ( "github.com/pingcap/tidb/util/rowcodec" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/tikv" + "go.etcd.io/etcd/tests/v3/integration" "go.uber.org/zap" ) @@ -1201,6 +1204,64 @@ func TestSetLabels(t *testing.T) { }) } +func TestSetLabelsWithEtcd(t *testing.T) { + ts := createBasicHTTPHandlerTestSuite() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ts.startServer(t) + defer ts.stopServer(t) + + integration.BeforeTestExternal(t) + cluster := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1}) + defer cluster.Terminate(t) + client := cluster.RandClient() + infosync.SetEtcdClient(client) + ts.domain.InfoSyncer().Restart(ctx) + + testUpdateLabels := func(labels, expected map[string]string) { + buffer := bytes.NewBuffer([]byte{}) + require.Nil(t, json.NewEncoder(buffer).Encode(labels)) + resp, err := ts.postStatus("/labels", "application/json", buffer) + require.NoError(t, err) + require.NotNil(t, resp) + defer func() { + require.NoError(t, resp.Body.Close()) + }() + require.Equal(t, http.StatusOK, resp.StatusCode) + newLabels := config.GetGlobalConfig().Labels + require.Equal(t, newLabels, expected) + servers, err := infosync.GetAllServerInfo(ctx) + require.NoError(t, err) + for _, server := range servers { + for k, expectV := range expected { + v, ok := server.Labels[k] + require.True(t, ok) + require.Equal(t, expectV, v) + } + return + } + require.Fail(t, "no server found") + } + + labels := map[string]string{ + "zone": "us-west-1", + "test": "123", + } + testUpdateLabels(labels, labels) + + updated := map[string]string{ + "zone": "bj-1", + } + labels["zone"] = "bj-1" + testUpdateLabels(updated, labels) + + // reset the global variable + config.UpdateGlobal(func(conf *config.Config) { + conf.Labels = map[string]string{} + }) +} + func TestSetLabelsConcurrentWithGetLabel(t *testing.T) { ts := createBasicHTTPHandlerTestSuite()