diff --git a/client/client_test.go b/client/client_test.go index 7b9470ace5b..c80b78bb96b 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -16,38 +16,33 @@ package pd import ( "context" + "reflect" "testing" "time" - . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "github.com/tikv/pd/client/testutil" "go.uber.org/goleak" "google.golang.org/grpc" ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&testClientSuite{}) - -type testClientSuite struct{} - -func (s *testClientSuite) TestTsLessEqual(c *C) { - c.Assert(tsLessEqual(9, 9, 9, 9), IsTrue) - c.Assert(tsLessEqual(8, 9, 9, 8), IsTrue) - c.Assert(tsLessEqual(9, 8, 8, 9), IsFalse) - c.Assert(tsLessEqual(9, 8, 9, 6), IsFalse) - c.Assert(tsLessEqual(9, 6, 9, 8), IsTrue) +func TestTsLessEqual(t *testing.T) { + re := require.New(t) + re.True(tsLessEqual(9, 9, 9, 9)) + re.True(tsLessEqual(8, 9, 9, 8)) + re.False(tsLessEqual(9, 8, 8, 9)) + re.False(tsLessEqual(9, 8, 9, 6)) + re.True(tsLessEqual(9, 6, 9, 8)) } -func (s *testClientSuite) TestUpdateURLs(c *C) { +func TestUpdateURLs(t *testing.T) { + re := require.New(t) members := []*pdpb.Member{ {Name: "pd4", ClientUrls: []string{"tmp://pd4"}}, {Name: "pd1", ClientUrls: []string{"tmp://pd1"}}, @@ -63,40 +58,35 @@ func (s *testClientSuite) TestUpdateURLs(c *C) { cli := &baseClient{option: newOption()} cli.urls.Store([]string{}) cli.updateURLs(members[1:]) - c.Assert(cli.GetURLs(), DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2]})) + re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs())) cli.updateURLs(members[1:]) - c.Assert(cli.GetURLs(), DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2]})) + re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2]}), cli.GetURLs())) cli.updateURLs(members) - c.Assert(cli.GetURLs(), DeepEquals, getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]})) + re.True(reflect.DeepEqual(getURLs([]*pdpb.Member{members[1], members[3], members[2], members[0]}), cli.GetURLs())) } const testClientURL = "tmp://test.url:5255" -var _ = Suite(&testClientCtxSuite{}) - -type testClientCtxSuite struct{} - -func (s *testClientCtxSuite) TestClientCtx(c *C) { +func TestClientCtx(t *testing.T) { + re := require.New(t) start := time.Now() ctx, cancel := context.WithTimeout(context.TODO(), time.Second*3) defer cancel() _, err := NewClientWithContext(ctx, []string{testClientURL}, SecurityOption{}) - c.Assert(err, NotNil) - c.Assert(time.Since(start), Less, time.Second*5) + re.Error(err) + re.Less(time.Since(start), time.Second*5) } -func (s *testClientCtxSuite) TestClientWithRetry(c *C) { +func TestClientWithRetry(t *testing.T) { + re := require.New(t) start := time.Now() _, err := NewClientWithContext(context.TODO(), []string{testClientURL}, SecurityOption{}, WithMaxErrorRetry(5)) - c.Assert(err, NotNil) - c.Assert(time.Since(start), Less, time.Second*10) + re.Error(err) + re.Less(time.Since(start), time.Second*10) } -var _ = Suite(&testClientDialOptionSuite{}) - -type testClientDialOptionSuite struct{} - -func (s *testClientDialOptionSuite) TestGRPCDialOption(c *C) { +func TestGRPCDialOption(t *testing.T) { + re := require.New(t) start := time.Now() ctx, cancel := context.WithTimeout(context.TODO(), 500*time.Millisecond) defer cancel() @@ -111,15 +101,12 @@ func (s *testClientDialOptionSuite) TestGRPCDialOption(c *C) { cli.urls.Store([]string{testClientURL}) cli.option.gRPCDialOptions = []grpc.DialOption{grpc.WithBlock()} err := cli.updateMember() - c.Assert(err, NotNil) - c.Assert(time.Since(start), Greater, 500*time.Millisecond) + re.Error(err) + re.Greater(time.Since(start), 500*time.Millisecond) } -var _ = Suite(&testTsoRequestSuite{}) - -type testTsoRequestSuite struct{} - -func (s *testTsoRequestSuite) TestTsoRequestWait(c *C) { +func TestTsoRequestWait(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) req := &tsoRequest{ done: make(chan error, 1), @@ -130,7 +117,7 @@ func (s *testTsoRequestSuite) TestTsoRequestWait(c *C) { } cancel() _, _, err := req.Wait() - c.Assert(errors.Cause(err), Equals, context.Canceled) + re.ErrorIs(errors.Cause(err), context.Canceled) ctx, cancel = context.WithCancel(context.Background()) req = &tsoRequest{ @@ -142,5 +129,5 @@ func (s *testTsoRequestSuite) TestTsoRequestWait(c *C) { } cancel() _, _, err = req.Wait() - c.Assert(errors.Cause(err), Equals, context.Canceled) + re.ErrorIs(errors.Cause(err), context.Canceled) } diff --git a/client/go.mod b/client/go.mod index 893cada680d..56380b0b51c 100644 --- a/client/go.mod +++ b/client/go.mod @@ -4,12 +4,12 @@ go 1.16 require ( github.com/opentracing/opentracing-go v1.2.0 - github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0 github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee github.com/prometheus/client_golang v1.11.0 + github.com/stretchr/testify v1.7.0 go.uber.org/goleak v1.1.11 go.uber.org/zap v1.20.0 google.golang.org/grpc v1.43.0 diff --git a/client/go.sum b/client/go.sum index 6682bdb2893..90019b9b382 100644 --- a/client/go.sum +++ b/client/go.sum @@ -69,7 +69,6 @@ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.12.1/go.mod h1:8XEsbTttt/W+VvjtQhLACqCisSPWTxCZ7sBRjU6iH9c= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= @@ -84,8 +83,10 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -97,9 +98,6 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= -github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8/go.mod h1:B1+S9LNcuMyLH/4HMTViQOJevkGiik3wW2AN9zb2fNQ= -github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0 h1:HVl5539r48eA+uDuX/ziBmQCxzT1pGrzWbKuXT46Bq0= -github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0/go.mod h1:PYMCGwN0JHjoqGr3HrZoD+b8Tgx8bKnArhSq8YVzUMc= github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTmyFqUwr+jcCvpVkK7sumiz+ko5H9eq4= @@ -108,7 +106,6 @@ github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 h1:C3N3itkduZXDZ github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00/go.mod h1:4qGtCB0QK0wBzKtFEGDhxXnSnbQApw1gc9siScUl8ew= github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a h1:TxdHGOFeNa1q1mVv6TgReayf26iI4F8PQUm6RnZ/V/E= github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= -github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee h1:VO2t6IBpfvW34TdtD/G10VvnGqjLic1jzOuHjUb5VqM= github.com/pingcap/log v0.0.0-20211215031037-e024ba4eb0ee/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -137,7 +134,6 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -154,8 +150,6 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= -go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= @@ -163,21 +157,14 @@ go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= -go.uber.org/multierr v1.4.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= -go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= -go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.12.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.20.0 h1:N4oPlghZwYG55MlU6LXk/Zp00FVNE9X9wrYO8CEs4lc= go.uber.org/zap v1.20.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -186,7 +173,6 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -244,10 +230,7 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191107010934-f79515f33823/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -290,8 +273,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -307,4 +290,3 @@ gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/client/option_test.go b/client/option_test.go index b3d044bbd1b..2a7f7824e12 100644 --- a/client/option_test.go +++ b/client/option_test.go @@ -15,43 +15,41 @@ package pd import ( + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" "github.com/tikv/pd/client/testutil" ) -var _ = Suite(&testClientOptionSuite{}) - -type testClientOptionSuite struct{} - -func (s *testClientSuite) TestDynamicOptionChange(c *C) { +func TestDynamicOptionChange(t *testing.T) { + re := require.New(t) o := newOption() // Check the default value setting. - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, defaultMaxTSOBatchWaitInterval) - c.Assert(o.getEnableTSOFollowerProxy(), Equals, defaultEnableTSOFollowerProxy) + re.Equal(defaultMaxTSOBatchWaitInterval, o.getMaxTSOBatchWaitInterval()) + re.Equal(defaultEnableTSOFollowerProxy, o.getEnableTSOFollowerProxy()) // Check the invalid value setting. - c.Assert(o.setMaxTSOBatchWaitInterval(time.Second), NotNil) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, defaultMaxTSOBatchWaitInterval) + re.NotNil(o.setMaxTSOBatchWaitInterval(time.Second)) + re.Equal(defaultMaxTSOBatchWaitInterval, o.getMaxTSOBatchWaitInterval()) expectInterval := time.Millisecond o.setMaxTSOBatchWaitInterval(expectInterval) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, expectInterval) + re.Equal(expectInterval, o.getMaxTSOBatchWaitInterval()) expectInterval = time.Duration(float64(time.Millisecond) * 0.5) o.setMaxTSOBatchWaitInterval(expectInterval) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, expectInterval) + re.Equal(expectInterval, o.getMaxTSOBatchWaitInterval()) expectInterval = time.Duration(float64(time.Millisecond) * 1.5) o.setMaxTSOBatchWaitInterval(expectInterval) - c.Assert(o.getMaxTSOBatchWaitInterval(), Equals, expectInterval) + re.Equal(expectInterval, o.getMaxTSOBatchWaitInterval()) expectBool := true o.setEnableTSOFollowerProxy(expectBool) // Check the value changing notification. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntil(t, func() bool { <-o.enableTSOFollowerProxyCh return true }) - c.Assert(o.getEnableTSOFollowerProxy(), Equals, expectBool) + re.Equal(expectBool, o.getEnableTSOFollowerProxy()) // Check whether any data will be sent to the channel. // It will panic if the test fails. close(o.enableTSOFollowerProxyCh) diff --git a/client/testutil/testutil.go b/client/testutil/testutil.go index 3627566ecfb..095a31ae74a 100644 --- a/client/testutil/testutil.go +++ b/client/testutil/testutil.go @@ -15,9 +15,8 @@ package testutil import ( + "testing" "time" - - "github.com/pingcap/check" ) const ( @@ -45,8 +44,8 @@ func WithSleepInterval(sleep time.Duration) WaitOption { } // WaitUntil repeatedly evaluates f() for a period of time, util it returns true. -func WaitUntil(c *check.C, f func() bool, opts ...WaitOption) { - c.Log("wait start") +func WaitUntil(t *testing.T, f func() bool, opts ...WaitOption) { + t.Log("wait start") option := &WaitOp{ retryTimes: waitMaxRetry, sleepInterval: waitRetrySleep, @@ -60,5 +59,5 @@ func WaitUntil(c *check.C, f func() bool, opts ...WaitOption) { } time.Sleep(option.sleepInterval) } - c.Fatal("wait timeout") + t.Fatal("wait timeout") } diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index c3c917d7b3a..dfb209c648d 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -17,10 +17,12 @@ package testutil import ( "os" "strings" + "testing" "time" "github.com/pingcap/check" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" "google.golang.org/grpc" ) @@ -71,6 +73,26 @@ func WaitUntil(c *check.C, f CheckFunc, opts ...WaitOption) { c.Fatal("wait timeout") } +// WaitUntilWithTestingT repeatedly evaluates f() for a period of time, util it returns true. +// NOTICE: this is a temporary function that we will be used to replace `WaitUntil` later. +func WaitUntilWithTestingT(t *testing.T, f CheckFunc, opts ...WaitOption) { + t.Log("wait start") + option := &WaitOp{ + retryTimes: waitMaxRetry, + sleepInterval: waitRetrySleep, + } + for _, opt := range opts { + opt(option) + } + for i := 0; i < option.retryTimes; i++ { + if f() { + return + } + time.Sleep(option.sleepInterval) + } + t.Fatal("wait timeout") +} + // NewRequestHeader creates a new request header. func NewRequestHeader(clusterID uint64) *pdpb.RequestHeader { return &pdpb.RequestHeader{ @@ -86,6 +108,15 @@ func MustNewGrpcClient(c *check.C, addr string) pdpb.PDClient { return pdpb.NewPDClient(conn) } +// MustNewGrpcClientWithTestify must create a new grpc client. +// NOTICE: this is a temporary function that we will be used to replace `MustNewGrpcClient` later. +func MustNewGrpcClientWithTestify(re *require.Assertions, addr string) pdpb.PDClient { + conn, err := grpc.Dial(strings.TrimPrefix(addr, "http://"), grpc.WithInsecure()) + + re.NoError(err) + return pdpb.NewPDClient(conn) +} + // CleanServer is used to clean data directory. func CleanServer(dataDir string) { // Clean data directory diff --git a/scripts/check-testing-t.sh b/scripts/check-testing-t.sh index 0697a007480..6d107b5a0d1 100755 --- a/scripts/check-testing-t.sh +++ b/scripts/check-testing-t.sh @@ -1,5 +1,7 @@ #!/bin/bash +# TODO: remove this script after migrating all tests to the new test framework. + # Check if there are any packages foget to add `TestingT` when use "github.com/pingcap/check". res=$(diff <(grep -rl --include=\*_test.go "github.com/pingcap/check" . | xargs -L 1 dirname | sort -u) \ @@ -13,7 +15,7 @@ fi # Check if there are duplicated `TestingT` in package. -res=$(grep -r --include=\*_test.go "TestingT(" . | cut -f1 | xargs -L 1 dirname | sort | uniq -d) +res=$(grep -r --include=\*_test.go "TestingT(t)" . | cut -f1 | xargs -L 1 dirname | sort | uniq -d) if [ "$res" ]; then echo "following packages may have duplicated TestingT:" diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 8b195eaa587..3afda979c44 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -21,16 +21,18 @@ import ( "fmt" "math" "path" + "reflect" "sort" "sync" "testing" "time" "github.com/gogo/protobuf/proto" - . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" pd "github.com/tikv/pd/client" "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/mock/mockid" @@ -50,30 +52,10 @@ const ( tsoRequestRound = 30 ) -func Test(t *testing.T) { - TestingT(t) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -var _ = Suite(&clientTestSuite{}) - -type clientTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *clientTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true -} - -func (s *clientTestSuite) TearDownSuite(c *C) { - s.cancel() -} - type client interface { GetLeaderAddr() string ScheduleCheckLeader() @@ -81,75 +63,81 @@ type client interface { GetAllocatorLeaderURLs() map[string]string } -func (s *clientTestSuite) TestClientLeaderChange(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) +func TestClientLeaderChange(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) var ts1, ts2 uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { p1, l1, err := cli.GetTS(context.TODO()) if err == nil { ts1 = tsoutil.ComposeTS(p1, l1) return true } - c.Log(err) + t.Log(err) return false }) - c.Assert(cluster.CheckTSOUnique(ts1), IsTrue) + re.True(cluster.CheckTSOUnique(ts1)) leader := cluster.GetLeader() - waitLeader(c, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) + waitLeader(t, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) err = cluster.GetServer(leader).Stop() - c.Assert(err, IsNil) + re.NoError(err) leader = cluster.WaitLeader() - c.Assert(leader, Not(Equals), "") - waitLeader(c, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) + re.NotEmpty(leader) + waitLeader(t, cli.(client), cluster.GetServer(leader).GetConfig().ClientUrls) // Check TS won't fall back after leader changed. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { p2, l2, err := cli.GetTS(context.TODO()) if err == nil { ts2 = tsoutil.ComposeTS(p2, l2) return true } - c.Log(err) + t.Log(err) return false }) - c.Assert(cluster.CheckTSOUnique(ts2), IsTrue) - c.Assert(ts1, Less, ts2) + re.True(cluster.CheckTSOUnique(ts2)) + re.Less(ts1, ts2) // Check URL list. cli.Close() urls := cli.(client).GetURLs() sort.Strings(urls) sort.Strings(endpoints) - c.Assert(urls, DeepEquals, endpoints) + re.True(reflect.DeepEqual(endpoints, urls)) } -func (s *clientTestSuite) TestLeaderTransfer(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) - c.Assert(err, IsNil) +func TestLeaderTransfer(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) var lastTS uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) return true } - c.Log(err) + t.Log(err) return false }) - c.Assert(cluster.CheckTSOUnique(lastTS), IsTrue) + re.True(cluster.CheckTSOUnique(lastTS)) // Start a goroutine the make sure TS won't fall back. quit := make(chan struct{}) @@ -167,8 +155,8 @@ func (s *clientTestSuite) TestLeaderTransfer(c *C) { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { ts := tsoutil.ComposeTS(physical, logical) - c.Assert(cluster.CheckTSOUnique(ts), IsTrue) - c.Assert(lastTS, Less, ts) + re.True(cluster.CheckTSOUnique(ts)) + re.Less(lastTS, ts) lastTS = ts } time.Sleep(time.Millisecond) @@ -179,69 +167,75 @@ func (s *clientTestSuite) TestLeaderTransfer(c *C) { for i := 0; i < 5; i++ { oldLeaderName := cluster.WaitLeader() err := cluster.GetServer(oldLeaderName).ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) newLeaderName := cluster.WaitLeader() - c.Assert(newLeaderName, Not(Equals), oldLeaderName) + re.NotEqual(oldLeaderName, newLeaderName) } close(quit) wg.Wait() } // More details can be found in this issue: https://github.com/tikv/pd/issues/4884 -func (s *clientTestSuite) TestUpdateAfterResetTSO(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 2) - c.Assert(err, IsNil) +func TestUpdateAfterResetTSO(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) // Transfer leader to trigger the TSO resetting. - c.Assert(failpoint.Enable("github.com/tikv/pd/server/updateAfterResetTSO", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/server/updateAfterResetTSO", "return(true)")) oldLeaderName := cluster.WaitLeader() err = cluster.GetServer(oldLeaderName).ResignLeader() - c.Assert(err, IsNil) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/updateAfterResetTSO"), IsNil) + re.NoError(err) + re.Nil(failpoint.Disable("github.com/tikv/pd/server/updateAfterResetTSO")) newLeaderName := cluster.WaitLeader() - c.Assert(newLeaderName, Not(Equals), oldLeaderName) + re.NotEqual(oldLeaderName, newLeaderName) // Request a new TSO. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) // Transfer leader back. - c.Assert(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/server/tso/delaySyncTimestamp", `return(true)`)) err = cluster.GetServer(newLeaderName).ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) // Should NOT panic here. - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { _, _, err := cli.GetTS(context.TODO()) return err == nil }) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp"), IsNil) + re.Nil(failpoint.Disable("github.com/tikv/pd/server/tso/delaySyncTimestamp")) } -func (s *clientTestSuite) TestTSOAllocatorLeader(c *C) { +func TestTSOAllocatorLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() err = cluster.RunInitialServers() - c.Assert(err, IsNil) - cluster.WaitAllLeaders(c, dcLocationConfig) + re.NoError(err) + cluster.WaitAllLeadersWithTestingT(t, dcLocationConfig) var ( testServers = cluster.GetServers() @@ -255,13 +249,13 @@ func (s *clientTestSuite) TestTSOAllocatorLeader(c *C) { var allocatorLeaderMap = make(map[string]string) for _, dcLocation := range dcLocationConfig { var pdName string - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { pdName = cluster.WaitAllocatorLeader(dcLocation) return len(pdName) > 0 }) allocatorLeaderMap[dcLocation] = pdName } - cli := setupCli(c, s.ctx, endpoints) + cli := setupCli(re, ctx, endpoints) // Check allocator leaders URL map. cli.Close() @@ -270,27 +264,30 @@ func (s *clientTestSuite) TestTSOAllocatorLeader(c *C) { urls := cli.(client).GetURLs() sort.Strings(urls) sort.Strings(endpoints) - c.Assert(urls, DeepEquals, endpoints) + re.True(reflect.DeepEqual(endpoints, urls)) continue } pdName, exist := allocatorLeaderMap[dcLocation] - c.Assert(exist, IsTrue) - c.Assert(len(pdName), Greater, 0) + re.True(exist) + re.Greater(len(pdName), 0) pdURL, exist := endpointsMap[pdName] - c.Assert(exist, IsTrue) - c.Assert(len(pdURL), Greater, 0) - c.Assert(url, Equals, pdURL) + re.True(exist) + re.Greater(len(pdURL), 0) + re.Equal(pdURL, url) } } -func (s *clientTestSuite) TestTSOFollowerProxy(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) +func TestTSOFollowerProxy(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli1 := setupCli(c, s.ctx, endpoints) - cli2 := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli1 := setupCli(re, ctx, endpoints) + cli2 := setupCli(re, ctx, endpoints) cli2.UpdateOption(pd.EnableTSOFollowerProxy, true) var wg sync.WaitGroup @@ -301,15 +298,15 @@ func (s *clientTestSuite) TestTSOFollowerProxy(c *C) { var lastTS uint64 for i := 0; i < tsoRequestRound; i++ { physical, logical, err := cli2.GetTS(context.Background()) - c.Assert(err, IsNil) + re.NoError(err) ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + re.Less(lastTS, ts) lastTS = ts // After requesting with the follower proxy, request with the leader directly. physical, logical, err = cli1.GetTS(context.Background()) - c.Assert(err, IsNil) + re.NoError(err) ts = tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + re.Less(lastTS, ts) lastTS = ts } }() @@ -317,71 +314,79 @@ func (s *clientTestSuite) TestTSOFollowerProxy(c *C) { wg.Wait() } -func (s *clientTestSuite) TestGlobalAndLocalTSO(c *C) { +func TestGlobalAndLocalTSO(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dcLocationConfig := map[string]string{ "pd1": "dc-1", "pd2": "dc-2", "pd3": "dc-3", } dcLocationNum := len(dcLocationConfig) - cluster, err := tests.NewTestCluster(s.ctx, dcLocationNum, func(conf *config.Config, serverName string) { + cluster, err := tests.NewTestCluster(ctx, dcLocationNum, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = dcLocationConfig[serverName] }) - c.Assert(err, IsNil) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) // Wait for all nodes becoming healthy. time.Sleep(time.Second * 5) // Join a new dc-location - pd4, err := cluster.Join(s.ctx, func(conf *config.Config, serverName string) { + pd4, err := cluster.Join(ctx, func(conf *config.Config, serverName string) { conf.EnableLocalTSO = true conf.Labels[config.ZoneLabel] = "dc-4" }) - c.Assert(err, IsNil) + re.NoError(err) err = pd4.Run() - c.Assert(err, IsNil) + re.NoError(err) dcLocationConfig["pd4"] = "dc-4" cluster.CheckClusterDCLocation() - cluster.WaitAllLeaders(c, dcLocationConfig) + cluster.WaitAllLeadersWithTestingT(t, dcLocationConfig) // Test a nonexistent dc-location for Local TSO p, l, err := cli.GetLocalTS(context.TODO(), "nonexistent-dc") - c.Assert(p, Equals, int64(0)) - c.Assert(l, Equals, int64(0)) - c.Assert(err, NotNil) - c.Assert(err, ErrorMatches, ".*unknown dc-location.*") + re.Equal(int64(0), p) + re.Equal(int64(0), l, int64(0)) + re.Error(err) + re.Contains(err.Error(), "unknown dc-location") wg := &sync.WaitGroup{} - requestGlobalAndLocalTSO(c, wg, dcLocationConfig, cli) + requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) // assert global tso after resign leader - c.Assert(failpoint.Enable("github.com/tikv/pd/client/skipUpdateMember", `return(true)`), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/skipUpdateMember", `return(true)`)) err = cluster.ResignLeader() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() - _, _, err = cli.GetTS(s.ctx) - c.Assert(err, NotNil) - c.Assert(pd.IsLeaderChange(err), IsTrue) - _, _, err = cli.GetTS(s.ctx) - c.Assert(err, IsNil) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/skipUpdateMember"), IsNil) + _, _, err = cli.GetTS(ctx) + re.Error(err) + re.True(pd.IsLeaderChange(err)) + _, _, err = cli.GetTS(ctx) + re.NoError(err) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/skipUpdateMember")) // Test the TSO follower proxy while enabling the Local TSO. cli.UpdateOption(pd.EnableTSOFollowerProxy, true) // Sleep a while here to prevent from canceling the ongoing TSO request. time.Sleep(time.Millisecond * 50) - requestGlobalAndLocalTSO(c, wg, dcLocationConfig, cli) + requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) cli.UpdateOption(pd.EnableTSOFollowerProxy, false) time.Sleep(time.Millisecond * 50) - requestGlobalAndLocalTSO(c, wg, dcLocationConfig, cli) + requestGlobalAndLocalTSO(re, wg, dcLocationConfig, cli) } -func requestGlobalAndLocalTSO(c *C, wg *sync.WaitGroup, dcLocationConfig map[string]string, cli pd.Client) { +func requestGlobalAndLocalTSO( + re *require.Assertions, + wg *sync.WaitGroup, + dcLocationConfig map[string]string, + cli pd.Client, +) { for _, dcLocation := range dcLocationConfig { wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -390,131 +395,143 @@ func requestGlobalAndLocalTSO(c *C, wg *sync.WaitGroup, dcLocationConfig map[str var lastTS uint64 for i := 0; i < tsoRequestRound; i++ { globalPhysical1, globalLogical1, err := cli.GetTS(context.TODO()) - c.Assert(err, IsNil) + re.NoError(err) globalTS1 := tsoutil.ComposeTS(globalPhysical1, globalLogical1) localPhysical, localLogical, err := cli.GetLocalTS(context.TODO(), dc) - c.Assert(err, IsNil) + re.NoError(err) localTS := tsoutil.ComposeTS(localPhysical, localLogical) globalPhysical2, globalLogical2, err := cli.GetTS(context.TODO()) - c.Assert(err, IsNil) + re.NoError(err) globalTS2 := tsoutil.ComposeTS(globalPhysical2, globalLogical2) - c.Assert(lastTS, Less, globalTS1) - c.Assert(globalTS1, Less, localTS) - c.Assert(localTS, Less, globalTS2) + re.Less(lastTS, globalTS1) + re.Less(globalTS1, localTS) + re.Less(localTS, globalTS2) lastTS = globalTS2 } - c.Assert(lastTS, Greater, uint64(0)) + re.Greater(lastTS, uint64(0)) }(dcLocation) } } wg.Wait() } -func (s *clientTestSuite) TestCustomTimeout(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestCustomTimeout(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithCustomTimeoutOption(1*time.Second)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithCustomTimeoutOption(1*time.Second)) start := time.Now() - c.Assert(failpoint.Enable("github.com/tikv/pd/server/customTimeout", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/server/customTimeout", "return(true)")) _, err = cli.GetAllStores(context.TODO()) - c.Assert(failpoint.Disable("github.com/tikv/pd/server/customTimeout"), IsNil) - c.Assert(err, NotNil) - c.Assert(time.Since(start), GreaterEqual, 1*time.Second) - c.Assert(time.Since(start), Less, 2*time.Second) + re.Nil(failpoint.Disable("github.com/tikv/pd/server/customTimeout")) + re.Error(err) + re.GreaterOrEqual(time.Since(start), 1*time.Second) + re.Less(time.Since(start), 2*time.Second) } -func (s *clientTestSuite) TestGetRegionFromFollowerClient(c *C) { +func TestGetRegionFromFollowerClient(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pd.LeaderHealthCheckInterval = 100 * time.Millisecond - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithForwardingOption(true)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - c.Assert(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)")) time.Sleep(200 * time.Millisecond) r, err := cli.GetRegion(context.Background(), []byte("a")) - c.Assert(err, IsNil) - c.Assert(r, NotNil) + re.NoError(err) + re.NotNil(r) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1"), IsNil) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) time.Sleep(200 * time.Millisecond) r, err = cli.GetRegion(context.Background(), []byte("a")) - c.Assert(err, IsNil) - c.Assert(r, NotNil) + re.NoError(err) + re.NotNil(r) } // case 1: unreachable -> normal -func (s *clientTestSuite) TestGetTsoFromFollowerClient1(c *C) { +func TestGetTsoFromFollowerClient1(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pd.LeaderHealthCheckInterval = 100 * time.Millisecond - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithForwardingOption(true)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - c.Assert(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) return true } - c.Log(err) + t.Log(err) return false }) - lastTS = checkTS(c, cli, lastTS) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork"), IsNil) + lastTS = checkTS(re, cli, lastTS) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) time.Sleep(2 * time.Second) - checkTS(c, cli, lastTS) + checkTS(re, cli, lastTS) } // case 2: unreachable -> leader transfer -> normal -func (s *clientTestSuite) TestGetTsoFromFollowerClient2(c *C) { +func TestGetTsoFromFollowerClient2(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pd.LeaderHealthCheckInterval = 100 * time.Millisecond - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints, pd.WithForwardingOption(true)) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) - c.Assert(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)"), IsNil) + re.Nil(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) var lastTS uint64 - testutil.WaitUntil(c, func() bool { + testutil.WaitUntilWithTestingT(t, func() bool { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { lastTS = tsoutil.ComposeTS(physical, logical) return true } - c.Log(err) + t.Log(err) return false }) - lastTS = checkTS(c, cli, lastTS) - c.Assert(cluster.GetServer(cluster.GetLeader()).ResignLeader(), IsNil) + lastTS = checkTS(re, cli, lastTS) + re.NoError(cluster.GetServer(cluster.GetLeader()).ResignLeader()) cluster.WaitLeader() - lastTS = checkTS(c, cli, lastTS) + lastTS = checkTS(re, cli, lastTS) - c.Assert(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork"), IsNil) + re.Nil(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) time.Sleep(5 * time.Second) - checkTS(c, cli, lastTS) + checkTS(re, cli, lastTS) } -func checkTS(c *C, cli pd.Client, lastTS uint64) uint64 { +func checkTS(re *require.Assertions, cli pd.Client, lastTS uint64) uint64 { for i := 0; i < tsoRequestRound; i++ { physical, logical, err := cli.GetTS(context.TODO()) if err == nil { ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + re.Less(lastTS, ts) lastTS = ts } time.Sleep(time.Millisecond) @@ -522,12 +539,12 @@ func checkTS(c *C, cli pd.Client, lastTS uint64) uint64 { return lastTS } -func (s *clientTestSuite) runServer(c *C, cluster *tests.TestCluster) []string { +func runServer(re *require.Assertions, cluster *tests.TestCluster) []string { err := cluster.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) cluster.WaitLeader() leaderServer := cluster.GetServer(cluster.GetLeader()) - c.Assert(leaderServer.BootstrapCluster(), IsNil) + re.NoError(leaderServer.BootstrapCluster()) testServers := cluster.GetServers() endpoints := make([]string, 0, len(testServers)) @@ -537,32 +554,80 @@ func (s *clientTestSuite) runServer(c *C, cluster *tests.TestCluster) []string { return endpoints } -func setupCli(c *C, ctx context.Context, endpoints []string, opts ...pd.ClientOption) pd.Client { +func setupCli(re *require.Assertions, ctx context.Context, endpoints []string, opts ...pd.ClientOption) pd.Client { cli, err := pd.NewClientWithContext(ctx, endpoints, pd.SecurityOption{}, opts...) - c.Assert(err, IsNil) + re.NoError(err) return cli } -func waitLeader(c *C, cli client, leader string) { - testutil.WaitUntil(c, func() bool { +func waitLeader(t *testing.T, cli client, leader string) { + testutil.WaitUntilWithTestingT(t, func() bool { cli.ScheduleCheckLeader() return cli.GetLeaderAddr() == leader }) } -func (s *clientTestSuite) TestCloseClient(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 1) - c.Assert(err, IsNil) +func TestConfigTTLAfterTransferLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + err = cluster.RunInitialServers() + re.NoError(err) + leader := cluster.GetServer(cluster.WaitLeader()) + re.NoError(leader.BootstrapCluster()) + addr := fmt.Sprintf("%s/pd/api/v1/config?ttlSecond=5", leader.GetAddr()) + postData, err := json.Marshal(map[string]interface{}{ + "schedule.max-snapshot-count": 999, + "schedule.enable-location-replacement": false, + "schedule.max-merge-region-size": 999, + "schedule.max-merge-region-keys": 999, + "schedule.scheduler-max-waiting-operator": 999, + "schedule.leader-schedule-limit": 999, + "schedule.region-schedule-limit": 999, + "schedule.hot-region-schedule-limit": 999, + "schedule.replica-schedule-limit": 999, + "schedule.merge-schedule-limit": 999, + }) + re.NoError(err) + resp, err := leader.GetHTTPClient().Post(addr, "application/json", bytes.NewBuffer(postData)) + resp.Body.Close() + re.NoError(err) + time.Sleep(2 * time.Second) + re.NoError(leader.Destroy()) + time.Sleep(2 * time.Second) + leader = cluster.GetServer(cluster.WaitLeader()) + re.NotNil(leader) + options := leader.GetPersistOptions() + re.NotNil(options) + re.Equal(uint64(999), options.GetMaxSnapshotCount()) + re.False(options.IsLocationReplacementEnabled()) + re.Equal(uint64(999), options.GetMaxMergeRegionSize()) + re.Equal(uint64(999), options.GetMaxMergeRegionKeys()) + re.Equal(uint64(999), options.GetSchedulerMaxWaitingOperator()) + re.Equal(uint64(999), options.GetLeaderScheduleLimit()) + re.Equal(uint64(999), options.GetRegionScheduleLimit()) + re.Equal(uint64(999), options.GetHotRegionScheduleLimit()) + re.Equal(uint64(999), options.GetReplicaScheduleLimit()) + re.Equal(uint64(999), options.GetMergeScheduleLimit()) +} + +func TestCloseClient(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + re.NoError(err) defer cluster.Destroy() - endpoints := s.runServer(c, cluster) - cli := setupCli(c, s.ctx, endpoints) + endpoints := runServer(re, cluster) + cli := setupCli(re, ctx, endpoints) cli.GetTSAsync(context.TODO()) time.Sleep(time.Second) cli.Close() } -var _ = Suite(&testClientSuite{}) - type idAllocator struct { allocator *mockid.IDAllocator } @@ -605,7 +670,8 @@ var ( } ) -type testClientSuite struct { +type clientTestSuite struct { + suite.Suite cleanup server.CleanupFunc ctx context.Context clean context.CancelFunc @@ -617,38 +683,34 @@ type testClientSuite struct { reportBucket pdpb.PD_ReportBucketsClient } -func checkerWithNilAssert(c *C) *assertutil.Checker { - checker := assertutil.NewChecker(c.FailNow) - checker.IsNil = func(obtained interface{}) { - c.Assert(obtained, IsNil) - } - return checker +func TestClientTestSuite(t *testing.T) { + suite.Run(t, new(clientTestSuite)) } -func (s *testClientSuite) SetUpSuite(c *C) { +func (suite *clientTestSuite) SetupSuite() { var err error - s.srv, s.cleanup, err = server.NewTestServer(checkerWithNilAssert(c)) - c.Assert(err, IsNil) - s.grpcPDClient = testutil.MustNewGrpcClient(c, s.srv.GetAddr()) - s.grpcSvr = &server.GrpcServer{Server: s.srv} - - mustWaitLeader(c, map[string]*server.Server{s.srv.GetAddr(): s.srv}) - bootstrapServer(c, newHeader(s.srv), s.grpcPDClient) - - s.ctx, s.clean = context.WithCancel(context.Background()) - s.client = setupCli(c, s.ctx, s.srv.GetEndpoints()) - - c.Assert(err, IsNil) - s.regionHeartbeat, err = s.grpcPDClient.RegionHeartbeat(s.ctx) - c.Assert(err, IsNil) - s.reportBucket, err = s.grpcPDClient.ReportBuckets(s.ctx) - c.Assert(err, IsNil) - cluster := s.srv.GetRaftCluster() - c.Assert(cluster, NotNil) + re := suite.Require() + suite.srv, suite.cleanup, err = server.NewTestServer(suite.checkerWithNilAssert()) + suite.NoError(err) + suite.grpcPDClient = testutil.MustNewGrpcClientWithTestify(re, suite.srv.GetAddr()) + suite.grpcSvr = &server.GrpcServer{Server: suite.srv} + + suite.mustWaitLeader(map[string]*server.Server{suite.srv.GetAddr(): suite.srv}) + suite.bootstrapServer(newHeader(suite.srv), suite.grpcPDClient) + + suite.ctx, suite.clean = context.WithCancel(context.Background()) + suite.client = setupCli(re, suite.ctx, suite.srv.GetEndpoints()) + + suite.regionHeartbeat, err = suite.grpcPDClient.RegionHeartbeat(suite.ctx) + suite.NoError(err) + suite.reportBucket, err = suite.grpcPDClient.ReportBuckets(suite.ctx) + suite.NoError(err) + cluster := suite.srv.GetRaftCluster() + suite.NotNil(cluster) now := time.Now().UnixNano() for _, store := range stores { - s.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ - Header: newHeader(s.srv), + suite.grpcSvr.PutStore(context.Background(), &pdpb.PutStoreRequest{ + Header: newHeader(suite.srv), Store: &metapb.Store{ Id: store.Id, Address: store.Address, @@ -660,13 +722,23 @@ func (s *testClientSuite) SetUpSuite(c *C) { config.EnableRegionBucket = true } -func (s *testClientSuite) TearDownSuite(c *C) { - s.client.Close() - s.clean() - s.cleanup() +func (suite *clientTestSuite) TearDownSuite() { + suite.client.Close() + suite.clean() + suite.cleanup() } -func mustWaitLeader(c *C, svrs map[string]*server.Server) *server.Server { +func (suite *clientTestSuite) checkerWithNilAssert() *assertutil.Checker { + checker := assertutil.NewChecker(func() { + suite.FailNow("should be nil") + }) + checker.IsNil = func(obtained interface{}) { + suite.Nil(obtained) + } + return checker +} + +func (suite *clientTestSuite) mustWaitLeader(svrs map[string]*server.Server) *server.Server { for i := 0; i < 500; i++ { for _, s := range svrs { if !s.IsClosed() && s.GetMember().IsLeader() { @@ -675,7 +747,7 @@ func mustWaitLeader(c *C, svrs map[string]*server.Server) *server.Server { } time.Sleep(100 * time.Millisecond) } - c.Fatal("no leader") + suite.FailNow("no leader") return nil } @@ -685,7 +757,7 @@ func newHeader(srv *server.Server) *pdpb.RequestHeader { } } -func bootstrapServer(c *C, header *pdpb.RequestHeader, client pdpb.PDClient) { +func (suite *clientTestSuite) bootstrapServer(header *pdpb.RequestHeader, client pdpb.PDClient) { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -701,10 +773,10 @@ func bootstrapServer(c *C, header *pdpb.RequestHeader, client pdpb.PDClient) { Region: region, } _, err := client.Bootstrap(context.Background(), req) - c.Assert(err, IsNil) + suite.NoError(err) } -func (s *testClientSuite) TestNormalTSO(c *C) { +func (suite *clientTestSuite) TestNormalTSO() { var wg sync.WaitGroup wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -712,10 +784,10 @@ func (s *testClientSuite) TestNormalTSO(c *C) { defer wg.Done() var lastTS uint64 for i := 0; i < tsoRequestRound; i++ { - physical, logical, err := s.client.GetTS(context.Background()) - c.Assert(err, IsNil) + physical, logical, err := suite.client.GetTS(context.Background()) + suite.NoError(err) ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Less, ts) + suite.Less(lastTS, ts) lastTS = ts } }() @@ -723,7 +795,7 @@ func (s *testClientSuite) TestNormalTSO(c *C) { wg.Wait() } -func (s *testClientSuite) TestGetTSAsync(c *C) { +func (suite *clientTestSuite) TestGetTSAsync() { var wg sync.WaitGroup wg.Add(tsoRequestConcurrencyNumber) for i := 0; i < tsoRequestConcurrencyNumber; i++ { @@ -731,14 +803,14 @@ func (s *testClientSuite) TestGetTSAsync(c *C) { defer wg.Done() tsFutures := make([]pd.TSFuture, tsoRequestRound) for i := range tsFutures { - tsFutures[i] = s.client.GetTSAsync(context.Background()) + tsFutures[i] = suite.client.GetTSAsync(context.Background()) } var lastTS uint64 = math.MaxUint64 for i := len(tsFutures) - 1; i >= 0; i-- { physical, logical, err := tsFutures[i].Wait() - c.Assert(err, IsNil) + suite.NoError(err) ts := tsoutil.ComposeTS(physical, logical) - c.Assert(lastTS, Greater, ts) + suite.Greater(lastTS, ts) lastTS = ts } }() @@ -746,7 +818,7 @@ func (s *testClientSuite) TestGetTSAsync(c *C) { wg.Wait() } -func (s *testClientSuite) TestGetRegion(c *C) { +func (suite *clientTestSuite) TestGetRegion() { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -757,24 +829,25 @@ func (s *testClientSuite) TestGetRegion(c *C) { Peers: peers, } req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: region, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegion(context.Background(), []byte("a")) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) + t := suite.T() + testutil.WaitUntilWithTestingT(t, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a")) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Meta, DeepEquals, region) && - c.Check(r.Leader, DeepEquals, peers[0]) && - c.Check(r.Buckets, IsNil) + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) && + r.Buckets == nil }) breq := &pdpb.ReportBucketsRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Buckets: &metapb.Buckets{ RegionId: regionID, Version: 1, @@ -790,30 +863,29 @@ func (s *testClientSuite) TestGetRegion(c *C) { }, }, } - c.Assert(s.reportBucket.Send(breq), IsNil) - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) - c.Assert(err, IsNil) + suite.NoError(suite.reportBucket.Send(breq)) + testutil.WaitUntilWithTestingT(t, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Buckets, NotNil) + return r.Buckets != nil }) - config := s.srv.GetRaftCluster().GetStoreConfig() + config := suite.srv.GetRaftCluster().GetStoreConfig() config.EnableRegionBucket = false - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) - c.Assert(err, IsNil) + testutil.WaitUntilWithTestingT(t, func() bool { + r, err := suite.client.GetRegion(context.Background(), []byte("a"), pd.WithBuckets()) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Buckets, IsNil) + return r.Buckets == nil }) config.EnableRegionBucket = true - c.Succeed() } -func (s *testClientSuite) TestGetPrevRegion(c *C) { +func (suite *clientTestSuite) TestGetPrevRegion() { regionLen := 10 regions := make([]*metapb.Region, 0, regionLen) for i := 0; i < regionLen; i++ { @@ -830,29 +902,28 @@ func (s *testClientSuite) TestGetPrevRegion(c *C) { } regions = append(regions, r) req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: r, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) } time.Sleep(500 * time.Millisecond) for i := 0; i < 20; i++ { - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetPrevRegion(context.Background(), []byte{byte(i)}) - c.Assert(err, IsNil) + testutil.WaitUntilWithTestingT(suite.T(), func() bool { + r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)}) + suite.NoError(err) if i > 0 && i < regionLen { - return c.Check(r.Leader, DeepEquals, peers[0]) && - c.Check(r.Meta, DeepEquals, regions[i-1]) + return reflect.DeepEqual(peers[0], r.Leader) && + reflect.DeepEqual(regions[i-1], r.Meta) } - return c.Check(r, IsNil) + return r == nil }) } - c.Succeed() } -func (s *testClientSuite) TestScanRegions(c *C) { +func (suite *clientTestSuite) TestScanRegions() { regionLen := 10 regions := make([]*metapb.Region, 0, regionLen) for i := 0; i < regionLen; i++ { @@ -869,53 +940,54 @@ func (s *testClientSuite) TestScanRegions(c *C) { } regions = append(regions, r) req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: r, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) } // Wait for region heartbeats. - testutil.WaitUntil(c, func() bool { - scanRegions, err := s.client.ScanRegions(context.Background(), []byte{0}, nil, 10) + t := suite.T() + testutil.WaitUntilWithTestingT(t, func() bool { + scanRegions, err := suite.client.ScanRegions(context.Background(), []byte{0}, nil, 10) return err == nil && len(scanRegions) == 10 }) // Set leader of region3 to nil. region3 := core.NewRegionInfo(regions[3], nil) - s.srv.GetRaftCluster().HandleRegionHeartbeat(region3) + suite.srv.GetRaftCluster().HandleRegionHeartbeat(region3) // Add down peer for region4. region4 := core.NewRegionInfo(regions[4], regions[4].Peers[0], core.WithDownPeers([]*pdpb.PeerStats{{Peer: regions[4].Peers[1]}})) - s.srv.GetRaftCluster().HandleRegionHeartbeat(region4) + suite.srv.GetRaftCluster().HandleRegionHeartbeat(region4) // Add pending peers for region5. region5 := core.NewRegionInfo(regions[5], regions[5].Peers[0], core.WithPendingPeers([]*metapb.Peer{regions[5].Peers[1], regions[5].Peers[2]})) - s.srv.GetRaftCluster().HandleRegionHeartbeat(region5) + suite.srv.GetRaftCluster().HandleRegionHeartbeat(region5) check := func(start, end []byte, limit int, expect []*metapb.Region) { - scanRegions, err := s.client.ScanRegions(context.Background(), start, end, limit) - c.Assert(err, IsNil) - c.Assert(scanRegions, HasLen, len(expect)) - c.Log("scanRegions", scanRegions) - c.Log("expect", expect) + scanRegions, err := suite.client.ScanRegions(context.Background(), start, end, limit) + suite.NoError(err) + suite.Len(scanRegions, len(expect)) + t.Log("scanRegions", scanRegions) + t.Log("expect", expect) for i := range expect { - c.Assert(scanRegions[i].Meta, DeepEquals, expect[i]) + suite.True(reflect.DeepEqual(expect[i], scanRegions[i].Meta)) if scanRegions[i].Meta.GetId() == region3.GetID() { - c.Assert(scanRegions[i].Leader, DeepEquals, &metapb.Peer{}) + suite.True(reflect.DeepEqual(&metapb.Peer{}, scanRegions[i].Leader)) } else { - c.Assert(scanRegions[i].Leader, DeepEquals, expect[i].Peers[0]) + suite.True(reflect.DeepEqual(expect[i].Peers[0], scanRegions[i].Leader)) } if scanRegions[i].Meta.GetId() == region4.GetID() { - c.Assert(scanRegions[i].DownPeers, DeepEquals, []*metapb.Peer{expect[i].Peers[1]}) + suite.True(reflect.DeepEqual([]*metapb.Peer{expect[i].Peers[1]}, scanRegions[i].DownPeers)) } if scanRegions[i].Meta.GetId() == region5.GetID() { - c.Assert(scanRegions[i].PendingPeers, DeepEquals, []*metapb.Peer{expect[i].Peers[1], expect[i].Peers[2]}) + suite.True(reflect.DeepEqual([]*metapb.Peer{expect[i].Peers[1], expect[i].Peers[2]}, scanRegions[i].PendingPeers)) } } } @@ -927,7 +999,7 @@ func (s *testClientSuite) TestScanRegions(c *C) { check([]byte{1}, []byte{6}, 2, regions[1:3]) } -func (s *testClientSuite) TestGetRegionByID(c *C) { +func (suite *clientTestSuite) TestGetRegionByID() { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -938,125 +1010,125 @@ func (s *testClientSuite) TestGetRegionByID(c *C) { Peers: peers, } req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: region, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) - c.Assert(err, IsNil) + err := suite.regionHeartbeat.Send(req) + suite.NoError(err) - testutil.WaitUntil(c, func() bool { - r, err := s.client.GetRegionByID(context.Background(), regionID) - c.Assert(err, IsNil) + testutil.WaitUntilWithTestingT(suite.T(), func() bool { + r, err := suite.client.GetRegionByID(context.Background(), regionID) + suite.NoError(err) if r == nil { return false } - return c.Check(r.Meta, DeepEquals, region) && - c.Check(r.Leader, DeepEquals, peers[0]) + return reflect.DeepEqual(region, r.Meta) && + reflect.DeepEqual(peers[0], r.Leader) }) - c.Succeed() } -func (s *testClientSuite) TestGetStore(c *C) { - cluster := s.srv.GetRaftCluster() - c.Assert(cluster, NotNil) +func (suite *clientTestSuite) TestGetStore() { + cluster := suite.srv.GetRaftCluster() + suite.NotNil(cluster) store := stores[0] // Get an up store should be OK. - n, err := s.client.GetStore(context.Background(), store.GetId()) - c.Assert(err, IsNil) - c.Assert(n, DeepEquals, store) + n, err := suite.client.GetStore(context.Background(), store.GetId()) + suite.NoError(err) + suite.True(reflect.DeepEqual(store, n)) - stores, err := s.client.GetAllStores(context.Background()) - c.Assert(err, IsNil) - c.Assert(stores, DeepEquals, stores) + actualStores, err := suite.client.GetAllStores(context.Background()) + suite.NoError(err) + suite.Len(actualStores, len(stores)) + stores = actualStores // Mark the store as offline. err = cluster.RemoveStore(store.GetId(), false) - c.Assert(err, IsNil) + suite.NoError(err) offlineStore := proto.Clone(store).(*metapb.Store) offlineStore.State = metapb.StoreState_Offline offlineStore.NodeState = metapb.NodeState_Removing // Get an offline store should be OK. - n, err = s.client.GetStore(context.Background(), store.GetId()) - c.Assert(err, IsNil) - c.Assert(n, DeepEquals, offlineStore) + n, err = suite.client.GetStore(context.Background(), store.GetId()) + suite.NoError(err) + suite.True(reflect.DeepEqual(offlineStore, n)) // Should return offline stores. contains := false - stores, err = s.client.GetAllStores(context.Background()) - c.Assert(err, IsNil) + stores, err = suite.client.GetAllStores(context.Background()) + suite.NoError(err) for _, store := range stores { if store.GetId() == offlineStore.GetId() { contains = true - c.Assert(store, DeepEquals, offlineStore) + suite.True(reflect.DeepEqual(offlineStore, store)) } } - c.Assert(contains, IsTrue) + suite.True(contains) // Mark the store as physically destroyed and offline. err = cluster.RemoveStore(store.GetId(), true) - c.Assert(err, IsNil) + suite.NoError(err) physicallyDestroyedStoreID := store.GetId() // Get a physically destroyed and offline store // It should be Tombstone(become Tombstone automically) or Offline - n, err = s.client.GetStore(context.Background(), physicallyDestroyedStoreID) - c.Assert(err, IsNil) + n, err = suite.client.GetStore(context.Background(), physicallyDestroyedStoreID) + suite.NoError(err) if n != nil { // store is still offline and physically destroyed - c.Assert(n.GetNodeState(), Equals, metapb.NodeState_Removing) - c.Assert(n.PhysicallyDestroyed, IsTrue) + suite.Equal(metapb.NodeState_Removing, n.GetNodeState()) + suite.True(n.PhysicallyDestroyed) } // Should return tombstone stores. contains = false - stores, err = s.client.GetAllStores(context.Background()) - c.Assert(err, IsNil) + stores, err = suite.client.GetAllStores(context.Background()) + suite.NoError(err) for _, store := range stores { if store.GetId() == physicallyDestroyedStoreID { contains = true - c.Assert(store.GetState(), Not(Equals), metapb.StoreState_Up) - c.Assert(store.PhysicallyDestroyed, IsTrue) + suite.NotEqual(metapb.StoreState_Up, store.GetState()) + suite.True(store.PhysicallyDestroyed) } } - c.Assert(contains, IsTrue) + suite.True(contains) // Should not return tombstone stores. - stores, err = s.client.GetAllStores(context.Background(), pd.WithExcludeTombstone()) - c.Assert(err, IsNil) + stores, err = suite.client.GetAllStores(context.Background(), pd.WithExcludeTombstone()) + suite.NoError(err) for _, store := range stores { if store.GetId() == physicallyDestroyedStoreID { - c.Assert(store.GetState(), Equals, metapb.StoreState_Offline) - c.Assert(store.PhysicallyDestroyed, IsTrue) + suite.Equal(metapb.StoreState_Offline, store.GetState()) + suite.True(store.PhysicallyDestroyed) } } } -func (s *testClientSuite) checkGCSafePoint(c *C, expectedSafePoint uint64) { +func (suite *clientTestSuite) checkGCSafePoint(expectedSafePoint uint64) { req := &pdpb.GetGCSafePointRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), } - resp, err := s.grpcSvr.GetGCSafePoint(context.Background(), req) - c.Assert(err, IsNil) - c.Assert(resp.SafePoint, Equals, expectedSafePoint) + resp, err := suite.grpcSvr.GetGCSafePoint(context.Background(), req) + suite.NoError(err) + suite.Equal(expectedSafePoint, resp.SafePoint) } -func (s *testClientSuite) TestUpdateGCSafePoint(c *C) { - s.checkGCSafePoint(c, 0) +func (suite *clientTestSuite) TestUpdateGCSafePoint() { + suite.checkGCSafePoint(0) for _, safePoint := range []uint64{0, 1, 2, 3, 233, 23333, 233333333333, math.MaxUint64} { - newSafePoint, err := s.client.UpdateGCSafePoint(context.Background(), safePoint) - c.Assert(err, IsNil) - c.Assert(newSafePoint, Equals, safePoint) - s.checkGCSafePoint(c, safePoint) + newSafePoint, err := suite.client.UpdateGCSafePoint(context.Background(), safePoint) + suite.NoError(err) + suite.Equal(safePoint, newSafePoint) + suite.checkGCSafePoint(safePoint) } // If the new safe point is less than the old one, it should not be updated. - newSafePoint, err := s.client.UpdateGCSafePoint(context.Background(), 1) - c.Assert(newSafePoint, Equals, uint64(math.MaxUint64)) - c.Assert(err, IsNil) - s.checkGCSafePoint(c, math.MaxUint64) + newSafePoint, err := suite.client.UpdateGCSafePoint(context.Background(), 1) + suite.Equal(uint64(math.MaxUint64), newSafePoint) + suite.NoError(err) + suite.checkGCSafePoint(math.MaxUint64) } -func (s *testClientSuite) TestUpdateServiceGCSafePoint(c *C) { +func (suite *clientTestSuite) TestUpdateServiceGCSafePoint() { serviceSafePoints := []struct { ServiceID string TTL int64 @@ -1067,105 +1139,105 @@ func (s *testClientSuite) TestUpdateServiceGCSafePoint(c *C) { {"c", 1000, 3}, } for _, ssp := range serviceSafePoints { - min, err := s.client.UpdateServiceGCSafePoint(context.Background(), + min, err := suite.client.UpdateServiceGCSafePoint(context.Background(), ssp.ServiceID, 1000, ssp.SafePoint) - c.Assert(err, IsNil) + suite.NoError(err) // An service safepoint of ID "gc_worker" is automatically initialized as 0 - c.Assert(min, Equals, uint64(0)) + suite.Equal(uint64(0), min) } - min, err := s.client.UpdateServiceGCSafePoint(context.Background(), + min, err := suite.client.UpdateServiceGCSafePoint(context.Background(), "gc_worker", math.MaxInt64, 10) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(1)) + suite.NoError(err) + suite.Equal(uint64(1), min) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", 1000, 4) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(2)) + suite.NoError(err) + suite.Equal(uint64(2), min) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "b", -100, 2) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) + suite.NoError(err) + suite.Equal(uint64(3), min) // Minimum safepoint does not regress - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "b", 1000, 2) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) + suite.NoError(err) + suite.Equal(uint64(3), min) // Update only the TTL of the minimum safepoint - oldMinSsp, err := s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(oldMinSsp.ServiceID, Equals, "c") - c.Assert(oldMinSsp.SafePoint, Equals, uint64(3)) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + oldMinSsp, err := suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", oldMinSsp.ServiceID) + suite.Equal(uint64(3), oldMinSsp.SafePoint) + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", 2000, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) - minSsp, err := s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "c") - c.Assert(oldMinSsp.SafePoint, Equals, uint64(3)) - c.Assert(minSsp.ExpiredAt-oldMinSsp.ExpiredAt, GreaterEqual, int64(1000)) + suite.NoError(err) + suite.Equal(uint64(3), min) + minSsp, err := suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", minSsp.ServiceID) + suite.Equal(uint64(3), oldMinSsp.SafePoint) + suite.GreaterOrEqual(minSsp.ExpiredAt-oldMinSsp.ExpiredAt, int64(1000)) // Shrinking TTL is also allowed - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", 1, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "c") - c.Assert(minSsp.ExpiredAt, Less, oldMinSsp.ExpiredAt) + suite.NoError(err) + suite.Equal(uint64(3), min) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", minSsp.ServiceID) + suite.Less(minSsp.ExpiredAt, oldMinSsp.ExpiredAt) // TTL can be infinite (represented by math.MaxInt64) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", math.MaxInt64, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(3)) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "c") - c.Assert(minSsp.ExpiredAt, Equals, int64(math.MaxInt64)) + suite.NoError(err) + suite.Equal(uint64(3), min) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("c", minSsp.ServiceID) + suite.Equal(minSsp.ExpiredAt, int64(math.MaxInt64)) // Delete "a" and "c" - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", -1, 3) - c.Assert(err, IsNil) - c.Assert(min, Equals, uint64(4)) - min, err = s.client.UpdateServiceGCSafePoint(context.Background(), + suite.NoError(err) + suite.Equal(uint64(4), min) + min, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", -1, 4) - c.Assert(err, IsNil) + suite.NoError(err) // Now gc_worker is the only remaining service safe point. - c.Assert(min, Equals, uint64(10)) + suite.Equal(uint64(10), min) // gc_worker cannot be deleted. - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "gc_worker", -1, 10) - c.Assert(err, NotNil) + suite.Error(err) // Cannot set non-infinity TTL for gc_worker - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "gc_worker", 10000000, 10) - c.Assert(err, NotNil) + suite.Error(err) // Service safepoint must have a non-empty ID - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "", 1000, 15) - c.Assert(err, NotNil) + suite.Error(err) // Put some other safepoints to test fixing gc_worker's safepoint when there exists other safepoints. - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", 1000, 11) - c.Assert(err, IsNil) - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + suite.NoError(err) + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "b", 1000, 12) - c.Assert(err, IsNil) - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + suite.NoError(err) + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "c", 1000, 13) - c.Assert(err, IsNil) + suite.NoError(err) // Force set invalid ttl to gc_worker gcWorkerKey := path.Join("gc", "safe_point", "service", "gc_worker") @@ -1176,38 +1248,38 @@ func (s *testClientSuite) TestUpdateServiceGCSafePoint(c *C) { SafePoint: 10, } value, err := json.Marshal(gcWorkerSsp) - c.Assert(err, IsNil) - err = s.srv.GetStorage().Save(gcWorkerKey, string(value)) - c.Assert(err, IsNil) + suite.NoError(err) + err = suite.srv.GetStorage().Save(gcWorkerKey, string(value)) + suite.NoError(err) } - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "gc_worker") - c.Assert(minSsp.SafePoint, Equals, uint64(10)) - c.Assert(minSsp.ExpiredAt, Equals, int64(math.MaxInt64)) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("gc_worker", minSsp.ServiceID) + suite.Equal(uint64(10), minSsp.SafePoint) + suite.Equal(int64(math.MaxInt64), minSsp.ExpiredAt) // Force delete gc_worker, then the min service safepoint is 11 of "a". - err = s.srv.GetStorage().Remove(gcWorkerKey) - c.Assert(err, IsNil) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.SafePoint, Equals, uint64(11)) + err = suite.srv.GetStorage().Remove(gcWorkerKey) + suite.NoError(err) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal(uint64(11), minSsp.SafePoint) // After calling LoadMinServiceGCS when "gc_worker"'s service safepoint is missing, "gc_worker"'s service safepoint // will be newly created. // Increase "a" so that "gc_worker" is the only minimum that will be returned by LoadMinServiceGCSafePoint. - _, err = s.client.UpdateServiceGCSafePoint(context.Background(), + _, err = suite.client.UpdateServiceGCSafePoint(context.Background(), "a", 1000, 14) - c.Assert(err, IsNil) + suite.NoError(err) - minSsp, err = s.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) - c.Assert(err, IsNil) - c.Assert(minSsp.ServiceID, Equals, "gc_worker") - c.Assert(minSsp.SafePoint, Equals, uint64(11)) - c.Assert(minSsp.ExpiredAt, Equals, int64(math.MaxInt64)) + minSsp, err = suite.srv.GetStorage().LoadMinServiceGCSafePoint(time.Now()) + suite.NoError(err) + suite.Equal("gc_worker", minSsp.ServiceID) + suite.Equal(uint64(11), minSsp.SafePoint) + suite.Equal(int64(math.MaxInt64), minSsp.ExpiredAt) } -func (s *testClientSuite) TestScatterRegion(c *C) { +func (suite *clientTestSuite) TestScatterRegion() { regionID := regionIDAllocator.alloc() region := &metapb.Region{ Id: regionID, @@ -1220,106 +1292,46 @@ func (s *testClientSuite) TestScatterRegion(c *C) { EndKey: []byte("ggg"), } req := &pdpb.RegionHeartbeatRequest{ - Header: newHeader(s.srv), + Header: newHeader(suite.srv), Region: region, Leader: peers[0], } - err := s.regionHeartbeat.Send(req) + err := suite.regionHeartbeat.Send(req) regionsID := []uint64{regionID} - c.Assert(err, IsNil) + suite.NoError(err) // Test interface `ScatterRegions`. - testutil.WaitUntil(c, func() bool { - scatterResp, err := s.client.ScatterRegions(context.Background(), regionsID, pd.WithGroup("test"), pd.WithRetry(1)) - if c.Check(err, NotNil) { + t := suite.T() + testutil.WaitUntilWithTestingT(t, func() bool { + scatterResp, err := suite.client.ScatterRegions(context.Background(), regionsID, pd.WithGroup("test"), pd.WithRetry(1)) + if err != nil { return false } - if c.Check(scatterResp.FinishedPercentage, Not(Equals), uint64(100)) { + if scatterResp.FinishedPercentage != uint64(100) { return false } - resp, err := s.client.GetOperator(context.Background(), regionID) - if c.Check(err, NotNil) { + resp, err := suite.client.GetOperator(context.Background(), regionID) + if err != nil { return false } - return c.Check(resp.GetRegionId(), Equals, regionID) && c.Check(string(resp.GetDesc()), Equals, "scatter-region") && c.Check(resp.GetStatus(), Equals, pdpb.OperatorStatus_RUNNING) + return resp.GetRegionId() == regionID && + string(resp.GetDesc()) == "scatter-region" && + resp.GetStatus() == pdpb.OperatorStatus_RUNNING }, testutil.WithSleepInterval(1*time.Second)) // Test interface `ScatterRegion`. // TODO: Deprecate interface `ScatterRegion`. - testutil.WaitUntil(c, func() bool { - err := s.client.ScatterRegion(context.Background(), regionID) - if c.Check(err, NotNil) { + testutil.WaitUntilWithTestingT(t, func() bool { + err := suite.client.ScatterRegion(context.Background(), regionID) + if err != nil { fmt.Println(err) return false } - resp, err := s.client.GetOperator(context.Background(), regionID) - if c.Check(err, NotNil) { + resp, err := suite.client.GetOperator(context.Background(), regionID) + if err != nil { return false } - return c.Check(resp.GetRegionId(), Equals, regionID) && c.Check(string(resp.GetDesc()), Equals, "scatter-region") && c.Check(resp.GetStatus(), Equals, pdpb.OperatorStatus_RUNNING) + return resp.GetRegionId() == regionID && + string(resp.GetDesc()) == "scatter-region" && + resp.GetStatus() == pdpb.OperatorStatus_RUNNING }, testutil.WithSleepInterval(1*time.Second)) - - c.Succeed() -} - -type testConfigTTLSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *testConfigTTLSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true -} - -func (s *testConfigTTLSuite) TearDownSuite(c *C) { - s.cancel() -} - -var _ = SerialSuites(&testConfigTTLSuite{}) - -var ttlConfig = map[string]interface{}{ - "schedule.max-snapshot-count": 999, - "schedule.enable-location-replacement": false, - "schedule.max-merge-region-size": 999, - "schedule.max-merge-region-keys": 999, - "schedule.scheduler-max-waiting-operator": 999, - "schedule.leader-schedule-limit": 999, - "schedule.region-schedule-limit": 999, - "schedule.hot-region-schedule-limit": 999, - "schedule.replica-schedule-limit": 999, - "schedule.merge-schedule-limit": 999, -} - -func assertTTLConfig(c *C, options *config.PersistOptions, checker Checker) { - c.Assert(options.GetMaxSnapshotCount(), checker, uint64(999)) - c.Assert(options.IsLocationReplacementEnabled(), checker, false) - c.Assert(options.GetMaxMergeRegionSize(), checker, uint64(999)) - c.Assert(options.GetMaxMergeRegionKeys(), checker, uint64(999)) - c.Assert(options.GetSchedulerMaxWaitingOperator(), checker, uint64(999)) - c.Assert(options.GetLeaderScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetRegionScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetHotRegionScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetReplicaScheduleLimit(), checker, uint64(999)) - c.Assert(options.GetMergeScheduleLimit(), checker, uint64(999)) -} - -func (s *testConfigTTLSuite) TestConfigTTLAfterTransferLeader(c *C) { - cluster, err := tests.NewTestCluster(s.ctx, 3) - c.Assert(err, IsNil) - defer cluster.Destroy() - err = cluster.RunInitialServers() - c.Assert(err, IsNil) - leader := cluster.GetServer(cluster.WaitLeader()) - c.Assert(leader.BootstrapCluster(), IsNil) - addr := fmt.Sprintf("%s/pd/api/v1/config?ttlSecond=5", leader.GetAddr()) - postData, err := json.Marshal(ttlConfig) - c.Assert(err, IsNil) - resp, err := leader.GetHTTPClient().Post(addr, "application/json", bytes.NewBuffer(postData)) - resp.Body.Close() - c.Assert(err, IsNil) - time.Sleep(2 * time.Second) - _ = leader.Destroy() - time.Sleep(2 * time.Second) - leader = cluster.GetServer(cluster.WaitLeader()) - assertTTLConfig(c, leader.GetPersistOptions(), Equals) } diff --git a/tests/client/client_tls_test.go b/tests/client/client_tls_test.go index 3fbca17c835..48a6fec3d2d 100644 --- a/tests/client/client_tls_test.go +++ b/tests/client/client_tls_test.go @@ -22,21 +22,19 @@ import ( "os" "path/filepath" "strings" + "testing" "time" - . "github.com/pingcap/check" + "github.com/stretchr/testify/require" pd "github.com/tikv/pd/client" "github.com/tikv/pd/pkg/grpcutil" "github.com/tikv/pd/pkg/netutil" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" "go.etcd.io/etcd/pkg/transport" "google.golang.org/grpc" ) -var _ = Suite(&clientTLSTestSuite{}) - var ( testTLSInfo = transport.TLSInfo{ KeyFile: "./cert/pd-server-key.pem", @@ -57,49 +55,38 @@ var ( } ) -type clientTLSTestSuite struct { - ctx context.Context - cancel context.CancelFunc -} - -func (s *clientTLSTestSuite) SetUpSuite(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) - server.EnableZap = true -} - -func (s *clientTLSTestSuite) TearDownSuite(c *C) { - s.cancel() -} - // TestTLSReloadAtomicReplace ensures server reloads expired/valid certs // when all certs are atomically replaced by directory renaming. // And expects server to reject client requests, and vice versa. -func (s *clientTLSTestSuite) TestTLSReloadAtomicReplace(c *C) { +func TestTLSReloadAtomicReplace(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() tmpDir, err := os.MkdirTemp(os.TempDir(), "cert-tmp") - c.Assert(err, IsNil) + re.NoError(err) os.RemoveAll(tmpDir) defer os.RemoveAll(tmpDir) certsDir, err := os.MkdirTemp(os.TempDir(), "cert-to-load") - c.Assert(err, IsNil) + re.NoError(err) defer os.RemoveAll(certsDir) certsDirExp, err := os.MkdirTemp(os.TempDir(), "cert-expired") - c.Assert(err, IsNil) + re.NoError(err) defer os.RemoveAll(certsDirExp) cloneFunc := func() transport.TLSInfo { tlsInfo, terr := copyTLSFiles(testTLSInfo, certsDir) - c.Assert(terr, IsNil) + re.NoError(terr) _, err = copyTLSFiles(testTLSInfoExpired, certsDirExp) - c.Assert(err, IsNil) + re.NoError(err) return tlsInfo } replaceFunc := func() { err = os.Rename(certsDir, tmpDir) - c.Assert(err, IsNil) + re.NoError(err) err = os.Rename(certsDirExp, certsDir) - c.Assert(err, IsNil) + re.NoError(err) // after rename, // 'certsDir' contains expired certs // 'tmpDir' contains valid certs @@ -107,25 +94,26 @@ func (s *clientTLSTestSuite) TestTLSReloadAtomicReplace(c *C) { } revertFunc := func() { err = os.Rename(tmpDir, certsDirExp) - c.Assert(err, IsNil) + re.NoError(err) err = os.Rename(certsDir, tmpDir) - c.Assert(err, IsNil) + re.NoError(err) err = os.Rename(certsDirExp, certsDir) - c.Assert(err, IsNil) + re.NoError(err) } - s.testTLSReload(c, cloneFunc, replaceFunc, revertFunc) + testTLSReload(re, ctx, cloneFunc, replaceFunc, revertFunc) } -func (s *clientTLSTestSuite) testTLSReload( - c *C, +func testTLSReload( + re *require.Assertions, + ctx context.Context, cloneFunc func() transport.TLSInfo, replaceFunc func(), revertFunc func()) { tlsInfo := cloneFunc() // 1. start cluster with valid certs - clus, err := tests.NewTestCluster(s.ctx, 1, func(conf *config.Config, serverName string) { + clus, err := tests.NewTestCluster(ctx, 1, func(conf *config.Config, serverName string) { conf.Security.TLSConfig = grpcutil.TLSConfig{ KeyPath: tlsInfo.KeyFile, CertPath: tlsInfo.CertFile, @@ -137,10 +125,10 @@ func (s *clientTLSTestSuite) testTLSReload( conf.PeerUrls = strings.ReplaceAll(conf.PeerUrls, "http", "https") conf.InitialCluster = strings.ReplaceAll(conf.InitialCluster, "http", "https") }) - c.Assert(err, IsNil) + re.NoError(err) defer clus.Destroy() err = clus.RunInitialServers() - c.Assert(err, IsNil) + re.NoError(err) clus.WaitLeader() testServers := clus.GetServers() @@ -148,20 +136,20 @@ func (s *clientTLSTestSuite) testTLSReload( for _, s := range testServers { endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls) tlsConfig, err := s.GetConfig().Security.ToTLSConfig() - c.Assert(err, IsNil) + re.NoError(err) httpClient := &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, TLSClientConfig: tlsConfig, }, } - c.Assert(netutil.IsEnableHTTPS(httpClient), IsTrue) + re.True(netutil.IsEnableHTTPS(httpClient)) } // 2. concurrent client dialing while certs become expired errc := make(chan error, 1) go func() { for { - dctx, dcancel := context.WithTimeout(s.ctx, time.Second) + dctx, dcancel := context.WithTimeout(ctx, time.Second) cli, err := pd.NewClientWithContext(dctx, endpoints, pd.SecurityOption{ CAPath: testClientTLSInfo.TrustedCAFile, CertPath: testClientTLSInfo.CertFile, @@ -183,46 +171,46 @@ func (s *clientTLSTestSuite) testTLSReload( // 4. expect dial time-out when loading expired certs select { case cerr := <-errc: - c.Assert(strings.Contains(cerr.Error(), "failed to get cluster id"), IsTrue) + re.Contains(cerr.Error(), "failed to get cluster id") case <-time.After(5 * time.Second): - c.Fatal("failed to receive dial timeout error") + re.FailNow("failed to receive dial timeout error") } // 5. replace expired certs back with valid ones revertFunc() // 6. new requests should trigger listener to reload valid certs - dctx, dcancel := context.WithTimeout(s.ctx, 5*time.Second) + dctx, dcancel := context.WithTimeout(ctx, 5*time.Second) cli, err := pd.NewClientWithContext(dctx, endpoints, pd.SecurityOption{ CAPath: testClientTLSInfo.TrustedCAFile, CertPath: testClientTLSInfo.CertFile, KeyPath: testClientTLSInfo.KeyFile, }, pd.WithGRPCDialOptions(grpc.WithBlock())) - c.Assert(err, IsNil) + re.NoError(err) dcancel() cli.Close() // 7. test use raw bytes to init tls config - caData, certData, keyData := loadTLSContent(c, + caData, certData, keyData := loadTLSContent(re, testClientTLSInfo.TrustedCAFile, testClientTLSInfo.CertFile, testClientTLSInfo.KeyFile) - ctx1, cancel1 := context.WithTimeout(s.ctx, 2*time.Second) + ctx1, cancel1 := context.WithTimeout(ctx, 2*time.Second) _, err = pd.NewClientWithContext(ctx1, endpoints, pd.SecurityOption{ SSLCABytes: caData, SSLCertBytes: certData, SSLKEYBytes: keyData, }, pd.WithGRPCDialOptions(grpc.WithBlock())) - c.Assert(err, IsNil) + re.NoError(err) cancel1() } -func loadTLSContent(c *C, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { +func loadTLSContent(re *require.Assertions, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { var err error caData, err = os.ReadFile(caPath) - c.Assert(err, IsNil) + re.NoError(err) certData, err = os.ReadFile(certPath) - c.Assert(err, IsNil) + re.NoError(err) keyData, err = os.ReadFile(keyPath) - c.Assert(err, IsNil) + re.NoError(err) return } @@ -245,6 +233,7 @@ func copyTLSFiles(ti transport.TLSInfo, dst string) (transport.TLSInfo, error) { } return ci, nil } + func copyFile(src, dst string) error { f, err := os.Open(src) if err != nil { diff --git a/tests/client/go.mod b/tests/client/go.mod index 93fb9d96eaa..9d539193d52 100644 --- a/tests/client/go.mod +++ b/tests/client/go.mod @@ -5,9 +5,9 @@ go 1.16 require ( github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.5.2 // indirect - github.com/pingcap/check v0.0.0-20211026125417-57bd13f7b5f0 github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 github.com/pingcap/kvproto v0.0.0-20220510035547-0e2f26c0a46a + github.com/stretchr/testify v1.7.0 github.com/tikv/pd v0.0.0-00010101000000-000000000000 github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 go.etcd.io/etcd v0.5.0-alpha.5.0.20191023171146-3cf2f69b5738 diff --git a/tests/cluster.go b/tests/cluster.go index 2061668f393..3b0e10a02e6 100644 --- a/tests/cluster.go +++ b/tests/cluster.go @@ -19,6 +19,7 @@ import ( "net/http" "os" "sync" + "testing" "time" "github.com/coreos/go-semver/semver" @@ -622,6 +623,26 @@ func (c *TestCluster) WaitAllLeaders(testC *check.C, dcLocations map[string]stri wg.Wait() } +// WaitAllLeadersWithTestingT will block and wait for the election of PD leader and all Local TSO Allocator leaders. +// NOTICE: this is a temporary function that we will be used to replace `WaitAllLeaders` later. +func (c *TestCluster) WaitAllLeadersWithTestingT(t *testing.T, dcLocations map[string]string) { + c.WaitLeader() + c.CheckClusterDCLocation() + // Wait for each DC's Local TSO Allocator leader + wg := sync.WaitGroup{} + for _, dcLocation := range dcLocations { + wg.Add(1) + go func(dc string) { + testutil.WaitUntilWithTestingT(t, func() bool { + leaderName := c.WaitAllocatorLeader(dc) + return leaderName != "" + }) + wg.Done() + }(dcLocation) + } + wg.Wait() +} + // GetCluster returns PD cluster. func (c *TestCluster) GetCluster() *metapb.Cluster { leader := c.GetLeader()