diff --git a/pkg/utils/tsoutil/tso_dispatcher.go b/pkg/utils/tsoutil/tso_dispatcher.go index 152a3996538..69baf4b1e41 100644 --- a/pkg/utils/tsoutil/tso_dispatcher.go +++ b/pkg/utils/tsoutil/tso_dispatcher.go @@ -58,7 +58,7 @@ func NewTSODispatcher(tsoProxyHandleDuration, tsoProxyBatchSize prometheus.Histo return tsoDispatcher } -// DispatchRequest is the entry point for dispatching/forwarding a tso request to the detination host +// DispatchRequest is the entry point for dispatching/forwarding a tso request to the destination host func (s *TSODispatcher) DispatchRequest( ctx context.Context, req Request, @@ -69,9 +69,9 @@ func (s *TSODispatcher) DispatchRequest( val, loaded := s.dispatchChs.LoadOrStore(req.getForwardedHost(), make(chan Request, maxMergeRequests)) reqCh := val.(chan Request) if !loaded { - tsDeadlineCh := make(chan deadline, 1) + tsDeadlineCh := make(chan *TSDeadline, 1) go s.dispatch(ctx, tsoProtoFactory, req.getForwardedHost(), req.getClientConn(), reqCh, tsDeadlineCh, doneCh, errCh, tsoPrimaryWatchers...) - go watchTSDeadline(ctx, tsDeadlineCh) + go WatchTSDeadline(ctx, tsDeadlineCh) } reqCh <- req } @@ -82,7 +82,7 @@ func (s *TSODispatcher) dispatch( forwardedHost string, clientConn *grpc.ClientConn, tsoRequestCh <-chan Request, - tsDeadlineCh chan<- deadline, + tsDeadlineCh chan<- *TSDeadline, doneCh <-chan struct{}, errCh chan<- error, tsoPrimaryWatchers ...*etcdutil.LoopWatcher) { @@ -121,11 +121,7 @@ func (s *TSODispatcher) dispatch( requests[i] = <-tsoRequestCh } done := make(chan struct{}) - dl := deadline{ - timer: time.After(DefaultTSOProxyTimeout), - done: done, - cancel: cancel, - } + dl := NewTSDeadline(DefaultTSOProxyTimeout, done, cancel) select { case tsDeadlineCh <- dl: case <-dispatcherCtx.Done(): @@ -199,13 +195,28 @@ func (s *TSODispatcher) finishRequest(requests []Request, physical, firstLogical return nil } -type deadline struct { +// TSDeadline is used to watch the deadline of each tso request. +type TSDeadline struct { timer <-chan time.Time done chan struct{} cancel context.CancelFunc } -func watchTSDeadline(ctx context.Context, tsDeadlineCh <-chan deadline) { +// NewTSDeadline creates a new TSDeadline. +func NewTSDeadline( + timeout time.Duration, + done chan struct{}, + cancel context.CancelFunc, +) *TSDeadline { + return &TSDeadline{ + timer: time.After(timeout), + done: done, + cancel: cancel, + } +} + +// WatchTSDeadline watches the deadline of each tso request. +func WatchTSDeadline(ctx context.Context, tsDeadlineCh <-chan *TSDeadline) { defer logutil.LogPanic() ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/server/grpc_service.go b/server/grpc_service.go index 3096357654b..1badabb19d8 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -406,16 +406,17 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { var ( server = &tsoServer{stream: stream} forwardStream tsopb.TSO_TsoClient - cancel context.CancelFunc + forwardCtx context.Context + cancelForward context.CancelFunc lastForwardedHost string ) defer func() { s.concurrentTSOProxyStreamings.Add(-1) - // cancel the forward stream - if cancel != nil { - cancel() + if cancelForward != nil { + cancelForward() } }() + maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) if maxConcurrentTSOProxyStreamings >= 0 { if newCount := s.concurrentTSOProxyStreamings.Add(1); newCount > maxConcurrentTSOProxyStreamings { @@ -423,6 +424,9 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { } } + tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1) + go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) + for { select { case <-s.ctx.Done(): @@ -449,22 +453,24 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { return errors.WithStack(ErrNotFoundTSOAddr) } if forwardStream == nil || lastForwardedHost != forwardedHost { - if cancel != nil { - cancel() + if cancelForward != nil { + cancelForward() } clientConn, err := s.getDelegateClient(s.ctx, forwardedHost) if err != nil { return errors.WithStack(err) } - forwardStream, cancel, err = s.createTSOForwardStream(clientConn) + forwardStream, forwardCtx, cancelForward, err = + s.createTSOForwardStream(stream.Context(), clientConn) if err != nil { return errors.WithStack(err) } lastForwardedHost = forwardedHost } - tsopbResp, err := s.forwardTSORequestWithDeadLine(stream.Context(), request, forwardStream) + tsopbResp, err := s.forwardTSORequestWithDeadLine( + forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh) if err != nil { return errors.WithStack(err) } @@ -502,37 +508,39 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { } func (s *GrpcServer) forwardTSORequestWithDeadLine( - ctx context.Context, request *pdpb.TsoRequest, forwardStream tsopb.TSO_TsoClient, + forwardCtx context.Context, + cancelForward context.CancelFunc, + forwardStream tsopb.TSO_TsoClient, + request *pdpb.TsoRequest, + tsDeadlineCh chan<- *tsoutil.TSDeadline, ) (*tsopb.TsoResponse, error) { - defer logutil.LogPanic() - // Create a context with deadline for forwarding TSO request to TSO service. - ctxTimeout, cancel := context.WithTimeout(ctx, tsoutil.DefaultTSOProxyTimeout) - defer cancel() - - tsoProxyBatchSize.Observe(float64(request.GetCount())) + done := make(chan struct{}) + dl := tsoutil.NewTSDeadline(tsoutil.DefaultTSOProxyTimeout, done, cancelForward) + select { + case tsDeadlineCh <- dl: + case <-forwardCtx.Done(): + return nil, forwardCtx.Err() + } - // used to receive the result from doSomething function - tsoRespCh := make(chan *tsopbTSOResponse, 1) start := time.Now() - go s.forwardTSORequestAsync(ctxTimeout, request, forwardStream, tsoRespCh) - select { - case <-ctxTimeout.Done(): - tsoProxyForwardTimeoutCounter.Inc() - return nil, ErrForwardTSOTimeout - case tsoResp := <-tsoRespCh: - if tsoResp.err == nil { - tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) + resp, err := s.forwardTSORequest(forwardCtx, request, forwardStream) + close(done) + if err != nil { + if strings.Contains(err.Error(), errs.NotLeaderErr) { + s.tsoPrimaryWatcher.ForceLoad() } - return tsoResp.response, tsoResp.err + return nil, err } + tsoProxyBatchSize.Observe(float64(request.GetCount())) + tsoProxyHandleDuration.Observe(time.Since(start).Seconds()) + return resp, nil } -func (s *GrpcServer) forwardTSORequestAsync( - ctxTimeout context.Context, +func (s *GrpcServer) forwardTSORequest( + ctx context.Context, request *pdpb.TsoRequest, forwardStream tsopb.TSO_TsoClient, - tsoRespCh chan<- *tsopbTSOResponse, -) { +) (*tsopb.TsoResponse, error) { tsopbReq := &tsopb.TsoRequest{ Header: &tsopb.RequestHeader{ ClusterId: request.GetHeader().GetClusterId(), @@ -545,46 +553,32 @@ func (s *GrpcServer) forwardTSORequestAsync( } failpoint.Inject("tsoProxySendToTSOTimeout", func() { - <-ctxTimeout.Done() - failpoint.Return() + // block until watchDeadline routine cancels the context. + <-ctx.Done() }) - if err := forwardStream.Send(tsopbReq); err != nil { - select { - case <-ctxTimeout.Done(): - return - case tsoRespCh <- &tsopbTSOResponse{err: err}: - } - return - } - select { - case <-ctxTimeout.Done(): - return + case <-ctx.Done(): + return nil, ctx.Err() default: } + if err := forwardStream.Send(tsopbReq); err != nil { + return nil, err + } + failpoint.Inject("tsoProxyRecvFromTSOTimeout", func() { - <-ctxTimeout.Done() - failpoint.Return() + // block until watchDeadline routine cancels the context. + <-ctx.Done() }) - response, err := forwardStream.Recv() - if err != nil { - if strings.Contains(err.Error(), errs.NotLeaderErr) { - s.tsoPrimaryWatcher.ForceLoad() - } - } select { - case <-ctxTimeout.Done(): - return - case tsoRespCh <- &tsopbTSOResponse{response: response, err: err}: + case <-ctx.Done(): + return nil, ctx.Err() + default: } -} -type tsopbTSOResponse struct { - response *tsopb.TsoResponse - err error + return forwardStream.Recv() } // tsoServer wraps PD_TsoServer to ensure when any error @@ -2140,13 +2134,15 @@ func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatC } } -func (s *GrpcServer) createTSOForwardStream(client *grpc.ClientConn) (tsopb.TSO_TsoClient, context.CancelFunc, error) { +func (s *GrpcServer) createTSOForwardStream( + ctx context.Context, client *grpc.ClientConn, +) (tsopb.TSO_TsoClient, context.Context, context.CancelFunc, error) { done := make(chan struct{}) - ctx, cancel := context.WithCancel(s.ctx) - go checkStream(ctx, cancel, done) - forwardStream, err := tsopb.NewTSOClient(client).Tso(ctx) + forwardCtx, cancelForward := context.WithCancel(ctx) + go checkStream(forwardCtx, cancelForward, done) + forwardStream, err := tsopb.NewTSOClient(client).Tso(forwardCtx) done <- struct{}{} - return forwardStream, cancel, err + return forwardStream, forwardCtx, cancelForward, err } func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) {