Skip to content

Commit

Permalink
tso: merge lastTS and lastArrivalTS into an atomic pointer (tikv#1054)
Browse files Browse the repository at this point in the history
* fix the issue that stale timestamp may be a future one

Signed-off-by: you06 <[email protected]>

* add regression test

Signed-off-by: you06 <[email protected]>

* lazy init lastTSO

Signed-off-by: you06 <[email protected]>

* fix panic

Signed-off-by: you06 <[email protected]>

* address comment

Signed-off-by: you06 <[email protected]>

---------

Signed-off-by: you06 <[email protected]>
  • Loading branch information
you06 authored Nov 15, 2023
1 parent 6659170 commit 7c96dfd
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 62 deletions.
18 changes: 3 additions & 15 deletions oracle/oracles/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,8 @@ func NewEmptyPDOracle() oracle.Oracle {
func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
lastTSPointer := lastTSInterface.(*uint64)
atomic.StoreUint64(lastTSPointer, ts)
lasTSArrivalInterface, _ := o.lastArrivalTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
lasTSArrivalPointer := lasTSArrivalInterface.(*uint64)
atomic.StoreUint64(lasTSArrivalPointer, uint64(time.Now().Unix()*1000))
}
setEmptyPDOracleLastArrivalTs(oc, ts)
}

// setEmptyPDOracleLastArrivalTs exports PD oracle's global last ts to test.
func setEmptyPDOracleLastArrivalTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
o.setLastArrivalTS(ts, oracle.GlobalTxnScope)
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, &atomic.Pointer[lastTSO]{})
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
lastTSPointer.Store(&lastTSO{tso: ts, arrival: ts})
}
}
82 changes: 35 additions & 47 deletions oracle/oracles/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ const slowDist = 30 * time.Millisecond
// pdOracle is an Oracle that uses a placement driver client as source.
type pdOracle struct {
c pd.Client
// txn_scope (string) -> lastTSPointer (*uint64)
// txn_scope (string) -> lastTSPointer (*atomic.Pointer[lastTSO])
lastTSMap sync.Map
// txn_scope (string) -> lastArrivalTSPointer (*uint64)
lastArrivalTSMap sync.Map
quit chan struct{}
quit chan struct{}
}

// lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched.
type lastTSO struct {
tso uint64
arrival uint64
}

// NewPdOracle create an Oracle that uses a pd client source.
Expand Down Expand Up @@ -163,63 +167,51 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, new(uint64))
}
lastTSPointer := lastTSInterface.(*uint64)
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
return
}
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
break
}
}
o.setLastArrivalTS(o.getArrivalTimestamp(), txnScope)
}

func (o *pdOracle) setLastArrivalTS(ts uint64, txnScope string) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
current := &lastTSO{
tso: ts,
arrival: o.getArrivalTimestamp(),
}
lastTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
lastTSInterface, _ = o.lastArrivalTSMap.LoadOrStore(txnScope, new(uint64))
pointer := &atomic.Pointer[lastTSO]{}
pointer.Store(current)
// do not handle the stored case, because it only runs once.
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, pointer)
}
lastTSPointer := lastTSInterface.(*uint64)
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
last := lastTSPointer.Load()
if current.tso <= last.tso || current.arrival <= last.arrival {
return
}
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
if lastTSPointer.CompareAndSwap(last, current) {
return
}
}
}

func (o *pdOracle) getLastTS(txnScope string) (uint64, bool) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
last, exist := o.getLastTSWithArrivalTS(txnScope)
if !exist {
return 0, false
}
return atomic.LoadUint64(lastTSInterface.(*uint64)), true
return last.tso, true
}

func (o *pdOracle) getLastArrivalTS(txnScope string) (uint64, bool) {
func (o *pdOracle) getLastTSWithArrivalTS(txnScope string) (*lastTSO, bool) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastArrivalTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
return 0, false
return nil, false
}
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
last := lastTSPointer.Load()
if last == nil {
return nil, false
}
return atomic.LoadUint64(lastArrivalTSInterface.(*uint64)), true
return last, true
}

func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) {
Expand Down Expand Up @@ -293,22 +285,18 @@ func (o *pdOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *orac
}

func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64, error) {
ts, ok := o.getLastTS(txnScope)
last, ok := o.getLastTSWithArrivalTS(txnScope)
if !ok {
return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope)
}
arrivalTS, ok := o.getLastArrivalTS(txnScope)
if !ok {
return 0, errors.Errorf("get stale arrival timestamp fail, txnScope: %s", txnScope)
}
ts, arrivalTS := last.tso, last.arrival
arrivalTime := oracle.GetTimeFromTS(arrivalTS)
physicalTime := oracle.GetTimeFromTS(ts)
if uint64(physicalTime.Unix()) <= prevSecond {
return 0, errors.Errorf("invalid prevSecond %v", prevSecond)
}

staleTime := physicalTime.Add(-arrivalTime.Sub(time.Now().Add(-time.Duration(prevSecond) * time.Second)))

staleTime := physicalTime.Add(time.Now().Add(-time.Duration(prevSecond) * time.Second).Sub(arrivalTime))
return oracle.GoTimeToTS(staleTime), nil
}

Expand Down
32 changes: 32 additions & 0 deletions oracle/oracles/pd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,35 @@ func TestPdOracle_GetStaleTimestamp(t *testing.T) {
assert.NotNil(t, err)
assert.Regexp(t, ".*invalid prevSecond.*", err.Error())
}

func TestNonFutureStaleTSO(t *testing.T) {
o := oracles.NewEmptyPDOracle()
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now()))
for i := 0; i < 100; i++ {
time.Sleep(10 * time.Millisecond)
now := time.Now()
upperBound := now.Add(5 * time.Millisecond) // allow 5ms time drift

closeCh := make(chan struct{})
go func() {
time.Sleep(100 * time.Microsecond)
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now))
close(closeCh)
}()
CHECK:
for {
select {
case <-closeCh:
break CHECK
default:
ts, err := o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 0)
assert.Nil(t, err)
staleTime := oracle.GetTimeFromTS(ts)
if staleTime.After(upperBound) && time.Since(now) < time.Millisecond /* only check staleTime within 1ms */ {
assert.Less(t, staleTime, upperBound, i)
t.FailNow()
}
}
}
}
}

0 comments on commit 7c96dfd

Please sign in to comment.