Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client/tso: use the TSO request pool at the tsoClient level to avoid data race (#8077) #8141

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -791,34 +791,19 @@ func (c *client) GetLocalTSAsync(ctx context.Context, dcLocation string) TSFutur
defer span.Finish()
}

req := c.getTSORequest(ctx, dcLocation)
if err := c.dispatchTSORequestWithRetry(req); err != nil {
req.tryDone(err)
}
return req
}

func (c *client) getTSORequest(ctx context.Context, dcLocation string) *tsoRequest {
req := tsoReqPool.Get().(*tsoRequest)
// Set needed fields in the request before using it.
req.start = time.Now()
req.clientCtx = c.ctx
req.requestCtx = ctx
req.physical = 0
req.logical = 0
req.dcLocation = dcLocation
return req
return c.dispatchTSORequestWithRetry(ctx, dcLocation)
}

const (
dispatchRetryDelay = 50 * time.Millisecond
dispatchRetryCount = 2
)

func (c *client) dispatchTSORequestWithRetry(req *tsoRequest) error {
func (c *client) dispatchTSORequestWithRetry(ctx context.Context, dcLocation string) TSFuture {
var (
retryable bool
err error
req *tsoRequest
)
for i := 0; i < dispatchRetryCount; i++ {
// Do not delay for the first time.
Expand All @@ -831,12 +816,22 @@ func (c *client) dispatchTSORequestWithRetry(req *tsoRequest) error {
err = errs.ErrClientGetTSO.FastGenByArgs("tso client is nil")
continue
}
// Get a new request from the pool if it's nil or not from the current pool.
if req == nil || req.pool != tsoClient.tsoReqPool {
req = tsoClient.getTSORequest(ctx, dcLocation)
}
retryable, err = tsoClient.dispatchRequest(req)
if !retryable {
break
}
}
return err
if err != nil {
if req == nil {
return newTSORequestFastFail(err)
}
req.tryDone(err)
}
return req
}

func (c *client) GetTS(ctx context.Context) (physical int64, logical int64, err error) {
Expand Down
61 changes: 29 additions & 32 deletions client/tso_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,6 @@ type TSOClient interface {
GetMinTS(ctx context.Context) (int64, int64, error)
}

type tsoRequest struct {
start time.Time
clientCtx context.Context
requestCtx context.Context
done chan error
physical int64
logical int64
dcLocation string
}

var tsoReqPool = sync.Pool{
New: func() any {
return &tsoRequest{
done: make(chan error, 1),
physical: 0,
logical: 0,
}
},
}

func (req *tsoRequest) tryDone(err error) {
select {
case req.done <- err:
default:
}
}

type tsoClient struct {
ctx context.Context
cancel context.CancelFunc
Expand All @@ -84,6 +57,8 @@ type tsoClient struct {
// tso allocator leader is switched.
tsoAllocServingURLSwitchedCallback []func()

// tsoReqPool is the pool to recycle `*tsoRequest`.
tsoReqPool *sync.Pool
// tsoDispatcher is used to dispatch different TSO requests to
// the corresponding dc-location TSO channel.
tsoDispatcher sync.Map // Same as map[string]*tsoDispatcher
Expand All @@ -104,11 +79,20 @@ func newTSOClient(
) *tsoClient {
ctx, cancel := context.WithCancel(ctx)
c := &tsoClient{
ctx: ctx,
cancel: cancel,
option: option,
svcDiscovery: svcDiscovery,
tsoStreamBuilderFactory: factory,
ctx: ctx,
cancel: cancel,
option: option,
svcDiscovery: svcDiscovery,
tsoStreamBuilderFactory: factory,
tsoReqPool: &sync.Pool{
New: func() any {
return &tsoRequest{
done: make(chan error, 1),
physical: 0,
logical: 0,
}
},
},
checkTSDeadlineCh: make(chan struct{}),
checkTSODispatcherCh: make(chan struct{}, 1),
updateTSOConnectionCtxsCh: make(chan struct{}, 1),
Expand Down Expand Up @@ -155,6 +139,19 @@ func (c *tsoClient) Close() {
log.Info("tso client is closed")
}

func (c *tsoClient) getTSORequest(ctx context.Context, dcLocation string) *tsoRequest {
req := c.tsoReqPool.Get().(*tsoRequest)
// Set needed fields in the request before using it.
req.start = time.Now()
req.pool = c.tsoReqPool
req.requestCtx = ctx
req.clientCtx = c.ctx
req.physical = 0
req.logical = 0
req.dcLocation = dcLocation
return req
}

// GetTSOAllocators returns {dc-location -> TSO allocator leader URL} connection map
func (c *tsoClient) GetTSOAllocators() *sync.Map {
return &c.tsoAllocators
Expand Down
32 changes: 0 additions & 32 deletions client/tso_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,38 +115,6 @@ func (c *tsoClient) dispatchRequest(request *tsoRequest) (bool, error) {
return false, nil
}

// TSFuture is a future which promises to return a TSO.
type TSFuture interface {
// Wait gets the physical and logical time, it would block caller if data is not available yet.
Wait() (int64, int64, error)
}

func (req *tsoRequest) Wait() (physical int64, logical int64, err error) {
// If tso command duration is observed very high, the reason could be it
// takes too long for Wait() be called.
start := time.Now()
cmdDurationTSOAsyncWait.Observe(start.Sub(req.start).Seconds())
select {
case err = <-req.done:
defer trace.StartRegion(req.requestCtx, "pdclient.tsoReqDone").End()
err = errors.WithStack(err)
defer tsoReqPool.Put(req)
if err != nil {
cmdFailDurationTSO.Observe(time.Since(req.start).Seconds())
return 0, 0, err
}
physical, logical = req.physical, req.logical
now := time.Now()
cmdDurationWait.Observe(now.Sub(start).Seconds())
cmdDurationTSO.Observe(now.Sub(req.start).Seconds())
return
case <-req.requestCtx.Done():
return 0, 0, errors.WithStack(req.requestCtx.Err())
case <-req.clientCtx.Done():
return 0, 0, errors.WithStack(req.clientCtx.Err())
}
}

func (c *tsoClient) updateTSODispatcher() {
// Set up the new TSO dispatcher and batch controller.
c.GetTSOAllocators().Range(func(dcLocationKey, _ any) bool {
Expand Down
96 changes: 96 additions & 0 deletions client/tso_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright 2024 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pd

import (
"context"
"runtime/trace"
"sync"
"time"

"github.com/pingcap/errors"
)

// TSFuture is a future which promises to return a TSO.
type TSFuture interface {
// Wait gets the physical and logical time, it would block caller if data is not available yet.
Wait() (int64, int64, error)
}

var (
_ TSFuture = (*tsoRequest)(nil)
_ TSFuture = (*tsoRequestFastFail)(nil)
)

type tsoRequest struct {
requestCtx context.Context
clientCtx context.Context
done chan error
physical int64
logical int64
dcLocation string

// Runtime fields.
start time.Time
pool *sync.Pool
}

// tryDone tries to send the result to the channel, it will not block.
func (req *tsoRequest) tryDone(err error) {
select {
case req.done <- err:
default:
}
}

// Wait will block until the TSO result is ready.
func (req *tsoRequest) Wait() (physical int64, logical int64, err error) {
// If tso command duration is observed very high, the reason could be it
// takes too long for Wait() be called.
start := time.Now()
cmdDurationTSOAsyncWait.Observe(start.Sub(req.start).Seconds())
select {
case err = <-req.done:
defer trace.StartRegion(req.requestCtx, "pdclient.tsoReqDone").End()
defer req.pool.Put(req)
err = errors.WithStack(err)
if err != nil {
cmdFailDurationTSO.Observe(time.Since(req.start).Seconds())
return 0, 0, err
}
physical, logical = req.physical, req.logical
now := time.Now()
cmdDurationWait.Observe(now.Sub(start).Seconds())
cmdDurationTSO.Observe(now.Sub(req.start).Seconds())
return
case <-req.requestCtx.Done():
return 0, 0, errors.WithStack(req.requestCtx.Err())
case <-req.clientCtx.Done():
return 0, 0, errors.WithStack(req.clientCtx.Err())
}
}

type tsoRequestFastFail struct {
err error
}

func newTSORequestFastFail(err error) *tsoRequestFastFail {
return &tsoRequestFastFail{err}
}

// Wait returns the error directly.
func (req *tsoRequestFastFail) Wait() (physical int64, logical int64, err error) {
return 0, 0, req.err
}
Loading