From bb1e5115b0e3179aebc48c6866e28750717a9411 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Fri, 30 Jun 2023 16:32:12 +0800 Subject: [PATCH] tso: fix memory leak introduced by timer.After (#6730) close tikv/pd#6719, ref tikv/pd#6720 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/timerpool/pool.go | 43 ++++++++++++++++++ pkg/timerpool/pool_test.go | 70 +++++++++++++++++++++++++++++ pkg/utils/tsoutil/tso_dispatcher.go | 17 ++++--- server/grpc_service.go | 12 +++-- 4 files changed, 133 insertions(+), 9 deletions(-) create mode 100644 pkg/timerpool/pool.go create mode 100644 pkg/timerpool/pool_test.go diff --git a/pkg/timerpool/pool.go b/pkg/timerpool/pool.go new file mode 100644 index 00000000000..28ffacfc629 --- /dev/null +++ b/pkg/timerpool/pool.go @@ -0,0 +1,43 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Note: This file is copied from https://go-review.googlesource.com/c/go/+/276133 + +package timerpool + +import ( + "sync" + "time" +) + +// GlobalTimerPool is a global pool for reusing *time.Timer. +var GlobalTimerPool TimerPool + +// TimerPool is a wrapper of sync.Pool which caches *time.Timer for reuse. +type TimerPool struct { + pool sync.Pool +} + +// Get returns a timer with a given duration. +func (tp *TimerPool) Get(d time.Duration) *time.Timer { + if v := tp.pool.Get(); v != nil { + timer := v.(*time.Timer) + timer.Reset(d) + return timer + } + return time.NewTimer(d) +} + +// Put tries to call timer.Stop() before putting it back into pool, +// if the timer.Stop() returns false (it has either already expired or been stopped), +// have a shot at draining the channel with residual time if there is one. +func (tp *TimerPool) Put(timer *time.Timer) { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + tp.pool.Put(timer) +} diff --git a/pkg/timerpool/pool_test.go b/pkg/timerpool/pool_test.go new file mode 100644 index 00000000000..d6dffc723a9 --- /dev/null +++ b/pkg/timerpool/pool_test.go @@ -0,0 +1,70 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Note: This file is copied from https://go-review.googlesource.com/c/go/+/276133 + +package timerpool + +import ( + "testing" + "time" +) + +func TestTimerPool(t *testing.T) { + var tp TimerPool + + for i := 0; i < 100; i++ { + timer := tp.Get(20 * time.Millisecond) + + select { + case <-timer.C: + t.Errorf("timer expired too early") + continue + default: + } + + select { + case <-time.After(100 * time.Millisecond): + t.Errorf("timer didn't expire on time") + case <-timer.C: + } + + tp.Put(timer) + } +} + +const timeout = 10 * time.Millisecond + +func BenchmarkTimerUtilization(b *testing.B) { + b.Run("TimerWithPool", func(b *testing.B) { + for i := 0; i < b.N; i++ { + t := GlobalTimerPool.Get(timeout) + GlobalTimerPool.Put(t) + } + }) + b.Run("TimerWithoutPool", func(b *testing.B) { + for i := 0; i < b.N; i++ { + t := time.NewTimer(timeout) + t.Stop() + } + }) +} + +func BenchmarkTimerPoolParallel(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + t := GlobalTimerPool.Get(timeout) + GlobalTimerPool.Put(t) + } + }) +} + +func BenchmarkTimerNativeParallel(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + t := time.NewTimer(timeout) + t.Stop() + } + }) +} diff --git a/pkg/utils/tsoutil/tso_dispatcher.go b/pkg/utils/tsoutil/tso_dispatcher.go index 69baf4b1e41..f9585ba5cdd 100644 --- a/pkg/utils/tsoutil/tso_dispatcher.go +++ b/pkg/utils/tsoutil/tso_dispatcher.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/timerpool" "github.com/tikv/pd/pkg/utils/etcdutil" "github.com/tikv/pd/pkg/utils/logutil" "go.uber.org/zap" @@ -197,7 +198,7 @@ func (s *TSODispatcher) finishRequest(requests []Request, physical, firstLogical // TSDeadline is used to watch the deadline of each tso request. type TSDeadline struct { - timer <-chan time.Time + timer *time.Timer done chan struct{} cancel context.CancelFunc } @@ -208,8 +209,9 @@ func NewTSDeadline( done chan struct{}, cancel context.CancelFunc, ) *TSDeadline { + timer := timerpool.GlobalTimerPool.Get(timeout) return &TSDeadline{ - timer: time.After(timeout), + timer: timer, done: done, cancel: cancel, } @@ -224,13 +226,15 @@ func WatchTSDeadline(ctx context.Context, tsDeadlineCh <-chan *TSDeadline) { select { case d := <-tsDeadlineCh: select { - case <-d.timer: + case <-d.timer.C: log.Error("tso proxy request processing is canceled due to timeout", errs.ZapError(errs.ErrProxyTSOTimeout)) d.cancel() + timerpool.GlobalTimerPool.Put(d.timer) case <-d.done: - continue + timerpool.GlobalTimerPool.Put(d.timer) case <-ctx.Done(): + timerpool.GlobalTimerPool.Put(d.timer) return } case <-ctx.Done(): @@ -241,11 +245,12 @@ func WatchTSDeadline(ctx context.Context, tsDeadlineCh <-chan *TSDeadline) { func checkStream(streamCtx context.Context, cancel context.CancelFunc, done chan struct{}) { defer logutil.LogPanic() - + timer := time.NewTimer(3 * time.Second) + defer timer.Stop() select { case <-done: return - case <-time.After(3 * time.Second): + case <-timer.C: cancel() case <-streamCtx.Done(): } diff --git a/server/grpc_service.go b/server/grpc_service.go index 7c28bcbccac..468a40490d1 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -606,13 +606,15 @@ func (s *tsoServer) Send(m *pdpb.TsoResponse) error { }) done <- s.stream.Send(m) }() + timer := time.NewTimer(tsoutil.DefaultTSOProxyTimeout) + defer timer.Stop() select { case err := <-done: if err != nil { atomic.StoreInt32(&s.closed, 1) } return errors.WithStack(err) - case <-time.After(tsoutil.DefaultTSOProxyTimeout): + case <-timer.C: atomic.StoreInt32(&s.closed, 1) return ErrForwardTSOTimeout } @@ -633,6 +635,8 @@ func (s *tsoServer) Recv(timeout time.Duration) (*pdpb.TsoRequest, error) { request, err := s.stream.Recv() requestCh <- &pdpbTSORequest{request: request, err: err} }() + timer := time.NewTimer(timeout) + defer timer.Stop() select { case req := <-requestCh: if req.err != nil { @@ -640,7 +644,7 @@ func (s *tsoServer) Recv(timeout time.Duration) (*pdpb.TsoRequest, error) { return nil, errors.WithStack(req.err) } return req.request, nil - case <-time.After(timeout): + case <-timer.C: atomic.StoreInt32(&s.closed, 1) return nil, ErrTSOProxyRecvFromClientTimeout } @@ -2201,10 +2205,12 @@ func forwardReportBucketClientToServer(forwardStream pdpb.PD_ReportBucketsClient // TODO: If goroutine here timeout when tso stream created successfully, we need to handle it correctly. func checkStream(streamCtx context.Context, cancel context.CancelFunc, done chan struct{}) { defer logutil.LogPanic() + timer := time.NewTimer(3 * time.Second) + defer timer.Stop() select { case <-done: return - case <-time.After(3 * time.Second): + case <-timer.C: cancel() case <-streamCtx.Done(): }