diff --git a/pkg/assertutil/assertutil.go b/pkg/assertutil/assertutil.go index 5da16155674..9eb2719b220 100644 --- a/pkg/assertutil/assertutil.go +++ b/pkg/assertutil/assertutil.go @@ -14,6 +14,8 @@ package assertutil +import "github.com/stretchr/testify/require" + // Checker accepts the injection of check functions and context from test files. // Any check function should be set before usage unless the test will fail. type Checker struct { @@ -21,11 +23,23 @@ type Checker struct { FailNow func() } -// NewChecker creates Checker with FailNow function. +// NewChecker creates Checker. func NewChecker() *Checker { return &Checker{} } +// CheckerWithNilAssert creates Checker with nil assert function. +func CheckerWithNilAssert(re *require.Assertions) *Checker { + checker := NewChecker() + checker.FailNow = func() { + re.FailNow("should be nil") + } + checker.IsNil = func(obtained interface{}) { + re.Nil(obtained) + } + return checker +} + // AssertNil calls the injected IsNil assertion. func (c *Checker) AssertNil(obtained interface{}) { if c.IsNil == nil { diff --git a/server/api/server_test.go b/server/api/server_test.go index 273f62cab54..b82dfc5ea21 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -27,6 +27,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" @@ -78,7 +79,7 @@ var zapLogOnce sync.Once func mustNewCluster(re *require.Assertions, num int, opts ...func(cfg *config.Config)) ([]*config.Config, []*server.Server, cleanUpFunc) { ctx, cancel := context.WithCancel(context.Background()) svrs := make([]*server.Server, 0, num) - cfgs := server.NewTestMultiConfig(checkerWithNilAssert(re), num) + cfgs := server.NewTestMultiConfig(assertutil.CheckerWithNilAssert(re), num) ch := make(chan *server.Server, num) for _, cfg := range cfgs { diff --git a/server/api/version_test.go b/server/api/version_test.go index 41254649c34..9973a871b05 100644 --- a/server/api/version_test.go +++ b/server/api/version_test.go @@ -30,17 +30,6 @@ import ( "github.com/tikv/pd/server/config" ) -func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - re.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - re.Nil(obtained) - } - return checker -} - func TestGetVersion(t *testing.T) { // TODO: enable it. t.Skip("Temporary disable. See issue: https://github.com/tikv/pd/issues/1893") @@ -51,7 +40,7 @@ func TestGetVersion(t *testing.T) { temp, _ := os.Create(fname) os.Stdout = temp - cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) + cfg := server.NewTestSingleConfig(assertutil.CheckerWithNilAssert(re)) reqCh := make(chan struct{}) go func() { <-reqCh diff --git a/server/join/join_test.go b/server/join/join_test.go index b8f001b5398..1dbdd7d374f 100644 --- a/server/join/join_test.go +++ b/server/join/join_test.go @@ -23,21 +23,10 @@ import ( "github.com/tikv/pd/server" ) -func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - re.FailNow("") - } - checker.IsNil = func(obtained interface{}) { - re.Nil(obtained) - } - return checker -} - // A PD joins itself. func TestPDJoinsItself(t *testing.T) { re := require.New(t) - cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) + cfg := server.NewTestSingleConfig(assertutil.CheckerWithNilAssert(re)) defer testutil.CleanServer(cfg.DataDir) cfg.Join = cfg.AdvertiseClientUrls re.Error(PrepareJoinCluster(cfg)) diff --git a/server/server_test.go b/server/server_test.go index 58f572fb2df..f520314a5b1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -63,22 +63,11 @@ func (suite *leaderServerTestSuite) mustWaitLeader(svrs []*Server) *Server { return leader } -func (suite *leaderServerTestSuite) checkerWithNilAssert() *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - suite.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - suite.Nil(obtained) - } - return checker -} - func (suite *leaderServerTestSuite) SetupSuite() { suite.ctx, suite.cancel = context.WithCancel(context.Background()) suite.svrs = make(map[string]*Server) - cfgs := NewTestMultiConfig(suite.checkerWithNilAssert(), 3) + cfgs := NewTestMultiConfig(assertutil.CheckerWithNilAssert(suite.Require()), 3) ch := make(chan *Server, 3) for i := 0; i < 3; i++ { @@ -153,7 +142,7 @@ func (suite *leaderServerTestSuite) newTestServersWithCfgs(ctx context.Context, func (suite *leaderServerTestSuite) TestCheckClusterID() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cfgs := NewTestMultiConfig(suite.checkerWithNilAssert(), 2) + cfgs := NewTestMultiConfig(assertutil.CheckerWithNilAssert(suite.Require()), 2) for i, cfg := range cfgs { cfg.DataDir = fmt.Sprintf("/tmp/test_pd_check_clusterID_%d", i) // Clean up before testing. @@ -209,7 +198,7 @@ func (suite *leaderServerTestSuite) TestRegisterServerHandler() { } return mux, info, nil } - cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) + cfg := NewTestSingleConfig(assertutil.CheckerWithNilAssert(suite.Require())) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) suite.NoError(err) @@ -248,7 +237,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderForwarded() { } return mux, info, nil } - cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) + cfg := NewTestSingleConfig(assertutil.CheckerWithNilAssert(suite.Require())) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) suite.NoError(err) @@ -291,7 +280,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderXReal() { } return mux, info, nil } - cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) + cfg := NewTestSingleConfig(assertutil.CheckerWithNilAssert(suite.Require())) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) suite.NoError(err) @@ -334,7 +323,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderBoth() { } return mux, info, nil } - cfg := NewTestSingleConfig(suite.checkerWithNilAssert()) + cfg := NewTestSingleConfig(assertutil.CheckerWithNilAssert(suite.Require())) ctx, cancel := context.WithCancel(context.Background()) svr, err := CreateServer(ctx, cfg, mokHandler) suite.NoError(err) diff --git a/tests/client/client_test.go b/tests/client/client_test.go index 003b4f73c32..8e67cfb4949 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -690,7 +690,7 @@ func TestClientTestSuite(t *testing.T) { func (suite *clientTestSuite) SetupSuite() { var err error re := suite.Require() - suite.srv, suite.cleanup, err = server.NewTestServer(suite.checkerWithNilAssert()) + suite.srv, suite.cleanup, err = server.NewTestServer(assertutil.CheckerWithNilAssert(re)) suite.NoError(err) suite.grpcPDClient = testutil.MustNewGrpcClient(re, suite.srv.GetAddr()) suite.grpcSvr = &server.GrpcServer{Server: suite.srv} @@ -728,17 +728,6 @@ func (suite *clientTestSuite) TearDownSuite() { suite.cleanup() } -func (suite *clientTestSuite) checkerWithNilAssert() *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - suite.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - suite.Nil(obtained) - } - return checker -} - func (suite *clientTestSuite) mustWaitLeader(svrs map[string]*server.Server) *server.Server { for i := 0; i < 500; i++ { for _, s := range svrs { diff --git a/tests/pdctl/global_test.go b/tests/pdctl/global_test.go index c182c739403..a13fee11441 100644 --- a/tests/pdctl/global_test.go +++ b/tests/pdctl/global_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/log" "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/apiutil" + "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/server" cmd "github.com/tikv/pd/tools/pd-ctl/pdctl" @@ -47,7 +48,7 @@ func TestSendAndGetComponent(t *testing.T) { } return mux, info, nil } - cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) + cfg := server.NewTestSingleConfig(assertutil.CheckerWithNilAssert(re)) ctx, cancel := context.WithCancel(context.Background()) svr, err := server.CreateServer(ctx, cfg, handler) re.NoError(err) diff --git a/tests/pdctl/helper.go b/tests/pdctl/helper.go index 9a6adf566a3..5691dde66ca 100644 --- a/tests/pdctl/helper.go +++ b/tests/pdctl/helper.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/spf13/cobra" "github.com/stretchr/testify/require" - "github.com/tikv/pd/pkg/assertutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/core" @@ -122,14 +121,3 @@ func MustPutRegion(re *require.Assertions, cluster *tests.TestCluster, regionID, re.NoError(err) return r } - -func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - re.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - re.Nil(obtained) - } - return checker -} diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 552a2d0c221..229b4756045 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -42,17 +42,6 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.LeakOptions...) } -func checkerWithNilAssert(re *require.Assertions) *assertutil.Checker { - checker := assertutil.NewChecker() - checker.FailNow = func() { - re.FailNow("should be nil") - } - checker.IsNil = func(obtained interface{}) { - re.Nil(obtained) - } - return checker -} - func TestMemberDelete(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) @@ -314,7 +303,7 @@ func TestGetLeader(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cfg := server.NewTestSingleConfig(checkerWithNilAssert(re)) + cfg := server.NewTestSingleConfig(assertutil.CheckerWithNilAssert(re)) wg := &sync.WaitGroup{} wg.Add(1) done := make(chan bool)