Skip to content

Commit

Permalink
Simplify TSO Proxy implementation by using one forward stream for one…
Browse files Browse the repository at this point in the history
… gPRC stream (tikv#6572)

close tikv#6549, ref tikv#6565

Simplify tso proxy implementation by using one forward stream for one grpc.ServerStream.
tikv#6565 is a longer term solution for both follower batching and tso microservice. 
It's well implemented, but just need more time to bake, and we need a short term workable solution for now.

Signed-off-by: Bin Shi <[email protected]>
  • Loading branch information
binshi-bing authored and rleungx committed Aug 2, 2023
1 parent 18032d0 commit 1d8d406
Show file tree
Hide file tree
Showing 4 changed files with 569 additions and 14 deletions.
15 changes: 15 additions & 0 deletions pkg/utils/etcdutil/etcdutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ const (
defaultLoadFromEtcdRetryTimes = int(defaultLoadDataFromEtcdTimeout / defaultLoadFromEtcdRetryInterval)
defaultLoadBatchSize = 400
defaultWatchChangeRetryInterval = 1 * time.Second
defaultForceLoadMinimalInterval = 200 * time.Millisecond
)

// LoopWatcher loads data from etcd and sets a watcher for it.
Expand All @@ -386,6 +387,11 @@ type LoopWatcher struct {
// postEventFn is used to call after handling all events.
postEventFn func() error

// forceLoadMu is used to ensure two force loads have minimal interval.
forceLoadMu sync.Mutex
// lastTimeForceLoad is used to record the last time force loading data from etcd.
lastTimeForceLoad time.Time

// loadTimeout is used to set the timeout for loading data from etcd.
loadTimeout time.Duration
// loadRetryTimes is used to set the retry times for loading data from etcd.
Expand Down Expand Up @@ -415,6 +421,7 @@ func NewLoopWatcher(ctx context.Context, wg *sync.WaitGroup, client *clientv3.Cl
deleteFn: deleteFn,
postEventFn: postEventFn,
opts: opts,
lastTimeForceLoad: time.Now(),
loadTimeout: defaultLoadDataFromEtcdTimeout,
loadRetryTimes: defaultLoadFromEtcdRetryTimes,
loadBatchSize: defaultLoadBatchSize,
Expand Down Expand Up @@ -611,6 +618,14 @@ func (lw *LoopWatcher) load(ctx context.Context) (nextRevision int64, err error)

// ForceLoad forces to load the key.
func (lw *LoopWatcher) ForceLoad() {
lw.forceLoadMu.Lock()
if time.Since(lw.lastTimeForceLoad) < defaultForceLoadMinimalInterval {
lw.forceLoadMu.Unlock()
return
}
lw.lastTimeForceLoad = time.Now()
lw.forceLoadMu.Unlock()

select {
case lw.forceLoadCh <- struct{}{}:
default:
Expand Down
10 changes: 5 additions & 5 deletions pkg/utils/tsoutil/tso_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Request interface {
// getCount returns the count of timestamps to retrieve
getCount() uint32
// process sends request and receive response via stream.
// count defins the count of timestamps to retrieve.
// count defines the count of timestamps to retrieve.
process(forwardStream stream, count uint32, tsoProtoFactory ProtoFactory) (tsoResp, error)
// postProcess sends the response back to the sender of the request
postProcess(countSum, physical, firstLogical int64, suffixBits uint32) (int64, error)
Expand All @@ -50,7 +50,7 @@ type TSOProtoRequest struct {
stream tsopb.TSO_TsoServer
}

// NewTSOProtoRequest creats a TSOProtoRequest and returns as a Request
// NewTSOProtoRequest creates a TSOProtoRequest and returns as a Request
func NewTSOProtoRequest(forwardedHost string, clientConn *grpc.ClientConn, request *tsopb.TsoRequest, stream tsopb.TSO_TsoServer) Request {
tsoRequest := &TSOProtoRequest{
forwardedHost: forwardedHost,
Expand All @@ -77,7 +77,7 @@ func (r *TSOProtoRequest) getCount() uint32 {
}

// process sends request and receive response via stream.
// count defins the count of timestamps to retrieve.
// count defines the count of timestamps to retrieve.
func (r *TSOProtoRequest) process(forwardStream stream, count uint32, tsoProtoFactory ProtoFactory) (tsoResp, error) {
return forwardStream.process(r.request.GetHeader().GetClusterId(), count,
r.request.GetHeader().GetKeyspaceId(), r.request.GetHeader().GetKeyspaceGroupId(), r.request.GetDcLocation())
Expand Down Expand Up @@ -111,7 +111,7 @@ type PDProtoRequest struct {
stream pdpb.PD_TsoServer
}

// NewPDProtoRequest creats a PDProtoRequest and returns as a Request
// NewPDProtoRequest creates a PDProtoRequest and returns as a Request
func NewPDProtoRequest(forwardedHost string, clientConn *grpc.ClientConn, request *pdpb.TsoRequest, stream pdpb.PD_TsoServer) Request {
tsoRequest := &PDProtoRequest{
forwardedHost: forwardedHost,
Expand All @@ -138,7 +138,7 @@ func (r *PDProtoRequest) getCount() uint32 {
}

// process sends request and receive response via stream.
// count defins the count of timestamps to retrieve.
// count defines the count of timestamps to retrieve.
func (r *PDProtoRequest) process(forwardStream stream, count uint32, tsoProtoFactory ProtoFactory) (tsoResp, error) {
return forwardStream.process(r.request.GetHeader().GetClusterId(), count,
utils.DefaultKeyspaceID, utils.DefaultKeyspaceGroupID, r.request.GetDcLocation())
Expand Down
182 changes: 173 additions & 9 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ var (
ErrNotStarted = status.Errorf(codes.Unavailable, "server not started")
ErrSendHeartbeatTimeout = status.Errorf(codes.DeadlineExceeded, "send heartbeat timeout")
ErrNotFoundTSOAddr = status.Errorf(codes.NotFound, "not found tso address")
ErrForwardTSOTimeout = status.Errorf(codes.DeadlineExceeded, "forward tso request timeout")
)

// GrpcServer wraps Server to provide grpc service.
Expand Down Expand Up @@ -324,6 +325,10 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb

// Tso implements gRPC PDServer.
func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
if s.IsAPIServiceMode() {
return s.forwardTSO(stream)
}

var (
doneCh chan struct{}
errCh chan error
Expand Down Expand Up @@ -361,15 +366,8 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
errCh = make(chan error)
}

var tsoProtoFactory tsoutil.ProtoFactory
if s.IsAPIServiceMode() {
tsoProtoFactory = s.tsoProtoFactory
} else {
tsoProtoFactory = s.pdProtoFactory
}

tsoRequest := tsoutil.NewPDProtoRequest(forwardedHost, clientConn, request, stream)
s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, tsoProtoFactory, doneCh, errCh, s.tsoPrimaryWatcher)
s.tsoDispatcher.DispatchRequest(ctx, tsoRequest, s.pdProtoFactory, doneCh, errCh, s.tsoPrimaryWatcher)
continue
}

Expand All @@ -379,7 +377,8 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
return status.Errorf(codes.Unknown, "server not started")
}
if request.GetHeader().GetClusterId() != s.clusterID {
return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId())
return status.Errorf(codes.FailedPrecondition,
"mismatch cluster id, need %d but got %d", s.clusterID, request.GetHeader().GetClusterId())
}
count := request.GetCount()
ts, err := s.tsoAllocatorManager.HandleRequest(request.GetDcLocation(), count)
Expand All @@ -398,6 +397,162 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
}
}

// forwardTSO forward the TSO requests to the TSO service.
func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error {
var (
server = &tsoServer{stream: stream}
forwardStream tsopb.TSO_TsoClient
cancel context.CancelFunc
lastForwardedHost string
)
defer func() {
if forwardStream != nil {
forwardStream.CloseSend()
}
// cancel the forward stream
if cancel != nil {
cancel()
}
}()

for {
select {
case <-s.ctx.Done():
return errors.WithStack(s.ctx.Err())
case <-stream.Context().Done():
return stream.Context().Err()
default:
}

request, err := server.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return errors.WithStack(err)
}
if request.GetCount() == 0 {
err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive")
return status.Errorf(codes.Unknown, err.Error())
}

forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName)
if !ok || len(forwardedHost) == 0 {
return errors.WithStack(ErrNotFoundTSOAddr)
}
if forwardStream == nil || lastForwardedHost != forwardedHost {
if forwardStream != nil {
forwardStream.CloseSend()
}
if cancel != nil {
cancel()
}

clientConn, err := s.getDelegateClient(s.ctx, forwardedHost)
if err != nil {
return errors.WithStack(err)
}
forwardStream, cancel, err = s.createTSOForwardStream(clientConn)
if err != nil {
return errors.WithStack(err)
}
lastForwardedHost = forwardedHost
}

tsoReq := &tsopb.TsoRequest{
Header: &tsopb.RequestHeader{
ClusterId: request.GetHeader().GetClusterId(),
SenderId: request.GetHeader().GetSenderId(),
KeyspaceId: utils.DefaultKeyspaceID,
KeyspaceGroupId: utils.DefaultKeyspaceGroupID,
},
Count: request.GetCount(),
DcLocation: request.GetDcLocation(),
}
if err := forwardStream.Send(tsoReq); err != nil {
return errors.WithStack(err)
}

tsopbResp, err := forwardStream.Recv()
if err != nil {
if strings.Contains(err.Error(), errs.NotLeaderErr) {
s.tsoPrimaryWatcher.ForceLoad()
}
return errors.WithStack(err)
}

// The error types defined for tsopb and pdpb are different, so we need to convert them.
var pdpbErr *pdpb.Error
tsopbErr := tsopbResp.GetHeader().GetError()
if tsopbErr != nil {
if tsopbErr.Type == tsopb.ErrorType_OK {
pdpbErr = &pdpb.Error{
Type: pdpb.ErrorType_OK,
Message: tsopbErr.GetMessage(),
}
} else {
// TODO: specify FORWARD FAILURE error type instead of UNKNOWN.
pdpbErr = &pdpb.Error{
Type: pdpb.ErrorType_UNKNOWN,
Message: tsopbErr.GetMessage(),
}
}
}

response := &pdpb.TsoResponse{
Header: &pdpb.ResponseHeader{
ClusterId: tsopbResp.GetHeader().GetClusterId(),
Error: pdpbErr,
},
Count: tsopbResp.GetCount(),
Timestamp: tsopbResp.GetTimestamp(),
}
if err := server.Send(response); err != nil {
return errors.WithStack(err)
}
}
}

// tsoServer wraps PD_TsoServer to ensure when any error
// occurs on Send() or Recv(), both endpoints will be closed.
type tsoServer struct {
stream pdpb.PD_TsoServer
closed int32
}

func (s *tsoServer) Send(m *pdpb.TsoResponse) error {
if atomic.LoadInt32(&s.closed) == 1 {
return io.EOF
}
done := make(chan error, 1)
go func() {
defer logutil.LogPanic()
done <- s.stream.Send(m)
}()
select {
case err := <-done:
if err != nil {
atomic.StoreInt32(&s.closed, 1)
}
return errors.WithStack(err)
case <-time.After(tsoutil.DefaultTSOProxyTimeout):
atomic.StoreInt32(&s.closed, 1)
return ErrForwardTSOTimeout
}
}

func (s *tsoServer) Recv() (*pdpb.TsoRequest, error) {
if atomic.LoadInt32(&s.closed) == 1 {
return nil, io.EOF
}
req, err := s.stream.Recv()
if err != nil {
atomic.StoreInt32(&s.closed, 1)
return nil, errors.WithStack(err)
}
return req, nil
}

func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context) (forwardedHost string, err error) {
if s.IsAPIServiceMode() {
var ok bool
Expand Down Expand Up @@ -1903,6 +2058,15 @@ func forwardRegionHeartbeatClientToServer(forwardStream pdpb.PD_RegionHeartbeatC
}
}

func (s *GrpcServer) createTSOForwardStream(client *grpc.ClientConn) (tsopb.TSO_TsoClient, context.CancelFunc, error) {
done := make(chan struct{})
ctx, cancel := context.WithCancel(s.ctx)
go checkStream(ctx, cancel, done)
forwardStream, err := tsopb.NewTSOClient(client).Tso(ctx)
done <- struct{}{}
return forwardStream, cancel, err
}

func (s *GrpcServer) createReportBucketsForwardStream(client *grpc.ClientConn) (pdpb.PD_ReportBucketsClient, context.CancelFunc, error) {
done := make(chan struct{})
ctx, cancel := context.WithCancel(s.ctx)
Expand Down
Loading

0 comments on commit 1d8d406

Please sign in to comment.