diff --git a/cdc/sink/dmlsink/txn/txn_dml_sink.go b/cdc/sink/dmlsink/txn/txn_dml_sink.go index ce5ec8b9bdc..616c7f2da7a 100644 --- a/cdc/sink/dmlsink/txn/txn_dml_sink.go +++ b/cdc/sink/dmlsink/txn/txn_dml_sink.go @@ -43,7 +43,7 @@ var _ dmlsink.EventSink[*model.SingleTableTxn] = (*dmlSink)(nil) type dmlSink struct { alive struct { sync.RWMutex - conflictDetector *causality.ConflictDetector[*worker, *txnEvent] + conflictDetector *causality.ConflictDetector[*txnEvent] isDead bool } @@ -107,15 +107,19 @@ func newSink(ctx context.Context, dead: make(chan struct{}), } + sink.alive.conflictDetector = causality.NewConflictDetector[*txnEvent](conflictDetectorSlots, causality.WorkerOption{ + WorkerCount: len(backends), + CacheSize: 1024, + IsBlock: true, + }) + g, ctx1 := errgroup.WithContext(ctx) for i, backend := range backends { w := newWorker(ctx1, changefeedID, i, backend, len(backends)) - g.Go(func() error { return w.run() }) + g.Go(func() error { return w.runLoop(sink.alive.conflictDetector) }) sink.workers = append(sink.workers, w) } - sink.alive.conflictDetector = causality.NewConflictDetector[*worker, *txnEvent](sink.workers, conflictDetectorSlots) - sink.wg.Add(1) go func() { defer sink.wg.Done() @@ -165,9 +169,6 @@ func (s *dmlSink) Close() { } s.wg.Wait() - for _, w := range s.workers { - w.close() - } if s.statistics != nil { s.statistics.Close() } diff --git a/cdc/sink/dmlsink/txn/worker.go b/cdc/sink/dmlsink/txn/worker.go index 014ca2d234a..6bfe08581f3 100644 --- a/cdc/sink/dmlsink/txn/worker.go +++ b/cdc/sink/dmlsink/txn/worker.go @@ -22,23 +22,17 @@ import ( "github.com/pingcap/tiflow/cdc/model" "github.com/pingcap/tiflow/cdc/sink/metrics/txn" "github.com/pingcap/tiflow/cdc/sink/tablesink/state" - "github.com/pingcap/tiflow/pkg/chann" + "github.com/pingcap/tiflow/pkg/causality" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" ) -type txnWithNotifier struct { - *txnEvent - postTxnExecuted func() -} - type worker struct { ctx context.Context changefeed string workerCount int ID int - txnCh *chann.DrainableChann[txnWithNotifier] backend backend // Metrics. @@ -58,13 +52,13 @@ func newWorker(ctx context.Context, changefeedID model.ChangeFeedID, ID int, backend backend, workerCount int, ) *worker { wid := fmt.Sprintf("%d", ID) + return &worker{ ctx: ctx, changefeed: fmt.Sprintf("%s.%s", changefeedID.Namespace, changefeedID.ID), workerCount: workerCount, ID: ID, - txnCh: chann.NewAutoDrainChann[txnWithNotifier](chann.Cap(-1 /*unbounded*/)), backend: backend, metricConflictDetectDuration: txn.ConflictDetectDuration.WithLabelValues(changefeedID.Namespace, changefeedID.ID), @@ -79,21 +73,8 @@ func newWorker(ctx context.Context, changefeedID model.ChangeFeedID, } } -// Add adds a txnEvent to the worker. -// The worker will call postTxnExecuted() after the txn executed. -// The postTxnExecuted will remove the txn related Node in the conflict detector's -// dependency graph and resolve related dependencies for these transacitons -// which depend on this executed txn. -func (w *worker) Add(txn *txnEvent, postTxnExecuted func()) { - w.txnCh.In() <- txnWithNotifier{txn, postTxnExecuted} -} - -func (w *worker) close() { - w.txnCh.CloseAndDrain() -} - -// Continuously get events from txnCh and call backend flush based on conditions. -func (w *worker) run() error { +// Run a loop. +func (w *worker) runLoop(conflictDetector *causality.ConflictDetector[*txnEvent]) error { defer func() { if err := w.backend.Close(); err != nil { log.Info("Transaction dmlSink backend close fail", @@ -107,6 +88,7 @@ func (w *worker) run() error { zap.Int("workerID", w.ID)) start := time.Now() + txnCh := conflictDetector.GetOutChByWorkerID(int64(w.ID)) for { select { case <-w.ctx.Done(): @@ -114,18 +96,18 @@ func (w *worker) run() error { zap.String("changefeedID", w.changefeed), zap.Int("workerID", w.ID)) return nil - case txn := <-w.txnCh.Out(): + case txn := <-txnCh: // we get the data from txnCh.out until no more data here or reach the state that can be flushed. // If no more data in txnCh.out, and also not reach the state that can be flushed, // we will wait for 10ms and then do flush to avoid too much flush with small amount of txns. - if txn.txnEvent != nil { - needFlush := w.onEvent(txn) + if txn.TxnEvent != nil { + needFlush := w.onEvent(txn.TxnEvent, txn.PostTxnExecuted) if !needFlush { delay := time.NewTimer(w.flushInterval) for !needFlush { select { - case txn := <-w.txnCh.Out(): - needFlush = w.onEvent(txn) + case txn := <-txnCh: + needFlush = w.onEvent(txn.TxnEvent, txn.PostTxnExecuted) case <-delay.C: needFlush = true } @@ -157,24 +139,24 @@ func (w *worker) run() error { } // onEvent is called when a new event is received. -// It returns true if it needs flush immediately. -func (w *worker) onEvent(txn txnWithNotifier) bool { +// It returns true if the event is sent to backend. +func (w *worker) onEvent(txn *txnEvent, postTxnExecuted func()) bool { w.hasPending = true - if txn.txnEvent.GetTableSinkState() != state.TableSinkSinking { + if txn.GetTableSinkState() != state.TableSinkSinking { // The table where the event comes from is in stopping, so it's safe // to drop the event directly. - txn.txnEvent.Callback() + txn.Callback() // Still necessary to append the callbacks into the pending list. - w.postTxnExecutedCallbacks = append(w.postTxnExecutedCallbacks, txn.postTxnExecuted) + w.postTxnExecutedCallbacks = append(w.postTxnExecutedCallbacks, postTxnExecuted) return false } w.metricConflictDetectDuration.Observe(txn.conflictResolved.Sub(txn.start).Seconds()) w.metricQueueDuration.Observe(time.Since(txn.start).Seconds()) w.metricTxnWorkerHandledRows.Add(float64(len(txn.Event.Rows))) - w.postTxnExecutedCallbacks = append(w.postTxnExecutedCallbacks, txn.postTxnExecuted) - return w.backend.OnTxnEvent(txn.txnEvent.TxnCallbackableEvent) + w.postTxnExecutedCallbacks = append(w.postTxnExecutedCallbacks, postTxnExecuted) + return w.backend.OnTxnEvent(txn.TxnCallbackableEvent) } // doFlush flushes the backend. diff --git a/pkg/causality/conflict_detector.go b/pkg/causality/conflict_detector.go index f1f7495e77c..1942e7ff19f 100644 --- a/pkg/causality/conflict_detector.go +++ b/pkg/causality/conflict_detector.go @@ -16,17 +16,19 @@ package causality import ( "sync" + "github.com/pingcap/log" "github.com/pingcap/tiflow/pkg/causality/internal" "github.com/pingcap/tiflow/pkg/chann" "go.uber.org/atomic" + "go.uber.org/zap" ) // ConflictDetector implements a logic that dispatches transaction // to different workers in a way that transactions modifying the same // keys are never executed concurrently and have their original orders // preserved. -type ConflictDetector[Worker worker[Txn], Txn txnEvent] struct { - workers []Worker +type ConflictDetector[Txn txnEvent] struct { + workers []workerCache[Txn] // slots are used to find all unfinished transactions // conflicting with an incoming transactions. @@ -38,29 +40,26 @@ type ConflictDetector[Worker worker[Txn], Txn txnEvent] struct { // Used to run a background goroutine to GC or notify nodes. notifiedNodes *chann.DrainableChann[func()] - garbageNodes *chann.DrainableChann[txnFinishedEvent] + garbageNodes *chann.DrainableChann[*internal.Node] wg sync.WaitGroup closeCh chan struct{} } -type txnFinishedEvent struct { - node *internal.Node - sortedKeysHash []uint64 -} - // NewConflictDetector creates a new ConflictDetector. -func NewConflictDetector[Worker worker[Txn], Txn txnEvent]( - workers []Worker, - numSlots uint64, -) *ConflictDetector[Worker, Txn] { - ret := &ConflictDetector[Worker, Txn]{ - workers: workers, +func NewConflictDetector[Txn txnEvent]( + numSlots uint64, opt WorkerOption, +) *ConflictDetector[Txn] { + ret := &ConflictDetector[Txn]{ + workers: make([]workerCache[Txn], opt.WorkerCount), slots: internal.NewSlots[*internal.Node](numSlots), numSlots: numSlots, notifiedNodes: chann.NewAutoDrainChann[func()](), - garbageNodes: chann.NewAutoDrainChann[txnFinishedEvent](), + garbageNodes: chann.NewAutoDrainChann[*internal.Node](), closeCh: make(chan struct{}), } + for i := 0; i < opt.WorkerCount; i++ { + ret.workers[i] = newWorker[Txn](opt) + } ret.wg.Add(1) go func() { @@ -75,12 +74,12 @@ func NewConflictDetector[Worker worker[Txn], Txn txnEvent]( // // NOTE: if multiple threads access this concurrently, // Txn.GenSortedDedupKeysHash must be sorted by the slot index. -func (d *ConflictDetector[Worker, Txn]) Add(txn Txn) { +func (d *ConflictDetector[Txn]) Add(txn Txn) { sortedKeysHash := txn.GenSortedDedupKeysHash(d.numSlots) - node := internal.NewNode() - node.OnResolved = func(workerID int64) { - // This callback is called after the transaction is executed. - postTxnExecuted := func() { + node := internal.NewNode(sortedKeysHash) + txnWithNotifier := TxnWithNotifier[Txn]{ + TxnEvent: txn, + PostTxnExecuted: func() { // After this transaction is executed, we can remove the node from the graph, // and resolve related dependencies for these transacitons which depend on this // executed transaction. @@ -88,23 +87,25 @@ func (d *ConflictDetector[Worker, Txn]) Add(txn Txn) { // Send this node to garbageNodes to GC it from the slots if this node is still // occupied related slots. - d.garbageNodes.In() <- txnFinishedEvent{node, sortedKeysHash} - } - // Send this txn to related worker as soon as all dependencies are resolved. - d.sendToWorker(txn, postTxnExecuted, workerID) + d.garbageNodes.In() <- node + }, + } + node.TrySendToWorker = func(workerID int64) bool { + // Try sending this txn to related worker as soon as all dependencies are resolved. + return d.sendToWorker(txnWithNotifier, workerID) } node.RandWorkerID = func() int64 { return d.nextWorkerID.Add(1) % int64(len(d.workers)) } node.OnNotified = func(callback func()) { d.notifiedNodes.In() <- callback } - d.slots.Add(node, sortedKeysHash) + d.slots.Add(node) } // Close closes the ConflictDetector. -func (d *ConflictDetector[Worker, Txn]) Close() { +func (d *ConflictDetector[Txn]) Close() { close(d.closeCh) d.wg.Wait() } -func (d *ConflictDetector[Worker, Txn]) runBackgroundTasks() { +func (d *ConflictDetector[Txn]) runBackgroundTasks() { defer func() { d.notifiedNodes.CloseAndDrain() d.garbageNodes.CloseAndDrain() @@ -117,20 +118,31 @@ func (d *ConflictDetector[Worker, Txn]) runBackgroundTasks() { if notifyCallback != nil { notifyCallback() } - case event := <-d.garbageNodes.Out(): - if event.node != nil { - d.slots.Free(event.node, event.sortedKeysHash) + case node := <-d.garbageNodes.Out(): + if node != nil { + d.slots.Free(node) } } } } // sendToWorker should not call txn.Callback if it returns an error. -func (d *ConflictDetector[Worker, Txn]) sendToWorker(txn Txn, postTxnExecuted func(), workerID int64) { +func (d *ConflictDetector[Txn]) sendToWorker(txn TxnWithNotifier[Txn], workerID int64) bool { if workerID < 0 { - panic("must assign with a valid workerID") + log.Panic("must assign with a valid workerID", zap.Int64("workerID", workerID)) } - txn.OnConflictResolved() worker := d.workers[workerID] - worker.Add(txn, postTxnExecuted) + ok := worker.add(txn) + if ok { + txn.TxnEvent.OnConflictResolved() + } + return ok +} + +// GetOutChByWorkerID returns the output channel of the worker. +func (d *ConflictDetector[Txn]) GetOutChByWorkerID(workerID int64) <-chan TxnWithNotifier[Txn] { + if workerID < 0 { + log.Panic("must assign with a valid workerID", zap.Int64("workerID", workerID)) + } + return d.workers[workerID].out() } diff --git a/pkg/causality/internal/node.go b/pkg/causality/internal/node.go index 75324918a5b..942c827fb6c 100644 --- a/pkg/causality/internal/node.go +++ b/pkg/causality/internal/node.go @@ -15,10 +15,11 @@ package internal import ( "sync" - stdatomic "sync/atomic" + "sync/atomic" "github.com/google/btree" - "go.uber.org/atomic" + "github.com/pingcap/log" + "go.uber.org/zap" ) type ( @@ -28,11 +29,12 @@ type ( const ( unassigned = workerID(-2) assignedToAny = workerID(-1) + invalidNodeID = int64(-1) ) var ( - nextNodeID = atomic.NewInt64(0) + nextNodeID = atomic.Int64{} // btreeFreeList is a shared free list used by all // btrees in order to lessen the burden of GC. @@ -46,10 +48,11 @@ var ( // in conflict detection. type Node struct { // Immutable fields. - id int64 + id int64 + sortedDedupKeysHash []uint64 // Called when all dependencies are resolved. - OnResolved func(id workerID) + TrySendToWorker func(id workerID) bool // Set the id generator to get a random ID. RandWorkerID func() workerID // Set the callback that the node is notified. @@ -82,10 +85,11 @@ type Node struct { } // NewNode creates a new node. -func NewNode() (ret *Node) { +func NewNode(sortedDedupKeysHash []uint64) (ret *Node) { defer func() { ret.id = genNextNodeID() - ret.OnResolved = nil + ret.sortedDedupKeysHash = sortedDedupKeysHash + ret.TrySendToWorker = nil ret.RandWorkerID = nil ret.totalDependencies = 0 ret.resolvedDependencies = 0 @@ -104,9 +108,14 @@ func (n *Node) NodeID() int64 { return n.id } +// Hashs implements interface internal.SlotNode. +func (n *Node) Hashes() []uint64 { + return n.sortedDedupKeysHash +} + // DependOn implements interface internal.SlotNode. func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) { - resolvedDependencies, removedDependencies := int32(0), int32(0) + resolvedDependencies := int32(0) depend := func(target *Node) { if target == nil { @@ -115,14 +124,14 @@ func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) // with any other nodes. However it's still necessary to track // it because Node.tryResolve needs to counting the number of // resolved dependencies. - resolvedDependencies = stdatomic.AddInt32(&n.resolvedDependencies, 1) - stdatomic.StoreInt64(&n.resolvedList[resolvedDependencies-1], assignedToAny) - removedDependencies = stdatomic.AddInt32(&n.removedDependencies, 1) + resolvedDependencies = atomic.AddInt32(&n.resolvedDependencies, 1) + atomic.StoreInt64(&n.resolvedList[resolvedDependencies-1], assignedToAny) + atomic.AddInt32(&n.removedDependencies, 1) return } if target.id == n.id { - panic("you cannot depend on yourself") + log.Panic("node cannot depend on itself") } // The target node might be removed or modified in other places, for example @@ -134,17 +143,17 @@ func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) // The target has already been assigned to a worker. // In this case, record the worker ID in `resolvedList`, and this node // probably can be sent to the same worker and executed sequentially. - resolvedDependencies = stdatomic.AddInt32(&n.resolvedDependencies, 1) - stdatomic.StoreInt64(&n.resolvedList[resolvedDependencies-1], target.assignedTo) + resolvedDependencies = atomic.AddInt32(&n.resolvedDependencies, 1) + atomic.StoreInt64(&n.resolvedList[resolvedDependencies-1], target.assignedTo) } // Add the node to the target's dependers if the target has not been removed. if target.removed { // The target has already been removed. - removedDependencies = stdatomic.AddInt32(&n.removedDependencies, 1) + atomic.AddInt32(&n.removedDependencies, 1) } else if _, exist := target.getOrCreateDependers().ReplaceOrInsert(n); exist { // Should never depend on a target redundantly. - panic("should never exist") + log.Panic("should never exist") } } @@ -153,7 +162,7 @@ func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) // ?: why gen new ID here? n.id = genNextNodeID() - // `totalDependcies` and `resolvedList` must be initialized before depending on any targets. + // `totalDependencies` and `resolvedList` must be initialized before depending on any targets. n.totalDependencies = int32(len(dependencyNodes) + noDependencyKeyCnt) n.resolvedList = make([]int64, 0, n.totalDependencies) for i := 0; i < int(n.totalDependencies); i++ { @@ -167,7 +176,7 @@ func (n *Node) DependOn(dependencyNodes map[int64]*Node, noDependencyKeyCnt int) depend(nil) } - n.maybeResolve(resolvedDependencies, removedDependencies) + n.maybeResolve() } // Remove implements interface internal.SlotNode. @@ -179,8 +188,8 @@ func (n *Node) Remove() { if n.dependers != nil { // `mu` must be holded during accessing dependers. n.dependers.Ascend(func(node *Node) bool { - removedDependencies := stdatomic.AddInt32(&node.removedDependencies, 1) - node.maybeResolve(0, removedDependencies) + atomic.AddInt32(&node.removedDependencies, 1) + node.OnNotified(node.maybeResolve) return true }) n.dependers.Clear(true) @@ -195,12 +204,12 @@ func (n *Node) Free() { n.mu.Lock() defer n.mu.Unlock() if n.id == invalidNodeID { - panic("double free") + log.Panic("double free") } n.id = invalidNodeID - n.OnResolved = nil - n.RandWorkerID = nil + n.TrySendToWorker = nil + // n.RandWorkerID = nil // TODO: reuse node if necessary. Currently it's impossible if async-notify is used. // The reason is a node can step functions `assignTo`, `Remove`, `Free`, then `assignTo`. @@ -208,28 +217,31 @@ func (n *Node) Free() { // or not. } -// assignTo assigns a node to a worker. Returns `true` on success. -func (n *Node) assignTo(workerID int64) bool { +// tryAssignTo assigns a node to a worker. Returns `true` on success. +func (n *Node) tryAssignTo(workerID int64) bool { n.mu.Lock() defer n.mu.Unlock() if n.assignedTo != unassigned { // Already resolved by some other guys. - return false + return true } - n.assignedTo = workerID - if n.OnResolved != nil { - n.OnResolved(workerID) - n.OnResolved = nil + if n.TrySendToWorker != nil { + ok := n.TrySendToWorker(workerID) + if !ok { + return false + } + n.TrySendToWorker = nil } + n.assignedTo = workerID if n.dependers != nil { // `mu` must be holded during accessing dependers. n.dependers.Ascend(func(node *Node) bool { - resolvedDependencies := stdatomic.AddInt32(&node.resolvedDependencies, 1) - stdatomic.StoreInt64(&node.resolvedList[resolvedDependencies-1], n.assignedTo) - node.maybeResolve(resolvedDependencies, 0) + resolvedDependencies := atomic.AddInt32(&node.resolvedDependencies, 1) + atomic.StoreInt64(&node.resolvedList[resolvedDependencies-1], n.assignedTo) + node.OnNotified(node.maybeResolve) return true }) } @@ -237,17 +249,21 @@ func (n *Node) assignTo(workerID int64) bool { return true } -func (n *Node) maybeResolve(resolvedDependencies, removedDependencies int32) { - if workerNum, ok := n.tryResolve(resolvedDependencies, removedDependencies); ok { - if workerNum < 0 { - panic("Node.tryResolve must return a valid worker ID") +func (n *Node) maybeResolve() { + if workerID, ok := n.tryResolve(); ok { + if workerID == unassigned { + log.Panic("invalid worker ID", zap.Uint64("workerID", uint64(workerID))) } - if n.OnNotified != nil { - // Notify the conflict detector background worker to assign the node to the worker asynchronously. - n.OnNotified(func() { n.assignTo(workerNum) }) - } else { - // Assign the node to the worker directly. - n.assignTo(workerNum) + + if workerID >= 0 { + n.tryAssignTo(workerID) + return + } + + workerID := n.RandWorkerID() + if !n.tryAssignTo(workerID) { + // If the worker is full, we need to try to assign to another worker. + n.OnNotified(n.maybeResolve) } } } @@ -256,39 +272,38 @@ func (n *Node) maybeResolve(resolvedDependencies, removedDependencies int32) { // Returns (_, false) if there is a conflict, // returns (rand, true) if there is no conflict, // returns (N, true) if only worker N can be used. -func (n *Node) tryResolve(resolvedDependencies, removedDependencies int32) (int64, bool) { - assignedTo, resolved := n.doResolve(resolvedDependencies, removedDependencies) - if resolved && assignedTo == assignedToAny { - assignedTo = n.RandWorkerID() - } - return assignedTo, resolved -} - -func (n *Node) doResolve(resolvedDependencies, removedDependencies int32) (int64, bool) { +func (n *Node) tryResolve() (int64, bool) { if n.totalDependencies == 0 { // No conflicts, can select any workers. return assignedToAny, true } + removedDependencies := atomic.LoadInt32(&n.removedDependencies) + if removedDependencies == n.totalDependencies { + // All dependcies are removed, so assign the node to any worker is fine. + return assignedToAny, true + } + + resolvedDependencies := atomic.LoadInt32(&n.resolvedDependencies) if resolvedDependencies == n.totalDependencies { - firstDep := stdatomic.LoadInt64(&n.resolvedList[0]) + firstDep := atomic.LoadInt64(&n.resolvedList[0]) hasDiffDep := false for i := 1; i < int(n.totalDependencies); i++ { - curr := stdatomic.LoadInt64(&n.resolvedList[i]) - // // Todo: simplify assign to logic, only resolve dependencies nodes after - // // corresponding transactions are executed. - // // - // // In DependOn, depend(nil) set resolvedList[i] to assignedToAny - // // for these no dependecy keys. - // if curr == assignedToAny { - // continue - // } + curr := atomic.LoadInt64(&n.resolvedList[i]) + // Todo: simplify assign to logic, only resolve dependencies nodes after + // corresponding transactions are executed. + // + // In DependOn, depend(nil) set resolvedList[i] to assignedToAny + // for these no dependecy keys. + if curr == assignedToAny { + continue + } if firstDep != curr { hasDiffDep = true break } } - if !hasDiffDep { + if !hasDiffDep && firstDep != unassigned { // If all dependency nodes are assigned to the same worker, we can assign // this node to the same worker directly, and they will execute sequentially. // On the other hand, if dependency nodes are assigned to different workers, @@ -298,11 +313,6 @@ func (n *Node) doResolve(resolvedDependencies, removedDependencies int32) (int64 } } - // All dependcies are removed, so assign the node to any worker is fine. - if removedDependencies == n.totalDependencies { - return assignedToAny, true - } - return unassigned, false } diff --git a/pkg/causality/internal/node_test.go b/pkg/causality/internal/node_test.go index 776f5359550..ad9c5f6d47d 100644 --- a/pkg/causality/internal/node_test.go +++ b/pkg/causality/internal/node_test.go @@ -21,15 +21,24 @@ import ( var _ SlotNode[*Node] = &Node{} // Asserts that *Node implements SlotNode[*Node]. +func newNodeForTest() *Node { + node := NewNode(nil) + node.OnNotified = func(callback func()) { + // run the callback immediately + callback() + } + return node +} + func TestNodeFree(t *testing.T) { // This case should not be run parallel to // others, for fear that the use-after-free - // will race with NewNode() in other cases. + // will race with newNodeForTest() in other cases. - nodeA := NewNode() + nodeA := newNodeForTest() nodeA.Free() - nodeA = NewNode() + nodeA = newNodeForTest() nodeA.Free() // Double freeing should panic. @@ -41,8 +50,8 @@ func TestNodeFree(t *testing.T) { func TestNodeEquals(t *testing.T) { t.Parallel() - nodeA := NewNode() - nodeB := NewNode() + nodeA := newNodeForTest() + nodeB := newNodeForTest() require.False(t, nodeA.NodeID() == nodeB.NodeID()) require.True(t, nodeA.NodeID() == nodeA.NodeID()) } @@ -51,8 +60,8 @@ func TestNodeDependOn(t *testing.T) { t.Parallel() // Construct a dependency graph: A --> B - nodeA := NewNode() - nodeB := NewNode() + nodeA := newNodeForTest() + nodeB := newNodeForTest() nodeA.DependOn(map[int64]*Node{nodeB.NodeID(): nodeB}, 999) require.Equal(t, nodeA.dependerCount(), 0) @@ -63,20 +72,20 @@ func TestNodeSingleDependency(t *testing.T) { t.Parallel() // Node B depends on A, without any other resolved dependencies. - nodeA := NewNode() - nodeB := NewNode() + nodeA := newNodeForTest() + nodeB := newNodeForTest() nodeB.RandWorkerID = func() workerID { return 100 } nodeB.DependOn(map[int64]*Node{nodeA.NodeID(): nodeA}, 0) - require.True(t, nodeA.assignTo(1)) + require.True(t, nodeA.tryAssignTo(1)) require.Equal(t, workerID(1), nodeA.assignedWorkerID()) require.Equal(t, workerID(1), nodeB.assignedWorkerID()) // Node D depends on C, with some other resolved dependencies. - nodeC := NewNode() - nodeD := NewNode() + nodeC := newNodeForTest() + nodeD := newNodeForTest() nodeD.RandWorkerID = func() workerID { return 100 } nodeD.DependOn(map[int64]*Node{nodeA.NodeID(): nodeC}, 999) - require.True(t, nodeC.assignTo(2)) + require.True(t, nodeC.tryAssignTo(2)) require.Equal(t, workerID(2), nodeC.assignedWorkerID()) nodeC.Remove() require.Equal(t, workerID(100), nodeD.assignedWorkerID()) @@ -90,15 +99,15 @@ func TestNodeMultipleDependencies(t *testing.T) { // C─┤ // └────►B - nodeA := NewNode() - nodeB := NewNode() - nodeC := NewNode() + nodeA := newNodeForTest() + nodeB := newNodeForTest() + nodeC := newNodeForTest() nodeC.DependOn(map[int64]*Node{nodeA.NodeID(): nodeA, nodeB.NodeID(): nodeB}, 999) nodeC.RandWorkerID = func() workerID { return 100 } - require.True(t, nodeA.assignTo(1)) - require.True(t, nodeB.assignTo(2)) + require.True(t, nodeA.tryAssignTo(1)) + require.True(t, nodeB.tryAssignTo(2)) require.Equal(t, unassigned, nodeC.assignedWorkerID()) @@ -111,17 +120,17 @@ func TestNodeResolveImmediately(t *testing.T) { t.Parallel() // Node A depends on 0 unresolved dependencies and some resolved dependencies. - nodeA := NewNode() + nodeA := newNodeForTest() nodeA.RandWorkerID = func() workerID { return workerID(100) } nodeA.DependOn(nil, 999) require.Equal(t, workerID(100), nodeA.assignedWorkerID()) // Node D depends on B and C, all of them are assigned to 1. - nodeB := NewNode() - require.True(t, nodeB.assignTo(1)) - nodeC := NewNode() - require.True(t, nodeC.assignTo(1)) - nodeD := NewNode() + nodeB := newNodeForTest() + require.True(t, nodeB.tryAssignTo(1)) + nodeC := newNodeForTest() + require.True(t, nodeC.tryAssignTo(1)) + nodeD := newNodeForTest() nodeD.RandWorkerID = func() workerID { return workerID(100) } nodeD.DependOn(map[int64]*Node{nodeB.NodeID(): nodeB, nodeC.NodeID(): nodeC}, 0) require.Equal(t, workerID(1), nodeD.assignedWorkerID()) @@ -129,7 +138,7 @@ func TestNodeResolveImmediately(t *testing.T) { // Node E depends on B and C and some other resolved dependencies. nodeB.Remove() nodeC.Remove() - nodeE := NewNode() + nodeE := newNodeForTest() nodeE.RandWorkerID = func() workerID { return workerID(100) } nodeE.DependOn(map[int64]*Node{nodeB.NodeID(): nodeB, nodeC.NodeID(): nodeC}, 999) require.Equal(t, workerID(100), nodeE.assignedWorkerID()) @@ -138,7 +147,7 @@ func TestNodeResolveImmediately(t *testing.T) { func TestNodeDependOnSelf(t *testing.T) { t.Parallel() - nodeA := NewNode() + nodeA := newNodeForTest() require.Panics(t, func() { nodeA.DependOn(map[int64]*Node{nodeA.NodeID(): nodeA}, 999) }) @@ -147,7 +156,9 @@ func TestNodeDependOnSelf(t *testing.T) { func TestNodeDoubleAssigning(t *testing.T) { t.Parallel() - nodeA := NewNode() - require.True(t, nodeA.assignTo(1)) - require.False(t, nodeA.assignTo(2)) + // nodeA := newNodeForTest() + // require.True(t, nodeA.tryAssignTo(1)) + // require.False(t, nodeA.tryAssignTo(2)) + + require.True(t, -1 == assignedToAny) } diff --git a/pkg/causality/internal/slots.go b/pkg/causality/internal/slots.go index d6d131e02c8..23ec22bdcf2 100644 --- a/pkg/causality/internal/slots.go +++ b/pkg/causality/internal/slots.go @@ -27,6 +27,8 @@ type slot[E SlotNode[E]] struct { type SlotNode[T any] interface { // NodeID tells the node's ID. NodeID() int64 + // Hashs returns the sorted and deduped hashes of the node. + Hashes() []uint64 // Construct a dependency on `others`. DependOn(dependencyNodes map[int64]T, noDependencyKeyCnt int) // Remove the node itself and notify all dependers. @@ -56,7 +58,8 @@ func NewSlots[E SlotNode[E]](numSlots uint64) *Slots[E] { } // Add adds an elem to the slots and calls DependOn for elem. -func (s *Slots[E]) Add(elem E, hashes []uint64) { +func (s *Slots[E]) Add(elem E) { + hashes := elem.Hashes() dependencyNodes := make(map[int64]E, len(hashes)) noDependecyCnt := 0 @@ -101,7 +104,8 @@ func (s *Slots[E]) Add(elem E, hashes []uint64) { } // Free removes an element from the Slots. -func (s *Slots[E]) Free(elem E, hashes []uint64) { +func (s *Slots[E]) Free(elem E) { + hashes := elem.Hashes() for _, hash := range hashes { slotIdx := getSlot(hash, s.numSlots) s.slots[slotIdx].mu.Lock() diff --git a/pkg/causality/internal/slots_test.go b/pkg/causality/internal/slots_test.go index 42983909b25..025cf19dc7f 100644 --- a/pkg/causality/internal/slots_test.go +++ b/pkg/causality/internal/slots_test.go @@ -30,14 +30,14 @@ func TestSlotsTrivial(t *testing.T) { nodes := make([]*Node, 0, 1000) for i := 0; i < count; i++ { - node := NewNode() + node := NewNode([]uint64{1, 2, 3, 4, 5}) node.RandWorkerID = func() workerID { return 100 } - slots.Add(node, []uint64{1, 2, 3, 4, 5}) + slots.Add(node) nodes = append(nodes, node) } for i := 0; i < count; i++ { - slots.Free(nodes[i], []uint64{1, 2, 3, 4, 5}) + slots.Free(nodes[i]) } require.Equal(t, 0, len(slots.slots[1].nodes)) @@ -55,7 +55,7 @@ func TestSlotsConcurrentOps(t *testing.T) { freeNodeChan := make(chan *Node, N) inuseNodeChan := make(chan *Node, N) newNode := func() *Node { - node := NewNode() + node := NewNode([]uint64{1, 9, 17, 25, 33}) node.RandWorkerID = func() workerID { return 100 } return node } @@ -77,7 +77,7 @@ func TestSlotsConcurrentOps(t *testing.T) { return case node := <-freeNodeChan: // keys belong to the same slot after hash, since slot num is 8 - slots.Add(node, []uint64{1, 9, 17, 25, 33}) + slots.Add(node) inuseNodeChan <- node } } @@ -92,7 +92,7 @@ func TestSlotsConcurrentOps(t *testing.T) { return case node := <-inuseNodeChan: // keys belong to the same slot after hash, since slot num is 8 - slots.Free(node, []uint64{1, 9, 17, 25, 33}) + slots.Free(node) freeNodeChan <- newNode() } } diff --git a/pkg/causality/tests/driver.go b/pkg/causality/tests/driver.go index e4e9f637888..2d25b9cab5a 100644 --- a/pkg/causality/tests/driver.go +++ b/pkg/causality/tests/driver.go @@ -26,7 +26,7 @@ import ( type conflictTestDriver struct { workers []*workerForTest - conflictDetector *causality.ConflictDetector[*workerForTest, *txnForTest] + conflictDetector *causality.ConflictDetector[*txnForTest] generator workloadGenerator pendingCount atomic.Int64 @@ -40,7 +40,11 @@ func newConflictTestDriver( for i := 0; i < numWorkers; i++ { workers = append(workers, newWorkerForTest()) } - detector := causality.NewConflictDetector[*workerForTest, *txnForTest](workers, uint64(numSlots)) + detector := causality.NewConflictDetector[*txnForTest](uint64(numSlots), causality.WorkerOption{ + WorkerCount: numWorkers, + CacheSize: 1024, + IsBlock: true, + }) return &conflictTestDriver{ workers: workers, conflictDetector: detector, diff --git a/pkg/causality/tests/worker.go b/pkg/causality/tests/worker.go index f78f406ba44..fb91657e717 100644 --- a/pkg/causality/tests/worker.go +++ b/pkg/causality/tests/worker.go @@ -64,8 +64,9 @@ func newWorkerForTest() *workerForTest { return ret } -func (w *workerForTest) Add(txn *txnForTest, unlock func()) { +func (w *workerForTest) Add(txn *txnForTest, unlock func()) bool { w.txnQueue.Push(txnWithUnlock{txnForTest: txn, unlock: unlock}) + return true } func (w *workerForTest) Close() { diff --git a/pkg/causality/worker.go b/pkg/causality/worker.go index 738ba889183..d49ac3b4ac7 100644 --- a/pkg/causality/worker.go +++ b/pkg/causality/worker.go @@ -13,6 +13,13 @@ package causality +import ( + "sync/atomic" + + "github.com/pingcap/log" + "go.uber.org/zap" +) + type txnEvent interface { // OnConflictResolved is called when the event leaves ConflictDetector. OnConflictResolved() @@ -24,6 +31,91 @@ type txnEvent interface { GenSortedDedupKeysHash(numSlots uint64) []uint64 } -type worker[Txn txnEvent] interface { - Add(txn Txn, unlock func()) +// TxnWithNotifier is a wrapper of txnEvent with a PostTxnExecuted. +type TxnWithNotifier[Txn txnEvent] struct { + TxnEvent Txn + // The PostTxnExecuted will remove the txn related Node in the conflict detector's + // dependency graph and resolve related dependencies for these transacitons + // which depend on this executed txn. + // + // NOTE: the PostTxnExecuted() must be called after the txn executed. + PostTxnExecuted func() +} + +// WorkerOption is the option for creating a worker. +type WorkerOption struct { + WorkerCount int + // CacheSize controls the max number of txns a worker can hold. + CacheSize int + // IsBlock indicates whether the worker should block when the cache is full. + IsBlock bool +} + +// In current implementation, the conflict detector will push txn to the workerCache. +type workerCache[Txn txnEvent] interface { + // add adds a event to the workerCache. + add(txn TxnWithNotifier[Txn]) bool + // out returns a channel to receive events which are ready to be executed. + out() <-chan TxnWithNotifier[Txn] +} + +func newWorker[Txn txnEvent](opt WorkerOption) workerCache[Txn] { + log.Info("create new worker cache in conflict detector", + zap.Int("workerCount", opt.WorkerCount), + zap.Int("cacheSize", opt.CacheSize), zap.Bool("isBlock", opt.IsBlock)) + if opt.CacheSize <= 0 { + log.Panic("WorkerOption.CacheSize should be greater than 0, please report a bug") + } + + if opt.IsBlock { + return &boundedWorkerWithBlock[Txn]{ch: make(chan TxnWithNotifier[Txn], opt.CacheSize)} + } + return &boundedWorker[Txn]{ch: make(chan TxnWithNotifier[Txn], opt.CacheSize)} +} + +// boundedWorker is a worker which has a limit on the number of txns it can hold. +type boundedWorker[Txn txnEvent] struct { + ch chan TxnWithNotifier[Txn] } + +func (w *boundedWorker[Txn]) add(txn TxnWithNotifier[Txn]) bool { + select { + case w.ch <- txn: + return true + default: + return false + } +} + +func (w *boundedWorker[Txn]) out() <-chan TxnWithNotifier[Txn] { + return w.ch +} + +// boundedWorkerWithBlock is a special boundedWorker. Once the worker +// is full, it will block until all cached txns are consumed. +type boundedWorkerWithBlock[Txn txnEvent] struct { + ch chan TxnWithNotifier[Txn] + isBlocked atomic.Bool +} + +func (w *boundedWorkerWithBlock[Txn]) add(txn TxnWithNotifier[Txn]) bool { + if w.isBlocked.Load() && len(w.ch) <= 0 { + w.isBlocked.Store(false) + } + + if !w.isBlocked.Load() { + select { + case w.ch <- txn: + return true + default: + w.isBlocked.CompareAndSwap(false, true) + } + } + return false +} + +func (w *boundedWorkerWithBlock[Txn]) out() <-chan TxnWithNotifier[Txn] { + return w.ch +} + +// TODO: maybe we can implement a strategy that can automatically adapt to different scenarios