diff --git a/.github/workflows/check.yaml b/.github/workflows/check.yaml index 4fbed3641e5..47ef287d73f 100644 --- a/.github/workflows/check.yaml +++ b/.github/workflows/check.yaml @@ -6,6 +6,7 @@ concurrency: jobs: statics: runs-on: ubuntu-latest + timeout-minutes: 8 steps: - uses: actions/setup-go@v2 with: diff --git a/client/client_test.go b/client/client_test.go index 7b9470ace5b..c80b78bb96b 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -16,38 +16,33 @@ package pd import ( "context" + "reflect" "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/client/testutil" "go.uber.org/goleak" "google.golang.org/grpc" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&testClientSuite{}) - -type testClientSuite struct{} - -func (s *testClientSuite) TestTsLessEqual(c *C) { - c.Assert(tsLessEqual(9, 9, 9, 9), IsTrue) - c.Assert(tsLessEqual(8, 9, 9, 8), IsTrue) - c.Assert(tsLessEqual(9, 8, 8, 9), IsFalse) - c.Assert(tsLessEqual(9, 8, 9, 6), IsFalse) - c.Assert(tsLessEqual(9, 6, 9, 8), IsTrue) +func TestTsLessEqual(t *testing.T) { + re := require.New(t) + re.True(tsLessEqual(9, 9, 9, 9)) + re.True(tsLessEqual(8, 9, 9, 8)) + re.False(tsLessEqual(9, 8, 8, 9)) + re.False(tsLessEqual(9, 8, 9, 6)) + re.True(tsLessEqual(9, 6, 9, 8)) } -func (s *testClientSuite) TestUpdateURLs(c *C) { +func TestUpdateURLs(t *testing.T) { + re := require.New(t) members := []*pdpb.Member{ {Name: "pd4", ClientUrls: []string{"tmp://pd4"}}, {Name: "pd1", ClientUrls: []string{"tmp://pd1"}}, @@ -63,40 +58,35 @@ func (s *testClientSuite) TestUpdateURLs(c *C) { cli := &baseClient{option: newOption()} cli.urls.Store([]string{}) cli.updateURLs(members[1:]) - c.Assert(cli.GetURLs(), DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2]})) + re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs())) cli.updateURLs(members[1:]) - c.Assert(cli.GetURLs(), DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2]})) + re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs())) cli.updateURLs(members) - c.Assert(cli.GetURLs(), DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]})) + re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]}), cli.GetURLs())) } const testClientURL = "tmp://test.url:5255" -var _ = Suite(&testClientCtxSuite{}) - -type testClientCtxSuite struct{} - -func (s *testClientCtxSuite) TestClientCtx(c *C) { +func TestClientCtx(t *testing.T) { + re := require.New(t) start := time.Now() ctx, cancel := context.WithTimeout(context.TODO(), time.Second*3) defer cancel() _, err := NewClientWithContext(ctx, []string{testClientURL}, SecurityOption{}) - c.Assert(err, NotNil) - c.Assert(time.Since(start), Less, time.Second*5) + re.Error(err) + re.Less(time.Since(start), time.Second*5) } -func (s *testClientCtxSuite) TestClientWithRetry(c *C) { +func TestClientWithRetry(t *testing.T) { + re := require.New(t) start := time.Now() _, err := NewClientWithContext(context.TODO(), []string{testClientURL}, SecurityOption{}, WithMaxErrorRetry(5)) - c.Assert(err, NotNil) - c.Assert(time.Since(start), Less, time.Second*10) + re.Error(err) + re.Less(time.Since(start), time.Second*10) } -var _ = Suite(&testClientDialOptionSuite{}) - -type testClientDialOptionSuite struct{} - -func (s *testClientDialOptionSuite) TestGRPCDialOption(c *C) { +func TestGRPCDialOption(t *testing.T) { + re := require.New(t) start := time.Now() ctx, cancel := context.WithTimeout(context.TODO(), 500*time.Millisecond) defer cancel() @@ -111,15 +101,12 @@ func (s *testClientDialOptionSuite) TestGRPCDialOption(c *C) { cli.urls.Store([]string{testClientURL}) cli.option.gRPCDialOptions = []grpc.DialOption{grpc.WithBlock()} err := cli.updateMember() - c.Assert(err, NotNil) - c.Assert(time.Since(start), Greater, 500*time.Millisecond) + re.Error(err) + re.Greater(time.Since(start), 500*time.Millisecond) } -var _ = Suite(&testTsoRequestSuite{}) - -type testTsoRequestSuite struct{} - -func (s *testTsoRequestSuite) TestTsoRequestWait(c *C) { +func TestTsoRequestWait(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) req := &tsoRequest{ done: make(chan error, 1), @@ -130,7 +117,7 @@ func (s *testTsoRequestSuite) TestTsoRequestWait(c *C) { } cancel() _, _, err := req.Wait() - c.Assert(errors.Cause(err), Equals, context.Canceled) + re.ErrorIs(errors.Cause(err), context.Canceled) ctx, cancel = context.WithCancel(context.Background()) req = &tsoRequest{ @@ -142,5 +129,5 @@ func (s *testTsoRequestSuite) TestTsoRequestWait(c *C) { } cancel() _, _, err = req.Wait() - c.Assert(errors.Cause(err), Equals, context.Canceled) + re.ErrorIs(errors.Cause(err), context.Canceled) } diff --git a/client/go.mod b/client/go.mod index 893cada680d..56380b0b51c 100644 --- a/client/go.mod +++ b/client/go.mod @@ -4,12 +4,12 @@ go 1.16 require ( github.com/opentracing/opentracing-go v1.2.0 - github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0 github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee github.com/prometheus/client_golang v1.11.0 + github.com/stretchr/testify v1.7.0 go.uber.org/goleak v1.1.11 go.uber.org/zap v1.20.0 google.golang.org/grpc v1.43.0 diff --git a/client/go.sum b/client/go.sum index 6682bdb2893..90019b9b382 100644 --- a/client/go.sum +++ b/client/go.sum @@ -69,7 +69,6 @@ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.12.1/go.mod h1:8XEsbTttt/W+VvjtQhLACqCisSPWTxCZ7sBRjU6iH9c= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= @@ -84,8 +83,10 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -97,9 +98,6 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= -github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8/go.mod h1:B1+S9LNcuMyLH/4HMTViQOJevkGiik3wW2AN9zb2fNQ= -github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0 h1:HVl5539r48eA+uDuX/ziBmQCxzT1pGrzWbKuXT46Bq0= -github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0/go.mod h1:PYMCGwN0JHjoqGr3HrZoD+b8Tgx8bKnArhSq8YVzUMc= github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTmyFqUwr+jcCvpVkK7sumiz+ko5H9eq4= @@ -108,7 +106,6 @@ github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 h1:C3N3itkduZXDZ github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00/go.mod h1:4qGtCB0QK0wBzKtFEGDhxXnSnbQApw1gc9siScUl8ew= github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a h1:TxdHGOFeNa1q1mVv6TgReayf26iI4F8PQUm6RnZ/V/E= github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= -github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee h1:VO2t6IBpfvW34TdtD/G10VvnGqjLic1jzOuHjUb5VqM= github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -137,7 +134,6 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -154,8 +150,6 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= -go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= @@ -163,21 +157,14 @@ go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= -go.uber.org/multierr v1.4.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= -go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.12.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.20.0 h1:N4oPlghZwYG55MlU6LXk/Zp00FVNE9X9wrYO8CEs4lc= go.uber.org/zap v1.20.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -186,7 +173,6 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -244,10 +230,7 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191107010934-f79515f33823/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -290,8 +273,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -307,4 +290,3 @@ gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/client/option_test.go b/client/option_test.go index b3d044bbd1b..2a7f7824e12 100644 --- a/client/option_test.go +++ b/client/option_test.go @@ -15,43 +15,41 @@ package pd import ( + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/client/testutil" ) -var _ = Suite(&testClientOptionSuite{}) - -type testClientOptionSuite struct{} - -func (s *testClientSuite) TestDynamicOptionChange(c *C) { +func TestDynamicOptionChange(t *testing.T) { + re := require.New(t) o := newOption() // Check the default value setting. - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, defaultMaxTSOBatchWaitInterval) - c.Assert(o.getEnableTSOFollowerProxy(), Equals, defaultEnableTSOFollowerProxy) + re.Equal(defaultMaxTSOBatchWaitInterval, o.getMaxTSOBatchWaitInterval()) + re.Equal(defaultEnableTSOFollowerProxy, o.getEnableTSOFollowerProxy()) // Check the invalid value setting. - c.Assert(o.setMaxTSOBatchWaitInterval(time.Second), NotNil) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, defaultMaxTSOBatchWaitInterval) + re.NotNil(o.setMaxTSOBatchWaitInterval(time.Second)) + re.Equal(defaultMaxTSOBatchWaitInterval, o.getMaxTSOBatchWaitInterval()) expectInterval := time.Millisecond o.setMaxTSOBatchWaitInterval(expectInterval) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, expectInterval) + re.Equal(expectInterval, o.getMaxTSOBatchWaitInterval()) expectInterval = time.Duration(float64(time.Millisecond) * 0.5) o.setMaxTSOBatchWaitInterval(expectInterval) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, expectInterval) + re.Equal(expectInterval, o.getMaxTSOBatchWaitInterval()) expectInterval = time.Duration(float64(time.Millisecond) * 1.5) o.setMaxTSOBatchWaitInterval(expectInterval) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, expectInterval) + re.Equal(expectInterval, o.getMaxTSOBatchWaitInterval()) expectBool := true o.setEnableTSOFollowerProxy(expectBool) // Check the value changing notification. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntil(t, func() bool { <-o.enableTSOFollowerProxyCh return true }) - c.Assert(o.getEnableTSOFollowerProxy(), Equals, expectBool) + re.Equal(expectBool, o.getEnableTSOFollowerProxy()) // Check whether any data will be sent to the channel. // It will panic if the test fails. close(o.enableTSOFollowerProxyCh) diff --git a/client/testutil/testutil.go b/client/testutil/testutil.go index 3627566ecfb..095a31ae74a 100644 --- a/client/testutil/testutil.go +++ b/client/testutil/testutil.go @@ -15,9 +15,8 @@ package testutil import ( + "testing" "time" - - "github.com/pingcap/check" ) const ( @@ -45,8 +44,8 @@ func WithSleepInterval(sleep time.Duration) WaitOption { } // WaitUntil repeatedly evaluates f() for a period of time, util it returns true. -func WaitUntil(c *check.C, f func() bool, opts ...WaitOption) { - c.Log("wait start") +func WaitUntil(t *testing.T, f func() bool, opts ...WaitOption) { + t.Log("wait start") option := &WaitOp{ retryTimes: waitMaxRetry, sleepInterval: waitRetrySleep, @@ -60,5 +59,5 @@ func WaitUntil(c *check.C, f func() bool, opts ...WaitOption) { } time.Sleep(option.sleepInterval) } - c.Fatal("wait timeout") + t.Fatal("wait timeout") } diff --git a/metrics/grafana/pd.json b/metrics/grafana/pd.json index 32747d85965..1a35f91bddf 100644 --- a/metrics/grafana/pd.json +++ b/metrics/grafana/pd.json @@ -7187,6 +7187,194 @@ "align": false, "alignLevel": null } + }, { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The inner status of balance Hot Region scheduler", + "fill": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 51 + }, + "id": 1458, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "sort": "current", + "sortDesc": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 2, + "links": [], + "nullPointMode": "null", + "paceLength": 10, + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(pd_scheduler_event_count{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=\"$instance\", type=\"balance-hot-region-scheduler\"}[5m])) by (name)", + "format": "time_series", + "intervalFactor": 2, + "legendFormat": "{{name}}", + "metric": "pd_scheduler_event_count", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Balance Hot Region scheduler", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "ops", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + },{ + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The inner status of split bucket scheduler", + "fill": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 51 + }, + "id": 1459, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "sort": "current", + "sortDesc": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 2, + "links": [], + "nullPointMode": "null", + "paceLength": 10, + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(pd_scheduler_event_count{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=\"$instance\", type=\"split-bucket-scheduler\"}[5m])) by (name)", + "format": "time_series", + "intervalFactor": 2, + "legendFormat": "{{name}}", + "metric": "pd_scheduler_event_count", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Split Bucket scheduler", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "ops", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } }, { "aliasColors": {}, @@ -7199,7 +7387,7 @@ "h": 8, "w": 24, "x": 0, - "y": 51 + "y": 59 }, "id": 108, "legend": { @@ -7304,7 +7492,7 @@ "h": 8, "w": 12, "x": 0, - "y": 59 + "y": 67 }, "id": 1424, "interval": null, @@ -7377,7 +7565,7 @@ "h": 8, "w": 12, "x": 12, - "y": 59 + "y": 67 }, "id": 141, "legend": { @@ -7469,7 +7657,7 @@ "h": 8, "w": 12, "x": 0, - "y": 67 + "y": 75 }, "id": 70, "legend": { @@ -7561,7 +7749,7 @@ "h": 8, "w": 12, "x": 12, - "y": 67 + "y": 75 }, "id": 71, "legend": { @@ -7652,7 +7840,7 @@ "h": 8, "w": 12, "x": 0, - "y": 75 + "y": 83 }, "id": 109, "legend": { @@ -7746,7 +7934,7 @@ "h": 8, "w": 12, "x": 12, - "y": 75 + "y": 83 }, "id": 110, "legend": { @@ -10616,7 +10804,7 @@ "h": 8, "w": 12, "x": 12, - "y": 55 + "y": 47 }, "id": 1403, "legend": { @@ -10697,6 +10885,198 @@ "align": false, "alignLevel": null } + }, { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The Interval of TIkv bucket report interval", + "editable": true, + "error": false, + "fill": 0, + "grid": {}, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 55 + }, + "id": 1451, + "legend": { + "alignAsTable": true, + "avg": true, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null as zero", + "paceLength": 10, + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "histogram_quantile(0.99, sum(rate(pd_server_bucket_report_interval_seconds_bucket{k8s_cluster=\"$k8s_cluster\", tidb_cluster=~\"$tidb_cluster.*\", store=~\"$store\"}[1m])) by (address, store, le))", + "format": "time_series", + "hide": false, + "intervalFactor": 2, + "legendFormat": "{{address}}-store-{{store}}", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "99% Bucket Report Interval", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "s", + "label": null, + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "s", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + },{ + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The State of Bucket Report", + "editable": true, + "error": false, + "fill": 0, + "grid": {}, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 63 + }, + "id": 1452, + "legend": { + "alignAsTable": true, + "avg": true, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null as zero", + "paceLength": 10, + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(pd_server_bucket_report{k8s_cluster=\"$k8s_cluster\", tidb_cluster=~\"$tidb_cluster.*\", store=~\"$store\", instance=\"$instance\"}[1m])) by (address, store, type,status)", + "format": "time_series", + "hide": false, + "intervalFactor": 2, + "legendFormat": "{{address}}-store-{{store}}", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Bucket Report State", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "opm", + "label": null, + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "s", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "repeat": null, @@ -11072,6 +11452,102 @@ "title": "Region Heartbeat Interval", "transparent": true, "type": "bargauge" + },{ + "datasource": "${DS_TEST-CLUSTER}", + "fieldConfig": { + "defaults": { + "custom": { + "align": null + }, + "mappings": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 48 + }, + "id": 1454, + "interval": "", + "options": { + "displayMode": "lcd", + "orientation": "horizontal", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "showUnfilled": true + }, + "pluginVersion": "7.1.5", + "repeatDirection": "h", + "targets": [ + { + "expr": "sum(delta(pd_server_bucket_report_interval_seconds_bucket{k8s_cluster=\"$k8s_cluster\", tidb_cluster=~\"$tidb_cluster.*\", instance=~\"$instance\"}[1m])) by (le)", + "format": "heatmap", + "hide": false, + "interval": "", + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Bucket Report Interval", + "transparent": true, + "type": "bargauge" + },{ + "datasource": "${DS_TEST-CLUSTER}", + "fieldConfig": { + "defaults": { + "custom": { + "align": null + }, + "mappings": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 48 + }, + "id": 1455, + "interval": "", + "options": { + "displayMode": "lcd", + "orientation": "horizontal", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "showUnfilled": true + }, + "pluginVersion": "7.1.5", + "repeatDirection": "h", + "targets": [ + { + "expr": "sum(delta(pd_scheduler_buckets_hot_degree_hist_bucket{k8s_cluster=\"$k8s_cluster\", tidb_cluster=~\"$tidb_cluster.*\", instance=~\"$instance\"}[1m])) by (le)", + "format": "heatmap", + "hide": false, + "interval": "", + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Hot Degree of Bucket", + "transparent": true, + "type": "bargauge" } ], "title": "Heartbeat distribution ", diff --git a/pkg/mock/mockhbstream/mockhbstream_test.go b/pkg/mock/mockhbstream/mockhbstream_test.go index e6d05f19d1b..5f9d814835b 100644 --- a/pkg/mock/mockhbstream/mockhbstream_test.go +++ b/pkg/mock/mockhbstream/mockhbstream_test.go @@ -19,40 +19,22 @@ import ( "testing" "github.com/gogo/protobuf/proto" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/eraftpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/server/schedule/hbstream" ) -func TestHeaertbeatStreams(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testHeartbeatStreamSuite{}) - -type testHeartbeatStreamSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testHeartbeatStreamSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testHeartbeatStreamSuite) TearDownTest(c *C) { - s.cancel() -} - -func (s *testHeartbeatStreamSuite) TestActivity(c *C) { +func TestActivity(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) cluster.AddRegionStore(1, 1) cluster.AddRegionStore(2, 0) cluster.AddLeaderRegion(1, 1) @@ -66,24 +48,24 @@ func (s *testHeartbeatStreamSuite) TestActivity(c *C) { // Active stream is stream1. hbs.BindStream(1, stream1) - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { hbs.SendMsg(region, proto.Clone(msg).(*pdpb.RegionHeartbeatResponse)) return stream1.Recv() != nil && stream2.Recv() == nil }) // Rebind to stream2. hbs.BindStream(1, stream2) - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { hbs.SendMsg(region, proto.Clone(msg).(*pdpb.RegionHeartbeatResponse)) return stream1.Recv() == nil && stream2.Recv() != nil }) // SendErr to stream2. hbs.SendErr(pdpb.ErrorType_UNKNOWN, "test error", &metapb.Peer{Id: 1, StoreId: 1}) res := stream2.Recv() - c.Assert(res, NotNil) - c.Assert(res.GetHeader().GetError(), NotNil) + re.NotNil(res) + re.NotNil(res.GetHeader().GetError()) // Switch back to 1 again. hbs.BindStream(1, stream1) - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { hbs.SendMsg(region, proto.Clone(msg).(*pdpb.RegionHeartbeatResponse)) return stream1.Recv() != nil && stream2.Recv() == nil }) diff --git a/pkg/movingaverage/avg_over_time_test.go b/pkg/movingaverage/avg_over_time_test.go index 74e54974656..9006fea5d5d 100644 --- a/pkg/movingaverage/avg_over_time_test.go +++ b/pkg/movingaverage/avg_over_time_test.go @@ -16,16 +16,14 @@ package movingaverage import ( "math/rand" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testAvgOverTimeSuite{}) - -type testAvgOverTimeSuite struct{} - -func (t *testAvgOverTimeSuite) TestPulse(c *C) { +func TestPulse(t *testing.T) { + re := require.New(t) aot := NewAvgOverTime(5 * time.Second) // warm up for i := 0; i < 5; i++ { @@ -38,27 +36,28 @@ func (t *testAvgOverTimeSuite) TestPulse(c *C) { } else { aot.Add(0, time.Second) } - c.Assert(aot.Get(), LessEqual, 600.) - c.Assert(aot.Get(), GreaterEqual, 400.) + re.LessOrEqual(aot.Get(), 600.) + re.GreaterOrEqual(aot.Get(), 400.) } } -func (t *testAvgOverTimeSuite) TestChange(c *C) { +func TestChange(t *testing.T) { + re := require.New(t) aot := NewAvgOverTime(5 * time.Second) // phase 1: 1000 for i := 0; i < 20; i++ { aot.Add(1000, time.Second) } - c.Assert(aot.Get(), LessEqual, 1010.) - c.Assert(aot.Get(), GreaterEqual, 990.) + re.LessOrEqual(aot.Get(), 1010.) + re.GreaterOrEqual(aot.Get(), 990.) // phase 2: 500 for i := 0; i < 5; i++ { aot.Add(500, time.Second) } - c.Assert(aot.Get(), LessEqual, 900.) - c.Assert(aot.Get(), GreaterEqual, 495.) + re.LessOrEqual(aot.Get(), 900.) + re.GreaterOrEqual(aot.Get(), 495.) for i := 0; i < 15; i++ { aot.Add(500, time.Second) } @@ -67,32 +66,34 @@ func (t *testAvgOverTimeSuite) TestChange(c *C) { for i := 0; i < 5; i++ { aot.Add(100, time.Second) } - c.Assert(aot.Get(), LessEqual, 678.) - c.Assert(aot.Get(), GreaterEqual, 99.) + re.LessOrEqual(aot.Get(), 678.) + re.GreaterOrEqual(aot.Get(), 99.) // clear aot.Set(10) - c.Assert(aot.Get(), Equals, 10.) + re.Equal(10., aot.Get()) } -func (t *testAvgOverTimeSuite) TestMinFilled(c *C) { +func TestMinFilled(t *testing.T) { + re := require.New(t) interval := 10 * time.Second rate := 1.0 for aotSize := 2; aotSize < 10; aotSize++ { for mfSize := 2; mfSize < 10; mfSize++ { tm := NewTimeMedian(aotSize, mfSize, interval) for i := 0; i < tm.GetFilledPeriod(); i++ { - c.Assert(tm.Get(), Equals, 0.0) + re.Equal(0.0, tm.Get()) tm.Add(rate*interval.Seconds(), interval) } - c.Assert(tm.Get(), Equals, rate) + re.Equal(rate, tm.Get()) } } } -func (t *testAvgOverTimeSuite) TestUnstableInterval(c *C) { +func TestUnstableInterval(t *testing.T) { + re := require.New(t) aot := NewAvgOverTime(5 * time.Second) - c.Assert(aot.Get(), Equals, 0.) + re.Equal(0., aot.Get()) // warm up for i := 0; i < 5; i++ { aot.Add(1000, time.Second) @@ -101,8 +102,8 @@ func (t *testAvgOverTimeSuite) TestUnstableInterval(c *C) { for i := 0; i < 1000; i++ { r := float64(rand.Intn(5)) aot.Add(1000*r, time.Second*time.Duration(r)) - c.Assert(aot.Get(), LessEqual, 1010.) - c.Assert(aot.Get(), GreaterEqual, 990.) + re.LessOrEqual(aot.Get(), 1010.) + re.GreaterOrEqual(aot.Get(), 990.) } // warm up for i := 0; i < 5; i++ { @@ -112,7 +113,7 @@ func (t *testAvgOverTimeSuite) TestUnstableInterval(c *C) { for i := 0; i < 1000; i++ { rate := float64(i%5*100) + 500 aot.Add(rate*3, time.Second*3) - c.Assert(aot.Get(), LessEqual, 910.) - c.Assert(aot.Get(), GreaterEqual, 490.) + re.LessOrEqual(aot.Get(), 910.) + re.GreaterOrEqual(aot.Get(), 490.) } } diff --git a/pkg/movingaverage/max_filter_test.go b/pkg/movingaverage/max_filter_test.go index 5651bbb4b8d..7d3906ec93c 100644 --- a/pkg/movingaverage/max_filter_test.go +++ b/pkg/movingaverage/max_filter_test.go @@ -15,22 +15,21 @@ package movingaverage import ( - . "github.com/pingcap/check" -) - -var _ = Suite(&testMaxFilter{}) + "testing" -type testMaxFilter struct{} + "github.com/stretchr/testify/require" +) -func (t *testMaxFilter) TestMaxFilter(c *C) { +func TestMaxFilter(t *testing.T) { + re := require.New(t) var empty float64 = 0 data := []float64{2, 1, 3, 4, 1, 1, 3, 3, 2, 0, 5} expected := []float64{2, 2, 3, 4, 4, 4, 4, 4, 3, 3, 5} mf := NewMaxFilter(5) - c.Assert(mf.Get(), Equals, empty) + re.Equal(empty, mf.Get()) - checkReset(c, mf, empty) - checkAdd(c, mf, data, expected) - checkSet(c, mf, data, expected) + checkReset(re, mf, empty) + checkAdd(re, mf, data, expected) + checkSet(re, mf, data, expected) } diff --git a/pkg/movingaverage/moving_average_test.go b/pkg/movingaverage/moving_average_test.go index 8ef6d89a670..e54aa70b64a 100644 --- a/pkg/movingaverage/moving_average_test.go +++ b/pkg/movingaverage/moving_average_test.go @@ -20,17 +20,9 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testMovingAvg{}) - -type testMovingAvg struct{} - func addRandData(ma MovingAvg, n int, mx float64) { rand.Seed(time.Now().UnixNano()) for i := 0; i < n; i++ { @@ -40,55 +32,56 @@ func addRandData(ma MovingAvg, n int, mx float64) { // checkReset checks the Reset works properly. // emptyValue is the moving average of empty data set. -func checkReset(c *C, ma MovingAvg, emptyValue float64) { +func checkReset(re *require.Assertions, ma MovingAvg, emptyValue float64) { addRandData(ma, 100, 1000) ma.Reset() - c.Assert(ma.Get(), Equals, emptyValue) + re.Equal(emptyValue, ma.Get()) } // checkAddGet checks Add works properly. -func checkAdd(c *C, ma MovingAvg, data []float64, expected []float64) { - c.Assert(len(data), Equals, len(expected)) +func checkAdd(re *require.Assertions, ma MovingAvg, data []float64, expected []float64) { + re.Equal(len(expected), len(data)) for i, x := range data { ma.Add(x) - c.Assert(math.Abs(ma.Get()-expected[i]), LessEqual, 1e-7) + re.LessOrEqual(math.Abs(ma.Get()-expected[i]), 1e-7) } } // checkSet checks Set = Reset + Add -func checkSet(c *C, ma MovingAvg, data []float64, expected []float64) { - c.Assert(len(data), Equals, len(expected)) +func checkSet(re *require.Assertions, ma MovingAvg, data []float64, expected []float64) { + re.Equal(len(expected), len(data)) // Reset + Add addRandData(ma, 100, 1000) ma.Reset() - checkAdd(c, ma, data, expected) + checkAdd(re, ma, data, expected) // Set addRandData(ma, 100, 1000) ma.Set(data[0]) - c.Assert(ma.Get(), Equals, expected[0]) - checkAdd(c, ma, data[1:], expected[1:]) + re.Equal(expected[0], ma.Get()) + checkAdd(re, ma, data[1:], expected[1:]) } // checkInstantaneous checks GetInstantaneous -func checkInstantaneous(c *C, ma MovingAvg) { +func checkInstantaneous(re *require.Assertions, ma MovingAvg) { value := 100.000000 ma.Add(value) - c.Assert(ma.GetInstantaneous(), Equals, value) + re.Equal(value, ma.GetInstantaneous()) } -func (t *testMovingAvg) TestMedianFilter(c *C) { +func TestMedianFilter(t *testing.T) { + re := require.New(t) var empty float64 = 0 data := []float64{2, 4, 2, 800, 600, 6, 3} expected := []float64{2, 3, 2, 3, 4, 6, 6} mf := NewMedianFilter(5) - c.Assert(mf.Get(), Equals, empty) + re.Equal(empty, mf.Get()) - checkReset(c, mf, empty) - checkAdd(c, mf, data, expected) - checkSet(c, mf, data, expected) + checkReset(re, mf, empty) + checkAdd(re, mf, data, expected) + checkSet(re, mf, data, expected) } type testCase struct { @@ -96,7 +89,8 @@ type testCase struct { expected []float64 } -func (t *testMovingAvg) TestMovingAvg(c *C) { +func TestMovingAvg(t *testing.T) { + re := require.New(t) var empty float64 = 0 data := []float64{1, 1, 1, 1, 5, 1, 1, 1} testCases := []testCase{{ @@ -116,11 +110,11 @@ func (t *testMovingAvg) TestMovingAvg(c *C) { expected: []float64{1.000000, 1.000000, 1.000000, 1.000000, 5.000000, 5.000000, 5.000000, 5.000000}, }, } - for _, test := range testCases { - c.Assert(test.ma.Get(), Equals, empty) - checkReset(c, test.ma, empty) - checkAdd(c, test.ma, data, test.expected) - checkSet(c, test.ma, data, test.expected) - checkInstantaneous(c, test.ma) + for _, testCase := range testCases { + re.Equal(empty, testCase.ma.Get()) + checkReset(re, testCase.ma, empty) + checkAdd(re, testCase.ma, data, testCase.expected) + checkSet(re, testCase.ma, data, testCase.expected) + checkInstantaneous(re, testCase.ma) } } diff --git a/pkg/movingaverage/queue_test.go b/pkg/movingaverage/queue_test.go index 90769bb1249..56c2337c9a1 100644 --- a/pkg/movingaverage/queue_test.go +++ b/pkg/movingaverage/queue_test.go @@ -15,26 +15,30 @@ package movingaverage import ( - . "github.com/pingcap/check" + "testing" + + "github.com/stretchr/testify/require" ) -func (t *testMovingAvg) TestQueue(c *C) { +func TestQueue(t *testing.T) { + re := require.New(t) sq := NewSafeQueue() sq.PushBack(1) sq.PushBack(2) v1 := sq.PopFront() v2 := sq.PopFront() - c.Assert(1, Equals, v1.(int)) - c.Assert(2, Equals, v2.(int)) + re.Equal(1, v1.(int)) + re.Equal(2, v2.(int)) } -func (t *testMovingAvg) TestClone(c *C) { +func TestClone(t *testing.T) { + re := require.New(t) s1 := NewSafeQueue() s1.PushBack(1) s1.PushBack(2) s2 := s1.Clone() s2.PopFront() s2.PopFront() - c.Assert(s1.que.Len(), Equals, 2) - c.Assert(s2.que.Len(), Equals, 0) + re.Equal(2, s1.que.Len()) + re.Equal(0, s2.que.Len()) } diff --git a/pkg/netutil/address_test.go b/pkg/netutil/address_test.go index 8c93be2a124..477f794c243 100644 --- a/pkg/netutil/address_test.go +++ b/pkg/netutil/address_test.go @@ -18,18 +18,11 @@ import ( "net/http" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testNetSuite{}) - -type testNetSuite struct{} - -func (s *testNetSuite) TestResolveLoopBackAddr(c *C) { +func TestResolveLoopBackAddr(t *testing.T) { + re := require.New(t) nodes := []struct { address string backAddress string @@ -41,24 +34,25 @@ func (s *testNetSuite) TestResolveLoopBackAddr(c *C) { } for _, n := range nodes { - c.Assert(ResolveLoopBackAddr(n.address, n.backAddress), Equals, "192.168.130.22:2379") + re.Equal("192.168.130.22:2379", ResolveLoopBackAddr(n.address, n.backAddress)) } } -func (s *testNetSuite) TestIsEnableHttps(c *C) { - c.Assert(IsEnableHTTPS(http.DefaultClient), IsFalse) +func TestIsEnableHttps(t *testing.T) { + re := require.New(t) + re.False(IsEnableHTTPS(http.DefaultClient)) httpClient := &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, TLSClientConfig: nil, }, } - c.Assert(IsEnableHTTPS(httpClient), IsFalse) + re.False(IsEnableHTTPS(httpClient)) httpClient = &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, TLSClientConfig: &tls.Config{}, }, } - c.Assert(IsEnableHTTPS(httpClient), IsFalse) + re.False(IsEnableHTTPS(httpClient)) } diff --git a/pkg/progress/progress_test.go b/pkg/progress/progress_test.go index c4b030941f8..72d23c40a6a 100644 --- a/pkg/progress/progress_test.go +++ b/pkg/progress/progress_test.go @@ -20,82 +20,76 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testProgressSuite{}) - -type testProgressSuite struct{} - -func (s *testProgressSuite) Test(c *C) { +func TestProgress(t *testing.T) { + re := require.New(t) n := "test" m := NewManager() - c.Assert(m.AddProgress(n, 100, 100, 10*time.Second), IsFalse) + re.False(m.AddProgress(n, 100, 100, 10*time.Second)) p, ls, cs, err := m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.0) - c.Assert(ls, Equals, math.MaxFloat64) - c.Assert(cs, Equals, 0.0) + re.NoError(err) + re.Equal(0.0, p) + re.Equal(math.MaxFloat64, ls) + re.Equal(0.0, cs) time.Sleep(time.Second) - c.Assert(m.AddProgress(n, 100, 100, 10*time.Second), IsTrue) + re.True(m.AddProgress(n, 100, 100, 10*time.Second)) m.UpdateProgress(n, 30, 30, false) p, ls, cs, err = m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.7) + re.NoError(err) + re.Equal(0.7, p) // 30/(70/1s+) > 30/70 - c.Assert(ls, Greater, 30.0/70.0) + re.Greater(ls, 30.0/70.0) // 70/1s+ > 70 - c.Assert(cs, Less, 70.0) + re.Less(cs, 70.0) // there is no scheduling for i := 0; i < 100; i++ { m.UpdateProgress(n, 30, 30, false) } - c.Assert(m.progesses[n].history.Len(), Equals, 61) + re.Equal(61, m.progesses[n].history.Len()) p, ls, cs, err = m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.7) - c.Assert(ls, Equals, math.MaxFloat64) - c.Assert(cs, Equals, 0.0) + re.NoError(err) + re.Equal(0.7, p) + re.Equal(math.MaxFloat64, ls) + re.Equal(0.0, cs) ps := m.GetProgresses(func(p string) bool { return strings.Contains(p, n) }) - c.Assert(ps, HasLen, 1) - c.Assert(ps[0], Equals, n) + re.Len(ps, 1) + re.Equal(n, ps[0]) ps = m.GetProgresses(func(p string) bool { return strings.Contains(p, "a") }) - c.Assert(ps, HasLen, 0) - c.Assert(m.RemoveProgress(n), IsTrue) - c.Assert(m.RemoveProgress(n), IsFalse) + re.Len(ps, 0) + re.True(m.RemoveProgress(n)) + re.False(m.RemoveProgress(n)) } -func (s *testProgressSuite) TestAbnormal(c *C) { +func TestAbnormal(t *testing.T) { + re := require.New(t) n := "test" m := NewManager() - c.Assert(m.AddProgress(n, 100, 100, 10*time.Second), IsFalse) + re.False(m.AddProgress(n, 100, 100, 10*time.Second)) p, ls, cs, err := m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.0) - c.Assert(ls, Equals, math.MaxFloat64) - c.Assert(cs, Equals, 0.0) + re.NoError(err) + re.Equal(0.0, p) + re.Equal(math.MaxFloat64, ls) + re.Equal(0.0, cs) // When offline a store, but there are still many write operations m.UpdateProgress(n, 110, 110, false) p, ls, cs, err = m.Status(n) - c.Assert(err, IsNil) - c.Assert(p, Equals, 0.0) - c.Assert(ls, Equals, math.MaxFloat64) - c.Assert(cs, Equals, 0.0) + re.NoError(err) + re.Equal(0.0, p) + re.Equal(math.MaxFloat64, ls) + re.Equal(0.0, cs) // It usually won't happens m.UpdateProgressTotal(n, 10) p, ls, cs, err = m.Status(n) - c.Assert(err, NotNil) - c.Assert(p, Equals, 0.0) - c.Assert(ls, Equals, 0.0) - c.Assert(cs, Equals, 0.0) + re.Error(err) + re.Equal(0.0, p) + re.Equal(0.0, ls) + re.Equal(0.0, cs) } diff --git a/pkg/rangetree/range_tree_test.go b/pkg/rangetree/range_tree_test.go index d1e9cd79de5..695183f2f90 100644 --- a/pkg/rangetree/range_tree_test.go +++ b/pkg/rangetree/range_tree_test.go @@ -18,19 +18,10 @@ import ( "bytes" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/btree" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testRangeTreeSuite{}) - -type testRangeTreeSuite struct { -} - type simpleBucketItem struct { startKey []byte endKey []byte @@ -79,7 +70,7 @@ func bucketDebrisFactory(startKey, endKey []byte, item RangeItem) []RangeItem { left := maxKey(startKey, item.GetStartKey()) right := minKey(endKey, item.GetEndKey()) - // they have no intersection if they are neighbour like |010 - 100| and |100 - 200|. + // they have no intersection if they are neighbors like |010 - 100| and |100 - 200|. if bytes.Compare(left, right) >= 0 { return nil } @@ -94,52 +85,54 @@ func bucketDebrisFactory(startKey, endKey []byte, item RangeItem) []RangeItem { return res } -func (bs *testRangeTreeSuite) TestRingPutItem(c *C) { +func TestRingPutItem(t *testing.T) { + re := require.New(t) bucketTree := NewRangeTree(2, bucketDebrisFactory) bucketTree.Update(newSimpleBucketItem([]byte("002"), []byte("100"))) - c.Assert(bucketTree.Len(), Equals, 1) + re.Equal(1, bucketTree.Len()) bucketTree.Update(newSimpleBucketItem([]byte("100"), []byte("200"))) - c.Assert(bucketTree.Len(), Equals, 2) + re.Equal(2, bucketTree.Len()) // init key range: [002,100], [100,200] - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("002"))), HasLen, 0) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("009"))), HasLen, 1) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), HasLen, 1) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("110"))), HasLen, 2) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("200"), []byte("300"))), HasLen, 0) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("002"))), 0) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("000"), []byte("009"))), 1) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), 1) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("110"))), 2) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("200"), []byte("300"))), 0) // test1: insert one key range, the old overlaps will retain like split buckets. // key range: [002,010],[010,090],[090,100],[100,200] bucketTree.Update(newSimpleBucketItem([]byte("010"), []byte("090"))) - c.Assert(bucketTree.Len(), Equals, 4) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), HasLen, 1) + re.Equal(4, bucketTree.Len()) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), 1) // test2: insert one key range, the old overlaps will retain like merge . // key range: [001,080], [080,090],[090,100],[100,200] bucketTree.Update(newSimpleBucketItem([]byte("001"), []byte("080"))) - c.Assert(bucketTree.Len(), Equals, 4) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), HasLen, 2) + re.Equal(4, bucketTree.Len()) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), 2) // test2: insert one keyrange, the old overlaps will retain like merge . // key range: [001,120],[120,200] bucketTree.Update(newSimpleBucketItem([]byte("001"), []byte("120"))) - c.Assert(bucketTree.Len(), Equals, 2) - c.Assert(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), HasLen, 1) + re.Equal(2, bucketTree.Len()) + re.Len(bucketTree.GetOverlaps(newSimpleBucketItem([]byte("010"), []byte("090"))), 1) } -func (bs *testRangeTreeSuite) TestDebris(c *C) { +func TestDebris(t *testing.T) { + re := require.New(t) ringItem := newSimpleBucketItem([]byte("010"), []byte("090")) var overlaps []RangeItem overlaps = bucketDebrisFactory([]byte("000"), []byte("100"), ringItem) - c.Assert(overlaps, HasLen, 0) + re.Len(overlaps, 0) overlaps = bucketDebrisFactory([]byte("000"), []byte("080"), ringItem) - c.Assert(overlaps, HasLen, 1) + re.Len(overlaps, 1) overlaps = bucketDebrisFactory([]byte("020"), []byte("080"), ringItem) - c.Assert(overlaps, HasLen, 2) + re.Len(overlaps, 2) overlaps = bucketDebrisFactory([]byte("010"), []byte("090"), ringItem) - c.Assert(overlaps, HasLen, 0) + re.Len(overlaps, 0) overlaps = bucketDebrisFactory([]byte("010"), []byte("100"), ringItem) - c.Assert(overlaps, HasLen, 0) + re.Len(overlaps, 0) overlaps = bucketDebrisFactory([]byte("100"), []byte("200"), ringItem) - c.Assert(overlaps, HasLen, 0) + re.Len(overlaps, 0) } diff --git a/pkg/ratelimit/concurrency_limiter_test.go b/pkg/ratelimit/concurrency_limiter_test.go index 86dfda0eef6..6a2a5c80b9c 100644 --- a/pkg/ratelimit/concurrency_limiter_test.go +++ b/pkg/ratelimit/concurrency_limiter_test.go @@ -15,29 +15,24 @@ package ratelimit import ( - . "github.com/pingcap/check" -) - -var _ = Suite(&testConcurrencyLimiterSuite{}) - -type testConcurrencyLimiterSuite struct { -} + "testing" -func (s *testConcurrencyLimiterSuite) TestConcurrencyLimiter(c *C) { - c.Parallel() + "github.com/stretchr/testify/require" +) +func TestConcurrencyLimiter(t *testing.T) { + re := require.New(t) cl := newConcurrencyLimiter(10) - for i := 0; i < 10; i++ { - c.Assert(cl.allow(), Equals, true) + re.True(cl.allow()) } - c.Assert(cl.allow(), Equals, false) + re.False(cl.allow()) cl.release() - c.Assert(cl.allow(), Equals, true) - c.Assert(cl.getLimit(), Equals, uint64(10)) + re.True(cl.allow()) + re.Equal(uint64(10), cl.getLimit()) cl.setLimit(5) - c.Assert(cl.getLimit(), Equals, uint64(5)) - c.Assert(cl.getCurrent(), Equals, uint64(10)) + re.Equal(uint64(5), cl.getLimit()) + re.Equal(uint64(10), cl.getCurrent()) cl.release() - c.Assert(cl.getCurrent(), Equals, uint64(9)) + re.Equal(uint64(9), cl.getCurrent()) } diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index 43f01cea41b..4bf930ed6c5 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -20,6 +20,15 @@ import ( "golang.org/x/time/rate" ) +// DimensionConfig is the limit dimension config of one label +type DimensionConfig struct { + // qps conifg + QPS float64 + QPSBurst int + // concurrency config + ConcurrencyLimit uint64 +} + // Limiter is a controller for the request rate. type Limiter struct { qpsLimiter sync.Map @@ -30,7 +39,9 @@ type Limiter struct { // NewLimiter returns a global limiter which can be updated in the later. func NewLimiter() *Limiter { - return &Limiter{labelAllowList: make(map[string]struct{})} + return &Limiter{ + labelAllowList: make(map[string]struct{}), + } } // Allow is used to check whether it has enough token. @@ -65,10 +76,12 @@ func (l *Limiter) Release(label string) { } // Update is used to update Ratelimiter with Options -func (l *Limiter) Update(label string, opts ...Option) { +func (l *Limiter) Update(label string, opts ...Option) UpdateStatus { + var status UpdateStatus for _, opt := range opts { - opt(label, l) + status |= opt(label, l) } + return status } // GetQPSLimiterStatus returns the status of a given label's QPS limiter. @@ -80,8 +93,8 @@ func (l *Limiter) GetQPSLimiterStatus(label string) (limit rate.Limit, burst int return 0, 0 } -// DeleteQPSLimiter deletes QPS limiter of given label -func (l *Limiter) DeleteQPSLimiter(label string) { +// QPSUnlimit deletes QPS limiter of the given label +func (l *Limiter) QPSUnlimit(label string) { l.qpsLimiter.Delete(label) } @@ -94,8 +107,8 @@ func (l *Limiter) GetConcurrencyLimiterStatus(label string) (limit uint64, curre return 0, 0 } -// DeleteConcurrencyLimiter deletes concurrency limiter of given label -func (l *Limiter) DeleteConcurrencyLimiter(label string) { +// ConcurrencyUnlimit deletes concurrency limiter of the given label +func (l *Limiter) ConcurrencyUnlimit(label string) { l.concurrencyLimiter.Delete(label) } diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index bd095543a05..d1a570ccb35 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -16,146 +16,152 @@ package ratelimit import ( "sync" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "golang.org/x/time/rate" ) -var _ = Suite(&testRatelimiterSuite{}) - -type testRatelimiterSuite struct { -} - -func (s *testRatelimiterSuite) TestUpdateConcurrencyLimiter(c *C) { - c.Parallel() +func TestUpdateConcurrencyLimiter(t *testing.T) { + re := require.New(t) opts := []Option{UpdateConcurrencyLimiter(10)} limiter := NewLimiter() label := "test" - for _, opt := range opts { - opt(label, limiter) - } + status := limiter.Update(label, opts...) + re.True(status&ConcurrencyChanged != 0) var lock sync.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup for i := 0; i < 15; i++ { wg.Add(1) go func() { - CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) }() } wg.Wait() - c.Assert(failedCount, Equals, 5) - c.Assert(successCount, Equals, 10) + re.Equal(5, failedCount) + re.Equal(10, successCount) for i := 0; i < 10; i++ { limiter.Release(label) } limit, current := limiter.GetConcurrencyLimiterStatus(label) - c.Assert(limit, Equals, uint64(10)) - c.Assert(current, Equals, uint64(0)) + re.Equal(uint64(10), limit) + re.Equal(uint64(0), current) - limiter.Update(label, UpdateConcurrencyLimiter(5)) + status = limiter.Update(label, UpdateConcurrencyLimiter(10)) + re.True(status&ConcurrencyNoChange != 0) + + status = limiter.Update(label, UpdateConcurrencyLimiter(5)) + re.True(status&ConcurrencyChanged != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 10) - c.Assert(successCount, Equals, 5) + re.Equal(10, failedCount) + re.Equal(5, successCount) for i := 0; i < 5; i++ { limiter.Release(label) } - limiter.DeleteConcurrencyLimiter(label) + status = limiter.Update(label, UpdateConcurrencyLimiter(0)) + re.True(status&ConcurrencyDeleted != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 0) - c.Assert(successCount, Equals, 15) + re.Equal(0, failedCount) + re.Equal(15, successCount) limit, current = limiter.GetConcurrencyLimiterStatus(label) - c.Assert(limit, Equals, uint64(0)) - c.Assert(current, Equals, uint64(0)) + re.Equal(uint64(0), limit) + re.Equal(uint64(0), current) } -func (s *testRatelimiterSuite) TestBlockList(c *C) { - c.Parallel() +func TestBlockList(t *testing.T) { + re := require.New(t) opts := []Option{AddLabelAllowList()} limiter := NewLimiter() label := "test" - c.Assert(limiter.IsInAllowList(label), Equals, false) + re.False(limiter.IsInAllowList(label)) for _, opt := range opts { opt(label, limiter) } - c.Assert(limiter.IsInAllowList(label), Equals, true) + re.True(limiter.IsInAllowList(label)) - UpdateQPSLimiter(rate.Every(time.Second), 1)(label, limiter) + status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) + re.True(status&InAllowList != 0) for i := 0; i < 10; i++ { - c.Assert(limiter.Allow(label), Equals, true) + re.True(limiter.Allow(label)) } } -func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { - c.Parallel() - opts := []Option{UpdateQPSLimiter(rate.Every(time.Second), 1)} +func TestUpdateQPSLimiter(t *testing.T) { + re := require.New(t) + opts := []Option{UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)} limiter := NewLimiter() label := "test" - for _, opt := range opts { - opt(label, limiter) - } + status := limiter.Update(label, opts...) + re.True(status&QPSChanged != 0) var lock sync.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup wg.Add(3) for i := 0; i < 3; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 2) - c.Assert(successCount, Equals, 1) + re.Equal(2, failedCount) + re.Equal(1, successCount) limit, burst := limiter.GetQPSLimiterStatus(label) - c.Assert(limit, Equals, rate.Limit(1)) - c.Assert(burst, Equals, 1) + re.Equal(rate.Limit(1), limit) + re.Equal(1, burst) + + status = limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)) + re.True(status&QPSNoChange != 0) - limiter.Update(label, UpdateQPSLimiter(5, 5)) + status = limiter.Update(label, UpdateQPSLimiter(5, 5)) + re.True(status&QPSChanged != 0) limit, burst = limiter.GetQPSLimiterStatus(label) - c.Assert(limit, Equals, rate.Limit(5)) - c.Assert(burst, Equals, 5) + re.Equal(rate.Limit(5), limit) + re.Equal(5, burst) time.Sleep(time.Second) for i := 0; i < 10; i++ { if i < 5 { - c.Assert(limiter.Allow(label), Equals, true) + re.True(limiter.Allow(label)) } else { - c.Assert(limiter.Allow(label), Equals, false) + re.False(limiter.Allow(label)) } } time.Sleep(time.Second) - limiter.DeleteQPSLimiter(label) + + status = limiter.Update(label, UpdateQPSLimiter(0, 0)) + re.True(status&QPSDeleted != 0) for i := 0; i < 10; i++ { - c.Assert(limiter.Allow(label), Equals, true) + re.True(limiter.Allow(label)) } qLimit, qCurrent := limiter.GetQPSLimiterStatus(label) - c.Assert(qLimit, Equals, rate.Limit(0)) - c.Assert(qCurrent, Equals, 0) + re.Equal(rate.Limit(0), qLimit) + re.Equal(0, qCurrent) } -func (s *testRatelimiterSuite) TestQPSLimiter(c *C) { - c.Parallel() - opts := []Option{UpdateQPSLimiter(rate.Every(3*time.Second), 100)} +func TestQPSLimiter(t *testing.T) { + re := require.New(t) + opts := []Option{UpdateQPSLimiter(float64(rate.Every(3*time.Second)), 100)} limiter := NewLimiter() label := "test" @@ -168,25 +174,28 @@ func (s *testRatelimiterSuite) TestQPSLimiter(c *C) { var wg sync.WaitGroup wg.Add(200) for i := 0; i < 200; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount+successCount, Equals, 200) - c.Assert(failedCount, Equals, 100) - c.Assert(successCount, Equals, 100) + re.Equal(200, failedCount+successCount) + re.Equal(100, failedCount) + re.Equal(100, successCount) time.Sleep(4 * time.Second) // 3+1 wg.Add(1) - CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) wg.Wait() - c.Assert(successCount, Equals, 101) + re.Equal(101, successCount) } -func (s *testRatelimiterSuite) TestTwoLimiters(c *C) { - c.Parallel() - opts := []Option{UpdateQPSLimiter(100, 100), - UpdateConcurrencyLimiter(100), +func TestTwoLimiters(t *testing.T) { + re := require.New(t) + cfg := &DimensionConfig{ + QPS: 100, + QPSBurst: 100, + ConcurrencyLimit: 100, } + opts := []Option{UpdateDimensionConfig(cfg)} limiter := NewLimiter() label := "test" @@ -199,38 +208,38 @@ func (s *testRatelimiterSuite) TestTwoLimiters(c *C) { var wg sync.WaitGroup wg.Add(200) for i := 0; i < 200; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 100) - c.Assert(successCount, Equals, 100) + re.Equal(100, failedCount) + re.Equal(100, successCount) time.Sleep(1 * time.Second) wg.Add(100) for i := 0; i < 100; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(failedCount, Equals, 200) - c.Assert(successCount, Equals, 100) + re.Equal(200, failedCount) + re.Equal(100, successCount) for i := 0; i < 100; i++ { limiter.Release(label) } - limiter.Update(label, UpdateQPSLimiter(rate.Every(10*time.Second), 1)) + limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(10*time.Second)), 1)) wg.Add(100) for i := 0; i < 100; i++ { - go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) } wg.Wait() - c.Assert(successCount, Equals, 101) - c.Assert(failedCount, Equals, 299) + re.Equal(101, successCount) + re.Equal(299, failedCount) limit, current := limiter.GetConcurrencyLimiterStatus(label) - c.Assert(limit, Equals, uint64(100)) - c.Assert(current, Equals, uint64(1)) + re.Equal(uint64(100), limit) + re.Equal(uint64(1), current) } -func CountRateLimiterHandleResult(limiter *Limiter, label string, successCount *int, +func countRateLimiterHandleResult(limiter *Limiter, label string, successCount *int, failedCount *int, lock *sync.Mutex, wg *sync.WaitGroup) { result := limiter.Allow(label) lock.Lock() diff --git a/pkg/ratelimit/option.go b/pkg/ratelimit/option.go index af98eddb827..53afb9926d4 100644 --- a/pkg/ratelimit/option.go +++ b/pkg/ratelimit/option.go @@ -16,39 +16,101 @@ package ratelimit import "golang.org/x/time/rate" +// UpdateStatus is flags for updating limiter config. +type UpdateStatus uint32 + +// Flags for limiter. +const ( + eps float64 = 1e-8 + // QPSNoChange shows that limiter's config isn't changed. + QPSNoChange UpdateStatus = 1 << iota + // QPSChanged shows that limiter's config is changed and not deleted. + QPSChanged + // QPSDeleted shows that limiter's config is deleted. + QPSDeleted + // ConcurrencyNoChange shows that limiter's config isn't changed. + ConcurrencyNoChange + // ConcurrencyChanged shows that limiter's config is changed and not deleted. + ConcurrencyChanged + // ConcurrencyDeleted shows that limiter's config is deleted. + ConcurrencyDeleted + // InAllowList shows that limiter's config isn't changed because it is in in allow list. + InAllowList +) + // Option is used to create a limiter with the optional settings. // these setting is used to add a kind of limiter for a service -type Option func(string, *Limiter) +type Option func(string, *Limiter) UpdateStatus // AddLabelAllowList adds a label into allow list. // It means the given label will not be limited func AddLabelAllowList() Option { - return func(label string, l *Limiter) { + return func(label string, l *Limiter) UpdateStatus { l.labelAllowList[label] = struct{}{} + return 0 + } +} + +func updateConcurrencyConfig(l *Limiter, label string, limit uint64) UpdateStatus { + oldConcurrencyLimit, _ := l.GetConcurrencyLimiterStatus(label) + if oldConcurrencyLimit == limit { + return ConcurrencyNoChange + } + if limit < 1 { + l.ConcurrencyUnlimit(label) + return ConcurrencyDeleted + } + if limiter, exist := l.concurrencyLimiter.LoadOrStore(label, newConcurrencyLimiter(limit)); exist { + limiter.(*concurrencyLimiter).setLimit(limit) + } + return ConcurrencyChanged +} + +func updateQPSConfig(l *Limiter, label string, limit float64, burst int) UpdateStatus { + oldQPSLimit, oldBurst := l.GetQPSLimiterStatus(label) + + if (float64(oldQPSLimit)-limit < eps && float64(oldQPSLimit)-limit > -eps) && oldBurst == burst { + return QPSNoChange + } + if limit <= eps || burst < 1 { + l.QPSUnlimit(label) + return QPSDeleted } + if limiter, exist := l.qpsLimiter.LoadOrStore(label, NewRateLimiter(limit, burst)); exist { + limiter.(*RateLimiter).SetLimit(rate.Limit(limit)) + limiter.(*RateLimiter).SetBurst(burst) + } + return QPSChanged } // UpdateConcurrencyLimiter creates a concurrency limiter for a given label if it doesn't exist. func UpdateConcurrencyLimiter(limit uint64) Option { - return func(label string, l *Limiter) { + return func(label string, l *Limiter) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { - return - } - if limiter, exist := l.concurrencyLimiter.LoadOrStore(label, newConcurrencyLimiter(limit)); exist { - limiter.(*concurrencyLimiter).setLimit(limit) + return InAllowList } + return updateConcurrencyConfig(l, label, limit) } } // UpdateQPSLimiter creates a QPS limiter for a given label if it doesn't exist. -func UpdateQPSLimiter(limit rate.Limit, burst int) Option { - return func(label string, l *Limiter) { +func UpdateQPSLimiter(limit float64, burst int) Option { + return func(label string, l *Limiter) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { - return + return InAllowList } - if limiter, exist := l.qpsLimiter.LoadOrStore(label, NewRateLimiter(float64(limit), burst)); exist { - limiter.(*RateLimiter).SetLimit(limit) - limiter.(*RateLimiter).SetBurst(burst) + return updateQPSConfig(l, label, limit, burst) + } +} + +// UpdateDimensionConfig creates QPS limiter and concurrency limiter for a given label by config if it doesn't exist. +func UpdateDimensionConfig(cfg *DimensionConfig) Option { + return func(label string, l *Limiter) UpdateStatus { + if _, allow := l.labelAllowList[label]; allow { + return InAllowList } + status := updateQPSConfig(l, label, cfg.QPS, cfg.QPSBurst) + status |= updateConcurrencyConfig(l, label, cfg.ConcurrencyLimit) + return status } } diff --git a/pkg/ratelimit/ratelimiter.go b/pkg/ratelimit/ratelimiter.go index e15c858009e..b2b6e3a036a 100644 --- a/pkg/ratelimit/ratelimiter.go +++ b/pkg/ratelimit/ratelimiter.go @@ -15,6 +15,7 @@ package ratelimit import ( + "context" "time" "github.com/tikv/pd/pkg/syncutil" @@ -25,14 +26,14 @@ import ( // It implements `Available` function which is not included in `golang.org/x/time/rate`. // Note: AvailableN will increase the wait time of WaitN. type RateLimiter struct { - mu syncutil.Mutex - *rate.Limiter + mu syncutil.Mutex + limiter *rate.Limiter } // NewRateLimiter returns a new Limiter that allows events up to rate r (it means limiter refill r token per second) // and permits bursts of at most b tokens. func NewRateLimiter(r float64, b int) *RateLimiter { - return &RateLimiter{Limiter: rate.NewLimiter(rate.Limit(r), b)} + return &RateLimiter{limiter: rate.NewLimiter(rate.Limit(r), b)} } // Available returns whether limiter has enough tokens. @@ -41,7 +42,7 @@ func (l *RateLimiter) Available(n int) bool { l.mu.Lock() defer l.mu.Unlock() now := time.Now() - r := l.Limiter.ReserveN(now, n) + r := l.limiter.ReserveN(now, n) delay := r.DelayFrom(now) r.CancelAt(now) return delay == 0 @@ -57,5 +58,42 @@ func (l *RateLimiter) AllowN(n int) bool { l.mu.Lock() defer l.mu.Unlock() now := time.Now() - return l.Limiter.AllowN(now, n) + return l.limiter.AllowN(now, n) +} + +// SetBurst is shorthand for SetBurstAt(time.Now(), newBurst). +func (l *RateLimiter) SetBurst(burst int) { + l.mu.Lock() + defer l.mu.Unlock() + l.limiter.SetBurst(burst) +} + +// SetLimit is shorthand for SetLimitAt(time.Now(), newLimit). +func (l *RateLimiter) SetLimit(limit rate.Limit) { + l.mu.Lock() + defer l.mu.Unlock() + l.limiter.SetLimit(limit) +} + +// Limit returns the maximum overall event rate. +func (l *RateLimiter) Limit() rate.Limit { + return l.limiter.Limit() +} + +// Burst returns the maximum burst size. Burst is the maximum number of tokens +// that can be consumed in a single call to Allow, Reserve, or Wait, so higher +// Burst values allow more events to happen at once. +// A zero Burst allows no events, unless limit == Inf. +func (l *RateLimiter) Burst() int { + return l.limiter.Burst() +} + +// WaitN blocks until lim permits n events to happen. +// It returns an error if n exceeds the Limiter's burst size, the Context is +// canceled, or the expected wait time exceeds the Context's Deadline. +// The burst limit is ignored if the rate limit is Inf. +func (l *RateLimiter) WaitN(ctx context.Context, n int) error { + l.mu.Lock() + defer l.mu.Unlock() + return l.limiter.WaitN(ctx, n) } diff --git a/pkg/ratelimit/ratelimiter_test.go b/pkg/ratelimit/ratelimiter_test.go index ccc8d05090a..f16bb6a83d2 100644 --- a/pkg/ratelimit/ratelimiter_test.go +++ b/pkg/ratelimit/ratelimiter_test.go @@ -18,34 +18,24 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testRateLimiterSuite{}) - -type testRateLimiterSuite struct { -} - -func (s *testRateLimiterSuite) TestRateLimiter(c *C) { - c.Parallel() - +func TestRateLimiter(t *testing.T) { + re := require.New(t) limiter := NewRateLimiter(100, 100) - c.Assert(limiter.Available(1), Equals, true) + re.True(limiter.Available(1)) - c.Assert(limiter.AllowN(50), Equals, true) - c.Assert(limiter.Available(50), Equals, true) - c.Assert(limiter.Available(100), Equals, false) - c.Assert(limiter.Available(50), Equals, true) - c.Assert(limiter.AllowN(50), Equals, true) - c.Assert(limiter.Available(50), Equals, false) + re.True(limiter.AllowN(50)) + re.True(limiter.Available(50)) + re.False(limiter.Available(100)) + re.True(limiter.Available(50)) + re.True(limiter.AllowN(50)) + re.False(limiter.Available(50)) time.Sleep(time.Second) - c.Assert(limiter.Available(1), Equals, true) - c.Assert(limiter.AllowN(99), Equals, true) - c.Assert(limiter.Allow(), Equals, true) - c.Assert(limiter.Available(1), Equals, false) + re.True(limiter.Available(1)) + re.True(limiter.AllowN(99)) + re.True(limiter.Allow()) + re.False(limiter.Available(1)) } diff --git a/pkg/reflectutil/tag_test.go b/pkg/reflectutil/tag_test.go index d7ddccd22e8..8e8c4dc7754 100644 --- a/pkg/reflectutil/tag_test.go +++ b/pkg/reflectutil/tag_test.go @@ -18,7 +18,7 @@ import ( "reflect" "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) type testStruct1 struct { @@ -34,41 +34,36 @@ type testStruct3 struct { Enable bool `json:"enable,string"` } -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testTagSuite{}) - -type testTagSuite struct{} - -func (s *testTagSuite) TestFindJSONFullTagByChildTag(c *C) { +func TestFindJSONFullTagByChildTag(t *testing.T) { + re := require.New(t) key := "enable" result := FindJSONFullTagByChildTag(reflect.TypeOf(testStruct1{}), key) - c.Assert(result, Equals, "object.action.enable") + re.Equal("object.action.enable", result) key = "action" result = FindJSONFullTagByChildTag(reflect.TypeOf(testStruct1{}), key) - c.Assert(result, Equals, "object.action") + re.Equal("object.action", result) key = "disable" result = FindJSONFullTagByChildTag(reflect.TypeOf(testStruct1{}), key) - c.Assert(result, HasLen, 0) + re.Len(result, 0) } -func (s *testTagSuite) TestFindSameFieldByJSON(c *C) { +func TestFindSameFieldByJSON(t *testing.T) { + re := require.New(t) input := map[string]interface{}{ "name": "test2", } t2 := testStruct2{} - c.Assert(FindSameFieldByJSON(&t2, input), Equals, true) + re.True(FindSameFieldByJSON(&t2, input)) input = map[string]interface{}{ "enable": "test2", } - c.Assert(FindSameFieldByJSON(&t2, input), Equals, false) + re.False(FindSameFieldByJSON(&t2, input)) } -func (s *testTagSuite) TestFindFieldByJSONTag(c *C) { +func TestFindFieldByJSONTag(t *testing.T) { + re := require.New(t) t1 := testStruct1{} t2 := testStruct2{} t3 := testStruct3{} @@ -77,17 +72,17 @@ func (s *testTagSuite) TestFindFieldByJSONTag(c *C) { tags := []string{"object"} result := FindFieldByJSONTag(reflect.TypeOf(t1), tags) - c.Assert(result, Equals, type2) + re.Equal(type2, result) tags = []string{"object", "action"} result = FindFieldByJSONTag(reflect.TypeOf(t1), tags) - c.Assert(result, Equals, type3) + re.Equal(type3, result) tags = []string{"object", "name"} result = FindFieldByJSONTag(reflect.TypeOf(t1), tags) - c.Assert(result.Kind(), Equals, reflect.String) + re.Equal(reflect.String, result.Kind()) tags = []string{"object", "action", "enable"} result = FindFieldByJSONTag(reflect.TypeOf(t1), tags) - c.Assert(result.Kind(), Equals, reflect.Bool) + re.Equal(reflect.Bool, result.Kind()) } diff --git a/pkg/requestutil/context_test.go b/pkg/requestutil/context_test.go index 98560577183..fe93182d537 100644 --- a/pkg/requestutil/context_test.go +++ b/pkg/requestutil/context_test.go @@ -19,22 +19,14 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testRequestContextSuite{}) - -type testRequestContextSuite struct { -} - -func (s *testRequestContextSuite) TestRequestInfo(c *C) { +func TestRequestInfo(t *testing.T) { + re := require.New(t) ctx := context.Background() _, ok := RequestInfoFrom(ctx) - c.Assert(ok, Equals, false) + re.False(ok) timeNow := time.Now().Unix() ctx = WithRequestInfo(ctx, RequestInfo{ @@ -47,25 +39,26 @@ func (s *testRequestContextSuite) TestRequestInfo(c *C) { StartTimeStamp: timeNow, }) result, ok := RequestInfoFrom(ctx) - c.Assert(result, NotNil) - c.Assert(ok, Equals, true) - c.Assert(result.ServiceLabel, Equals, "test label") - c.Assert(result.Method, Equals, "POST") - c.Assert(result.Component, Equals, "pdctl") - c.Assert(result.IP, Equals, "localhost") - c.Assert(result.URLParam, Equals, "{\"id\"=1}") - c.Assert(result.BodyParam, Equals, "{\"state\"=\"Up\"}") - c.Assert(result.StartTimeStamp, Equals, timeNow) + re.NotNil(result) + re.True(ok) + re.Equal("test label", result.ServiceLabel) + re.Equal("POST", result.Method) + re.Equal("pdctl", result.Component) + re.Equal("localhost", result.IP) + re.Equal("{\"id\"=1}", result.URLParam) + re.Equal("{\"state\"=\"Up\"}", result.BodyParam) + re.Equal(timeNow, result.StartTimeStamp) } -func (s *testRequestContextSuite) TestEndTime(c *C) { +func TestEndTime(t *testing.T) { + re := require.New(t) ctx := context.Background() _, ok := EndTimeFrom(ctx) - c.Assert(ok, Equals, false) + re.False(ok) timeNow := time.Now().Unix() ctx = WithEndTime(ctx, timeNow) result, ok := EndTimeFrom(ctx) - c.Assert(result, NotNil) - c.Assert(ok, Equals, true) - c.Assert(result, Equals, timeNow) + re.NotNil(result) + re.True(ok) + re.Equal(timeNow, result) } diff --git a/pkg/slice/slice_test.go b/pkg/slice/slice_test.go index 6c7030b977e..809dd2c54b3 100644 --- a/pkg/slice/slice_test.go +++ b/pkg/slice/slice_test.go @@ -17,21 +17,13 @@ package slice_test import ( "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/slice" ) -func Test(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testSliceSuite{}) - -type testSliceSuite struct { -} - -func (s *testSliceSuite) Test(c *C) { - tests := []struct { +func TestSlice(t *testing.T) { + re := require.New(t) + testCases := []struct { a []int anyOf bool noneOf bool @@ -43,24 +35,25 @@ func (s *testSliceSuite) Test(c *C) { {[]int{2, 2, 4}, true, false, true}, } - for _, t := range tests { - even := func(i int) bool { return t.a[i]%2 == 0 } - c.Assert(slice.AnyOf(t.a, even), Equals, t.anyOf) - c.Assert(slice.NoneOf(t.a, even), Equals, t.noneOf) - c.Assert(slice.AllOf(t.a, even), Equals, t.allOf) + for _, testCase := range testCases { + even := func(i int) bool { return testCase.a[i]%2 == 0 } + re.Equal(testCase.anyOf, slice.AnyOf(testCase.a, even)) + re.Equal(testCase.noneOf, slice.NoneOf(testCase.a, even)) + re.Equal(testCase.allOf, slice.AllOf(testCase.a, even)) } } -func (s *testSliceSuite) TestSliceContains(c *C) { +func TestSliceContains(t *testing.T) { + re := require.New(t) ss := []string{"a", "b", "c"} - c.Assert(slice.Contains(ss, "a"), IsTrue) - c.Assert(slice.Contains(ss, "d"), IsFalse) + re.Contains(ss, "a") + re.NotContains(ss, "d") us := []uint64{1, 2, 3} - c.Assert(slice.Contains(us, uint64(1)), IsTrue) - c.Assert(slice.Contains(us, uint64(4)), IsFalse) + re.Contains(us, uint64(1)) + re.NotContains(us, uint64(4)) is := []int64{1, 2, 3} - c.Assert(slice.Contains(is, int64(1)), IsTrue) - c.Assert(slice.Contains(is, int64(4)), IsFalse) + re.Contains(is, int64(1)) + re.NotContains(is, int64(4)) } diff --git a/pkg/systimemon/systimemon_test.go b/pkg/systimemon/systimemon_test.go index 73be25e2edb..d267d15d965 100644 --- a/pkg/systimemon/systimemon_test.go +++ b/pkg/systimemon/systimemon_test.go @@ -26,11 +26,11 @@ func TestSystimeMonitor(t *testing.T) { defer cancel() var jumpForward int32 - trigged := false + triggered := false go StartMonitor(ctx, func() time.Time { - if !trigged { - trigged = true + if !triggered { + triggered = true return time.Now() } diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index c3c917d7b3a..dfb209c648d 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -17,10 +17,12 @@ package testutil import ( "os" "strings" + "testing" "time" "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "google.golang.org/grpc" ) @@ -71,6 +73,26 @@ func WaitUntil(c *check.C, f CheckFunc, opts ...WaitOption) { c.Fatal("wait timeout") } +// WaitUntilWithTestingT repeatedly evaluates f() for a period of time, util it returns true. +// NOTICE: this is a temporary function that we will be used to replace `WaitUntil` later. +func WaitUntilWithTestingT(t *testing.T, f CheckFunc, opts ...WaitOption) { + t.Log("wait start") + option := &WaitOp{ + retryTimes: waitMaxRetry, + sleepInterval: waitRetrySleep, + } + for _, opt := range opts { + opt(option) + } + for i := 0; i < option.retryTimes; i++ { + if f() { + return + } + time.Sleep(option.sleepInterval) + } + t.Fatal("wait timeout") +} + // NewRequestHeader creates a new request header. func NewRequestHeader(clusterID uint64) *pdpb.RequestHeader { return &pdpb.RequestHeader{ @@ -86,6 +108,15 @@ func MustNewGrpcClient(c *check.C, addr string) pdpb.PDClient { return pdpb.NewPDClient(conn) } +// MustNewGrpcClientWithTestify must create a new grpc client. +// NOTICE: this is a temporary function that we will be used to replace `MustNewGrpcClient` later. +func MustNewGrpcClientWithTestify(re *require.Assertions, addr string) pdpb.PDClient { + conn, err := grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) + + re.NoError(err) + return pdpb.NewPDClient(conn) +} + // CleanServer is used to clean data directory. func CleanServer(dataDir string) { // Clean data directory diff --git a/pkg/typeutil/comparison_test.go b/pkg/typeutil/comparison_test.go index 7f6c7348040..24934684b03 100644 --- a/pkg/typeutil/comparison_test.go +++ b/pkg/typeutil/comparison_test.go @@ -18,31 +18,26 @@ import ( "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -func TestTypeUtil(t *testing.T) { - TestingT(t) +func TestMinUint64(t *testing.T) { + re := require.New(t) + re.Equal(uint64(1), MinUint64(1, 2)) + re.Equal(uint64(1), MinUint64(2, 1)) + re.Equal(uint64(1), MinUint64(1, 1)) } -var _ = Suite(&testMinMaxSuite{}) - -type testMinMaxSuite struct{} - -func (s *testMinMaxSuite) TestMinUint64(c *C) { - c.Assert(MinUint64(1, 2), Equals, uint64(1)) - c.Assert(MinUint64(2, 1), Equals, uint64(1)) - c.Assert(MinUint64(1, 1), Equals, uint64(1)) -} - -func (s *testMinMaxSuite) TestMaxUint64(c *C) { - c.Assert(MaxUint64(1, 2), Equals, uint64(2)) - c.Assert(MaxUint64(2, 1), Equals, uint64(2)) - c.Assert(MaxUint64(1, 1), Equals, uint64(1)) +func TestMaxUint64(t *testing.T) { + re := require.New(t) + re.Equal(uint64(2), MaxUint64(1, 2)) + re.Equal(uint64(2), MaxUint64(2, 1)) + re.Equal(uint64(1), MaxUint64(1, 1)) } -func (s *testMinMaxSuite) TestMinDuration(c *C) { - c.Assert(MinDuration(time.Minute, time.Second), Equals, time.Second) - c.Assert(MinDuration(time.Second, time.Minute), Equals, time.Second) - c.Assert(MinDuration(time.Second, time.Second), Equals, time.Second) +func TestMinDuration(t *testing.T) { + re := require.New(t) + re.Equal(time.Second, MinDuration(time.Minute, time.Second)) + re.Equal(time.Second, MinDuration(time.Second, time.Minute)) + re.Equal(time.Second, MinDuration(time.Second, time.Second)) } diff --git a/pkg/typeutil/conversion_test.go b/pkg/typeutil/conversion_test.go index a2d9764ade0..4d28fa152f3 100644 --- a/pkg/typeutil/conversion_test.go +++ b/pkg/typeutil/conversion_test.go @@ -17,33 +17,29 @@ package typeutil import ( "encoding/json" "reflect" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testUint64BytesSuite{}) - -type testUint64BytesSuite struct{} - -func (s *testUint64BytesSuite) TestBytesToUint64(c *C) { +func TestBytesToUint64(t *testing.T) { + re := require.New(t) str := "\x00\x00\x00\x00\x00\x00\x03\xe8" a, err := BytesToUint64([]byte(str)) - c.Assert(err, IsNil) - c.Assert(a, Equals, uint64(1000)) + re.NoError(err) + re.Equal(uint64(1000), a) } -func (s *testUint64BytesSuite) TestUint64ToBytes(c *C) { +func TestUint64ToBytes(t *testing.T) { + re := require.New(t) var a uint64 = 1000 b := Uint64ToBytes(a) str := "\x00\x00\x00\x00\x00\x00\x03\xe8" - c.Assert(b, DeepEquals, []byte(str)) + re.True(reflect.DeepEqual([]byte(str), b)) } -var _ = Suite(&testJSONSuite{}) - -type testJSONSuite struct{} - -func (s *testJSONSuite) TestJSONToUint64Slice(c *C) { +func TestJSONToUint64Slice(t *testing.T) { + re := require.New(t) type testArray struct { Array []uint64 `json:"array"` } @@ -51,16 +47,16 @@ func (s *testJSONSuite) TestJSONToUint64Slice(c *C) { Array: []uint64{1, 2, 3}, } bytes, _ := json.Marshal(a) - var t map[string]interface{} - err := json.Unmarshal(bytes, &t) - c.Assert(err, IsNil) + var jsonStr map[string]interface{} + err := json.Unmarshal(bytes, &jsonStr) + re.NoError(err) // valid case - res, ok := JSONToUint64Slice(t["array"]) - c.Assert(ok, IsTrue) - c.Assert(reflect.TypeOf(res[0]).Kind(), Equals, reflect.Uint64) + res, ok := JSONToUint64Slice(jsonStr["array"]) + re.True(ok) + re.Equal(reflect.Uint64, reflect.TypeOf(res[0]).Kind()) // invalid case - _, ok = t["array"].([]uint64) - c.Assert(ok, IsFalse) + _, ok = jsonStr["array"].([]uint64) + re.False(ok) // invalid type type testArray1 struct { @@ -70,10 +66,10 @@ func (s *testJSONSuite) TestJSONToUint64Slice(c *C) { Array: []string{"1", "2", "3"}, } bytes, _ = json.Marshal(a1) - var t1 map[string]interface{} - err = json.Unmarshal(bytes, &t1) - c.Assert(err, IsNil) - res, ok = JSONToUint64Slice(t1["array"]) - c.Assert(ok, IsFalse) - c.Assert(res, IsNil) + var jsonStr1 map[string]interface{} + err = json.Unmarshal(bytes, &jsonStr1) + re.NoError(err) + res, ok = JSONToUint64Slice(jsonStr1["array"]) + re.False(ok) + re.Nil(res) } diff --git a/pkg/typeutil/duration_test.go b/pkg/typeutil/duration_test.go index c3b6f9182da..a7db13ffd04 100644 --- a/pkg/typeutil/duration_test.go +++ b/pkg/typeutil/duration_test.go @@ -16,35 +16,34 @@ package typeutil import ( "encoding/json" + "testing" "github.com/BurntSushi/toml" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testDurationSuite{}) - -type testDurationSuite struct{} - type example struct { Interval Duration `json:"interval" toml:"interval"` } -func (s *testDurationSuite) TestJSON(c *C) { +func TestDurationJSON(t *testing.T) { + re := require.New(t) example := &example{} text := []byte(`{"interval":"1h1m1s"}`) - c.Assert(json.Unmarshal(text, example), IsNil) - c.Assert(example.Interval.Seconds(), Equals, float64(60*60+60+1)) + re.Nil(json.Unmarshal(text, example)) + re.Equal(float64(60*60+60+1), example.Interval.Seconds()) b, err := json.Marshal(example) - c.Assert(err, IsNil) - c.Assert(string(b), Equals, string(text)) + re.NoError(err) + re.Equal(string(text), string(b)) } -func (s *testDurationSuite) TestTOML(c *C) { +func TestDurationTOML(t *testing.T) { + re := require.New(t) example := &example{} text := []byte(`interval = "1h1m1s"`) - c.Assert(toml.Unmarshal(text, example), IsNil) - c.Assert(example.Interval.Seconds(), Equals, float64(60*60+60+1)) + re.Nil(toml.Unmarshal(text, example)) + re.Equal(float64(60*60+60+1), example.Interval.Seconds()) } diff --git a/pkg/typeutil/size_test.go b/pkg/typeutil/size_test.go index eae092cdb5c..4cc9e66f3de 100644 --- a/pkg/typeutil/size_test.go +++ b/pkg/typeutil/size_test.go @@ -16,32 +16,30 @@ package typeutil import ( "encoding/json" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testSizeSuite{}) - -type testSizeSuite struct { -} - -func (s *testSizeSuite) TestJSON(c *C) { +func TestSizeJSON(t *testing.T) { + re := require.New(t) b := ByteSize(265421587) o, err := json.Marshal(b) - c.Assert(err, IsNil) + re.NoError(err) var nb ByteSize err = json.Unmarshal(o, &nb) - c.Assert(err, IsNil) + re.NoError(err) b = ByteSize(1756821276000) o, err = json.Marshal(b) - c.Assert(err, IsNil) - c.Assert(string(o), Equals, `"1.598TiB"`) + re.NoError(err) + re.Equal(`"1.598TiB"`, string(o)) } -func (s *testSizeSuite) TestParseMbFromText(c *C) { - testdata := []struct { +func TestParseMbFromText(t *testing.T) { + re := require.New(t) + testCases := []struct { body []string size uint64 }{{ @@ -55,9 +53,9 @@ func (s *testSizeSuite) TestParseMbFromText(c *C) { size: uint64(1), }} - for _, t := range testdata { - for _, b := range t.body { - c.Assert(int(ParseMBFromText(b, 1)), Equals, int(t.size)) + for _, testCase := range testCases { + for _, b := range testCase.body { + re.Equal(int(testCase.size), int(ParseMBFromText(b, 1))) } } } diff --git a/pkg/typeutil/string_slice_test.go b/pkg/typeutil/string_slice_test.go index 8950dea1e00..f50ddb9218d 100644 --- a/pkg/typeutil/string_slice_test.go +++ b/pkg/typeutil/string_slice_test.go @@ -16,34 +16,33 @@ package typeutil import ( "encoding/json" + "reflect" + "testing" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testStringSliceSuite{}) - -type testStringSliceSuite struct { -} - -func (s *testStringSliceSuite) TestJSON(c *C) { +func TestStringSliceJSON(t *testing.T) { + re := require.New(t) b := StringSlice([]string{"zone", "rack"}) o, err := json.Marshal(b) - c.Assert(err, IsNil) - c.Assert(string(o), Equals, "\"zone,rack\"") + re.NoError(err) + re.Equal("\"zone,rack\"", string(o)) var nb StringSlice err = json.Unmarshal(o, &nb) - c.Assert(err, IsNil) - c.Assert(nb, DeepEquals, b) + re.NoError(err) + re.True(reflect.DeepEqual(b, nb)) } -func (s *testStringSliceSuite) TestEmpty(c *C) { +func TestEmpty(t *testing.T) { + re := require.New(t) ss := StringSlice([]string{}) b, err := json.Marshal(ss) - c.Assert(err, IsNil) - c.Assert(string(b), Equals, "\"\"") + re.NoError(err) + re.Equal("\"\"", string(b)) var ss2 StringSlice - c.Assert(ss2.UnmarshalJSON(b), IsNil) - c.Assert(ss2, DeepEquals, ss) + re.NoError(ss2.UnmarshalJSON(b)) + re.True(reflect.DeepEqual(ss, ss2)) } diff --git a/pkg/typeutil/time_test.go b/pkg/typeutil/time_test.go index 3e728c14eb9..b8078f63fa8 100644 --- a/pkg/typeutil/time_test.go +++ b/pkg/typeutil/time_test.go @@ -16,62 +16,62 @@ package typeutil import ( "math/rand" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testTimeSuite{}) - -type testTimeSuite struct{} - -func (s *testTimeSuite) TestParseTimestamp(c *C) { +func TestParseTimestamp(t *testing.T) { + re := require.New(t) for i := 0; i < 3; i++ { t := time.Now().Add(time.Second * time.Duration(rand.Int31n(1000))) data := Uint64ToBytes(uint64(t.UnixNano())) nt, err := ParseTimestamp(data) - c.Assert(err, IsNil) - c.Assert(nt.Equal(t), IsTrue) + re.NoError(err) + re.True(nt.Equal(t)) } data := []byte("pd") nt, err := ParseTimestamp(data) - c.Assert(err, NotNil) - c.Assert(nt.Equal(ZeroTime), IsTrue) + re.Error(err) + re.True(nt.Equal(ZeroTime)) } -func (s *testTimeSuite) TestSubTimeByWallClock(c *C) { +func TestSubTimeByWallClock(t *testing.T) { + re := require.New(t) for i := 0; i < 100; i++ { r := rand.Int63n(1000) t1 := time.Now() // Add r seconds. t2 := t1.Add(time.Second * time.Duration(r)) duration := SubRealTimeByWallClock(t2, t1) - c.Assert(duration, Equals, time.Second*time.Duration(r)) + re.Equal(time.Second*time.Duration(r), duration) milliseconds := SubTSOPhysicalByWallClock(t2, t1) - c.Assert(milliseconds, Equals, r*time.Second.Milliseconds()) - // Add r millionseconds. + re.Equal(r*time.Second.Milliseconds(), milliseconds) + // Add r milliseconds. t3 := t1.Add(time.Millisecond * time.Duration(r)) milliseconds = SubTSOPhysicalByWallClock(t3, t1) - c.Assert(milliseconds, Equals, r) + re.Equal(r, milliseconds) // Add r nanoseconds. t4 := t1.Add(time.Duration(-r)) duration = SubRealTimeByWallClock(t4, t1) - c.Assert(duration, Equals, time.Duration(-r)) + re.Equal(time.Duration(-r), duration) // For the millisecond comparison, please see TestSmallTimeDifference. } } -func (s *testTimeSuite) TestSmallTimeDifference(c *C) { +func TestSmallTimeDifference(t *testing.T) { + re := require.New(t) t1, err := time.Parse("2006-01-02 15:04:05.999", "2021-04-26 00:44:25.682") - c.Assert(err, IsNil) + re.NoError(err) t2, err := time.Parse("2006-01-02 15:04:05.999", "2021-04-26 00:44:25.681918") - c.Assert(err, IsNil) + re.NoError(err) duration := SubRealTimeByWallClock(t1, t2) - c.Assert(duration, Equals, time.Duration(82)*time.Microsecond) + re.Equal(time.Duration(82)*time.Microsecond, duration) duration = SubRealTimeByWallClock(t2, t1) - c.Assert(duration, Equals, time.Duration(-82)*time.Microsecond) + re.Equal(time.Duration(-82)*time.Microsecond, duration) milliseconds := SubTSOPhysicalByWallClock(t1, t2) - c.Assert(milliseconds, Equals, int64(1)) + re.Equal(int64(1), milliseconds) milliseconds = SubTSOPhysicalByWallClock(t2, t1) - c.Assert(milliseconds, Equals, int64(-1)) + re.Equal(int64(-1), milliseconds) } diff --git a/scripts/check-testing-t.sh b/scripts/check-testing-t.sh index 0697a007480..6d107b5a0d1 100755 --- a/scripts/check-testing-t.sh +++ b/scripts/check-testing-t.sh @@ -1,5 +1,7 @@ #!/bin/bash +# TODO: remove this script after migrating all tests to the new test framework. + # Check if there are any packages foget to add `TestingT` when use "github.com/pingcap/check". res=$(diff <(grep -rl --include=\*_test.go "github.com/pingcap/check" . | xargs -L 1 dirname | sort -u) \ @@ -13,7 +15,7 @@ fi # Check if there are duplicated `TestingT` in package. -res=$(grep -r --include=\*_test.go "TestingT(" . | cut -f1 | xargs -L 1 dirname | sort | uniq -d) +res=$(grep -r --include=\*_test.go "TestingT(t)" . | cut -f1 | xargs -L 1 dirname | sort | uniq -d) if [ "$res" ]; then echo "following packages may have duplicated TestingT:" diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 3be3b38b484..1ece28a5239 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -170,22 +170,3 @@ func (s *testTSOSuite) TestResetTS(c *C) { tu.StringEqual(c, "\"invalid tso value\"\n")) c.Assert(err, IsNil) } - -var _ = Suite(&testServiceSuite{}) - -type testServiceSuite struct { - svr *server.Server - cleanup cleanUpFunc -} - -func (s *testServiceSuite) SetUpSuite(c *C) { - s.svr, s.cleanup = mustNewServer(c) - mustWaitLeader(c, []*server.Server{s.svr}) - - mustBootstrapCluster(c, s.svr) - mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) -} - -func (s *testServiceSuite) TearDownSuite(c *C) { - s.cleanup() -} diff --git a/server/api/scheduler.go b/server/api/scheduler.go index ecf0d51a2a6..5faa01c764b 100644 --- a/server/api/scheduler.go +++ b/server/api/scheduler.go @@ -18,6 +18,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/gorilla/mux" "github.com/pingcap/errors" @@ -44,6 +45,12 @@ func newSchedulerHandler(svr *server.Server, r *render.Render) *schedulerHandler } } +type schedulerPausedPeriod struct { + Name string `json:"name"` + PausedAt time.Time `json:"paused_at"` + ResumeAt time.Time `json:"resume_at"` +} + // @Tags scheduler // @Summary List all created schedulers by status. // @Produce json @@ -58,9 +65,11 @@ func (h *schedulerHandler) GetSchedulers(w http.ResponseWriter, r *http.Request) } status := r.URL.Query().Get("status") + _, tsFlag := r.URL.Query()["timestamp"] switch status { case "paused": var pausedSchedulers []string + pausedPeriods := []schedulerPausedPeriod{} for _, scheduler := range schedulers { paused, err := h.Handler.IsSchedulerPaused(scheduler) if err != nil { @@ -69,10 +78,35 @@ func (h *schedulerHandler) GetSchedulers(w http.ResponseWriter, r *http.Request) } if paused { - pausedSchedulers = append(pausedSchedulers, scheduler) + if tsFlag { + s := schedulerPausedPeriod{ + Name: scheduler, + PausedAt: time.Time{}, + ResumeAt: time.Time{}, + } + pausedAt, err := h.Handler.GetPausedSchedulerDelayAt(scheduler) + if err != nil { + h.r.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + s.PausedAt = time.Unix(pausedAt, 0) + resumeAt, err := h.Handler.GetPausedSchedulerDelayUntil(scheduler) + if err != nil { + h.r.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + s.ResumeAt = time.Unix(resumeAt, 0) + pausedPeriods = append(pausedPeriods, s) + } else { + pausedSchedulers = append(pausedSchedulers, scheduler) + } } } - h.r.JSON(w, http.StatusOK, pausedSchedulers) + if tsFlag { + h.r.JSON(w, http.StatusOK, pausedPeriods) + } else { + h.r.JSON(w, http.StatusOK, pausedSchedulers) + } return case "disabled": var disabledSchedulers []string diff --git a/server/api/scheduler_test.go b/server/api/scheduler_test.go index 04a2aee900e..8c20bdf6182 100644 --- a/server/api/scheduler_test.go +++ b/server/api/scheduler_test.go @@ -490,6 +490,11 @@ func (s *testScheduleSuite) testPauseOrResume(name, createdName string, body []b c.Assert(err, IsNil) err = tu.CheckPostJSON(testDialClient, s.urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(c)) c.Assert(err, IsNil) + pausedAt, err := handler.GetPausedSchedulerDelayAt(createdName) + c.Assert(err, IsNil) + resumeAt, err := handler.GetPausedSchedulerDelayUntil(createdName) + c.Assert(err, IsNil) + c.Assert(resumeAt-pausedAt, Equals, int64(1)) time.Sleep(time.Second) isPaused, err = handler.IsSchedulerPaused(createdName) c.Assert(err, IsNil) diff --git a/server/api/server_test.go b/server/api/server_test.go index 51467db3938..8d9f1b4c227 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -154,14 +154,14 @@ func mustBootstrapCluster(c *C, s *server.Server) { c.Assert(resp.GetHeader().GetError().GetType(), Equals, pdpb.ErrorType_OK) } -var _ = Suite(&testServerServiceSuite{}) +var _ = Suite(&testServiceSuite{}) -type testServerServiceSuite struct { +type testServiceSuite struct { svr *server.Server cleanup cleanUpFunc } -func (s *testServerServiceSuite) SetUpSuite(c *C) { +func (s *testServiceSuite) SetUpSuite(c *C) { s.svr, s.cleanup = mustNewServer(c) mustWaitLeader(c, []*server.Server{s.svr}) @@ -169,7 +169,7 @@ func (s *testServerServiceSuite) SetUpSuite(c *C) { mustPutStore(c, s.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) } -func (s *testServerServiceSuite) TearDownSuite(c *C) { +func (s *testServiceSuite) TearDownSuite(c *C) { s.cleanup() } diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 5a24444a1ad..4ff1232752e 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -2351,3 +2351,13 @@ func newCacheCluster(c *RaftCluster) *cacheCluster { stores: c.GetStores(), } } + +// GetPausedSchedulerDelayAt returns DelayAt of a paused scheduler +func (c *RaftCluster) GetPausedSchedulerDelayAt(name string) (int64, error) { + return c.coordinator.getPausedSchedulerDelayAt(name) +} + +// GetPausedSchedulerDelayUntil returns DelayUntil of a paused scheduler +func (c *RaftCluster) GetPausedSchedulerDelayUntil(name string) (int64, error) { + return c.coordinator.getPausedSchedulerDelayUntil(name) +} diff --git a/server/cluster/coordinator.go b/server/cluster/coordinator.go index 108a538034e..530e858877f 100644 --- a/server/cluster/coordinator.go +++ b/server/cluster/coordinator.go @@ -718,10 +718,12 @@ func (c *coordinator) pauseOrResumeScheduler(name string, t int64) error { } var err error for _, sc := range s { - var delayUntil int64 + var delayAt, delayUntil int64 if t > 0 { - delayUntil = time.Now().Unix() + t + delayAt = time.Now().Unix() + delayUntil = delayAt + t } + atomic.StoreInt64(&sc.delayAt, delayAt) atomic.StoreInt64(&sc.delayUntil, delayUntil) } return err @@ -851,6 +853,7 @@ type scheduleController struct { nextInterval time.Duration ctx context.Context cancel context.CancelFunc + delayAt int64 delayUntil int64 } @@ -909,3 +912,45 @@ func (s *scheduleController) IsPaused() bool { delayUntil := atomic.LoadInt64(&s.delayUntil) return time.Now().Unix() < delayUntil } + +// GetPausedSchedulerDelayAt returns paused timestamp of a paused scheduler +func (s *scheduleController) GetDelayAt() int64 { + if s.IsPaused() { + return atomic.LoadInt64(&s.delayAt) + } + return 0 +} + +// GetPausedSchedulerDelayUntil returns resume timestamp of a paused scheduler +func (s *scheduleController) GetDelayUntil() int64 { + if s.IsPaused() { + return atomic.LoadInt64(&s.delayUntil) + } + return 0 +} + +func (c *coordinator) getPausedSchedulerDelayAt(name string) (int64, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return -1, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return -1, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.GetDelayAt(), nil +} + +func (c *coordinator) getPausedSchedulerDelayUntil(name string) (int64, error) { + c.RLock() + defer c.RUnlock() + if c.cluster == nil { + return -1, errs.ErrNotBootstrapped.FastGenByArgs() + } + s, ok := c.schedulers[name] + if !ok { + return -1, errs.ErrSchedulerNotFound.FastGenByArgs() + } + return s.GetDelayUntil(), nil +} diff --git a/server/cluster/coordinator_test.go b/server/cluster/coordinator_test.go index 20ab1f4f8fa..b234374a765 100644 --- a/server/cluster/coordinator_test.go +++ b/server/cluster/coordinator_test.go @@ -907,6 +907,11 @@ func (s *testCoordinatorSuite) TestPauseScheduler(c *C) { co.pauseOrResumeScheduler(schedulers.BalanceLeaderName, 60) paused, _ := co.isSchedulerPaused(schedulers.BalanceLeaderName) c.Assert(paused, Equals, true) + pausedAt, err := co.getPausedSchedulerDelayAt(schedulers.BalanceLeaderName) + c.Assert(err, IsNil) + resumeAt, err := co.getPausedSchedulerDelayUntil(schedulers.BalanceLeaderName) + c.Assert(err, IsNil) + c.Assert(resumeAt-pausedAt, Equals, int64(60)) allowed, _ := co.isSchedulerAllowed(schedulers.BalanceLeaderName) c.Assert(allowed, Equals, false) } diff --git a/server/cluster/unsafe_recovery_controller.go b/server/cluster/unsafe_recovery_controller.go index 2a84535624c..9782d1a20d8 100644 --- a/server/cluster/unsafe_recovery_controller.go +++ b/server/cluster/unsafe_recovery_controller.go @@ -454,6 +454,7 @@ func (u *unsafeRecoveryController) changeStage(stage unsafeRecoveryStage) { stores += ", " } } + // TODO: clean up existing operators output.Info = fmt.Sprintf("Unsafe recovery enters collect report stage: failed stores %s", stores) case tombstoneTiFlashLearner: output.Info = "Unsafe recovery enters tombstone TiFlash learner stage" @@ -967,6 +968,17 @@ func (u *unsafeRecoveryController) generateForceLeaderPlan(newestRegionTree *reg return true }) + if hasPlan { + for storeID := range u.storeReports { + plan := u.getRecoveryPlan(storeID) + if plan.ForceLeader == nil { + // Fill an empty force leader plan to the stores that doesn't have any force leader plan + // to avoid exiting existing force leaders. + plan.ForceLeader = &pdpb.ForceLeader{} + } + } + } + return hasPlan } diff --git a/server/cluster/unsafe_recovery_controller_test.go b/server/cluster/unsafe_recovery_controller_test.go index 8b70b4cd0b6..edd6bf9c187 100644 --- a/server/cluster/unsafe_recovery_controller_test.go +++ b/server/cluster/unsafe_recovery_controller_test.go @@ -328,14 +328,14 @@ func (s *testUnsafeRecoverySuite) TestForceLeaderFail(c *C) { cluster := newTestRaftCluster(s.ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) cluster.coordinator = newCoordinator(s.ctx, cluster, hbstream.NewTestHeartbeatStreams(s.ctx, cluster.meta.GetId(), cluster, true)) cluster.coordinator.run() - for _, store := range newTestStores(3, "6.0.0") { + for _, store := range newTestStores(4, "6.0.0") { c.Assert(cluster.PutStore(store.GetMeta()), IsNil) } recoveryController := newUnsafeRecoveryController(cluster) c.Assert(recoveryController.RemoveFailedStores(map[uint64]struct{}{ - 2: {}, 3: {}, - }, 1), IsNil) + 4: {}, + }, 60), IsNil) reports := map[uint64]*pdpb.StoreReport{ 1: { @@ -345,28 +345,57 @@ func (s *testUnsafeRecoverySuite) TestForceLeaderFail(c *C) { RegionState: &raft_serverpb.RegionLocalState{ Region: &metapb.Region{ Id: 1001, + StartKey: []byte(""), + EndKey: []byte("x"), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, Peers: []*metapb.Peer{ - {Id: 11, StoreId: 1}, {Id: 21, StoreId: 2}, {Id: 31, StoreId: 3}}}}}, + {Id: 11, StoreId: 1}, {Id: 21, StoreId: 3}, {Id: 31, StoreId: 4}}}}}, + }, + }, + 2: { + PeerReports: []*pdpb.PeerReport{ + { + RaftState: &raft_serverpb.RaftLocalState{LastIndex: 10, HardState: &eraftpb.HardState{Term: 1, Commit: 10}}, + RegionState: &raft_serverpb.RegionLocalState{ + Region: &metapb.Region{ + Id: 1002, + StartKey: []byte("x"), + EndKey: []byte(""), + RegionEpoch: &metapb.RegionEpoch{ConfVer: 10, Version: 1}, + Peers: []*metapb.Peer{ + {Id: 12, StoreId: 2}, {Id: 22, StoreId: 3}, {Id: 32, StoreId: 4}}}}}, }, }, } - req := newStoreHeartbeat(1, reports[1]) - resp := &pdpb.StoreHeartbeatResponse{} - req.StoreReport.Step = 1 - recoveryController.HandleStoreHeartbeat(req, resp) + req1 := newStoreHeartbeat(1, reports[1]) + resp1 := &pdpb.StoreHeartbeatResponse{} + req1.StoreReport.Step = 1 + recoveryController.HandleStoreHeartbeat(req1, resp1) + req2 := newStoreHeartbeat(2, reports[2]) + resp2 := &pdpb.StoreHeartbeatResponse{} + req2.StoreReport.Step = 1 + recoveryController.HandleStoreHeartbeat(req2, resp2) c.Assert(recoveryController.GetStage(), Equals, forceLeader) + recoveryController.HandleStoreHeartbeat(req1, resp1) - applyRecoveryPlan(c, 1, reports, resp) - // force leader doesn't succeed - reports[1].PeerReports[0].IsForceLeader = false - recoveryController.HandleStoreHeartbeat(req, resp) + // force leader on store 1 succeed + applyRecoveryPlan(c, 1, reports, resp1) + applyRecoveryPlan(c, 2, reports, resp2) + // force leader on store 2 doesn't succeed + reports[2].PeerReports[0].IsForceLeader = false + + // force leader should retry on store 2 + recoveryController.HandleStoreHeartbeat(req1, resp1) + recoveryController.HandleStoreHeartbeat(req2, resp2) c.Assert(recoveryController.GetStage(), Equals, forceLeader) + recoveryController.HandleStoreHeartbeat(req1, resp1) // force leader succeed this time - applyRecoveryPlan(c, 1, reports, resp) - recoveryController.HandleStoreHeartbeat(req, resp) + applyRecoveryPlan(c, 1, reports, resp1) + applyRecoveryPlan(c, 2, reports, resp2) + recoveryController.HandleStoreHeartbeat(req1, resp1) + recoveryController.HandleStoreHeartbeat(req2, resp2) c.Assert(recoveryController.GetStage(), Equals, demoteFailedVoter) } diff --git a/server/config/store_config.go b/server/config/store_config.go index 6e8ba7c22f7..27fc456dd08 100644 --- a/server/config/store_config.go +++ b/server/config/store_config.go @@ -35,6 +35,8 @@ var ( defaultRegionMaxSize = uint64(144) // default region split size is 96MB defaultRegionSplitSize = uint64(96) + // default bucket size is 96MB + defaultBucketSize = uint64(96) // default region max key is 144000 defaultRegionMaxKey = uint64(1440000) // default region split key is 960000 @@ -58,7 +60,7 @@ type Coprocessor struct { RegionMaxKeys int `json:"region-max-keys"` RegionSplitKeys int `json:"region-split-keys"` EnableRegionBucket bool `json:"enable-region-bucket"` - RegionBucketSize int `json:"region-bucket-size"` + RegionBucketSize string `json:"region-bucket-size"` } // String implements fmt.Stringer interface. @@ -111,11 +113,14 @@ func (c *StoreConfig) IsEnableRegionBucket() bool { } // GetRegionBucketSize returns region bucket size if enable region buckets. -func (c *StoreConfig) GetRegionBucketSize() int { +func (c *StoreConfig) GetRegionBucketSize() uint64 { if c == nil || !c.Coprocessor.EnableRegionBucket { return 0 } - return c.Coprocessor.RegionBucketSize + if len(c.Coprocessor.RegionBucketSize) == 0 { + return defaultBucketSize + } + return typeutil.ParseMBFromText(c.Coprocessor.RegionBucketSize, defaultBucketSize) } // CheckRegionSize return error if the smallest region's size is less than mergeSize diff --git a/server/config/store_config_test.go b/server/config/store_config_test.go index 106d8b7bf4e..478e1ebb3d7 100644 --- a/server/config/store_config_test.go +++ b/server/config/store_config_test.go @@ -77,6 +77,30 @@ func (t *testTiKVConfigSuite) TestUpdateConfig(c *C) { c.Assert(manager.source.(*TiKVConfigSource).schema, Equals, "http") } +func (t *testTiKVConfigSuite) TestParseConfig(c *C) { + body := ` +{ +"coprocessor":{ +"split-region-on-table":false, +"batch-split-limit":10, +"region-max-size":"384MiB", +"region-split-size":"256MiB", +"region-max-keys":3840000, +"region-split-keys":2560000, +"consistency-check-method":"mvcc", +"enable-region-bucket":true, +"region-bucket-size":"96MiB", +"region-size-threshold-for-approximate":"384MiB", +"region-bucket-merge-size-ratio":0.33 +} +} +` + + var config StoreConfig + c.Assert(json.Unmarshal([]byte(body), &config), IsNil) + c.Assert(config.GetRegionBucketSize(), Equals, uint64(96)) +} + func (t *testTiKVConfigSuite) TestMergeCheck(c *C) { testdata := []struct { size uint64 diff --git a/server/handler.go b/server/handler.go index 8567c6ec12b..238d0dfdcc3 100644 --- a/server/handler.go +++ b/server/handler.go @@ -1123,3 +1123,21 @@ func (h *Handler) AddEvictOrGrant(storeID float64, name string) error { } return nil } + +// GetPausedSchedulerDelayAt returns paused unix timestamp when a scheduler is paused +func (h *Handler) GetPausedSchedulerDelayAt(name string) (int64, error) { + rc, err := h.GetRaftCluster() + if err != nil { + return -1, err + } + return rc.GetPausedSchedulerDelayAt(name) +} + +// GetPausedSchedulerDelayUntil returns resume unix timestamp when a scheduler is paused +func (h *Handler) GetPausedSchedulerDelayUntil(name string) (int64, error) { + rc, err := h.GetRaftCluster() + if err != nil { + return -1, err + } + return rc.GetPausedSchedulerDelayUntil(name) +} diff --git a/server/replication/replication_mode_test.go b/server/replication/replication_mode_test.go index 1f4afb01ca9..8162da599ff 100644 --- a/server/replication/replication_mode_test.go +++ b/server/replication/replication_mode_test.go @@ -21,9 +21,9 @@ import ( "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" pb "github.com/pingcap/kvproto/pkg/replication_modepb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/mock/mockcluster" "github.com/tikv/pd/pkg/typeutil" "github.com/tikv/pd/server/config" @@ -31,32 +31,16 @@ import ( "github.com/tikv/pd/server/storage" ) -func TestReplicationMode(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testReplicationMode{}) - -type testReplicationMode struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testReplicationMode) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) -} - -func (s *testReplicationMode) TearDownTest(c *C) { - s.cancel() -} - -func (s *testReplicationMode) TestInitial(c *C) { +func TestInitial(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeMajority} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) rep, err := NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{Mode: pb.ReplicationMode_MAJORITY}) + re.NoError(err) + re.Equal(&pb.ReplicationStatus{Mode: pb.ReplicationMode_MAJORITY}, rep.GetReplicationStatus()) conf = config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "dr-label", @@ -68,8 +52,8 @@ func (s *testReplicationMode) TestInitial(c *C) { WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} rep, err = NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.NoError(err) + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -77,19 +61,22 @@ func (s *testReplicationMode) TestInitial(c *C) { StateId: 1, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) } -func (s *testReplicationMode) TestStatus(c *C) { +func TestStatus(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "dr-label", WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) rep, err := NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.NoError(err) + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -97,11 +84,11 @@ func (s *testReplicationMode) TestStatus(c *C) { StateId: 1, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) err = rep.drSwitchToAsync(nil) - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.NoError(err) + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -109,12 +96,12 @@ func (s *testReplicationMode) TestStatus(c *C) { StateId: 2, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) err = rep.drSwitchToSyncRecover() - c.Assert(err, IsNil) + re.NoError(err) stateID := rep.drAutoSync.StateID - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -122,16 +109,16 @@ func (s *testReplicationMode) TestStatus(c *C) { StateId: stateID, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) // test reload rep, err = NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) - c.Assert(rep.drAutoSync.State, Equals, drStateSyncRecover) + re.NoError(err) + re.Equal(drStateSyncRecover, rep.drAutoSync.State) err = rep.drSwitchToSync() - c.Assert(err, IsNil) - c.Assert(rep.GetReplicationStatus(), DeepEquals, &pb.ReplicationStatus{ + re.NoError(err) + re.Equal(&pb.ReplicationStatus{ Mode: pb.ReplicationMode_DR_AUTO_SYNC, DrAutoSync: &pb.DRAutoSync{ LabelKey: "dr-label", @@ -139,7 +126,7 @@ func (s *testReplicationMode) TestStatus(c *C) { StateId: rep.drAutoSync.StateID, WaitSyncTimeoutHint: 60, }, - }) + }, rep.GetReplicationStatus()) } type mockFileReplicator struct { @@ -172,7 +159,10 @@ func newMockReplicator(ids []uint64) *mockFileReplicator { } } -func (s *testReplicationMode) TestStateSwitch(c *C) { +func TestStateSwitch(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", @@ -183,10 +173,10 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) replicator := newMockReplicator([]uint64{1}) rep, err := NewReplicationModeManager(conf, store, cluster, replicator) - c.Assert(err, IsNil) + re.NoError(err) cluster.AddLabelsStore(1, 1, map[string]string{"zone": "zone1"}) cluster.AddLabelsStore(2, 1, map[string]string{"zone": "zone1"}) @@ -194,12 +184,12 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { cluster.AddLabelsStore(4, 1, map[string]string{"zone": "zone1"}) // initial state is sync - c.Assert(rep.drGetState(), Equals, drStateSync) + re.Equal(drStateSync, rep.drGetState()) stateID := rep.drAutoSync.StateID - c.Assert(stateID, Not(Equals), uint64(0)) - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) + re.NotEqual(uint64(0), stateID) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[1]) assertStateIDUpdate := func() { - c.Assert(rep.drAutoSync.StateID, Not(Equals), stateID) + re.NotEqual(stateID, rep.drAutoSync.StateID) stateID = rep.drAutoSync.StateID } syncStoreStatus := func(storeIDs ...uint64) { @@ -211,124 +201,124 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { // only one zone, sync -> async_wait -> async rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) - c.Assert(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit(), IsFalse) + re.False(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit()) conf.DRAutoSync.PauseRegionSplit = true rep.UpdateConfig(conf) - c.Assert(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit(), IsTrue) + re.True(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit()) syncStoreStatus(1, 2, 3, 4) rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) // add new store in dr zone. cluster.AddLabelsStore(5, 1, map[string]string{"zone": "zone2"}) cluster.AddLabelsStore(6, 1, map[string]string{"zone": "zone2"}) // async -> sync rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) rep.drSwitchToSync() - c.Assert(rep.drGetState(), Equals, drStateSync) + re.Equal(drStateSync, rep.drGetState()) assertStateIDUpdate() // sync -> async_wait rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) - s.setStoreState(cluster, "down", "up", "up", "up", "up", "up") + re.Equal(drStateSync, rep.drGetState()) + setStoreState(cluster, "down", "up", "up", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) - s.setStoreState(cluster, "down", "down", "up", "up", "up", "up") - s.setStoreState(cluster, "down", "down", "down", "up", "up", "up") + re.Equal(drStateSync, rep.drGetState()) + setStoreState(cluster, "down", "down", "up", "up", "up", "up") + setStoreState(cluster, "down", "down", "down", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) // cannot guarantee majority, keep sync. + re.Equal(drStateSync, rep.drGetState()) // cannot guarantee majority, keep sync. - s.setStoreState(cluster, "up", "up", "up", "up", "down", "up") + setStoreState(cluster, "up", "up", "up", "up", "down", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() rep.drSwitchToSync() replicator.errors[2] = errors.New("fail to replicate") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() delete(replicator.errors, 1) // async_wait -> sync - s.setStoreState(cluster, "up", "up", "up", "up", "up", "up") + setStoreState(cluster, "up", "up", "up", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) - c.Assert(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit(), IsFalse) + re.Equal(drStateSync, rep.drGetState()) + re.False(rep.GetReplicationStatus().GetDrAutoSync().GetPauseRegionSplit()) // async_wait -> async_wait - s.setStoreState(cluster, "up", "up", "up", "up", "down", "up") + setStoreState(cluster, "up", "up", "up", "up", "down", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID)) - s.setStoreState(cluster, "down", "up", "up", "up", "down", "up") + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) + setStoreState(cluster, "down", "up", "up", "up", "down", "up") rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[2,3,4]}`, stateID)) - s.setStoreState(cluster, "up", "down", "up", "up", "down", "up") + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[2,3,4]}`, stateID), replicator.lastData[1]) + setStoreState(cluster, "up", "down", "up", "up", "down", "up") rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d,"available_stores":[1,3,4]}`, stateID), replicator.lastData[1]) // async_wait -> async rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) syncStoreStatus(1, 3) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) syncStoreStatus(4) rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID), replicator.lastData[1]) // async -> async - s.setStoreState(cluster, "up", "up", "up", "up", "down", "up") + setStoreState(cluster, "up", "up", "up", "up", "down", "up") rep.tickDR() // store 2 won't be available before it syncs status. - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,3,4]}`, stateID), replicator.lastData[1]) syncStoreStatus(1, 2, 3, 4) rep.tickDR() assertStateIDUpdate() - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"async","state_id":%d,"available_stores":[1,2,3,4]}`, stateID), replicator.lastData[1]) // async -> sync_recover - s.setStoreState(cluster, "up", "up", "up", "up", "up", "up") + setStoreState(cluster, "up", "up", "up", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) assertStateIDUpdate() rep.drSwitchToAsync([]uint64{1, 2, 3, 4, 5}) - s.setStoreState(cluster, "down", "up", "up", "up", "up", "up") + setStoreState(cluster, "down", "up", "up", "up", "up", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) assertStateIDUpdate() // sync_recover -> async rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) - s.setStoreState(cluster, "up", "up", "up", "up", "down", "up") + re.Equal(drStateSyncRecover, rep.drGetState()) + setStoreState(cluster, "up", "up", "up", "up", "down", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsync) + re.Equal(drStateAsync, rep.drGetState()) assertStateIDUpdate() // lost majority, does not switch to async. rep.drSwitchToSyncRecover() assertStateIDUpdate() - s.setStoreState(cluster, "down", "down", "up", "up", "down", "up") + setStoreState(cluster, "down", "down", "up", "up", "down", "up") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) // sync_recover -> sync rep.drSwitchToSyncRecover() assertStateIDUpdate() - s.setStoreState(cluster, "up", "up", "up", "up", "up", "up") + setStoreState(cluster, "up", "up", "up", "up", "up", "up") cluster.AddLeaderRegion(1, 1, 2, 3, 4, 5) region := cluster.GetRegion(1) @@ -337,7 +327,7 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { })) cluster.PutRegion(region) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) region = region.Clone(core.SetReplicationStatus(&pb.RegionReplicationStatus{ State: pb.RegionReplicationState_INTEGRITY_OVER_LABEL, @@ -345,18 +335,21 @@ func (s *testReplicationMode) TestStateSwitch(c *C) { })) cluster.PutRegion(region) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSyncRecover) + re.Equal(drStateSyncRecover, rep.drGetState()) region = region.Clone(core.SetReplicationStatus(&pb.RegionReplicationStatus{ State: pb.RegionReplicationState_INTEGRITY_OVER_LABEL, StateId: rep.drAutoSync.StateID, })) cluster.PutRegion(region) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) + re.Equal(drStateSync, rep.drGetState()) assertStateIDUpdate() } -func (s *testReplicationMode) TestReplicateState(c *C) { +func TestReplicateState(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", @@ -367,36 +360,39 @@ func (s *testReplicationMode) TestReplicateState(c *C) { WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) replicator := newMockReplicator([]uint64{1}) rep, err := NewReplicationModeManager(conf, store, cluster, replicator) - c.Assert(err, IsNil) + re.NoError(err) stateID := rep.drAutoSync.StateID // replicate after initialized - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[1]) // repliate state to new member replicator.memberIDs = append(replicator.memberIDs, 2, 3) rep.checkReplicateFile() - c.Assert(replicator.lastData[2], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) - c.Assert(replicator.lastData[3], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[2]) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[3]) // inject error replicator.errors[2] = errors.New("failed to persist") rep.tickDR() // switch async_wait since there is only one zone newStateID := rep.drAutoSync.StateID - c.Assert(replicator.lastData[1], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID)) - c.Assert(replicator.lastData[2], Equals, fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID)) - c.Assert(replicator.lastData[3], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID)) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID), replicator.lastData[1]) + re.Equal(fmt.Sprintf(`{"state":"sync","state_id":%d}`, stateID), replicator.lastData[2]) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID), replicator.lastData[3]) // clear error, replicate to node 2 next time delete(replicator.errors, 2) rep.checkReplicateFile() - c.Assert(replicator.lastData[2], Equals, fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID)) + re.Equal(fmt.Sprintf(`{"state":"async_wait","state_id":%d}`, newStateID), replicator.lastData[2]) } -func (s *testReplicationMode) TestAsynctimeout(c *C) { +func TestAsynctimeout(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() store := storage.NewStorageWithMemoryBackend() conf := config.ReplicationModeConfig{ReplicationMode: modeDRAutoSync, DRAutoSync: config.DRAutoSyncReplicationConfig{ LabelKey: "zone", @@ -408,34 +404,34 @@ func (s *testReplicationMode) TestAsynctimeout(c *C) { WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, WaitAsyncTimeout: typeutil.Duration{Duration: 2 * time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) var replicator mockFileReplicator rep, err := NewReplicationModeManager(conf, store, cluster, &replicator) - c.Assert(err, IsNil) + re.NoError(err) cluster.AddLabelsStore(1, 1, map[string]string{"zone": "zone1"}) cluster.AddLabelsStore(2, 1, map[string]string{"zone": "zone1"}) cluster.AddLabelsStore(3, 1, map[string]string{"zone": "zone2"}) - s.setStoreState(cluster, "up", "up", "down") + setStoreState(cluster, "up", "up", "down") rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) // cannot switch state due to recently start + re.Equal(drStateSync, rep.drGetState()) // cannot switch state due to recently start rep.initTime = time.Now().Add(-3 * time.Minute) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) rep.drSwitchToSync() rep.UpdateMemberWaitAsyncTime(42) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateSync) // cannot switch state due to member not timeout + re.Equal(drStateSync, rep.drGetState()) // cannot switch state due to member not timeout rep.drMemberWaitAsyncTime[42] = time.Now().Add(-3 * time.Minute) rep.tickDR() - c.Assert(rep.drGetState(), Equals, drStateAsyncWait) + re.Equal(drStateAsyncWait, rep.drGetState()) } -func (s *testReplicationMode) setStoreState(cluster *mockcluster.Cluster, states ...string) { +func setStoreState(cluster *mockcluster.Cluster, states ...string) { for i, state := range states { store := cluster.GetStore(uint64(i + 1)) if state == "down" { @@ -447,7 +443,11 @@ func (s *testReplicationMode) setStoreState(cluster *mockcluster.Cluster, states } } -func (s *testReplicationMode) TestRecoverProgress(c *C) { +func TestRecoverProgress(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + regionScanBatchSize = 10 regionMinSampleSize = 5 @@ -461,14 +461,14 @@ func (s *testReplicationMode) TestRecoverProgress(c *C) { WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) cluster.AddLabelsStore(1, 1, map[string]string{}) rep, err := NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) + re.NoError(err) prepare := func(n int, asyncRegions []int) { rep.drSwitchToSyncRecover() - regions := s.genRegions(cluster, rep.drAutoSync.StateID, n) + regions := genRegions(cluster, rep.drAutoSync.StateID, n) for _, i := range asyncRegions { regions[i] = regions[i].Clone(core.SetReplicationStatus(&pb.RegionReplicationStatus{ State: pb.RegionReplicationState_SIMPLE_MAJORITY, @@ -482,32 +482,35 @@ func (s *testReplicationMode) TestRecoverProgress(c *C) { } prepare(20, nil) - c.Assert(rep.drRecoverCount, Equals, 20) - c.Assert(rep.estimateProgress(), Equals, float32(1.0)) + re.Equal(20, rep.drRecoverCount) + re.Equal(float32(1.0), rep.estimateProgress()) prepare(10, []int{9}) - c.Assert(rep.drRecoverCount, Equals, 9) - c.Assert(rep.drTotalRegion, Equals, 10) - c.Assert(rep.drSampleTotalRegion, Equals, 1) - c.Assert(rep.drSampleRecoverCount, Equals, 0) - c.Assert(rep.estimateProgress(), Equals, float32(9)/float32(10)) + re.Equal(9, rep.drRecoverCount) + re.Equal(10, rep.drTotalRegion) + re.Equal(1, rep.drSampleTotalRegion) + re.Equal(0, rep.drSampleRecoverCount) + re.Equal(float32(9)/float32(10), rep.estimateProgress()) prepare(30, []int{3, 4, 5, 6, 7, 8, 9}) - c.Assert(rep.drRecoverCount, Equals, 3) - c.Assert(rep.drTotalRegion, Equals, 30) - c.Assert(rep.drSampleTotalRegion, Equals, 7) - c.Assert(rep.drSampleRecoverCount, Equals, 0) - c.Assert(rep.estimateProgress(), Equals, float32(3)/float32(30)) + re.Equal(3, rep.drRecoverCount) + re.Equal(30, rep.drTotalRegion) + re.Equal(7, rep.drSampleTotalRegion) + re.Equal(0, rep.drSampleRecoverCount) + re.Equal(float32(3)/float32(30), rep.estimateProgress()) prepare(30, []int{9, 13, 14}) - c.Assert(rep.drRecoverCount, Equals, 9) - c.Assert(rep.drTotalRegion, Equals, 30) - c.Assert(rep.drSampleTotalRegion, Equals, 6) // 9 + 10,11,12,13,14 - c.Assert(rep.drSampleRecoverCount, Equals, 3) - c.Assert(rep.estimateProgress(), Equals, (float32(9)+float32(30-9)/2)/float32(30)) + re.Equal(9, rep.drRecoverCount) + re.Equal(30, rep.drTotalRegion) + re.Equal(6, rep.drSampleTotalRegion) // 9 + 10,11,12,13,14 + re.Equal(3, rep.drSampleRecoverCount) + re.Equal((float32(9)+float32(30-9)/2)/float32(30), rep.estimateProgress()) } -func (s *testReplicationMode) TestRecoverProgressWithSplitAndMerge(c *C) { +func TestRecoverProgressWithSplitAndMerge(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() regionScanBatchSize = 10 regionMinSampleSize = 5 @@ -521,14 +524,14 @@ func (s *testReplicationMode) TestRecoverProgressWithSplitAndMerge(c *C) { WaitStoreTimeout: typeutil.Duration{Duration: time.Minute}, WaitSyncTimeout: typeutil.Duration{Duration: time.Minute}, }} - cluster := mockcluster.NewCluster(s.ctx, config.NewTestOptions()) + cluster := mockcluster.NewCluster(ctx, config.NewTestOptions()) cluster.AddLabelsStore(1, 1, map[string]string{}) rep, err := NewReplicationModeManager(conf, store, cluster, newMockReplicator([]uint64{1})) - c.Assert(err, IsNil) + re.NoError(err) prepare := func(n int, asyncRegions []int) { rep.drSwitchToSyncRecover() - regions := s.genRegions(cluster, rep.drAutoSync.StateID, n) + regions := genRegions(cluster, rep.drAutoSync.StateID, n) for _, i := range asyncRegions { regions[i] = regions[i].Clone(core.SetReplicationStatus(&pb.RegionReplicationStatus{ State: pb.RegionReplicationState_SIMPLE_MAJORITY, @@ -545,8 +548,8 @@ func (s *testReplicationMode) TestRecoverProgressWithSplitAndMerge(c *C) { r := cluster.GetRegion(1).Clone(core.WithEndKey(cluster.GetRegion(2).GetEndKey())) cluster.PutRegion(r) rep.updateProgress() - c.Assert(rep.drRecoverCount, Equals, 19) - c.Assert(rep.estimateProgress(), Equals, float32(1.0)) + re.Equal(19, rep.drRecoverCount) + re.Equal(float32(1.0), rep.estimateProgress()) // merged happened during the scan prepare(20, nil) @@ -557,23 +560,23 @@ func (s *testReplicationMode) TestRecoverProgressWithSplitAndMerge(c *C) { rep.drRecoverCount = 1 rep.drRecoverKey = r1.GetEndKey() rep.updateProgress() - c.Assert(rep.drRecoverCount, Equals, 20) - c.Assert(rep.estimateProgress(), Equals, float32(1.0)) + re.Equal(20, rep.drRecoverCount) + re.Equal(float32(1.0), rep.estimateProgress()) // split, region gap happened during the scan rep.drRecoverCount, rep.drRecoverKey = 0, nil cluster.PutRegion(r1) rep.updateProgress() - c.Assert(rep.drRecoverCount, Equals, 1) - c.Assert(rep.estimateProgress(), Not(Equals), float32(1.0)) + re.Equal(1, rep.drRecoverCount) + re.NotEqual(float32(1.0), rep.estimateProgress()) // region gap missing cluster.PutRegion(r2) rep.updateProgress() - c.Assert(rep.drRecoverCount, Equals, 20) - c.Assert(rep.estimateProgress(), Equals, float32(1.0)) + re.Equal(20, rep.drRecoverCount) + re.Equal(float32(1.0), rep.estimateProgress()) } -func (s *testReplicationMode) genRegions(cluster *mockcluster.Cluster, stateID uint64, n int) []*core.RegionInfo { +func genRegions(cluster *mockcluster.Cluster, stateID uint64, n int) []*core.RegionInfo { var regions []*core.RegionInfo for i := 1; i <= n; i++ { cluster.AddLeaderRegion(uint64(i), 1) diff --git a/server/schedule/rangelist/range_list_test.go b/server/schedule/rangelist/range_list_test.go index 4b737d7f1fd..0f9ba595aa0 100644 --- a/server/schedule/rangelist/range_list_test.go +++ b/server/schedule/rangelist/range_list_test.go @@ -17,45 +17,42 @@ package rangelist import ( "testing" - "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) func TestRangeList(t *testing.T) { - check.TestingT(t) -} - -var _ = check.Suite(&testRangeListSuite{}) - -type testRangeListSuite struct{} - -func (s *testRangeListSuite) TestRangeList(c *check.C) { + re := require.New(t) rl := NewBuilder().Build() - c.Assert(rl.Len(), check.Equals, 0) + re.Equal(0, rl.Len()) i, data := rl.GetDataByKey([]byte("a")) - c.Assert(i, check.Equals, -1) - c.Assert(data, check.IsNil) + re.Equal(-1, i) + re.Nil(data) + i, data = rl.GetData([]byte("a"), []byte("b")) - c.Assert(i, check.Equals, -1) - c.Assert(data, check.IsNil) - c.Assert(rl.GetSplitKeys(nil, []byte("foo")), check.IsNil) + re.Equal(-1, i) + re.Nil(data) + + re.Nil(rl.GetSplitKeys(nil, []byte("foo"))) b := NewBuilder() b.AddItem(nil, nil, 1) rl = b.Build() - c.Assert(rl.Len(), check.Equals, 1) + re.Equal(1, rl.Len()) key, data := rl.Get(0) - c.Assert(key, check.IsNil) - c.Assert(data, check.DeepEquals, []interface{}{1}) + re.Nil(key) + + re.Equal([]interface{}{1}, data) i, data = rl.GetDataByKey([]byte("foo")) - c.Assert(i, check.Equals, 0) - c.Assert(data, check.DeepEquals, []interface{}{1}) + re.Equal(0, i) + re.Equal([]interface{}{1}, data) i, data = rl.GetData([]byte("a"), []byte("b")) - c.Assert(i, check.Equals, 0) - c.Assert(data, check.DeepEquals, []interface{}{1}) - c.Assert(rl.GetSplitKeys(nil, []byte("foo")), check.IsNil) + re.Equal(0, i) + re.Equal([]interface{}{1}, data) + re.Nil(rl.GetSplitKeys(nil, []byte("foo"))) } -func (s *testRangeListSuite) TestRangeList2(c *check.C) { +func TestRangeList2(t *testing.T) { + re := require.New(t) b := NewBuilder() b.SetCompareFunc(func(a, b interface{}) int { if a.(int) > b.(int) { @@ -88,11 +85,11 @@ func (s *testRangeListSuite) TestRangeList2(c *check.C) { } rl := b.Build() - c.Assert(rl.Len(), check.Equals, len(expectKeys)) + re.Equal(len(expectKeys), rl.Len()) for i := 0; i < rl.Len(); i++ { key, data := rl.Get(i) - c.Assert(key, check.DeepEquals, expectKeys[i]) - c.Assert(data, check.DeepEquals, expectData[i]) + re.Equal(expectKeys[i], key) + re.Equal(expectData[i], data) } getDataByKeyCases := []struct { @@ -103,8 +100,8 @@ func (s *testRangeListSuite) TestRangeList2(c *check.C) { } for _, tc := range getDataByKeyCases { i, data := rl.GetDataByKey([]byte(tc.key)) - c.Assert(i, check.Equals, tc.pos) - c.Assert(data, check.DeepEquals, expectData[i]) + re.Equal(tc.pos, i) + re.Equal(expectData[i], data) } getDataCases := []struct { @@ -116,9 +113,9 @@ func (s *testRangeListSuite) TestRangeList2(c *check.C) { } for _, tc := range getDataCases { i, data := rl.GetData([]byte(tc.start), []byte(tc.end)) - c.Assert(i, check.Equals, tc.pos) + re.Equal(tc.pos, i) if i >= 0 { - c.Assert(data, check.DeepEquals, expectData[i]) + re.Equal(expectData[i], data) } } @@ -131,6 +128,6 @@ func (s *testRangeListSuite) TestRangeList2(c *check.C) { {"cc", "fx", 4, 7}, } for _, tc := range getSplitKeysCases { - c.Assert(rl.GetSplitKeys([]byte(tc.start), []byte(tc.end)), check.DeepEquals, expectKeys[tc.indexStart:tc.indexEnd]) + re.Equal(expectKeys[tc.indexStart:tc.indexEnd], rl.GetSplitKeys([]byte(tc.start), []byte(tc.end))) } } diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 8b195eaa587..3afda979c44 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -21,16 +21,18 @@ import ( "fmt" "math" "path" + "reflect" "sort" "sync" "testing" "time" "github.com/gogo/protobuf/proto" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" pd "github.com/tikv/pd/client" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/mock/mockid" @@ -50,30 +52,10 @@ const ( tsoRequestRound = 30 ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&clientTestSuite{}) - -type clientTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *clientTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true -} - -func (s *clientTestSuite) TearDownSuite(c *C) { - s.cancel() -} - type client interface { GetLeaderAddr() string ScheduleCheckLeader() @@ -81,75 +63,81 @@ type client interface { GetAllocatorLeaderURLs() map[string]string } -func (s *clientTestSuite) TestClientLeaderChange(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) +func TestClientLeaderChange(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) var ts1, ts2 uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { p1, l1, err := cli.GetTS(context.TODO()) if err == nil { ts1 = tsoutil.ComposeTS(p1, l1) return true } - c.Log(err) + t.Log(err) return false }) - c.Assert(cluster.CheckTSOUnique(ts1), IsTrue) + re.True(cluster.CheckTSOUnique(ts1)) leader := cluster.GetLeader() - waitLeader(c, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) + waitLeader(t, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) err = cluster.GetServer(leader).Stop() - c.Assert(err, IsNil) + re.NoError(err) leader = cluster.WaitLeader() - c.Assert(leader, Not(Equals), "") - waitLeader(c, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) + re.NotEmpty(leader) + waitLeader(t, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) // Check TS won't fall back after leader changed. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { p2, l2, err := cli.GetTS(context.TODO()) if err == nil { ts2 = tsoutil.ComposeTS(p2, l2) return true } - c.Log(err) + t.Log(err) return false }) - c.Assert(cluster.CheckTSOUnique(ts2), IsTrue) - c.Assert(ts1, Less, ts2) + re.True(cluster.CheckTSOUnique(ts2)) + re.Less(ts1, ts2) // Check URL list. cli.Close() urls := cli.(client).GetURLs() sort.Strings(urls) sort.Strings(endpoints) - c.Assert(urls, DeepEquals, endpoints) + re.True(reflect.DeepEqual(endpoints, urls)) } -func (s *clientTestSuite) TestLeaderTransfer(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) - c.Assert(err, IsNil) +func TestLeaderTransfer(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) var lastTS uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) return true } - c.Log(err) + t.Log(err) return false }) - c.Assert(cluster.CheckTSOUnique(lastTS), IsTrue) + re.True(cluster.CheckTSOUnique(lastTS)) // Start a goroutine the make sure TS won't fall back. quit := make(chan struct{}) @@ -167,8 +155,8 @@ func (s *clientTestSuite) TestLeaderTransfer(c *C) { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { ts := tsoutil.ComposeTS(physical, logical) - c.Assert(cluster.CheckTSOUnique(ts), IsTrue) - c.Assert(lastTS, Less, ts) + re.True(cluster.CheckTSOUnique(ts)) + re.Less(lastTS, ts) lastTS = ts } time.Sleep(time.Millisecond) @@ -179,69 +167,75 @@ func (s *clientTestSuite) TestLeaderTransfer(c *C) { for i := 0; i < 5; i++ { oldLeaderName := cluster.WaitLeader() err := cluster.GetServer(oldLeaderName).ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) newLeaderName := cluster.WaitLeader() - c.Assert(newLeaderName, Not(Equals), oldLeaderName) + re.NotEqual(oldLeaderName, newLeaderName) } close(quit) wg.Wait() } // More details can be found in this issue: https://github.com/tikv/pd/issues/4884 -func (s *clientTestSuite) TestUpdateAfterResetTSO(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) - c.Assert(err, IsNil) +func TestUpdateAfterResetTSO(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) // Transfer leader to trigger the TSO resetting. - c.Assert(failpoint.Enable("github.com/tikv/pd/server/updateAfterResetTSO", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/server/updateAfterResetTSO", "return(true)")) oldLeaderName := cluster.WaitLeader() err = cluster.GetServer(oldLeaderName).ResignLeader() - c.Assert(err, IsNil) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/updateAfterResetTSO"), IsNil) + re.NoError(err) + re.Nil(failpoint.Disable("github.com/tikv/pd/server/updateAfterResetTSO")) newLeaderName := cluster.WaitLeader() - c.Assert(newLeaderName, Not(Equals), oldLeaderName) + re.NotEqual(oldLeaderName, newLeaderName) // Request a new TSO. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) // Transfer leader back. - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`)) err = cluster.GetServer(newLeaderName).ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) // Should NOT panic here. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp"), IsNil) + re.Nil(failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp")) } -func (s *clientTestSuite) TestTSOAllocatorLeader(c *C) { +func TestTSOAllocatorLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) - cluster.WaitAllLeaders(c, dcLocationConfig) + re.NoError(err) + cluster.WaitAllLeadersWithTestingT(t, dcLocationConfig) var ( testServers = cluster.GetServers() @@ -255,13 +249,13 @@ func (s *clientTestSuite) TestTSOAllocatorLeader(c *C) { var allocatorLeaderMap = make(map[string]string) for _, dcLocation := range dcLocationConfig { var pdName string - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { pdName = cluster.WaitAllocatorLeader(dcLocation) return len(pdName) > 0 }) allocatorLeaderMap[dcLocation] = pdName } - cli := setupCli(c, s.ctx, endpoints) + cli := setupCli(re, ctx, endpoints) // Check allocator leaders URL map. cli.Close() @@ -270,27 +264,30 @@ func (s *clientTestSuite) TestTSOAllocatorLeader(c *C) { urls := cli.(client).GetURLs() sort.Strings(urls) sort.Strings(endpoints) - c.Assert(urls, DeepEquals, endpoints) + re.True(reflect.DeepEqual(endpoints, urls)) continue } pdName, exist := allocatorLeaderMap[dcLocation] - c.Assert(exist, IsTrue) - c.Assert(len(pdName), Greater, 0) + re.True(exist) + re.Greater(len(pdName), 0) pdURL, exist := endpointsMap[pdName] - c.Assert(exist, IsTrue) - c.Assert(len(pdURL), Greater, 0) - c.Assert(url, Equals, pdURL) + re.True(exist) + re.Greater(len(pdURL), 0) + re.Equal(pdURL, url) } } -func (s *clientTestSuite) TestTSOFollowerProxy(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) +func TestTSOFollowerProxy(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli1 := setupCli(c, s.ctx, endpoints) - cli2 := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli1 := setupCli(re, ctx, endpoints) + cli2 := setupCli(re, ctx, endpoints) cli2.UpdateOption(pd.EnableTSOFollowerProxy, true) var wg sync.WaitGroup @@ -301,15 +298,15 @@ func (s *clientTestSuite) TestTSOFollowerProxy(c *C) { var lastTS uint64 for i := 0; i < tsoRequestRound; i++ { physical, logical, err := cli2.GetTS(context.Background()) - c.Assert(err, IsNil) + re.NoError(err) ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + re.Less(lastTS, ts) lastTS = ts // After requesting with the follower proxy, request with the leader directly. physical, logical, err = cli1.GetTS(context.Background()) - c.Assert(err, IsNil) + re.NoError(err) ts = tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + re.Less(lastTS, ts) lastTS = ts } }() @@ -317,71 +314,79 @@ func (s *clientTestSuite) TestTSOFollowerProxy(c *C) { wg.Wait() } -func (s *clientTestSuite) TestGlobalAndLocalTSO(c *C) { +func TestGlobalAndLocalTSO(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) // Wait for all nodes becoming healthy. time.Sleep(time.Second * 5) // Join a new dc-location - pd4, err := cluster.Join(s.ctx, func(conf *config.Config, serverName string) { + pd4, err := cluster.Join(ctx, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = "dc-4" }) - c.Assert(err, IsNil) + re.NoError(err) err = pd4.Run() - c.Assert(err, IsNil) + re.NoError(err) dcLocationConfig["pd4"] = "dc-4" cluster.CheckClusterDCLocation() - cluster.WaitAllLeaders(c, dcLocationConfig) + cluster.WaitAllLeadersWithTestingT(t, dcLocationConfig) // Test a nonexistent dc-location for Local TSO p, l, err := cli.GetLocalTS(context.TODO(), "nonexistent-dc") - c.Assert(p, Equals, int64(0)) - c.Assert(l, Equals, int64(0)) - c.Assert(err, NotNil) - c.Assert(err, ErrorMatches, ".*unknown dc-location.*") + re.Equal(int64(0), p) + re.Equal(int64(0), l, int64(0)) + re.Error(err) + re.Contains(err.Error(), "unknown dc-location") wg := &sync.WaitGroup{} - requestGlobalAndLocalTSO(c, wg, dcLocationConfig, cli) + requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) // assert global tso after resign leader - c.Assert(failpoint.Enable("github.com/tikv/pd/client/skipUpdateMember", `return(true)`), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/skipUpdateMember", `return(true)`)) err = cluster.ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() - _, _, err = cli.GetTS(s.ctx) - c.Assert(err, NotNil) - c.Assert(pd.IsLeaderChange(err), IsTrue) - _, _, err = cli.GetTS(s.ctx) - c.Assert(err, IsNil) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/skipUpdateMember"), IsNil) + _, _, err = cli.GetTS(ctx) + re.Error(err) + re.True(pd.IsLeaderChange(err)) + _, _, err = cli.GetTS(ctx) + re.NoError(err) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/skipUpdateMember")) // Test the TSO follower proxy while enabling the Local TSO. cli.UpdateOption(pd.EnableTSOFollowerProxy, true) // Sleep a while here to prevent from canceling the ongoing TSO request. time.Sleep(time.Millisecond * 50) - requestGlobalAndLocalTSO(c, wg, dcLocationConfig, cli) + requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) cli.UpdateOption(pd.EnableTSOFollowerProxy, false) time.Sleep(time.Millisecond * 50) - requestGlobalAndLocalTSO(c, wg, dcLocationConfig, cli) + requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) } -func requestGlobalAndLocalTSO(c *C, wg *sync.WaitGroup, dcLocationConfig map[string]string, cli pd.Client) { +func requestGlobalAndLocalTSO( + re *require.Assertions, + wg *sync.WaitGroup, + dcLocationConfig map[string]string, + cli pd.Client, +) { for _, dcLocation := range dcLocationConfig { wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -390,131 +395,143 @@ func requestGlobalAndLocalTSO(c *C, wg *sync.WaitGroup, dcLocationConfig map[str var lastTS uint64 for i := 0; i < tsoRequestRound; i++ { globalPhysical1, globalLogical1, err := cli.GetTS(context.TODO()) - c.Assert(err, IsNil) + re.NoError(err) globalTS1 := tsoutil.ComposeTS(globalPhysical1, globalLogical1) localPhysical, localLogical, err := cli.GetLocalTS(context.TODO(), dc) - c.Assert(err, IsNil) + re.NoError(err) localTS := tsoutil.ComposeTS(localPhysical, localLogical) globalPhysical2, globalLogical2, err := cli.GetTS(context.TODO()) - c.Assert(err, IsNil) + re.NoError(err) globalTS2 := tsoutil.ComposeTS(globalPhysical2, globalLogical2) - c.Assert(lastTS, Less, globalTS1) - c.Assert(globalTS1, Less, localTS) - c.Assert(localTS, Less, globalTS2) + re.Less(lastTS, globalTS1) + re.Less(globalTS1, localTS) + re.Less(localTS, globalTS2) lastTS = globalTS2 } - c.Assert(lastTS, Greater, uint64(0)) + re.Greater(lastTS, uint64(0)) }(dcLocation) } } wg.Wait() } -func (s *clientTestSuite) TestCustomTimeout(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestCustomTimeout(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithCustomTimeoutOption(1*time.Second)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithCustomTimeoutOption(1*time.Second)) start := time.Now() - c.Assert(failpoint.Enable("github.com/tikv/pd/server/customTimeout", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/server/customTimeout", "return(true)")) _, err = cli.GetAllStores(context.TODO()) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/customTimeout"), IsNil) - c.Assert(err, NotNil) - c.Assert(time.Since(start), GreaterEqual, 1*time.Second) - c.Assert(time.Since(start), Less, 2*time.Second) + re.Nil(failpoint.Disable("github.com/tikv/pd/server/customTimeout")) + re.Error(err) + re.GreaterOrEqual(time.Since(start), 1*time.Second) + re.Less(time.Since(start), 2*time.Second) } -func (s *clientTestSuite) TestGetRegionFromFollowerClient(c *C) { +func TestGetRegionFromFollowerClient(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pd.LeaderHealthCheckInterval = 100 * time.Millisecond - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithForwardingOption(true)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - c.Assert(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)")) time.Sleep(200 * time.Millisecond) r, err := cli.GetRegion(context.Background(), []byte("a")) - c.Assert(err, IsNil) - c.Assert(r, NotNil) + re.NoError(err) + re.NotNil(r) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1"), IsNil) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) time.Sleep(200 * time.Millisecond) r, err = cli.GetRegion(context.Background(), []byte("a")) - c.Assert(err, IsNil) - c.Assert(r, NotNil) + re.NoError(err) + re.NotNil(r) } // case 1: unreachable -> normal -func (s *clientTestSuite) TestGetTsoFromFollowerClient1(c *C) { +func TestGetTsoFromFollowerClient1(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pd.LeaderHealthCheckInterval = 100 * time.Millisecond - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithForwardingOption(true)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - c.Assert(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) return true } - c.Log(err) + t.Log(err) return false }) - lastTS = checkTS(c, cli, lastTS) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork"), IsNil) + lastTS = checkTS(re, cli, lastTS) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) time.Sleep(2 * time.Second) - checkTS(c, cli, lastTS) + checkTS(re, cli, lastTS) } // case 2: unreachable -> leader transfer -> normal -func (s *clientTestSuite) TestGetTsoFromFollowerClient2(c *C) { +func TestGetTsoFromFollowerClient2(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pd.LeaderHealthCheckInterval = 100 * time.Millisecond - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithForwardingOption(true)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - c.Assert(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) return true } - c.Log(err) + t.Log(err) return false }) - lastTS = checkTS(c, cli, lastTS) - c.Assert(cluster.GetServer(cluster.GetLeader()).ResignLeader(), IsNil) + lastTS = checkTS(re, cli, lastTS) + re.NoError(cluster.GetServer(cluster.GetLeader()).ResignLeader()) cluster.WaitLeader() - lastTS = checkTS(c, cli, lastTS) + lastTS = checkTS(re, cli, lastTS) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork"), IsNil) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) time.Sleep(5 * time.Second) - checkTS(c, cli, lastTS) + checkTS(re, cli, lastTS) } -func checkTS(c *C, cli pd.Client, lastTS uint64) uint64 { +func checkTS(re *require.Assertions, cli pd.Client, lastTS uint64) uint64 { for i := 0; i < tsoRequestRound; i++ { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + re.Less(lastTS, ts) lastTS = ts } time.Sleep(time.Millisecond) @@ -522,12 +539,12 @@ func checkTS(c *C, cli pd.Client, lastTS uint64) uint64 { return lastTS } -func (s *clientTestSuite) runServer(c *C, cluster *tests.TestCluster) []string { +func runServer(re *require.Assertions, cluster *tests.TestCluster) []string { err := cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) testServers := cluster.GetServers() endpoints := make([]string, 0, len(testServers)) @@ -537,32 +554,80 @@ func (s *clientTestSuite) runServer(c *C, cluster *tests.TestCluster) []string { return endpoints } -func setupCli(c *C, ctx context.Context, endpoints []string, opts ...pd.ClientOption) pd.Client { +func setupCli(re *require.Assertions, ctx context.Context, endpoints []string, opts ...pd.ClientOption) pd.Client { cli, err := pd.NewClientWithContext(ctx, endpoints, pd.SecurityOption{}, opts...) - c.Assert(err, IsNil) + re.NoError(err) return cli } -func waitLeader(c *C, cli client, leader string) { - testutil.WaitUntil(c, func() bool { +func waitLeader(t *testing.T, cli client, leader string) { + testutil.WaitUntilWithTestingT(t, func() bool { cli.ScheduleCheckLeader() return cli.GetLeaderAddr() == leader }) } -func (s *clientTestSuite) TestCloseClient(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestConfigTTLAfterTransferLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + err = cluster.RunInitialServers() + re.NoError(err) + leader := cluster.GetServer(cluster.WaitLeader()) + re.NoError(leader.BootstrapCluster()) + addr := fmt.Sprintf("%s/pd/api/v1/config?ttlSecond=5", leader.GetAddr()) + postData, err := json.Marshal(map[string]interface{}{ + "schedule.max-snapshot-count": 999, + "schedule.enable-location-replacement": false, + "schedule.max-merge-region-size": 999, + "schedule.max-merge-region-keys": 999, + "schedule.scheduler-max-waiting-operator": 999, + "schedule.leader-schedule-limit": 999, + "schedule.region-schedule-limit": 999, + "schedule.hot-region-schedule-limit": 999, + "schedule.replica-schedule-limit": 999, + "schedule.merge-schedule-limit": 999, + }) + re.NoError(err) + resp, err := leader.GetHTTPClient().Post(addr, "application/json", bytes.NewBuffer(postData)) + resp.Body.Close() + re.NoError(err) + time.Sleep(2 * time.Second) + re.NoError(leader.Destroy()) + time.Sleep(2 * time.Second) + leader = cluster.GetServer(cluster.WaitLeader()) + re.NotNil(leader) + options := leader.GetPersistOptions() + re.NotNil(options) + re.Equal(uint64(999), options.GetMaxSnapshotCount()) + re.False(options.IsLocationReplacementEnabled()) + re.Equal(uint64(999), options.GetMaxMergeRegionSize()) + re.Equal(uint64(999), options.GetMaxMergeRegionKeys()) + re.Equal(uint64(999), options.GetSchedulerMaxWaitingOperator()) + re.Equal(uint64(999), options.GetLeaderScheduleLimit()) + re.Equal(uint64(999), options.GetRegionScheduleLimit()) + re.Equal(uint64(999), options.GetHotRegionScheduleLimit()) + re.Equal(uint64(999), options.GetReplicaScheduleLimit()) + re.Equal(uint64(999), options.GetMergeScheduleLimit()) +} + +func TestCloseClient(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) cli.GetTSAsync(context.TODO()) time.Sleep(time.Second) cli.Close() } -var _ = Suite(&testClientSuite{}) - type idAllocator struct { allocator *mockid.IDAllocator } @@ -605,7 +670,8 @@ var ( } ) -type testClientSuite struct { +type clientTestSuite struct { + suite.Suite cleanup server.CleanupFunc ctx context.Context clean context.CancelFunc @@ -617,38 +683,34 @@ type testClientSuite struct { reportBucket pdpb.PD_ReportBucketsClient } -func checkerWithNilAssert(c *C) *assertutil.Checker { - checker := assertutil.NewChecker(c.FailNow) - checker.IsNil = func(obtained interface{}) { - c.Assert(obtained, IsNil) - } - return checker +func TestClientTestSuite(t *testing.T) { + suite.Run(t, new(clientTestSuite)) } -func (s *testClientSuite) SetUpSuite(c *C) { +func (suite *clientTestSuite) SetupSuite() { var err error - s.srv, s.cleanup, err = server.NewTestServer(checkerWithNilAssert(c)) - c.Assert(err, IsNil) - s.grpcPDClient = testutil.MustNewGrpcClient(c, s.srv.GetAddr()) - s.grpcSvr = &server.GrpcServer{Server: s.srv} - - mustWaitLeader(c, map[string]*server.Server{s.srv.GetAddr(): s.srv}) - bootstrapServer(c, newHeader(s.srv), s.grpcPDClient) - - s.ctx, s.clean = context.WithCancel(context.Background()) - s.client = setupCli(c, s.ctx, s.srv.GetEndpoints()) - - c.Assert(err, IsNil) - s.regionHeartbeat, err = s.grpcPDClient.RegionHeartbeat(s.ctx) - c.Assert(err, IsNil) - s.reportBucket, err = s.grpcPDClient.ReportBuckets(s.ctx) - c.Assert(err, IsNil) - cluster := s.srv.GetRaftCluster() - c.Assert(cluster, NotNil) + re := suite.Require() + suite.srv, suite.cleanup, err = server.NewTestServer(suite.checkerWithNilAssert()) + suite.NoError(err) + suite.grpcPDClient = testutil.MustNewGrpcClientWithTestify(re, suite.srv.GetAddr()) + suite.grpcSvr = &server.GrpcServer{Server: suite.srv} + + suite.mustWaitLeader(map[string]*server.Server{suite.srv.GetAddr(): suite.srv}) + suite.bootstrapServer(newHeader(suite.srv), suite.grpcPDClient) + + suite.ctx, suite.clean = context.WithCancel(context.Background()) + suite.client = setupCli(re, suite.ctx, suite.srv.GetEndpoints()) + + suite.regionHeartbeat, err = suite.grpcPDClient.RegionHeartbeat(suite.ctx) + suite.NoError(err) + suite.reportBucket, err = suite.grpcPDClient.ReportBuckets(suite.ctx) + suite.NoError(err) + cluster := suite.srv.GetRaftCluster() + suite.NotNil(cluster) now := time.Now().UnixNano() for _, store := range stores { - s.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ - Header: newHeader(s.srv), + suite.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ + Header: newHeader(suite.srv), Store: &metapb.Store{ Id: store.Id, Address: store.Address, @@ -660,13 +722,23 @@ func (s *testClientSuite) SetUpSuite(c *C) { config.EnableRegionBucket = true } -func (s *testClientSuite) TearDownSuite(c *C) { - s.client.Close() - s.clean() - s.cleanup() +func (suite *clientTestSuite) TearDownSuite() { + suite.client.Close() + suite.clean() + suite.cleanup() } -func mustWaitLeader(c *C, svrs map[string]*server.Server) *server.Server { +func (suite *clientTestSuite) checkerWithNilAssert() *assertutil.Checker { + checker := assertutil.NewChecker(func() { + suite.FailNow("should be nil") + }) + checker.IsNil = func(obtained interface{}) { + suite.Nil(obtained) + } + return checker +} + +func (suite *clientTestSuite) mustWaitLeader(svrs map[string]*server.Server) *server.Server { for i := 0; i < 500; i++ { for _, s := range svrs { if !s.IsClosed() && s.GetMember().IsLeader() { @@ -675,7 +747,7 @@ func mustWaitLeader(c *C, svrs map[string]*server.Server) *server.Server { } time.Sleep(100 * time.Millisecond) } - c.Fatal("no leader") + suite.FailNow("no leader") return nil } @@ -685,7 +757,7 @@ func newHeader(srv *server.Server) *pdpb.RequestHeader { } } -func bootstrapServer(c *C, header *pdpb.RequestHeader, client pdpb.PDClient) { +func (suite *clientTestSuite) bootstrapServer(header *pdpb.RequestHeader, client pdpb.PDClient) { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -701,10 +773,10 @@ func bootstrapServer(c *C, header *pdpb.RequestHeader, client pdpb.PDClient) { Region: region, } _, err := client.Bootstrap(context.Background(), req) - c.Assert(err, IsNil) + suite.NoError(err) } -func (s *testClientSuite) TestNormalTSO(c *C) { +func (suite *clientTestSuite) TestNormalTSO() { var wg sync.WaitGroup wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -712,10 +784,10 @@ func (s *testClientSuite) TestNormalTSO(c *C) { defer wg.Done() var lastTS uint64 for i := 0; i < tsoRequestRound; i++ { - physical, logical, err := s.client.GetTS(context.Background()) - c.Assert(err, IsNil) + physical, logical, err := suite.client.GetTS(context.Background()) + suite.NoError(err) ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + suite.Less(lastTS, ts) lastTS = ts } }() @@ -723,7 +795,7 @@ func (s *testClientSuite) TestNormalTSO(c *C) { wg.Wait() } -func (s *testClientSuite) TestGetTSAsync(c *C) { +func (suite *clientTestSuite) TestGetTSAsync() { var wg sync.WaitGroup wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -731,14 +803,14 @@ func (s *testClientSuite) TestGetTSAsync(c *C) { defer wg.Done() tsFutures := make([]pd.TSFuture, tsoRequestRound) for i := range tsFutures { - tsFutures[i] = s.client.GetTSAsync(context.Background()) + tsFutures[i] = suite.client.GetTSAsync(context.Background()) } var lastTS uint64 = math.MaxUint64 for i := len(tsFutures) - 1; i >= 0; i-- { physical, logical, err := tsFutures[i].Wait() - c.Assert(err, IsNil) + suite.NoError(err) ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Greater, ts) + suite.Greater(lastTS, ts) lastTS = ts } }() @@ -746,7 +818,7 @@ func (s *testClientSuite) TestGetTSAsync(c *C) { wg.Wait() } -func (s *testClientSuite) TestGetRegion(c *C) { +func (suite *clientTestSuite) TestGetRegion() { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -757,24 +829,25 @@ func (s *testClientSuite) TestGetRegion(c *C) { Peers: peers, } req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: region, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegion(context.Background(), []byte("a")) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) + t := suite.T() + testutil.WaitUntilWithTestingT(t, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a")) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Meta, DeepEquals, region) && - c.Check(r.Leader, DeepEquals, peers[0]) && - c.Check(r.Buckets, IsNil) + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil }) breq := &pdpb.ReportBucketsRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Buckets: &metapb.Buckets{ RegionId: regionID, Version: 1, @@ -790,30 +863,29 @@ func (s *testClientSuite) TestGetRegion(c *C) { }, }, } - c.Assert(s.reportBucket.Send(breq), IsNil) - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) - c.Assert(err, IsNil) + suite.NoError(suite.reportBucket.Send(breq)) + testutil.WaitUntilWithTestingT(t, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Buckets, NotNil) + return r.Buckets != nil }) - config := s.srv.GetRaftCluster().GetStoreConfig() + config := suite.srv.GetRaftCluster().GetStoreConfig() config.EnableRegionBucket = false - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) - c.Assert(err, IsNil) + testutil.WaitUntilWithTestingT(t, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Buckets, IsNil) + return r.Buckets == nil }) config.EnableRegionBucket = true - c.Succeed() } -func (s *testClientSuite) TestGetPrevRegion(c *C) { +func (suite *clientTestSuite) TestGetPrevRegion() { regionLen := 10 regions := make([]*metapb.Region, 0, regionLen) for i := 0; i < regionLen; i++ { @@ -830,29 +902,28 @@ func (s *testClientSuite) TestGetPrevRegion(c *C) { } regions = append(regions, r) req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: r, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) } time.Sleep(500 * time.Millisecond) for i := 0; i < 20; i++ { - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetPrevRegion(context.Background(), []byte{byte(i)}) - c.Assert(err, IsNil) + testutil.WaitUntilWithTestingT(suite.T(), func() bool { + r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)}) + suite.NoError(err) if i > 0 && i < regionLen { - return c.Check(r.Leader, DeepEquals, peers[0]) && - c.Check(r.Meta, DeepEquals, regions[i-1]) + return reflect.DeepEqual(peers[0], r.Leader) && + reflect.DeepEqual(regions[i-1], r.Meta) } - return c.Check(r, IsNil) + return r == nil }) } - c.Succeed() } -func (s *testClientSuite) TestScanRegions(c *C) { +func (suite *clientTestSuite) TestScanRegions() { regionLen := 10 regions := make([]*metapb.Region, 0, regionLen) for i := 0; i < regionLen; i++ { @@ -869,53 +940,54 @@ func (s *testClientSuite) TestScanRegions(c *C) { } regions = append(regions, r) req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: r, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) } // Wait for region heartbeats. - testutil.WaitUntil(c, func() bool { - scanRegions, err := s.client.ScanRegions(context.Background(), []byte{0}, nil, 10) + t := suite.T() + testutil.WaitUntilWithTestingT(t, func() bool { + scanRegions, err := suite.client.ScanRegions(context.Background(), []byte{0}, nil, 10) return err == nil && len(scanRegions) == 10 }) // Set leader of region3 to nil. region3 := core.NewRegionInfo(regions[3], nil) - s.srv.GetRaftCluster().HandleRegionHeartbeat(region3) + suite.srv.GetRaftCluster().HandleRegionHeartbeat(region3) // Add down peer for region4. region4 := core.NewRegionInfo(regions[4], regions[4].Peers[0], core.WithDownPeers([]*pdpb.PeerStats{{Peer: regions[4].Peers[1]}})) - s.srv.GetRaftCluster().HandleRegionHeartbeat(region4) + suite.srv.GetRaftCluster().HandleRegionHeartbeat(region4) // Add pending peers for region5. region5 := core.NewRegionInfo(regions[5], regions[5].Peers[0], core.WithPendingPeers([]*metapb.Peer{regions[5].Peers[1], regions[5].Peers[2]})) - s.srv.GetRaftCluster().HandleRegionHeartbeat(region5) + suite.srv.GetRaftCluster().HandleRegionHeartbeat(region5) check := func(start, end []byte, limit int, expect []*metapb.Region) { - scanRegions, err := s.client.ScanRegions(context.Background(), start, end, limit) - c.Assert(err, IsNil) - c.Assert(scanRegions, HasLen, len(expect)) - c.Log("scanRegions", scanRegions) - c.Log("expect", expect) + scanRegions, err := suite.client.ScanRegions(context.Background(), start, end, limit) + suite.NoError(err) + suite.Len(scanRegions, len(expect)) + t.Log("scanRegions", scanRegions) + t.Log("expect", expect) for i := range expect { - c.Assert(scanRegions[i].Meta, DeepEquals, expect[i]) + suite.True(reflect.DeepEqual(expect[i], scanRegions[i].Meta)) if scanRegions[i].Meta.GetId() == region3.GetID() { - c.Assert(scanRegions[i].Leader, DeepEquals, &metapb.Peer{}) + suite.True(reflect.DeepEqual(&metapb.Peer{}, scanRegions[i].Leader)) } else { - c.Assert(scanRegions[i].Leader, DeepEquals, expect[i].Peers[0]) + suite.True(reflect.DeepEqual(expect[i].Peers[0], scanRegions[i].Leader)) } if scanRegions[i].Meta.GetId() == region4.GetID() { - c.Assert(scanRegions[i].DownPeers, DeepEquals, []*metapb.Peer{expect[i].Peers[1]}) + suite.True(reflect.DeepEqual([]*metapb.Peer{expect[i].Peers[1]}, scanRegions[i].DownPeers)) } if scanRegions[i].Meta.GetId() == region5.GetID() { - c.Assert(scanRegions[i].PendingPeers, DeepEquals, []*metapb.Peer{expect[i].Peers[1], expect[i].Peers[2]}) + suite.True(reflect.DeepEqual([]*metapb.Peer{expect[i].Peers[1], expect[i].Peers[2]}, scanRegions[i].PendingPeers)) } } } @@ -927,7 +999,7 @@ func (s *testClientSuite) TestScanRegions(c *C) { check([]byte{1}, []byte{6}, 2, regions[1:3]) } -func (s *testClientSuite) TestGetRegionByID(c *C) { +func (suite *clientTestSuite) TestGetRegionByID() { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -938,125 +1010,125 @@ func (s *testClientSuite) TestGetRegionByID(c *C) { Peers: peers, } req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: region, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegionByID(context.Background(), regionID) - c.Assert(err, IsNil) + testutil.WaitUntilWithTestingT(suite.T(), func() bool { + r, err := suite.client.GetRegionByID(context.Background(), regionID) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Meta, DeepEquals, region) && - c.Check(r.Leader, DeepEquals, peers[0]) + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) }) - c.Succeed() } -func (s *testClientSuite) TestGetStore(c *C) { - cluster := s.srv.GetRaftCluster() - c.Assert(cluster, NotNil) +func (suite *clientTestSuite) TestGetStore() { + cluster := suite.srv.GetRaftCluster() + suite.NotNil(cluster) store := stores[0] // Get an up store should be OK. - n, err := s.client.GetStore(context.Background(), store.GetId()) - c.Assert(err, IsNil) - c.Assert(n, DeepEquals, store) + n, err := suite.client.GetStore(context.Background(), store.GetId()) + suite.NoError(err) + suite.True(reflect.DeepEqual(store, n)) - stores, err := s.client.GetAllStores(context.Background()) - c.Assert(err, IsNil) - c.Assert(stores, DeepEquals, stores) + actualStores, err := suite.client.GetAllStores(context.Background()) + suite.NoError(err) + suite.Len(actualStores, len(stores)) + stores = actualStores // Mark the store as offline. err = cluster.RemoveStore(store.GetId(), false) - c.Assert(err, IsNil) + suite.NoError(err) offlineStore := proto.Clone(store).(*metapb.Store) offlineStore.State = metapb.StoreState_Offline offlineStore.NodeState = metapb.NodeState_Removing // Get an offline store should be OK. - n, err = s.client.GetStore(context.Background(), store.GetId()) - c.Assert(err, IsNil) - c.Assert(n, DeepEquals, offlineStore) + n, err = suite.client.GetStore(context.Background(), store.GetId()) + suite.NoError(err) + suite.True(reflect.DeepEqual(offlineStore, n)) // Should return offline stores. contains := false - stores, err = s.client.GetAllStores(context.Background()) - c.Assert(err, IsNil) + stores, err = suite.client.GetAllStores(context.Background()) + suite.NoError(err) for _, store := range stores { if store.GetId() == offlineStore.GetId() { contains = true - c.Assert(store, DeepEquals, offlineStore) + suite.True(reflect.DeepEqual(offlineStore, store)) } } - c.Assert(contains, IsTrue) + suite.True(contains) // Mark the store as physically destroyed and offline. err = cluster.RemoveStore(store.GetId(), true) - c.Assert(err, IsNil) + suite.NoError(err) physicallyDestroyedStoreID := store.GetId() // Get a physically destroyed and offline store // It should be Tombstone(become Tombstone automically) or Offline - n, err = s.client.GetStore(context.Background(), physicallyDestroyedStoreID) - c.Assert(err, IsNil) + n, err = suite.client.GetStore(context.Background(), physicallyDestroyedStoreID) + suite.NoError(err) if n != nil { // store is still offline and physically destroyed - c.Assert(n.GetNodeState(), Equals, metapb.NodeState_Removing) - c.Assert(n.PhysicallyDestroyed, IsTrue) + suite.Equal(metapb.NodeState_Removing, n.GetNodeState()) + suite.True(n.PhysicallyDestroyed) } // Should return tombstone stores. contains = false - stores, err = s.client.GetAllStores(context.Background()) - c.Assert(err, IsNil) + stores, err = suite.client.GetAllStores(context.Background()) + suite.NoError(err) for _, store := range stores { if store.GetId() == physicallyDestroyedStoreID { contains = true - c.Assert(store.GetState(), Not(Equals), metapb.StoreState_Up) - c.Assert(store.PhysicallyDestroyed, IsTrue) + suite.NotEqual(metapb.StoreState_Up, store.GetState()) + suite.True(store.PhysicallyDestroyed) } } - c.Assert(contains, IsTrue) + suite.True(contains) // Should not return tombstone stores. - stores, err = s.client.GetAllStores(context.Background(), pd.WithExcludeTombstone()) - c.Assert(err, IsNil) + stores, err = suite.client.GetAllStores(context.Background(), pd.WithExcludeTombstone()) + suite.NoError(err) for _, store := range stores { if store.GetId() == physicallyDestroyedStoreID { - c.Assert(store.GetState(), Equals, metapb.StoreState_Offline) - c.Assert(store.PhysicallyDestroyed, IsTrue) + suite.Equal(metapb.StoreState_Offline, store.GetState()) + suite.True(store.PhysicallyDestroyed) } } } -func (s *testClientSuite) checkGCSafePoint(c *C, expectedSafePoint uint64) { +func (suite *clientTestSuite) checkGCSafePoint(expectedSafePoint uint64) { req := &pdpb.GetGCSafePointRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), } - resp, err := s.grpcSvr.GetGCSafePoint(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.SafePoint, Equals, expectedSafePoint) + resp, err := suite.grpcSvr.GetGCSafePoint(context.Background(), req) + suite.NoError(err) + suite.Equal(expectedSafePoint, resp.SafePoint) } -func (s *testClientSuite) TestUpdateGCSafePoint(c *C) { - s.checkGCSafePoint(c, 0) +func (suite *clientTestSuite) TestUpdateGCSafePoint() { + suite.checkGCSafePoint(0) for _, safePoint := range []uint64{0, 1, 2, 3, 233, 23333, 233333333333, math.MaxUint64} { - newSafePoint, err := s.client.UpdateGCSafePoint(context.Background(), safePoint) - c.Assert(err, IsNil) - c.Assert(newSafePoint, Equals, safePoint) - s.checkGCSafePoint(c, safePoint) + newSafePoint, err := suite.client.UpdateGCSafePoint(context.Background(), safePoint) + suite.NoError(err) + suite.Equal(safePoint, newSafePoint) + suite.checkGCSafePoint(safePoint) } // If the new safe point is less than the old one, it should not be updated. - newSafePoint, err := s.client.UpdateGCSafePoint(context.Background(), 1) - c.Assert(newSafePoint, Equals, uint64(math.MaxUint64)) - c.Assert(err, IsNil) - s.checkGCSafePoint(c, math.MaxUint64) + newSafePoint, err := suite.client.UpdateGCSafePoint(context.Background(), 1) + suite.Equal(uint64(math.MaxUint64), newSafePoint) + suite.NoError(err) + suite.checkGCSafePoint(math.MaxUint64) } -func (s *testClientSuite) TestUpdateServiceGCSafePoint(c *C) { +func (suite *clientTestSuite) TestUpdateServiceGCSafePoint() { serviceSafePoints := []struct { ServiceID string TTL int64 @@ -1067,105 +1139,105 @@ func (s *testClientSuite) TestUpdateServiceGCSafePoint(c *C) { {"c", 1000, 3}, } for _, ssp := range serviceSafePoints { - min, err := s.client.UpdateServiceGCSafePoint(context.Background(), + min, err := suite.client.UpdateServiceGCSafePoint(context.Background(), ssp.ServiceID, 1000, ssp.SafePoint) - c.Assert(err, IsNil) + suite.NoError(err) // An service safepoint of ID "gc_worker" is automatically initialized as 0 - c.Assert(min, Equals, uint64(0)) + suite.Equal(uint64(0), min) } - min, err := s.client.UpdateServiceGCSafePoint(context.Background(), + min, err := suite.client.UpdateServiceGCSafePoint(context.Background(), "gc_worker", math.MaxInt64, 10) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(1)) + suite.NoError(err) + suite.Equal(uint64(1), min) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", 1000, 4) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(2)) + suite.NoError(err) + suite.Equal(uint64(2), min) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "b", -100, 2) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) + suite.NoError(err) + suite.Equal(uint64(3), min) // Minimum safepoint does not regress - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "b", 1000, 2) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) + suite.NoError(err) + suite.Equal(uint64(3), min) // Update only the TTL of the minimum safepoint - oldMinSsp, err := s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(oldMinSsp.ServiceID, Equals, "c") - c.Assert(oldMinSsp.SafePoint, Equals, uint64(3)) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + oldMinSsp, err := suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", oldMinSsp.ServiceID) + suite.Equal(uint64(3), oldMinSsp.SafePoint) + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", 2000, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) - minSsp, err := s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "c") - c.Assert(oldMinSsp.SafePoint, Equals, uint64(3)) - c.Assert(minSsp.ExpiredAt-oldMinSsp.ExpiredAt, GreaterEqual, int64(1000)) + suite.NoError(err) + suite.Equal(uint64(3), min) + minSsp, err := suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", minSsp.ServiceID) + suite.Equal(uint64(3), oldMinSsp.SafePoint) + suite.GreaterOrEqual(minSsp.ExpiredAt-oldMinSsp.ExpiredAt, int64(1000)) // Shrinking TTL is also allowed - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", 1, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "c") - c.Assert(minSsp.ExpiredAt, Less, oldMinSsp.ExpiredAt) + suite.NoError(err) + suite.Equal(uint64(3), min) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", minSsp.ServiceID) + suite.Less(minSsp.ExpiredAt, oldMinSsp.ExpiredAt) // TTL can be infinite (represented by math.MaxInt64) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", math.MaxInt64, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "c") - c.Assert(minSsp.ExpiredAt, Equals, int64(math.MaxInt64)) + suite.NoError(err) + suite.Equal(uint64(3), min) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", minSsp.ServiceID) + suite.Equal(minSsp.ExpiredAt, int64(math.MaxInt64)) // Delete "a" and "c" - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", -1, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(4)) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + suite.NoError(err) + suite.Equal(uint64(4), min) + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", -1, 4) - c.Assert(err, IsNil) + suite.NoError(err) // Now gc_worker is the only remaining service safe point. - c.Assert(min, Equals, uint64(10)) + suite.Equal(uint64(10), min) // gc_worker cannot be deleted. - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "gc_worker", -1, 10) - c.Assert(err, NotNil) + suite.Error(err) // Cannot set non-infinity TTL for gc_worker - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "gc_worker", 10000000, 10) - c.Assert(err, NotNil) + suite.Error(err) // Service safepoint must have a non-empty ID - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "", 1000, 15) - c.Assert(err, NotNil) + suite.Error(err) // Put some other safepoints to test fixing gc_worker's safepoint when there exists other safepoints. - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", 1000, 11) - c.Assert(err, IsNil) - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + suite.NoError(err) + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "b", 1000, 12) - c.Assert(err, IsNil) - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + suite.NoError(err) + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", 1000, 13) - c.Assert(err, IsNil) + suite.NoError(err) // Force set invalid ttl to gc_worker gcWorkerKey := path.Join("gc", "safe_point", "service", "gc_worker") @@ -1176,38 +1248,38 @@ func (s *testClientSuite) TestUpdateServiceGCSafePoint(c *C) { SafePoint: 10, } value, err := json.Marshal(gcWorkerSsp) - c.Assert(err, IsNil) - err = s.srv.GetStorage().Save(gcWorkerKey, string(value)) - c.Assert(err, IsNil) + suite.NoError(err) + err = suite.srv.GetStorage().Save(gcWorkerKey, string(value)) + suite.NoError(err) } - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "gc_worker") - c.Assert(minSsp.SafePoint, Equals, uint64(10)) - c.Assert(minSsp.ExpiredAt, Equals, int64(math.MaxInt64)) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("gc_worker", minSsp.ServiceID) + suite.Equal(uint64(10), minSsp.SafePoint) + suite.Equal(int64(math.MaxInt64), minSsp.ExpiredAt) // Force delete gc_worker, then the min service safepoint is 11 of "a". - err = s.srv.GetStorage().Remove(gcWorkerKey) - c.Assert(err, IsNil) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.SafePoint, Equals, uint64(11)) + err = suite.srv.GetStorage().Remove(gcWorkerKey) + suite.NoError(err) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal(uint64(11), minSsp.SafePoint) // After calling LoadMinServiceGCS when "gc_worker"'s service safepoint is missing, "gc_worker"'s service safepoint // will be newly created. // Increase "a" so that "gc_worker" is the only minimum that will be returned by LoadMinServiceGCSafePoint. - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", 1000, 14) - c.Assert(err, IsNil) + suite.NoError(err) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "gc_worker") - c.Assert(minSsp.SafePoint, Equals, uint64(11)) - c.Assert(minSsp.ExpiredAt, Equals, int64(math.MaxInt64)) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("gc_worker", minSsp.ServiceID) + suite.Equal(uint64(11), minSsp.SafePoint) + suite.Equal(int64(math.MaxInt64), minSsp.ExpiredAt) } -func (s *testClientSuite) TestScatterRegion(c *C) { +func (suite *clientTestSuite) TestScatterRegion() { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -1220,106 +1292,46 @@ func (s *testClientSuite) TestScatterRegion(c *C) { EndKey: []byte("ggg"), } req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: region, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) + err := suite.regionHeartbeat.Send(req) regionsID := []uint64{regionID} - c.Assert(err, IsNil) + suite.NoError(err) // Test interface `ScatterRegions`. - testutil.WaitUntil(c, func() bool { - scatterResp, err := s.client.ScatterRegions(context.Background(), regionsID, pd.WithGroup("test"), pd.WithRetry(1)) - if c.Check(err, NotNil) { + t := suite.T() + testutil.WaitUntilWithTestingT(t, func() bool { + scatterResp, err := suite.client.ScatterRegions(context.Background(), regionsID, pd.WithGroup("test"), pd.WithRetry(1)) + if err != nil { return false } - if c.Check(scatterResp.FinishedPercentage, Not(Equals), uint64(100)) { + if scatterResp.FinishedPercentage != uint64(100) { return false } - resp, err := s.client.GetOperator(context.Background(), regionID) - if c.Check(err, NotNil) { + resp, err := suite.client.GetOperator(context.Background(), regionID) + if err != nil { return false } - return c.Check(resp.GetRegionId(), Equals, regionID) && c.Check(string(resp.GetDesc()), Equals, "scatter-region") && c.Check(resp.GetStatus(), Equals, pdpb.OperatorStatus_RUNNING) + return resp.GetRegionId() == regionID && + string(resp.GetDesc()) == "scatter-region" && + resp.GetStatus() == pdpb.OperatorStatus_RUNNING }, testutil.WithSleepInterval(1*time.Second)) // Test interface `ScatterRegion`. // TODO: Deprecate interface `ScatterRegion`. - testutil.WaitUntil(c, func() bool { - err := s.client.ScatterRegion(context.Background(), regionID) - if c.Check(err, NotNil) { + testutil.WaitUntilWithTestingT(t, func() bool { + err := suite.client.ScatterRegion(context.Background(), regionID) + if err != nil { fmt.Println(err) return false } - resp, err := s.client.GetOperator(context.Background(), regionID) - if c.Check(err, NotNil) { + resp, err := suite.client.GetOperator(context.Background(), regionID) + if err != nil { return false } - return c.Check(resp.GetRegionId(), Equals, regionID) && c.Check(string(resp.GetDesc()), Equals, "scatter-region") && c.Check(resp.GetStatus(), Equals, pdpb.OperatorStatus_RUNNING) + return resp.GetRegionId() == regionID && + string(resp.GetDesc()) == "scatter-region" && + resp.GetStatus() == pdpb.OperatorStatus_RUNNING }, testutil.WithSleepInterval(1*time.Second)) - - c.Succeed() -} - -type testConfigTTLSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testConfigTTLSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true -} - -func (s *testConfigTTLSuite) TearDownSuite(c *C) { - s.cancel() -} - -var _ = SerialSuites(&testConfigTTLSuite{}) - -var ttlConfig = map[string]interface{}{ - "schedule.max-snapshot-count": 999, - "schedule.enable-location-replacement": false, - "schedule.max-merge-region-size": 999, - "schedule.max-merge-region-keys": 999, - "schedule.scheduler-max-waiting-operator": 999, - "schedule.leader-schedule-limit": 999, - "schedule.region-schedule-limit": 999, - "schedule.hot-region-schedule-limit": 999, - "schedule.replica-schedule-limit": 999, - "schedule.merge-schedule-limit": 999, -} - -func assertTTLConfig(c *C, options *config.PersistOptions, checker Checker) { - c.Assert(options.GetMaxSnapshotCount(), checker, uint64(999)) - c.Assert(options.IsLocationReplacementEnabled(), checker, false) - c.Assert(options.GetMaxMergeRegionSize(), checker, uint64(999)) - c.Assert(options.GetMaxMergeRegionKeys(), checker, uint64(999)) - c.Assert(options.GetSchedulerMaxWaitingOperator(), checker, uint64(999)) - c.Assert(options.GetLeaderScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetRegionScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetHotRegionScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetReplicaScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetMergeScheduleLimit(), checker, uint64(999)) -} - -func (s *testConfigTTLSuite) TestConfigTTLAfterTransferLeader(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) - defer cluster.Destroy() - err = cluster.RunInitialServers() - c.Assert(err, IsNil) - leader := cluster.GetServer(cluster.WaitLeader()) - c.Assert(leader.BootstrapCluster(), IsNil) - addr := fmt.Sprintf("%s/pd/api/v1/config?ttlSecond=5", leader.GetAddr()) - postData, err := json.Marshal(ttlConfig) - c.Assert(err, IsNil) - resp, err := leader.GetHTTPClient().Post(addr, "application/json", bytes.NewBuffer(postData)) - resp.Body.Close() - c.Assert(err, IsNil) - time.Sleep(2 * time.Second) - _ = leader.Destroy() - time.Sleep(2 * time.Second) - leader = cluster.GetServer(cluster.WaitLeader()) - assertTTLConfig(c, leader.GetPersistOptions(), Equals) } diff --git a/tests/client/client_tls_test.go b/tests/client/client_tls_test.go index 3fbca17c835..48a6fec3d2d 100644 --- a/tests/client/client_tls_test.go +++ b/tests/client/client_tls_test.go @@ -22,21 +22,19 @@ import ( "os" "path/filepath" "strings" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" pd "github.com/tikv/pd/client" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/netutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" "go.etcd.io/etcd/pkg/transport" "google.golang.org/grpc" ) -var _ = Suite(&clientTLSTestSuite{}) - var ( testTLSInfo = transport.TLSInfo{ KeyFile: "./cert/pd-server-key.pem", @@ -57,49 +55,38 @@ var ( } ) -type clientTLSTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *clientTLSTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true -} - -func (s *clientTLSTestSuite) TearDownSuite(c *C) { - s.cancel() -} - // TestTLSReloadAtomicReplace ensures server reloads expired/valid certs // when all certs are atomically replaced by directory renaming. // And expects server to reject client requests, and vice versa. -func (s *clientTLSTestSuite) TestTLSReloadAtomicReplace(c *C) { +func TestTLSReloadAtomicReplace(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() tmpDir, err := os.MkdirTemp(os.TempDir(), "cert-tmp") - c.Assert(err, IsNil) + re.NoError(err) os.RemoveAll(tmpDir) defer os.RemoveAll(tmpDir) certsDir, err := os.MkdirTemp(os.TempDir(), "cert-to-load") - c.Assert(err, IsNil) + re.NoError(err) defer os.RemoveAll(certsDir) certsDirExp, err := os.MkdirTemp(os.TempDir(), "cert-expired") - c.Assert(err, IsNil) + re.NoError(err) defer os.RemoveAll(certsDirExp) cloneFunc := func() transport.TLSInfo { tlsInfo, terr := copyTLSFiles(testTLSInfo, certsDir) - c.Assert(terr, IsNil) + re.NoError(terr) _, err = copyTLSFiles(testTLSInfoExpired, certsDirExp) - c.Assert(err, IsNil) + re.NoError(err) return tlsInfo } replaceFunc := func() { err = os.Rename(certsDir, tmpDir) - c.Assert(err, IsNil) + re.NoError(err) err = os.Rename(certsDirExp, certsDir) - c.Assert(err, IsNil) + re.NoError(err) // after rename, // 'certsDir' contains expired certs // 'tmpDir' contains valid certs @@ -107,25 +94,26 @@ func (s *clientTLSTestSuite) TestTLSReloadAtomicReplace(c *C) { } revertFunc := func() { err = os.Rename(tmpDir, certsDirExp) - c.Assert(err, IsNil) + re.NoError(err) err = os.Rename(certsDir, tmpDir) - c.Assert(err, IsNil) + re.NoError(err) err = os.Rename(certsDirExp, certsDir) - c.Assert(err, IsNil) + re.NoError(err) } - s.testTLSReload(c, cloneFunc, replaceFunc, revertFunc) + testTLSReload(re, ctx, cloneFunc, replaceFunc, revertFunc) } -func (s *clientTLSTestSuite) testTLSReload( - c *C, +func testTLSReload( + re *require.Assertions, + ctx context.Context, cloneFunc func() transport.TLSInfo, replaceFunc func(), revertFunc func()) { tlsInfo := cloneFunc() // 1. start cluster with valid certs - clus, err := tests.NewTestCluster(s.ctx, 1, func(conf *config.Config, serverName string) { + clus, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.Security.TLSConfig = grpcutil.TLSConfig{ KeyPath: tlsInfo.KeyFile, CertPath: tlsInfo.CertFile, @@ -137,10 +125,10 @@ func (s *clientTLSTestSuite) testTLSReload( conf.PeerUrls = strings.ReplaceAll(conf.PeerUrls, "http", "https") conf.InitialCluster = strings.ReplaceAll(conf.InitialCluster, "http", "https") }) - c.Assert(err, IsNil) + re.NoError(err) defer clus.Destroy() err = clus.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) clus.WaitLeader() testServers := clus.GetServers() @@ -148,20 +136,20 @@ func (s *clientTLSTestSuite) testTLSReload( for _, s := range testServers { endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls) tlsConfig, err := s.GetConfig().Security.ToTLSConfig() - c.Assert(err, IsNil) + re.NoError(err) httpClient := &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, TLSClientConfig: tlsConfig, }, } - c.Assert(netutil.IsEnableHTTPS(httpClient), IsTrue) + re.True(netutil.IsEnableHTTPS(httpClient)) } // 2. concurrent client dialing while certs become expired errc := make(chan error, 1) go func() { for { - dctx, dcancel := context.WithTimeout(s.ctx, time.Second) + dctx, dcancel := context.WithTimeout(ctx, time.Second) cli, err := pd.NewClientWithContext(dctx, endpoints, pd.SecurityOption{ CAPath: testClientTLSInfo.TrustedCAFile, CertPath: testClientTLSInfo.CertFile, @@ -183,46 +171,46 @@ func (s *clientTLSTestSuite) testTLSReload( // 4. expect dial time-out when loading expired certs select { case cerr := <-errc: - c.Assert(strings.Contains(cerr.Error(), "failed to get cluster id"), IsTrue) + re.Contains(cerr.Error(), "failed to get cluster id") case <-time.After(5 * time.Second): - c.Fatal("failed to receive dial timeout error") + re.FailNow("failed to receive dial timeout error") } // 5. replace expired certs back with valid ones revertFunc() // 6. new requests should trigger listener to reload valid certs - dctx, dcancel := context.WithTimeout(s.ctx, 5*time.Second) + dctx, dcancel := context.WithTimeout(ctx, 5*time.Second) cli, err := pd.NewClientWithContext(dctx, endpoints, pd.SecurityOption{ CAPath: testClientTLSInfo.TrustedCAFile, CertPath: testClientTLSInfo.CertFile, KeyPath: testClientTLSInfo.KeyFile, }, pd.WithGRPCDialOptions(grpc.WithBlock())) - c.Assert(err, IsNil) + re.NoError(err) dcancel() cli.Close() // 7. test use raw bytes to init tls config - caData, certData, keyData := loadTLSContent(c, + caData, certData, keyData := loadTLSContent(re, testClientTLSInfo.TrustedCAFile, testClientTLSInfo.CertFile, testClientTLSInfo.KeyFile) - ctx1, cancel1 := context.WithTimeout(s.ctx, 2*time.Second) + ctx1, cancel1 := context.WithTimeout(ctx, 2*time.Second) _, err = pd.NewClientWithContext(ctx1, endpoints, pd.SecurityOption{ SSLCABytes: caData, SSLCertBytes: certData, SSLKEYBytes: keyData, }, pd.WithGRPCDialOptions(grpc.WithBlock())) - c.Assert(err, IsNil) + re.NoError(err) cancel1() } -func loadTLSContent(c *C, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { +func loadTLSContent(re *require.Assertions, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { var err error caData, err = os.ReadFile(caPath) - c.Assert(err, IsNil) + re.NoError(err) certData, err = os.ReadFile(certPath) - c.Assert(err, IsNil) + re.NoError(err) keyData, err = os.ReadFile(keyPath) - c.Assert(err, IsNil) + re.NoError(err) return } @@ -245,6 +233,7 @@ func copyTLSFiles(ti transport.TLSInfo, dst string) (transport.TLSInfo, error) { } return ci, nil } + func copyFile(src, dst string) error { f, err := os.Open(src) if err != nil { diff --git a/tests/client/go.mod b/tests/client/go.mod index 93fb9d96eaa..9d539193d52 100644 --- a/tests/client/go.mod +++ b/tests/client/go.mod @@ -5,9 +5,9 @@ go 1.16 require ( github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.5.2 // indirect - github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0 github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a + github.com/stretchr/testify v1.7.0 github.com/tikv/pd v0.0.0-00010101000000-000000000000 github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 go.etcd.io/etcd v0.5.0-alpha.5.0.20191023171146-3cf2f69b5738 diff --git a/tests/cluster.go b/tests/cluster.go index 2061668f393..3b0e10a02e6 100644 --- a/tests/cluster.go +++ b/tests/cluster.go @@ -19,6 +19,7 @@ import ( "net/http" "os" "sync" + "testing" "time" "github.com/coreos/go-semver/semver" @@ -622,6 +623,26 @@ func (c *TestCluster) WaitAllLeaders(testC *check.C, dcLocations map[string]stri wg.Wait() } +// WaitAllLeadersWithTestingT will block and wait for the election of PD leader and all Local TSO Allocator leaders. +// NOTICE: this is a temporary function that we will be used to replace `WaitAllLeaders` later. +func (c *TestCluster) WaitAllLeadersWithTestingT(t *testing.T, dcLocations map[string]string) { + c.WaitLeader() + c.CheckClusterDCLocation() + // Wait for each DC's Local TSO Allocator leader + wg := sync.WaitGroup{} + for _, dcLocation := range dcLocations { + wg.Add(1) + go func(dc string) { + testutil.WaitUntilWithTestingT(t, func() bool { + leaderName := c.WaitAllocatorLeader(dc) + return leaderName != "" + }) + wg.Done() + }(dcLocation) + } + wg.Wait() +} + // GetCluster returns PD cluster. func (c *TestCluster) GetCluster() *metapb.Cluster { leader := c.GetLeader() diff --git a/tools/pd-ctl/pdctl/command/scheduler.go b/tools/pd-ctl/pdctl/command/scheduler.go index 5c334d86dae..0a0d9635021 100644 --- a/tools/pd-ctl/pdctl/command/scheduler.go +++ b/tools/pd-ctl/pdctl/command/scheduler.go @@ -104,6 +104,7 @@ func NewShowSchedulerCommand() *cobra.Command { Run: showSchedulerCommandFunc, } c.Flags().String("status", "", "the scheduler status value can be [paused | disabled]") + c.Flags().BoolP("timestamp", "t", false, "fetch the paused and resume timestamp for paused scheduler(s)") return c } @@ -116,6 +117,9 @@ func showSchedulerCommandFunc(cmd *cobra.Command, args []string) { url := schedulersPrefix if flag := cmd.Flag("status"); flag != nil && flag.Value.String() != "" { url = fmt.Sprintf("%s?status=%s", url, flag.Value.String()) + if tsFlag, _ := cmd.Flags().GetBool("timestamp"); tsFlag { + url += "×tamp=true" + } } r, err := doRequest(cmd, url, http.MethodGet, http.Header{}) if err != nil {