From c3023465e2b7a66159e77a6b57387a8d16a88afd Mon Sep 17 00:00:00 2001 From: CharlesCheung96 Date: Tue, 9 Apr 2024 16:05:57 +0800 Subject: [PATCH] limit the maximum number of cached txns in mysql worker --- cdc/sink/dmlsink/txn/txn_dml_sink.go | 15 +-- cdc/sink/dmlsink/txn/worker.go | 52 ++++------ pkg/causality/conflict_detector.go | 46 ++++++--- pkg/causality/internal/node.go | 141 ++++++++++++++------------- pkg/causality/internal/node_test.go | 69 +++++++------ pkg/causality/tests/driver.go | 8 +- pkg/causality/tests/worker.go | 3 +- pkg/causality/worker.go | 93 +++++++++++++++++- 8 files changed, 270 insertions(+), 157 deletions(-) diff --git a/cdc/sink/dmlsink/txn/txn_dml_sink.go b/cdc/sink/dmlsink/txn/txn_dml_sink.go index ce5ec8b9bdc..8d338595748 100644 --- a/cdc/sink/dmlsink/txn/txn_dml_sink.go +++ b/cdc/sink/dmlsink/txn/txn_dml_sink.go @@ -43,7 +43,7 @@ var _ dmlsink.EventSink[*model.SingleTableTxn] = (*dmlSink)(nil) type dmlSink struct { alive struct { sync.RWMutex - conflictDetector *causality.ConflictDetector[*worker, *txnEvent] + conflictDetector *causality.ConflictDetector[*txnEvent] isDead bool } @@ -107,15 +107,19 @@ func newSink(ctx context.Context, dead: make(chan struct{}), } + sink.alive.conflictDetector = causality.NewConflictDetector[*txnEvent](conflictDetectorSlots, causality.WorkerOption{ + WorkerCount: len(backends), + Size: 1024, + IsBlock: true, + }) + g, ctx1 := errgroup.WithContext(ctx) for i, backend := range backends { w := newWorker(ctx1, changefeedID, i, backend, len(backends)) - g.Go(func() error { return w.run() }) + g.Go(func() error { return w.runLoop(sink.alive.conflictDetector) }) sink.workers = append(sink.workers, w) } - sink.alive.conflictDetector = causality.NewConflictDetector[*worker, *txnEvent](sink.workers, conflictDetectorSlots) - sink.wg.Add(1) go func() { defer sink.wg.Done() @@ -165,9 +169,6 @@ func (s *dmlSink) Close() { } s.wg.Wait() - for _, w := range s.workers { - w.close() - } if s.statistics != nil { s.statistics.Close() } diff --git a/cdc/sink/dmlsink/txn/worker.go b/cdc/sink/dmlsink/txn/worker.go index 014ca2d234a..c004b6bbe85 100644 --- a/cdc/sink/dmlsink/txn/worker.go +++ b/cdc/sink/dmlsink/txn/worker.go @@ -22,23 +22,17 @@ import ( "github.com/pingcap/tiflow/cdc/model" "github.com/pingcap/tiflow/cdc/sink/metrics/txn" "github.com/pingcap/tiflow/cdc/sink/tablesink/state" - "github.com/pingcap/tiflow/pkg/chann" + "github.com/pingcap/tiflow/pkg/causality" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" ) -type txnWithNotifier struct { - *txnEvent - postTxnExecuted func() -} - type worker struct { ctx context.Context changefeed string workerCount int ID int - txnCh *chann.DrainableChann[txnWithNotifier] backend backend // Metrics. @@ -58,13 +52,13 @@ func newWorker(ctx context.Context, changefeedID model.ChangeFeedID, ID int, backend backend, workerCount int, ) *worker { wid := fmt.Sprintf("%d", ID) + return &worker{ ctx: ctx, changefeed: fmt.Sprintf("%s.%s", changefeedID.Namespace, changefeedID.ID), workerCount: workerCount, ID: ID, - txnCh: chann.NewAutoDrainChann[txnWithNotifier](chann.Cap(-1 /*unbounded*/)), backend: backend, metricConflictDetectDuration: txn.ConflictDetectDuration.WithLabelValues(changefeedID.Namespace, changefeedID.ID), @@ -79,21 +73,8 @@ func newWorker(ctx context.Context, changefeedID model.ChangeFeedID, } } -// Add adds a txnEvent to the worker. -// The worker will call postTxnExecuted() after the txn executed. -// The postTxnExecuted will remove the txn related Node in the conflict detector's -// dependency graph and resolve related dependencies for these transacitons -// which depend on this executed txn. -func (w *worker) Add(txn *txnEvent, postTxnExecuted func()) { - w.txnCh.In() <- txnWithNotifier{txn, postTxnExecuted} -} - -func (w *worker) close() { - w.txnCh.CloseAndDrain() -} - -// Continuously get events from txnCh and call backend flush based on conditions. -func (w *worker) run() error { +// Run a loop. +func (w *worker) runLoop(conflictDetector *causality.ConflictDetector[*txnEvent]) error { defer func() { if err := w.backend.Close(); err != nil { log.Info("Transaction dmlSink backend close fail", @@ -107,6 +88,7 @@ func (w *worker) run() error { zap.Int("workerID", w.ID)) start := time.Now() + txnCh := conflictDetector.GetWorkerOutput(int64(w.ID)) for { select { case <-w.ctx.Done(): @@ -114,18 +96,18 @@ func (w *worker) run() error { zap.String("changefeedID", w.changefeed), zap.Int("workerID", w.ID)) return nil - case txn := <-w.txnCh.Out(): + case txn := <-txnCh: // we get the data from txnCh.out until no more data here or reach the state that can be flushed. // If no more data in txnCh.out, and also not reach the state that can be flushed, // we will wait for 10ms and then do flush to avoid too much flush with small amount of txns. - if txn.txnEvent != nil { - needFlush := w.onEvent(txn) + if txn.TxnEvent != nil { + needFlush := w.onEvent(txn.TxnEvent, txn.PostTxnExecuted) if !needFlush { delay := time.NewTimer(w.flushInterval) for !needFlush { select { - case txn := <-w.txnCh.Out(): - needFlush = w.onEvent(txn) + case txn := <-txnCh: + needFlush = w.onEvent(txn.TxnEvent, txn.PostTxnExecuted) case <-delay.C: needFlush = true } @@ -157,24 +139,24 @@ func (w *worker) run() error { } // onEvent is called when a new event is received. -// It returns true if it needs flush immediately. -func (w *worker) onEvent(txn txnWithNotifier) bool { +// It returns true if the event is sent to backend. +func (w *worker) onEvent(txn *txnEvent, postTxnExecuted func()) bool { w.hasPending = true - if txn.txnEvent.GetTableSinkState() != state.TableSinkSinking { + if txn.GetTableSinkState() != state.TableSinkSinking { // The table where the event comes from is in stopping, so it's safe // to drop the event directly. - txn.txnEvent.Callback() + txn.Callback() // Still necessary to append the callbacks into the pending list. - w.postTxnExecutedCallbacks = append(w.postTxnExecutedCallbacks, txn.postTxnExecuted) + w.postTxnExecutedCallbacks = append(w.postTxnExecutedCallbacks, postTxnExecuted) return false } w.metricConflictDetectDuration.Observe(txn.conflictResolved.Sub(txn.start).Seconds()) w.metricQueueDuration.Observe(time.Since(txn.start).Seconds()) w.metricTxnWorkerHandledRows.Add(float64(len(txn.Event.Rows))) - w.postTxnExecutedCallbacks = append(w.postTxnExecutedCallbacks, txn.postTxnExecuted) - return w.backend.OnTxnEvent(txn.txnEvent.TxnCallbackableEvent) + w.postTxnExecutedCallbacks = append(w.postTxnExecutedCallbacks, postTxnExecuted) + return w.backend.OnTxnEvent(txn.TxnCallbackableEvent) } // doFlush flushes the backend. diff --git a/pkg/causality/conflict_detector.go b/pkg/causality/conflict_detector.go index f1f7495e77c..09fd8213c79 100644 --- a/pkg/causality/conflict_detector.go +++ b/pkg/causality/conflict_detector.go @@ -16,17 +16,19 @@ package causality import ( "sync" + "github.com/pingcap/log" "github.com/pingcap/tiflow/pkg/causality/internal" "github.com/pingcap/tiflow/pkg/chann" "go.uber.org/atomic" + "go.uber.org/zap" ) // ConflictDetector implements a logic that dispatches transaction // to different workers in a way that transactions modifying the same // keys are never executed concurrently and have their original orders // preserved. -type ConflictDetector[Worker worker[Txn], Txn txnEvent] struct { - workers []Worker +type ConflictDetector[Txn txnEvent] struct { + workers []worker[Txn] // slots are used to find all unfinished transactions // conflicting with an incoming transactions. @@ -49,12 +51,11 @@ type txnFinishedEvent struct { } // NewConflictDetector creates a new ConflictDetector. -func NewConflictDetector[Worker worker[Txn], Txn txnEvent]( - workers []Worker, - numSlots uint64, -) *ConflictDetector[Worker, Txn] { - ret := &ConflictDetector[Worker, Txn]{ - workers: workers, +func NewConflictDetector[Txn txnEvent]( + numSlots uint64, opt WorkerOption, +) *ConflictDetector[Txn] { + ret := &ConflictDetector[Txn]{ + workers: make([]worker[Txn], opt.WorkerCount), slots: internal.NewSlots[*internal.Node](numSlots), numSlots: numSlots, notifiedNodes: chann.NewAutoDrainChann[func()](), @@ -62,6 +63,10 @@ func NewConflictDetector[Worker worker[Txn], Txn txnEvent]( closeCh: make(chan struct{}), } + for i := 0; i < opt.WorkerCount; i++ { + ret.workers[i] = newWorker[Txn](opt) + } + ret.wg.Add(1) go func() { defer ret.wg.Done() @@ -75,10 +80,10 @@ func NewConflictDetector[Worker worker[Txn], Txn txnEvent]( // // NOTE: if multiple threads access this concurrently, // Txn.GenSortedDedupKeysHash must be sorted by the slot index. -func (d *ConflictDetector[Worker, Txn]) Add(txn Txn) { +func (d *ConflictDetector[Txn]) Add(txn Txn) { sortedKeysHash := txn.GenSortedDedupKeysHash(d.numSlots) node := internal.NewNode() - node.OnResolved = func(workerID int64) { + node.TrySendToWorker = func(workerID int64) bool { // This callback is called after the transaction is executed. postTxnExecuted := func() { // After this transaction is executed, we can remove the node from the graph, @@ -91,20 +96,21 @@ func (d *ConflictDetector[Worker, Txn]) Add(txn Txn) { d.garbageNodes.In() <- txnFinishedEvent{node, sortedKeysHash} } // Send this txn to related worker as soon as all dependencies are resolved. - d.sendToWorker(txn, postTxnExecuted, workerID) + return d.sendToWorker(txn, postTxnExecuted, workerID) } node.RandWorkerID = func() int64 { return d.nextWorkerID.Add(1) % int64(len(d.workers)) } node.OnNotified = func(callback func()) { d.notifiedNodes.In() <- callback } + node.WorkerCount = int64(len(d.workers)) d.slots.Add(node, sortedKeysHash) } // Close closes the ConflictDetector. -func (d *ConflictDetector[Worker, Txn]) Close() { +func (d *ConflictDetector[Txn]) Close() { close(d.closeCh) d.wg.Wait() } -func (d *ConflictDetector[Worker, Txn]) runBackgroundTasks() { +func (d *ConflictDetector[Txn]) runBackgroundTasks() { defer func() { d.notifiedNodes.CloseAndDrain() d.garbageNodes.CloseAndDrain() @@ -126,11 +132,19 @@ func (d *ConflictDetector[Worker, Txn]) runBackgroundTasks() { } // sendToWorker should not call txn.Callback if it returns an error. -func (d *ConflictDetector[Worker, Txn]) sendToWorker(txn Txn, postTxnExecuted func(), workerID int64) { +func (d *ConflictDetector[Txn]) sendToWorker(txn Txn, postTxnExecuted func(), workerID int64) bool { if workerID < 0 { - panic("must assign with a valid workerID") + log.Panic("must assign with a valid workerID", zap.Int64("workerID", workerID)) } txn.OnConflictResolved() worker := d.workers[workerID] - worker.Add(txn, postTxnExecuted) + return worker.add(TxnWithNotifier[Txn]{txn, postTxnExecuted}) +} + +// GetWorkerOutput returns the output channel of the worker. +func (d *ConflictDetector[Txn]) GetWorkerOutput(workerID int64) <-chan TxnWithNotifier[Txn] { + if workerID < 0 { + log.Panic("must assign with a valid workerID", zap.Int64("workerID", workerID)) + } + return d.workers[workerID].out() } diff --git a/pkg/causality/internal/node.go b/pkg/causality/internal/node.go index 75324918a5b..95592ff069d 100644 --- a/pkg/causality/internal/node.go +++ b/pkg/causality/internal/node.go @@ -15,10 +15,11 @@ package internal import ( "sync" - stdatomic "sync/atomic" + "sync/atomic" "github.com/google/btree" - "go.uber.org/atomic" + "github.com/pingcap/log" + "go.uber.org/zap" ) type ( @@ -28,11 +29,12 @@ type ( const ( unassigned = workerID(-2) assignedToAny = workerID(-1) + invalidNodeID = int64(-1) ) var ( - nextNodeID = atomic.NewInt64(0) + nextNodeID = atomic.Int64{} // btreeFreeList is a shared free list used by all // btrees in order to lessen the burden of GC. @@ -49,9 +51,10 @@ type Node struct { id int64 // Called when all dependencies are resolved. - OnResolved func(id workerID) + TrySendToWorker func(id workerID) bool // Set the id generator to get a random ID. RandWorkerID func() workerID + WorkerCount int64 // Set the callback that the node is notified. OnNotified func(callback func()) @@ -85,7 +88,7 @@ type Node struct { func NewNode() (ret *Node) { defer func() { ret.id = genNextNodeID() - ret.OnResolved = nil + ret.TrySendToWorker = nil ret.RandWorkerID = nil ret.totalDependencies = 0 ret.resolvedDependencies = 0 @@ -106,7 +109,7 @@ func (n *Node) NodeID() int64 { // DependOn implements interface internal.SlotNode. func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) { - resolvedDependencies, removedDependencies := int32(0), int32(0) + resolvedDependencies := int32(0) depend := func(target *Node) { if target == nil { @@ -115,14 +118,14 @@ func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) // with any other nodes. However it's still necessary to track // it because Node.tryResolve needs to counting the number of // resolved dependencies. - resolvedDependencies = stdatomic.AddInt32(&n.resolvedDependencies, 1) - stdatomic.StoreInt64(&n.resolvedList[resolvedDependencies-1], assignedToAny) - removedDependencies = stdatomic.AddInt32(&n.removedDependencies, 1) + resolvedDependencies = atomic.AddInt32(&n.resolvedDependencies, 1) + atomic.StoreInt64(&n.resolvedList[resolvedDependencies-1], assignedToAny) + atomic.AddInt32(&n.removedDependencies, 1) return } if target.id == n.id { - panic("you cannot depend on yourself") + log.Panic("you cannot depend on yourself") } // The target node might be removed or modified in other places, for example @@ -134,17 +137,17 @@ func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) // The target has already been assigned to a worker. // In this case, record the worker ID in `resolvedList`, and this node // probably can be sent to the same worker and executed sequentially. - resolvedDependencies = stdatomic.AddInt32(&n.resolvedDependencies, 1) - stdatomic.StoreInt64(&n.resolvedList[resolvedDependencies-1], target.assignedTo) + resolvedDependencies = atomic.AddInt32(&n.resolvedDependencies, 1) + atomic.StoreInt64(&n.resolvedList[resolvedDependencies-1], target.assignedTo) } // Add the node to the target's dependers if the target has not been removed. if target.removed { // The target has already been removed. - removedDependencies = stdatomic.AddInt32(&n.removedDependencies, 1) + atomic.AddInt32(&n.removedDependencies, 1) } else if _, exist := target.getOrCreateDependers().ReplaceOrInsert(n); exist { // Should never depend on a target redundantly. - panic("should never exist") + log.Panic("should never exist") } } @@ -153,7 +156,7 @@ func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) // ?: why gen new ID here? n.id = genNextNodeID() - // `totalDependcies` and `resolvedList` must be initialized before depending on any targets. + // `totalDependencies` and `resolvedList` must be initialized before depending on any targets. n.totalDependencies = int32(len(dependencyNodes) + noDependencyKeyCnt) n.resolvedList = make([]int64, 0, n.totalDependencies) for i := 0; i < int(n.totalDependencies); i++ { @@ -167,7 +170,7 @@ func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) depend(nil) } - n.maybeResolve(resolvedDependencies, removedDependencies) + n.maybeResolve() } // Remove implements interface internal.SlotNode. @@ -179,8 +182,10 @@ func (n *Node) Remove() { if n.dependers != nil { // `mu` must be holded during accessing dependers. n.dependers.Ascend(func(node *Node) bool { - removedDependencies := stdatomic.AddInt32(&node.removedDependencies, 1) - node.maybeResolve(0, removedDependencies) + atomic.AddInt32(&node.removedDependencies, 1) + node.OnNotified(func() { + node.maybeResolve() + }) return true }) n.dependers.Clear(true) @@ -195,12 +200,12 @@ func (n *Node) Free() { n.mu.Lock() defer n.mu.Unlock() if n.id == invalidNodeID { - panic("double free") + log.Panic("double free") } n.id = invalidNodeID - n.OnResolved = nil - n.RandWorkerID = nil + n.TrySendToWorker = nil + // n.RandWorkerID = nil // TODO: reuse node if necessary. Currently it's impossible if async-notify is used. // The reason is a node can step functions `assignTo`, `Remove`, `Free`, then `assignTo`. @@ -208,28 +213,33 @@ func (n *Node) Free() { // or not. } -// assignTo assigns a node to a worker. Returns `true` on success. -func (n *Node) assignTo(workerID int64) bool { +// tryAssignTo assigns a node to a worker. Returns `true` on success. +func (n *Node) tryAssignTo(workerID int64) bool { n.mu.Lock() defer n.mu.Unlock() if n.assignedTo != unassigned { // Already resolved by some other guys. - return false + return true } - n.assignedTo = workerID - if n.OnResolved != nil { - n.OnResolved(workerID) - n.OnResolved = nil + if n.TrySendToWorker != nil { + ok := n.TrySendToWorker(workerID) + if !ok { + return false + } + n.TrySendToWorker = nil } + n.assignedTo = workerID if n.dependers != nil { // `mu` must be holded during accessing dependers. n.dependers.Ascend(func(node *Node) bool { - resolvedDependencies := stdatomic.AddInt32(&node.resolvedDependencies, 1) - stdatomic.StoreInt64(&node.resolvedList[resolvedDependencies-1], n.assignedTo) - node.maybeResolve(resolvedDependencies, 0) + resolvedDependencies := atomic.AddInt32(&node.resolvedDependencies, 1) + atomic.StoreInt64(&node.resolvedList[resolvedDependencies-1], n.assignedTo) + node.OnNotified(func() { + node.maybeResolve() + }) return true }) } @@ -237,18 +247,23 @@ func (n *Node) assignTo(workerID int64) bool { return true } -func (n *Node) maybeResolve(resolvedDependencies, removedDependencies int32) { - if workerNum, ok := n.tryResolve(resolvedDependencies, removedDependencies); ok { - if workerNum < 0 { - panic("Node.tryResolve must return a valid worker ID") +func (n *Node) maybeResolve() { + if workerID, ok := n.tryResolve(); ok { + if workerID >= 0 { + n.tryAssignTo(workerID) + return } - if n.OnNotified != nil { - // Notify the conflict detector background worker to assign the node to the worker asynchronously. - n.OnNotified(func() { n.assignTo(workerNum) }) - } else { - // Assign the node to the worker directly. - n.assignTo(workerNum) + + if workerID != assignedToAny { + log.Panic("invalid worker ID", zap.Uint64("workerID", uint64(workerID))) } + workerID := n.RandWorkerID() + if n.tryAssignTo(workerID) { + return + } + n.OnNotified(func() { + n.maybeResolve() + }) } } @@ -256,39 +271,38 @@ func (n *Node) maybeResolve(resolvedDependencies, removedDependencies int32) { // Returns (_, false) if there is a conflict, // returns (rand, true) if there is no conflict, // returns (N, true) if only worker N can be used. -func (n *Node) tryResolve(resolvedDependencies, removedDependencies int32) (int64, bool) { - assignedTo, resolved := n.doResolve(resolvedDependencies, removedDependencies) - if resolved && assignedTo == assignedToAny { - assignedTo = n.RandWorkerID() - } - return assignedTo, resolved -} - -func (n *Node) doResolve(resolvedDependencies, removedDependencies int32) (int64, bool) { +func (n *Node) tryResolve() (int64, bool) { if n.totalDependencies == 0 { // No conflicts, can select any workers. return assignedToAny, true } + removedDependencies := atomic.LoadInt32(&n.removedDependencies) + if removedDependencies == n.totalDependencies { + // All dependcies are removed, so assign the node to any worker is fine. + return assignedToAny, true + } + + resolvedDependencies := atomic.LoadInt32(&n.resolvedDependencies) if resolvedDependencies == n.totalDependencies { - firstDep := stdatomic.LoadInt64(&n.resolvedList[0]) + firstDep := atomic.LoadInt64(&n.resolvedList[0]) hasDiffDep := false for i := 1; i < int(n.totalDependencies); i++ { - curr := stdatomic.LoadInt64(&n.resolvedList[i]) - // // Todo: simplify assign to logic, only resolve dependencies nodes after - // // corresponding transactions are executed. - // // - // // In DependOn, depend(nil) set resolvedList[i] to assignedToAny - // // for these no dependecy keys. - // if curr == assignedToAny { - // continue - // } + curr := atomic.LoadInt64(&n.resolvedList[i]) + // Todo: simplify assign to logic, only resolve dependencies nodes after + // corresponding transactions are executed. + // + // In DependOn, depend(nil) set resolvedList[i] to assignedToAny + // for these no dependecy keys. + if curr == assignedToAny { + continue + } if firstDep != curr { hasDiffDep = true break } } - if !hasDiffDep { + if !hasDiffDep && firstDep != unassigned { // If all dependency nodes are assigned to the same worker, we can assign // this node to the same worker directly, and they will execute sequentially. // On the other hand, if dependency nodes are assigned to different workers, @@ -298,11 +312,6 @@ func (n *Node) doResolve(resolvedDependencies, removedDependencies int32) (int64 } } - // All dependcies are removed, so assign the node to any worker is fine. - if removedDependencies == n.totalDependencies { - return assignedToAny, true - } - return unassigned, false } diff --git a/pkg/causality/internal/node_test.go b/pkg/causality/internal/node_test.go index 776f5359550..d0587ea0c90 100644 --- a/pkg/causality/internal/node_test.go +++ b/pkg/causality/internal/node_test.go @@ -21,15 +21,24 @@ import ( var _ SlotNode[*Node] = &Node{} // Asserts that *Node implements SlotNode[*Node]. +func newNodeForTest() *Node { + node := NewNode() + node.OnNotified = func(callback func()) { + // run the callback immediately + callback() + } + return node +} + func TestNodeFree(t *testing.T) { // This case should not be run parallel to // others, for fear that the use-after-free - // will race with NewNode() in other cases. + // will race with newNodeForTest() in other cases. - nodeA := NewNode() + nodeA := newNodeForTest() nodeA.Free() - nodeA = NewNode() + nodeA = newNodeForTest() nodeA.Free() // Double freeing should panic. @@ -41,8 +50,8 @@ func TestNodeFree(t *testing.T) { func TestNodeEquals(t *testing.T) { t.Parallel() - nodeA := NewNode() - nodeB := NewNode() + nodeA := newNodeForTest() + nodeB := newNodeForTest() require.False(t, nodeA.NodeID() == nodeB.NodeID()) require.True(t, nodeA.NodeID() == nodeA.NodeID()) } @@ -51,8 +60,8 @@ func TestNodeDependOn(t *testing.T) { t.Parallel() // Construct a dependency graph: A --> B - nodeA := NewNode() - nodeB := NewNode() + nodeA := newNodeForTest() + nodeB := newNodeForTest() nodeA.DependOn(map[int64]*Node{nodeB.NodeID(): nodeB}, 999) require.Equal(t, nodeA.dependerCount(), 0) @@ -63,20 +72,20 @@ func TestNodeSingleDependency(t *testing.T) { t.Parallel() // Node B depends on A, without any other resolved dependencies. - nodeA := NewNode() - nodeB := NewNode() + nodeA := newNodeForTest() + nodeB := newNodeForTest() nodeB.RandWorkerID = func() workerID { return 100 } nodeB.DependOn(map[int64]*Node{nodeA.NodeID(): nodeA}, 0) - require.True(t, nodeA.assignTo(1)) + require.True(t, nodeA.tryAssignTo(1)) require.Equal(t, workerID(1), nodeA.assignedWorkerID()) require.Equal(t, workerID(1), nodeB.assignedWorkerID()) // Node D depends on C, with some other resolved dependencies. - nodeC := NewNode() - nodeD := NewNode() + nodeC := newNodeForTest() + nodeD := newNodeForTest() nodeD.RandWorkerID = func() workerID { return 100 } nodeD.DependOn(map[int64]*Node{nodeA.NodeID(): nodeC}, 999) - require.True(t, nodeC.assignTo(2)) + require.True(t, nodeC.tryAssignTo(2)) require.Equal(t, workerID(2), nodeC.assignedWorkerID()) nodeC.Remove() require.Equal(t, workerID(100), nodeD.assignedWorkerID()) @@ -90,15 +99,15 @@ func TestNodeMultipleDependencies(t *testing.T) { // C─┤ // └────►B - nodeA := NewNode() - nodeB := NewNode() - nodeC := NewNode() + nodeA := newNodeForTest() + nodeB := newNodeForTest() + nodeC := newNodeForTest() nodeC.DependOn(map[int64]*Node{nodeA.NodeID(): nodeA, nodeB.NodeID(): nodeB}, 999) nodeC.RandWorkerID = func() workerID { return 100 } - require.True(t, nodeA.assignTo(1)) - require.True(t, nodeB.assignTo(2)) + require.True(t, nodeA.tryAssignTo(1)) + require.True(t, nodeB.tryAssignTo(2)) require.Equal(t, unassigned, nodeC.assignedWorkerID()) @@ -111,17 +120,17 @@ func TestNodeResolveImmediately(t *testing.T) { t.Parallel() // Node A depends on 0 unresolved dependencies and some resolved dependencies. - nodeA := NewNode() + nodeA := newNodeForTest() nodeA.RandWorkerID = func() workerID { return workerID(100) } nodeA.DependOn(nil, 999) require.Equal(t, workerID(100), nodeA.assignedWorkerID()) // Node D depends on B and C, all of them are assigned to 1. - nodeB := NewNode() - require.True(t, nodeB.assignTo(1)) - nodeC := NewNode() - require.True(t, nodeC.assignTo(1)) - nodeD := NewNode() + nodeB := newNodeForTest() + require.True(t, nodeB.tryAssignTo(1)) + nodeC := newNodeForTest() + require.True(t, nodeC.tryAssignTo(1)) + nodeD := newNodeForTest() nodeD.RandWorkerID = func() workerID { return workerID(100) } nodeD.DependOn(map[int64]*Node{nodeB.NodeID(): nodeB, nodeC.NodeID(): nodeC}, 0) require.Equal(t, workerID(1), nodeD.assignedWorkerID()) @@ -129,7 +138,7 @@ func TestNodeResolveImmediately(t *testing.T) { // Node E depends on B and C and some other resolved dependencies. nodeB.Remove() nodeC.Remove() - nodeE := NewNode() + nodeE := newNodeForTest() nodeE.RandWorkerID = func() workerID { return workerID(100) } nodeE.DependOn(map[int64]*Node{nodeB.NodeID(): nodeB, nodeC.NodeID(): nodeC}, 999) require.Equal(t, workerID(100), nodeE.assignedWorkerID()) @@ -138,7 +147,7 @@ func TestNodeResolveImmediately(t *testing.T) { func TestNodeDependOnSelf(t *testing.T) { t.Parallel() - nodeA := NewNode() + nodeA := newNodeForTest() require.Panics(t, func() { nodeA.DependOn(map[int64]*Node{nodeA.NodeID(): nodeA}, 999) }) @@ -147,7 +156,9 @@ func TestNodeDependOnSelf(t *testing.T) { func TestNodeDoubleAssigning(t *testing.T) { t.Parallel() - nodeA := NewNode() - require.True(t, nodeA.assignTo(1)) - require.False(t, nodeA.assignTo(2)) + // nodeA := newNodeForTest() + // require.True(t, nodeA.tryAssignTo(1)) + // require.False(t, nodeA.tryAssignTo(2)) + + require.True(t, -1 == assignedToAny) } diff --git a/pkg/causality/tests/driver.go b/pkg/causality/tests/driver.go index e4e9f637888..83f9f3bf934 100644 --- a/pkg/causality/tests/driver.go +++ b/pkg/causality/tests/driver.go @@ -26,7 +26,7 @@ import ( type conflictTestDriver struct { workers []*workerForTest - conflictDetector *causality.ConflictDetector[*workerForTest, *txnForTest] + conflictDetector *causality.ConflictDetector[*txnForTest] generator workloadGenerator pendingCount atomic.Int64 @@ -40,7 +40,11 @@ func newConflictTestDriver( for i := 0; i < numWorkers; i++ { workers = append(workers, newWorkerForTest()) } - detector := causality.NewConflictDetector[*workerForTest, *txnForTest](workers, uint64(numSlots)) + detector := causality.NewConflictDetector[*txnForTest](uint64(numSlots), causality.WorkerOption{ + WorkerCount: numWorkers, + Size: 1024, + IsBlock: true, + }) return &conflictTestDriver{ workers: workers, conflictDetector: detector, diff --git a/pkg/causality/tests/worker.go b/pkg/causality/tests/worker.go index f78f406ba44..fb91657e717 100644 --- a/pkg/causality/tests/worker.go +++ b/pkg/causality/tests/worker.go @@ -64,8 +64,9 @@ func newWorkerForTest() *workerForTest { return ret } -func (w *workerForTest) Add(txn *txnForTest, unlock func()) { +func (w *workerForTest) Add(txn *txnForTest, unlock func()) bool { w.txnQueue.Push(txnWithUnlock{txnForTest: txn, unlock: unlock}) + return true } func (w *workerForTest) Close() { diff --git a/pkg/causality/worker.go b/pkg/causality/worker.go index 738ba889183..3fb9a489cf1 100644 --- a/pkg/causality/worker.go +++ b/pkg/causality/worker.go @@ -13,6 +13,13 @@ package causality +import ( + "sync/atomic" + + "github.com/pingcap/log" + "go.uber.org/zap" +) + type txnEvent interface { // OnConflictResolved is called when the event leaves ConflictDetector. OnConflictResolved() @@ -24,6 +31,90 @@ type txnEvent interface { GenSortedDedupKeysHash(numSlots uint64) []uint64 } +// TxnWithNotifier is a wrapper of txnEvent with a PostTxnExecuted. +type TxnWithNotifier[Txn txnEvent] struct { + TxnEvent Txn + // The PostTxnExecuted will remove the txn related Node in the conflict detector's + // dependency graph and resolve related dependencies for these transacitons + // which depend on this executed txn. + // + // NOTE: the PostTxnExecuted() must be called after the txn executed. + PostTxnExecuted func() +} + +// In current implementation, the conflict detector will push txn to the worker. type worker[Txn txnEvent] interface { - Add(txn Txn, unlock func()) + // add adds a txnEvent to the worker. + add(txn TxnWithNotifier[Txn]) bool + // out returns a channel to receive txnEvents which are ready to be executed. + out() <-chan TxnWithNotifier[Txn] +} + +// WorkerOption is the option for creating a worker. +type WorkerOption struct { + WorkerCount int + Size int + IsBlock bool +} + +func newWorker[Txn txnEvent](opt WorkerOption) worker[Txn] { + log.Info("create new worker in conflict detector", + zap.Int("workerCount", opt.WorkerCount), + zap.Int("size", opt.Size), zap.Bool("isBlock", opt.IsBlock)) + if opt.Size <= 0 { + log.Panic("WorkerOption.Size should be greater than 0, please report a bug") + } + + if opt.IsBlock { + return &boundedWorkerWithBlock[Txn]{ch: make(chan TxnWithNotifier[Txn], opt.Size)} + } + + return &boundedWorker[Txn]{ch: make(chan TxnWithNotifier[Txn], opt.Size)} +} + +// boundedWorker is a worker which has a limit on the number of txns it can hold. +type boundedWorker[Txn txnEvent] struct { + ch chan TxnWithNotifier[Txn] } + +func (w *boundedWorker[Txn]) add(txn TxnWithNotifier[Txn]) bool { + select { + case w.ch <- txn: + return true + default: + return false + } +} + +func (w *boundedWorker[Txn]) out() <-chan TxnWithNotifier[Txn] { + return w.ch +} + +// boundedWorkerWithBlock is a special boundedWorker. Once the worker +// is full, it will block until all cacehed txns are consumed. +type boundedWorkerWithBlock[Txn txnEvent] struct { + ch chan TxnWithNotifier[Txn] + isBlocked atomic.Bool +} + +func (w *boundedWorkerWithBlock[Txn]) add(txn TxnWithNotifier[Txn]) bool { + if w.isBlocked.Load() && len(w.ch) <= 0 { + w.isBlocked.Store(false) + } + + if !w.isBlocked.Load() { + select { + case w.ch <- txn: + return true + default: + w.isBlocked.CompareAndSwap(false, true) + } + } + return false +} + +func (w *boundedWorkerWithBlock[Txn]) out() <-chan TxnWithNotifier[Txn] { + return w.ch +} + +// TODO: maybe we can implement a strategy that can automatically adapt to different scenarios