From 261991609c97b26c886441509ed2565dd77707b3 Mon Sep 17 00:00:00 2001 From: crazycs520 Date: Tue, 17 Oct 2023 18:43:36 +0800 Subject: [PATCH] Fix batch client batchSendLoop panic (#1021) Signed-off-by: crazycs520 --- internal/client/client_batch.go | 12 ++++-- internal/locate/region_cache.go | 2 +- internal/locate/region_request_test.go | 51 ++++++++++++++++++++++++++ tikvrpc/tikvrpc.go | 15 ++++++-- util/misc.go | 2 +- 5 files changed, 72 insertions(+), 10 deletions(-) diff --git a/internal/client/client_batch.go b/internal/client/client_batch.go index 6729821e9..12dcc2805 100644 --- a/internal/client/client_batch.go +++ b/internal/client/client_batch.go @@ -296,14 +296,18 @@ func (a *batchConn) fetchMorePendingRequests( const idleTimeout = 3 * time.Minute +// BatchSendLoopPanicCounter is only used for testing. +var BatchSendLoopPanicCounter int64 = 0 + func (a *batchConn) batchSendLoop(cfg config.TiKVClient) { defer func() { if r := recover(); r != nil { metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchSendLoop).Inc() logutil.BgLogger().Error("batchSendLoop", - zap.Reflect("r", r), + zap.Any("r", r), zap.Stack("stack")) - logutil.BgLogger().Info("restart batchSendLoop") + atomic.AddInt64(&BatchSendLoopPanicCounter, 1) + logutil.BgLogger().Info("restart batchSendLoop", zap.Int64("count", atomic.LoadInt64(&BatchSendLoopPanicCounter))) go a.batchSendLoop(cfg) } }() @@ -436,7 +440,7 @@ func (s *batchCommandsStream) recv() (resp *tikvpb.BatchCommandsResponse, err er if r := recover(); r != nil { metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchRecvLoop).Inc() logutil.BgLogger().Error("batchCommandsClient.recv panic", - zap.Reflect("r", r), + zap.Any("r", r), zap.Stack("stack")) err = errors.New("batch conn recv paniced") } @@ -604,7 +608,7 @@ func (c *batchCommandsClient) batchRecvLoop(cfg config.TiKVClient, tikvTransport if r := recover(); r != nil { metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchRecvLoop).Inc() logutil.BgLogger().Error("batchRecvLoop", - zap.Reflect("r", r), + zap.Any("r", r), zap.Stack("stack")) logutil.BgLogger().Info("restart batchRecvLoop") go c.batchRecvLoop(cfg, tikvTransportLayerLoad, streamClient) diff --git a/internal/locate/region_cache.go b/internal/locate/region_cache.go index 622e70aec..110b32aa5 100644 --- a/internal/locate/region_cache.go +++ b/internal/locate/region_cache.go @@ -508,7 +508,7 @@ func (c *RegionCache) checkAndResolve(needCheckStores []*Store, needCheck func(* r := recover() if r != nil { logutil.BgLogger().Error("panic in the checkAndResolve goroutine", - zap.Reflect("r", r), + zap.Any("r", r), zap.Stack("stack trace")) } }() diff --git a/internal/locate/region_request_test.go b/internal/locate/region_request_test.go index c6ed1870d..b95f15de0 100644 --- a/internal/locate/region_request_test.go +++ b/internal/locate/region_request_test.go @@ -37,8 +37,10 @@ package locate import ( "context" "fmt" + "math/rand" "net" "sync" + "sync/atomic" "testing" "time" "unsafe" @@ -693,3 +695,52 @@ func (s *testRegionRequestToSingleStoreSuite) TestKVReadTimeoutWithDisableBatchC s.True(IsFakeRegionError(regionErr)) s.Equal(0, bo.GetTotalBackoffTimes()) // use kv read timeout will do fast retry, so backoff times should be 0. } + +func (s *testRegionRequestToSingleStoreSuite) TestBatchClientSendLoopPanic() { + // This test should use `go test -race` to run. + config.UpdateGlobal(func(conf *config.Config) { + conf.TiKVClient.MaxBatchSize = 128 + })() + + server, port := mock_server.StartMockTikvService() + s.True(port > 0) + rpcClient := client.NewRPCClient() + fnClient := &fnClient{fn: func(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (response *tikvrpc.Response, err error) { + return rpcClient.SendRequest(ctx, server.Addr(), req, timeout) + }} + tf := func(s *Store, bo *retry.Backoffer) livenessState { + return reachable + } + + defer func() { + rpcClient.Close() + server.Stop() + }() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + ctx, cancel := context.WithCancel(context.Background()) + bo := retry.NewBackofferWithVars(ctx, int(client.ReadTimeoutShort.Milliseconds()), nil) + region, err := s.cache.LocateRegionByID(bo, s.region) + s.Nil(err) + s.NotNil(region) + go func() { + // mock for kill query execution or timeout. + time.Sleep(time.Millisecond * time.Duration(rand.Intn(5)+1)) + cancel() + }() + req := tikvrpc.NewRequest(tikvrpc.CmdCop, &coprocessor.Request{Data: []byte("a"), StartTs: 1}) + regionRequestSender := NewRegionRequestSender(s.cache, fnClient) + regionRequestSender.regionCache.testingKnobs.mockRequestLiveness.Store((*livenessFunc)(&tf)) + regionRequestSender.SendReq(bo, req, region.Region, client.ReadTimeoutShort) + } + }() + } + wg.Wait() + // batchSendLoop should not panic. + s.Equal(atomic.LoadInt64(&client.BatchSendLoopPanicCounter), int64(0)) +} diff --git a/tikvrpc/tikvrpc.go b/tikvrpc/tikvrpc.go index 0119e55db..6e1d2373d 100644 --- a/tikvrpc/tikvrpc.go +++ b/tikvrpc/tikvrpc.go @@ -699,13 +699,20 @@ type MPPStreamResponse struct { // SetContext set the Context field for the given req to the specified ctx. func SetContext(req *Request, region *metapb.Region, peer *metapb.Peer) error { - ctx := &req.Context if region != nil { - ctx.RegionId = region.Id - ctx.RegionEpoch = region.RegionEpoch + req.Context.RegionId = region.Id + req.Context.RegionEpoch = region.RegionEpoch } - ctx.Peer = peer + req.Context.Peer = peer + // Shallow copy the context to avoid concurrent modification. + return AttachContext(req, req.Context) +} + +// AttachContext sets the request context to the request, +// Parameter `rpcCtx` use `kvrpcpb.Context` instead of `*kvrpcpb.Context` to avoid concurrent modification by shallow copy. +func AttachContext(req *Request, rpcCtx kvrpcpb.Context) error { + ctx := &rpcCtx switch req.Type { case CmdGet: req.Get().Context = ctx diff --git a/util/misc.go b/util/misc.go index bd3e2b779..e324bf797 100644 --- a/util/misc.go +++ b/util/misc.go @@ -89,7 +89,7 @@ func WithRecovery(exec func(), recoverFn func(r interface{})) { } if r != nil { logutil.BgLogger().Error("panic in the recoverable goroutine", - zap.Reflect("r", r), + zap.Any("r", r), zap.Stack("stack trace")) } }()