diff --git a/client/client.go b/client/client.go index 70de1322488..f34f5897013 100644 --- a/client/client.go +++ b/client/client.go @@ -744,16 +744,18 @@ func (c *client) checkLeaderHealth(ctx context.Context) { if client := c.pdSvcDiscovery.GetServingEndpointClientConn(); client != nil { healthCli := healthpb.NewHealthClient(client) resp, err := healthCli.Check(ctx, &healthpb.HealthCheckRequest{Service: ""}) - rpcErr, ok := status.FromError(err) failpoint.Inject("unreachableNetwork1", func() { resp = nil err = status.New(codes.Unavailable, "unavailable").Err() }) + rpcErr, ok := status.FromError(err) if (ok && isNetworkError(rpcErr.Code())) || resp.GetStatus() != healthpb.HealthCheckResponse_SERVING { atomic.StoreInt32(&(c.leaderNetworkFailure), int32(1)) } else { atomic.StoreInt32(&(c.leaderNetworkFailure), int32(0)) } + } else { + atomic.StoreInt32(&(c.leaderNetworkFailure), int32(1)) } } diff --git a/client/grpcutil/grpcutil.go b/client/grpcutil/grpcutil.go index 125f1125721..fe149e76ecc 100644 --- a/client/grpcutil/grpcutil.go +++ b/client/grpcutil/grpcutil.go @@ -21,6 +21,8 @@ import ( "sync" "time" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/tlsutil" @@ -88,6 +90,12 @@ func GetOrCreateGRPCConn(ctx context.Context, clientConns *sync.Map, addr string dCtx, cancel := context.WithTimeout(ctx, dialTimeout) defer cancel() cc, err := GetClientConn(dCtx, addr, tlsConfig, opt...) + failpoint.Inject("unreachableNetwork2", func(val failpoint.Value) { + if val, ok := val.(string); ok && val == addr { + cc = nil + err = errors.Errorf("unreachable network") + } + }) if err != nil { return nil, err } diff --git a/client/pd_service_discovery.go b/client/pd_service_discovery.go index 98ddd611326..b75276adbe9 100644 --- a/client/pd_service_discovery.go +++ b/client/pd_service_discovery.go @@ -614,7 +614,6 @@ func (c *pdServiceDiscovery) switchLeader(addrs []string) error { if _, err := c.GetOrCreateGRPCConn(addr); err != nil { log.Warn("[pd] failed to connect leader", zap.String("leader", addr), errs.ZapError(err)) - return err } // Set PD leader and Global TSO Allocator (which is also the PD leader) c.leader.Store(addr) diff --git a/client/tso_client.go b/client/tso_client.go index 35d9388c72b..fc38ee8e5ba 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -171,9 +171,10 @@ func (c *tsoClient) GetTSOAllocatorClientConnByDCLocation(dcLocation string) (*g if !ok { panic(fmt.Sprintf("the allocator leader in %s should exist", dcLocation)) } + // todo: if we support local tso forward, we should get or create client conns. cc, ok := c.svcDiscovery.GetClientConns().Load(url) if !ok { - panic(fmt.Sprintf("the client connection of %s in %s should exist", url, dcLocation)) + return nil, url.(string) } return cc.(*grpc.ClientConn), url.(string) } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index e4c5bf3c77a..0de4dc3a49e 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -254,7 +254,7 @@ func (c *tsoClient) checkAllocator( requestForwarded.WithLabelValues(forwardedHostTrim, addrTrim).Set(0) }() cc, u := c.GetTSOAllocatorClientConnByDCLocation(dc) - healthCli := healthpb.NewHealthClient(cc) + var healthCli healthpb.HealthClient ticker := time.NewTicker(time.Second) defer ticker.Stop() for { @@ -263,20 +263,25 @@ func (c *tsoClient) checkAllocator( log.Info("[tso] the leader of the allocator leader is changed", zap.String("dc", dc), zap.String("origin", url), zap.String("new", u)) return } - healthCtx, healthCancel := context.WithTimeout(dispatcherCtx, c.option.timeout) - resp, err := healthCli.Check(healthCtx, &healthpb.HealthCheckRequest{Service: ""}) - failpoint.Inject("unreachableNetwork", func() { - resp.Status = healthpb.HealthCheckResponse_UNKNOWN - }) - healthCancel() - if err == nil && resp.GetStatus() == healthpb.HealthCheckResponse_SERVING { - // create a stream of the original allocator - cctx, cancel := context.WithCancel(dispatcherCtx) - stream, err := c.tsoStreamBuilderFactory.makeBuilder(cc).build(cctx, cancel, c.option.timeout) - if err == nil && stream != nil { - log.Info("[tso] recover the original tso stream since the network has become normal", zap.String("dc", dc), zap.String("url", url)) - updateAndClear(url, &tsoConnectionContext{url, stream, cctx, cancel}) - return + if healthCli == nil && cc != nil { + healthCli = healthpb.NewHealthClient(cc) + } + if healthCli != nil { + healthCtx, healthCancel := context.WithTimeout(dispatcherCtx, c.option.timeout) + resp, err := healthCli.Check(healthCtx, &healthpb.HealthCheckRequest{Service: ""}) + failpoint.Inject("unreachableNetwork", func() { + resp.Status = healthpb.HealthCheckResponse_UNKNOWN + }) + healthCancel() + if err == nil && resp.GetStatus() == healthpb.HealthCheckResponse_SERVING { + // create a stream of the original allocator + cctx, cancel := context.WithCancel(dispatcherCtx) + stream, err := c.tsoStreamBuilderFactory.makeBuilder(cc).build(cctx, cancel, c.option.timeout) + if err == nil && stream != nil { + log.Info("[tso] recover the original tso stream since the network has become normal", zap.String("dc", dc), zap.String("url", url)) + updateAndClear(url, &tsoConnectionContext{url, stream, cctx, cancel}) + return + } } } select { @@ -285,7 +290,7 @@ func (c *tsoClient) checkAllocator( case <-ticker.C: // To ensure we can get the latest allocator leader // and once the leader is changed, we can exit this function. - _, u = c.GetTSOAllocatorClientConnByDCLocation(dc) + cc, u = c.GetTSOAllocatorClientConnByDCLocation(dc) } } } @@ -597,29 +602,32 @@ func (c *tsoClient) tryConnectToTSO( for i := 0; i < maxRetryTimes; i++ { c.svcDiscovery.ScheduleCheckMemberChanged() cc, url = c.GetTSOAllocatorClientConnByDCLocation(dc) - cctx, cancel := context.WithCancel(dispatcherCtx) - stream, err = c.tsoStreamBuilderFactory.makeBuilder(cc).build(cctx, cancel, c.option.timeout) - failpoint.Inject("unreachableNetwork", func() { - stream = nil - err = status.New(codes.Unavailable, "unavailable").Err() - }) - if stream != nil && err == nil { - updateAndClear(url, &tsoConnectionContext{url, stream, cctx, cancel}) - return nil - } - - if err != nil && c.option.enableForwarding { - // The reason we need to judge if the error code is equal to "Canceled" here is that - // when we create a stream we use a goroutine to manually control the timeout of the connection. - // There is no need to wait for the transport layer timeout which can reduce the time of unavailability. - // But it conflicts with the retry mechanism since we use the error code to decide if it is caused by network error. - // And actually the `Canceled` error can be regarded as a kind of network error in some way. - if rpcErr, ok := status.FromError(err); ok && (isNetworkError(rpcErr.Code()) || rpcErr.Code() == codes.Canceled) { - networkErrNum++ + if cc != nil { + cctx, cancel := context.WithCancel(dispatcherCtx) + stream, err = c.tsoStreamBuilderFactory.makeBuilder(cc).build(cctx, cancel, c.option.timeout) + failpoint.Inject("unreachableNetwork", func() { + stream = nil + err = status.New(codes.Unavailable, "unavailable").Err() + }) + if stream != nil && err == nil { + updateAndClear(url, &tsoConnectionContext{url, stream, cctx, cancel}) + return nil } - } - cancel() + if err != nil && c.option.enableForwarding { + // The reason we need to judge if the error code is equal to "Canceled" here is that + // when we create a stream we use a goroutine to manually control the timeout of the connection. + // There is no need to wait for the transport layer timeout which can reduce the time of unavailability. + // But it conflicts with the retry mechanism since we use the error code to decide if it is caused by network error. + // And actually the `Canceled` error can be regarded as a kind of network error in some way. + if rpcErr, ok := status.FromError(err); ok && (isNetworkError(rpcErr.Code()) || rpcErr.Code() == codes.Canceled) { + networkErrNum++ + } + } + cancel() + } else { + networkErrNum++ + } select { case <-dispatcherCtx.Done(): return err diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 3834d9b53bf..bb4d6851fd0 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -518,7 +518,7 @@ func TestCustomTimeout(t *testing.T) { re.Less(time.Since(start), 2*time.Second) } -func TestGetRegionFromFollowerClient(t *testing.T) { +func TestGetRegionByFollowerForwarding(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -544,7 +544,7 @@ func TestGetRegionFromFollowerClient(t *testing.T) { } // case 1: unreachable -> normal -func TestGetTsoFromFollowerClient1(t *testing.T) { +func TestGetTsoByFollowerForwarding1(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -575,7 +575,7 @@ func TestGetTsoFromFollowerClient1(t *testing.T) { } // case 2: unreachable -> leader transfer -> normal -func TestGetTsoFromFollowerClient2(t *testing.T) { +func TestGetTsoByFollowerForwarding2(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -609,6 +609,101 @@ func TestGetTsoFromFollowerClient2(t *testing.T) { checkTS(re, cli, lastTS) } +// case 3: network partition between client and follower A -> transfer leader to follower A -> normal +func TestGetTsoAndRegionByFollowerForwarding(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pd.LeaderHealthCheckInterval = 100 * time.Millisecond + cluster, err := tests.NewTestCluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + + endpoints := runServer(re, cluster) + re.NotEmpty(cluster.WaitLeader()) + leader := cluster.GetLeaderServer() + grpcPDClient := testutil.MustNewGrpcClient(re, leader.GetAddr()) + testutil.Eventually(re, func() bool { + regionHeartbeat, err := grpcPDClient.RegionHeartbeat(ctx) + re.NoError(err) + regionID := regionIDAllocator.alloc() + region := &metapb.Region{ + Id: regionID, + RegionEpoch: &metapb.RegionEpoch{ + ConfVer: 1, + Version: 1, + }, + Peers: peers, + } + req := &pdpb.RegionHeartbeatRequest{ + Header: newHeader(leader.GetServer()), + Region: region, + Leader: peers[0], + } + err = regionHeartbeat.Send(req) + re.NoError(err) + _, err = regionHeartbeat.Recv() + return err == nil + }) + follower := cluster.GetServer(cluster.GetFollower()) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/grpcutil/unreachableNetwork2", fmt.Sprintf("return(\"%s\")", follower.GetAddr()))) + + cli := setupCli(re, ctx, endpoints, pd.WithForwardingOption(true)) + var lastTS uint64 + testutil.Eventually(re, func() bool { + physical, logical, err := cli.GetTS(context.TODO()) + if err == nil { + lastTS = tsoutil.ComposeTS(physical, logical) + return true + } + t.Log(err) + return false + }) + lastTS = checkTS(re, cli, lastTS) + r, err := cli.GetRegion(context.Background(), []byte("a")) + re.NoError(err) + re.NotNil(r) + leader.GetServer().GetMember().ResignEtcdLeader(leader.GetServer().Context(), + leader.GetServer().Name(), follower.GetServer().Name()) + re.NotEmpty(cluster.WaitLeader()) + testutil.Eventually(re, func() bool { + physical, logical, err := cli.GetTS(context.TODO()) + if err == nil { + lastTS = tsoutil.ComposeTS(physical, logical) + return true + } + t.Log(err) + return false + }) + lastTS = checkTS(re, cli, lastTS) + testutil.Eventually(re, func() bool { + r, err = cli.GetRegion(context.Background(), []byte("a")) + if err == nil && r != nil { + return true + } + return false + }) + + re.NoError(failpoint.Disable("github.com/tikv/pd/client/grpcutil/unreachableNetwork2")) + testutil.Eventually(re, func() bool { + physical, logical, err := cli.GetTS(context.TODO()) + if err == nil { + lastTS = tsoutil.ComposeTS(physical, logical) + return true + } + t.Log(err) + return false + }) + lastTS = checkTS(re, cli, lastTS) + testutil.Eventually(re, func() bool { + r, err = cli.GetRegion(context.Background(), []byte("a")) + if err == nil && r != nil { + return true + } + return false + }) +} + func checkTS(re *require.Assertions, cli pd.Client, lastTS uint64) uint64 { for i := 0; i < tsoRequestRound; i++ { physical, logical, err := cli.GetTS(context.TODO())