diff --git a/cdc/sink/flowcontrol/flow_control.go b/cdc/sink/flowcontrol/flow_control.go index 467fd796972..c6099fa6557 100644 --- a/cdc/sink/flowcontrol/flow_control.go +++ b/cdc/sink/flowcontrol/flow_control.go @@ -20,144 +20,36 @@ import ( "github.com/edwingeng/deque" "github.com/pingcap/errors" "github.com/pingcap/log" - cerrors "github.com/pingcap/tiflow/pkg/errors" "go.uber.org/zap" ) -// TableMemoryQuota is designed to curb the total memory consumption of processing -// the event streams in a table. -// A higher-level controller more suitable for direct use by the processor is TableFlowController. -type TableMemoryQuota struct { - Quota uint64 // should not be changed once intialized - - IsAborted uint32 - - mu sync.Mutex - Consumed uint64 - - cond *sync.Cond -} - -// NewTableMemoryQuota creates a new TableMemoryQuota -// quota: max advised memory consumption in bytes. -func NewTableMemoryQuota(quota uint64) *TableMemoryQuota { - ret := &TableMemoryQuota{ - Quota: quota, - mu: sync.Mutex{}, - Consumed: 0, - } - - ret.cond = sync.NewCond(&ret.mu) - return ret -} - -// ConsumeWithBlocking is called when a hard-limit is needed. The method will -// block until enough memory has been freed up by Release. -// blockCallBack will be called if the function will block. -// Should be used with care to prevent deadlock. -func (c *TableMemoryQuota) ConsumeWithBlocking(nBytes uint64, blockCallBack func() error) error { - if nBytes >= c.Quota { - return cerrors.ErrFlowControllerEventLargerThanQuota.GenWithStackByArgs(nBytes, c.Quota) - } - - c.mu.Lock() - if c.Consumed+nBytes >= c.Quota { - c.mu.Unlock() - err := blockCallBack() - if err != nil { - return errors.Trace(err) - } - } else { - c.mu.Unlock() - } - - c.mu.Lock() - defer c.mu.Unlock() - - for { - if atomic.LoadUint32(&c.IsAborted) == 1 { - return cerrors.ErrFlowControllerAborted.GenWithStackByArgs() - } - - if c.Consumed+nBytes < c.Quota { - break - } - c.cond.Wait() - } - - c.Consumed += nBytes - return nil -} - -// ForceConsume is called when blocking is not acceptable and the limit can be violated -// for the sake of avoid deadlock. It merely records the increased memory consumption. -func (c *TableMemoryQuota) ForceConsume(nBytes uint64) error { - c.mu.Lock() - defer c.mu.Unlock() - - if atomic.LoadUint32(&c.IsAborted) == 1 { - return cerrors.ErrFlowControllerAborted.GenWithStackByArgs() - } - - c.Consumed += nBytes - return nil -} - -// Release is called when a chuck of memory is done being used. -func (c *TableMemoryQuota) Release(nBytes uint64) { - c.mu.Lock() - - if c.Consumed < nBytes { - c.mu.Unlock() - log.Panic("TableMemoryQuota: releasing more than consumed, report a bug", - zap.Uint64("consumed", c.Consumed), - zap.Uint64("released", nBytes)) - } - - c.Consumed -= nBytes - if c.Consumed < c.Quota { - c.mu.Unlock() - c.cond.Signal() - return - } - - c.mu.Unlock() -} - -// Abort interrupts any ongoing ConsumeWithBlocking call -func (c *TableMemoryQuota) Abort() { - atomic.StoreUint32(&c.IsAborted, 1) - c.cond.Signal() -} - -// GetConsumption returns the current memory consumption -func (c *TableMemoryQuota) GetConsumption() uint64 { - c.mu.Lock() - defer c.mu.Unlock() - - return c.Consumed -} - // TableFlowController provides a convenient interface to control the memory consumption of a per table event stream type TableFlowController struct { - memoryQuota *TableMemoryQuota + memoryQuota *tableMemoryQuota - mu sync.Mutex - queue deque.Deque + queueMu struct { + sync.Mutex + queue deque.Deque + } lastCommitTs uint64 } type commitTsSizeEntry struct { - CommitTs uint64 - Size uint64 + commitTs uint64 + size uint64 } // NewTableFlowController creates a new TableFlowController func NewTableFlowController(quota uint64) *TableFlowController { return &TableFlowController{ - memoryQuota: NewTableMemoryQuota(quota), - queue: deque.NewDeque(), + memoryQuota: newTableMemoryQuota(quota), + queueMu: struct { + sync.Mutex + queue deque.Deque + }{ + queue: deque.NewDeque(), + }, } } @@ -174,27 +66,27 @@ func (c *TableFlowController) Consume(commitTs uint64, size uint64, blockCallBac if commitTs > lastCommitTs { atomic.StoreUint64(&c.lastCommitTs, commitTs) - err := c.memoryQuota.ConsumeWithBlocking(size, blockCallBack) + err := c.memoryQuota.consumeWithBlocking(size, blockCallBack) if err != nil { return errors.Trace(err) } } else { // Here commitTs == lastCommitTs, which means that we are not crossing - // a transaction boundary. In this situation, we use `ForceConsume` because + // a transaction boundary. In this situation, we use `forceConsume` because // blocking the event stream mid-transaction is highly likely to cause // a deadlock. // TODO fix this in the future, after we figure out how to elegantly support large txns. - err := c.memoryQuota.ForceConsume(size) + err := c.memoryQuota.forceConsume(size) if err != nil { return errors.Trace(err) } } - c.mu.Lock() - defer c.mu.Unlock() - c.queue.PushBack(&commitTsSizeEntry{ - CommitTs: commitTs, - Size: size, + c.queueMu.Lock() + defer c.queueMu.Unlock() + c.queueMu.queue.PushBack(&commitTsSizeEntry{ + commitTs: commitTs, + size: size, }) return nil @@ -204,26 +96,26 @@ func (c *TableFlowController) Consume(commitTs uint64, size uint64, blockCallBac func (c *TableFlowController) Release(resolvedTs uint64) { var nBytesToRelease uint64 - c.mu.Lock() - for c.queue.Len() > 0 { - if peeked := c.queue.Front().(*commitTsSizeEntry); peeked.CommitTs <= resolvedTs { - nBytesToRelease += peeked.Size - c.queue.PopFront() + c.queueMu.Lock() + for c.queueMu.queue.Len() > 0 { + if peeked := c.queueMu.queue.Front().(*commitTsSizeEntry); peeked.commitTs <= resolvedTs { + nBytesToRelease += peeked.size + c.queueMu.queue.PopFront() } else { break } } - c.mu.Unlock() + c.queueMu.Unlock() - c.memoryQuota.Release(nBytesToRelease) + c.memoryQuota.release(nBytesToRelease) } // Abort interrupts any ongoing Consume call func (c *TableFlowController) Abort() { - c.memoryQuota.Abort() + c.memoryQuota.abort() } // GetConsumption returns the current memory consumption func (c *TableFlowController) GetConsumption() uint64 { - return c.memoryQuota.GetConsumption() + return c.memoryQuota.getConsumption() } diff --git a/cdc/sink/flowcontrol/flow_control_test.go b/cdc/sink/flowcontrol/flow_control_test.go index d6ac5ddf04e..24f639fdf8a 100644 --- a/cdc/sink/flowcontrol/flow_control_test.go +++ b/cdc/sink/flowcontrol/flow_control_test.go @@ -21,15 +21,10 @@ import ( "testing" "time" - "github.com/pingcap/check" - "github.com/pingcap/tiflow/pkg/util/testleak" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) -type flowControlSuite struct{} - -var _ = check.Suite(&flowControlSuite{}) - func dummyCallBack() error { return nil } @@ -44,10 +39,10 @@ func (c *mockCallBacker) cb() error { return c.injectedErr } -func (s *flowControlSuite) TestMemoryQuotaBasic(c *check.C) { - defer testleak.AfterTest(c)() +func TestMemoryQuotaBasic(t *testing.T) { + t.Parallel() - controller := NewTableMemoryQuota(1024) + controller := newTableMemoryQuota(1024) sizeCh := make(chan uint64, 1024) var ( wg sync.WaitGroup @@ -60,10 +55,10 @@ func (s *flowControlSuite) TestMemoryQuotaBasic(c *check.C) { for i := 0; i < 100000; i++ { size := (rand.Int() % 128) + 128 - err := controller.ConsumeWithBlocking(uint64(size), dummyCallBack) - c.Assert(err, check.IsNil) + err := controller.consumeWithBlocking(uint64(size), dummyCallBack) + require.Nil(t, err) - c.Assert(atomic.AddUint64(&consumed, uint64(size)), check.Less, uint64(1024)) + require.Less(t, atomic.AddUint64(&consumed, uint64(size)), uint64(1024)) sizeCh <- uint64(size) } @@ -75,21 +70,21 @@ func (s *flowControlSuite) TestMemoryQuotaBasic(c *check.C) { defer wg.Done() for size := range sizeCh { - c.Assert(atomic.LoadUint64(&consumed), check.GreaterEqual, size) + require.GreaterOrEqual(t, atomic.LoadUint64(&consumed), size) atomic.AddUint64(&consumed, -size) - controller.Release(size) + controller.release(size) } }() wg.Wait() - c.Assert(atomic.LoadUint64(&consumed), check.Equals, uint64(0)) - c.Assert(controller.GetConsumption(), check.Equals, uint64(0)) + require.Equal(t, uint64(0), atomic.LoadUint64(&consumed)) + require.Equal(t, uint64(0), controller.getConsumption()) } -func (s *flowControlSuite) TestMemoryQuotaForceConsume(c *check.C) { - defer testleak.AfterTest(c)() +func TestMemoryQuotaForceConsume(t *testing.T) { + t.Parallel() - controller := NewTableMemoryQuota(1024) + controller := newTableMemoryQuota(1024) sizeCh := make(chan uint64, 1024) var ( wg sync.WaitGroup @@ -104,12 +99,12 @@ func (s *flowControlSuite) TestMemoryQuotaForceConsume(c *check.C) { size := (rand.Int() % 128) + 128 if rand.Int()%3 == 0 { - err := controller.ConsumeWithBlocking(uint64(size), dummyCallBack) - c.Assert(err, check.IsNil) - c.Assert(atomic.AddUint64(&consumed, uint64(size)), check.Less, uint64(1024)) + err := controller.consumeWithBlocking(uint64(size), dummyCallBack) + require.Nil(t, err) + require.Less(t, atomic.AddUint64(&consumed, uint64(size)), uint64(1024)) } else { - err := controller.ForceConsume(uint64(size)) - c.Assert(err, check.IsNil) + err := controller.forceConsume(uint64(size)) + require.Nil(t, err) atomic.AddUint64(&consumed, uint64(size)) } sizeCh <- uint64(size) @@ -123,47 +118,47 @@ func (s *flowControlSuite) TestMemoryQuotaForceConsume(c *check.C) { defer wg.Done() for size := range sizeCh { - c.Assert(atomic.LoadUint64(&consumed), check.GreaterEqual, size) + require.GreaterOrEqual(t, atomic.LoadUint64(&consumed), size) atomic.AddUint64(&consumed, -size) - controller.Release(size) + controller.release(size) } }() wg.Wait() - c.Assert(atomic.LoadUint64(&consumed), check.Equals, uint64(0)) + require.Equal(t, uint64(0), atomic.LoadUint64(&consumed)) } -// TestMemoryQuotaAbort verifies that Abort works -func (s *flowControlSuite) TestMemoryQuotaAbort(c *check.C) { - defer testleak.AfterTest(c)() +// TestMemoryQuotaAbort verifies that abort works +func TestMemoryQuotaAbort(t *testing.T) { + t.Parallel() - controller := NewTableMemoryQuota(1024) + controller := newTableMemoryQuota(1024) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - err := controller.ConsumeWithBlocking(700, dummyCallBack) - c.Assert(err, check.IsNil) + err := controller.consumeWithBlocking(700, dummyCallBack) + require.Nil(t, err) - err = controller.ConsumeWithBlocking(700, dummyCallBack) - c.Assert(err, check.ErrorMatches, ".*ErrFlowControllerAborted.*") + err = controller.consumeWithBlocking(700, dummyCallBack) + require.Regexp(t, ".*ErrFlowControllerAborted.*", err) - err = controller.ForceConsume(700) - c.Assert(err, check.ErrorMatches, ".*ErrFlowControllerAborted.*") + err = controller.forceConsume(700) + require.Regexp(t, ".*ErrFlowControllerAborted.*", err) }() time.Sleep(2 * time.Second) - controller.Abort() + controller.abort() wg.Wait() } // TestMemoryQuotaReleaseZero verifies that releasing 0 bytes is successful -func (s *flowControlSuite) TestMemoryQuotaReleaseZero(c *check.C) { - defer testleak.AfterTest(c)() +func TestMemoryQuotaReleaseZero(t *testing.T) { + t.Parallel() - controller := NewTableMemoryQuota(1024) - controller.Release(0) + controller := newTableMemoryQuota(1024) + controller.release(0) } type mockedEvent struct { @@ -171,8 +166,9 @@ type mockedEvent struct { size uint64 } -func (s *flowControlSuite) TestFlowControlBasic(c *check.C) { - defer testleak.AfterTest(c)() +func TestFlowControlBasic(t *testing.T) { + t.Parallel() + var consumedBytes uint64 ctx, cancel := context.WithTimeout(context.TODO(), time.Second*5) defer cancel() @@ -191,8 +187,8 @@ func (s *flowControlSuite) TestFlowControlBasic(c *check.C) { case <-ctx.Done(): return ctx.Err() case mockedRowsCh <- &commitTsSizeEntry{ - CommitTs: lastCommitTs, - Size: size, + commitTs: lastCommitTs, + size: size, }: } } @@ -217,10 +213,10 @@ func (s *flowControlSuite) TestFlowControlBasic(c *check.C) { break } - atomic.AddUint64(&consumedBytes, mockedRow.Size) + atomic.AddUint64(&consumedBytes, mockedRow.size) updatedResolvedTs := false - if resolvedTs != mockedRow.CommitTs { - c.Assert(resolvedTs, check.Less, mockedRow.CommitTs) + if resolvedTs != mockedRow.commitTs { + require.Less(t, resolvedTs, mockedRow.commitTs) select { case <-ctx.Done(): return ctx.Err() @@ -228,22 +224,22 @@ func (s *flowControlSuite) TestFlowControlBasic(c *check.C) { resolvedTs: resolvedTs, }: } - resolvedTs = mockedRow.CommitTs + resolvedTs = mockedRow.commitTs updatedResolvedTs = true } - err := flowController.Consume(mockedRow.CommitTs, mockedRow.Size, dummyCallBack) - c.Check(err, check.IsNil) + err := flowController.Consume(mockedRow.commitTs, mockedRow.size, dummyCallBack) + require.Nil(t, err) select { case <-ctx.Done(): return ctx.Err() case eventCh <- &mockedEvent{ - size: mockedRow.Size, + size: mockedRow.size, }: } if updatedResolvedTs { // new Txn - c.Assert(atomic.LoadUint64(&consumedBytes), check.Less, uint64(2048)) - c.Assert(flowController.GetConsumption(), check.Less, uint64(2048)) + require.Less(t, atomic.LoadUint64(&consumedBytes), uint64(2048)) + require.Less(t, flowController.GetConsumption(), uint64(2048)) } } select { @@ -280,12 +276,12 @@ func (s *flowControlSuite) TestFlowControlBasic(c *check.C) { return nil }) - c.Assert(errg.Wait(), check.IsNil) - c.Assert(atomic.LoadUint64(&consumedBytes), check.Equals, uint64(0)) + require.Nil(t, errg.Wait()) + require.Equal(t, uint64(0), atomic.LoadUint64(&consumedBytes)) } -func (s *flowControlSuite) TestFlowControlAbort(c *check.C) { - defer testleak.AfterTest(c)() +func TestFlowControlAbort(t *testing.T) { + t.Parallel() callBacker := &mockCallBacker{} controller := NewTableFlowController(1024) @@ -295,14 +291,14 @@ func (s *flowControlSuite) TestFlowControlAbort(c *check.C) { defer wg.Done() err := controller.Consume(1, 1000, callBacker.cb) - c.Assert(err, check.IsNil) - c.Assert(callBacker.timesCalled, check.Equals, 0) + require.Nil(t, err) + require.Equal(t, 0, callBacker.timesCalled) err = controller.Consume(2, 1000, callBacker.cb) - c.Assert(err, check.ErrorMatches, ".*ErrFlowControllerAborted.*") - c.Assert(callBacker.timesCalled, check.Equals, 1) + require.Regexp(t, ".*ErrFlowControllerAborted.*", err) + require.Equal(t, 1, callBacker.timesCalled) err = controller.Consume(2, 10, callBacker.cb) - c.Assert(err, check.ErrorMatches, ".*ErrFlowControllerAborted.*") - c.Assert(callBacker.timesCalled, check.Equals, 1) + require.Regexp(t, ".*ErrFlowControllerAborted.*", err) + require.Equal(t, 1, callBacker.timesCalled) }() time.Sleep(3 * time.Second) @@ -311,8 +307,9 @@ func (s *flowControlSuite) TestFlowControlAbort(c *check.C) { wg.Wait() } -func (s *flowControlSuite) TestFlowControlCallBack(c *check.C) { - defer testleak.AfterTest(c)() +func TestFlowControlCallBack(t *testing.T) { + t.Parallel() + var consumedBytes uint64 ctx, cancel := context.WithTimeout(context.TODO(), time.Second*5) defer cancel() @@ -331,8 +328,8 @@ func (s *flowControlSuite) TestFlowControlCallBack(c *check.C) { case <-ctx.Done(): return ctx.Err() case mockedRowsCh <- &commitTsSizeEntry{ - CommitTs: lastCommitTs, - Size: size, + commitTs: lastCommitTs, + size: size, }: } } @@ -357,8 +354,8 @@ func (s *flowControlSuite) TestFlowControlCallBack(c *check.C) { break } - atomic.AddUint64(&consumedBytes, mockedRow.Size) - err := flowController.Consume(mockedRow.CommitTs, mockedRow.Size, func() error { + atomic.AddUint64(&consumedBytes, mockedRow.size) + err := flowController.Consume(mockedRow.commitTs, mockedRow.size, func() error { select { case <-ctx.Done(): return ctx.Err() @@ -368,14 +365,14 @@ func (s *flowControlSuite) TestFlowControlCallBack(c *check.C) { } return nil }) - c.Assert(err, check.IsNil) - lastCRTs = mockedRow.CommitTs + require.Nil(t, err) + lastCRTs = mockedRow.commitTs select { case <-ctx.Done(): return ctx.Err() case eventCh <- &mockedEvent{ - size: mockedRow.Size, + size: mockedRow.size, }: } } @@ -413,12 +410,12 @@ func (s *flowControlSuite) TestFlowControlCallBack(c *check.C) { return nil }) - c.Assert(errg.Wait(), check.IsNil) - c.Assert(atomic.LoadUint64(&consumedBytes), check.Equals, uint64(0)) + require.Nil(t, errg.Wait()) + require.Equal(t, uint64(0), atomic.LoadUint64(&consumedBytes)) } -func (s *flowControlSuite) TestFlowControlCallBackNotBlockingRelease(c *check.C) { - defer testleak.AfterTest(c)() +func TestFlowControlCallBackNotBlockingRelease(t *testing.T) { + t.Parallel() var wg sync.WaitGroup controller := NewTableFlowController(512) @@ -430,10 +427,10 @@ func (s *flowControlSuite) TestFlowControlCallBackNotBlockingRelease(c *check.C) go func() { defer wg.Done() err := controller.Consume(1, 511, func() error { - c.Fatalf("unreachable") + t.Error("unreachable") return nil }) - c.Assert(err, check.IsNil) + require.Nil(t, err) var isBlocked int32 wg.Add(1) @@ -441,7 +438,7 @@ func (s *flowControlSuite) TestFlowControlCallBackNotBlockingRelease(c *check.C) defer wg.Done() <-time.After(time.Second * 1) // makes sure that this test case is valid - c.Assert(atomic.LoadInt32(&isBlocked), check.Equals, int32(1)) + require.Equal(t, int32(1), atomic.LoadInt32(&isBlocked)) controller.Release(1) cancel() }() @@ -453,14 +450,14 @@ func (s *flowControlSuite) TestFlowControlCallBackNotBlockingRelease(c *check.C) return ctx.Err() }) - c.Assert(err, check.ErrorMatches, ".*context canceled.*") + require.Regexp(t, ".*context canceled.*", err) }() wg.Wait() } -func (s *flowControlSuite) TestFlowControlCallBackError(c *check.C) { - defer testleak.AfterTest(c)() +func TestFlowControlCallBackError(t *testing.T) { + t.Parallel() var wg sync.WaitGroup controller := NewTableFlowController(512) @@ -472,15 +469,15 @@ func (s *flowControlSuite) TestFlowControlCallBackError(c *check.C) { go func() { defer wg.Done() err := controller.Consume(1, 511, func() error { - c.Fatalf("unreachable") + t.Error("unreachable") return nil }) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = controller.Consume(2, 511, func() error { <-ctx.Done() return ctx.Err() }) - c.Assert(err, check.ErrorMatches, ".*context canceled.*") + require.Regexp(t, ".*context canceled.*", err) }() time.Sleep(100 * time.Millisecond) @@ -489,15 +486,15 @@ func (s *flowControlSuite) TestFlowControlCallBackError(c *check.C) { wg.Wait() } -func (s *flowControlSuite) TestFlowControlConsumeLargerThanQuota(c *check.C) { - defer testleak.AfterTest(c)() +func TestFlowControlConsumeLargerThanQuota(t *testing.T) { + t.Parallel() controller := NewTableFlowController(1024) err := controller.Consume(1, 2048, func() error { - c.Fatalf("unreachable") + t.Error("unreachable") return nil }) - c.Assert(err, check.ErrorMatches, ".*ErrFlowControllerEventLargerThanQuota.*") + require.Regexp(t, ".*ErrFlowControllerEventLargerThanQuota.*", err) } func BenchmarkTableFlowController(B *testing.B) { @@ -518,8 +515,8 @@ func BenchmarkTableFlowController(B *testing.B) { case <-ctx.Done(): return ctx.Err() case mockedRowsCh <- &commitTsSizeEntry{ - CommitTs: lastCommitTs, - Size: size, + commitTs: lastCommitTs, + size: size, }: } } @@ -544,7 +541,7 @@ func BenchmarkTableFlowController(B *testing.B) { break } - if resolvedTs != mockedRow.CommitTs { + if resolvedTs != mockedRow.commitTs { select { case <-ctx.Done(): return ctx.Err() @@ -552,9 +549,9 @@ func BenchmarkTableFlowController(B *testing.B) { resolvedTs: resolvedTs, }: } - resolvedTs = mockedRow.CommitTs + resolvedTs = mockedRow.commitTs } - err := flowController.Consume(mockedRow.CommitTs, mockedRow.Size, dummyCallBack) + err := flowController.Consume(mockedRow.commitTs, mockedRow.size, dummyCallBack) if err != nil { B.Fatal(err) } @@ -562,7 +559,7 @@ func BenchmarkTableFlowController(B *testing.B) { case <-ctx.Done(): return ctx.Err() case eventCh <- &mockedEvent{ - size: mockedRow.Size, + size: mockedRow.size, }: } } diff --git a/cdc/sink/flowcontrol/main_test.go b/cdc/sink/flowcontrol/main_test.go new file mode 100644 index 00000000000..a340611e61b --- /dev/null +++ b/cdc/sink/flowcontrol/main_test.go @@ -0,0 +1,24 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package flowcontrol + +import ( + "testing" + + "github.com/pingcap/tiflow/pkg/leakutil" +) + +func TestMain(m *testing.M) { + leakutil.SetUpLeakTest(m) +} diff --git a/cdc/sink/flowcontrol/table_memory_quota.go b/cdc/sink/flowcontrol/table_memory_quota.go new file mode 100644 index 00000000000..7ca15e7857f --- /dev/null +++ b/cdc/sink/flowcontrol/table_memory_quota.go @@ -0,0 +1,138 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package flowcontrol + +import ( + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/log" + cerrors "github.com/pingcap/tiflow/pkg/errors" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +// tableMemoryQuota is designed to curb the total memory consumption of processing +// the event streams in a table. +// A higher-level controller more suitable for direct use by the processor is TableFlowController. +type tableMemoryQuota struct { + quota uint64 // should not be changed once initialized + + isAborted atomic.Bool + + consumed struct { + sync.Mutex + bytes uint64 + } + + consumedCond *sync.Cond +} + +// newTableMemoryQuota creates a new tableMemoryQuota +// quota: max advised memory consumption in bytes. +func newTableMemoryQuota(quota uint64) *tableMemoryQuota { + ret := &tableMemoryQuota{ + quota: quota, + } + + ret.consumedCond = sync.NewCond(&ret.consumed) + return ret +} + +// consumeWithBlocking is called when a hard-limit is needed. The method will +// block until enough memory has been freed up by release. +// blockCallBack will be called if the function will block. +// Should be used with care to prevent deadlock. +func (c *tableMemoryQuota) consumeWithBlocking(nBytes uint64, blockCallBack func() error) error { + if nBytes >= c.quota { + return cerrors.ErrFlowControllerEventLargerThanQuota.GenWithStackByArgs(nBytes, c.quota) + } + + c.consumed.Lock() + if c.consumed.bytes+nBytes >= c.quota { + c.consumed.Unlock() + err := blockCallBack() + if err != nil { + return errors.Trace(err) + } + } else { + c.consumed.Unlock() + } + + c.consumed.Lock() + defer c.consumed.Unlock() + + for { + if c.isAborted.Load() { + return cerrors.ErrFlowControllerAborted.GenWithStackByArgs() + } + + if c.consumed.bytes+nBytes < c.quota { + break + } + c.consumedCond.Wait() + } + + c.consumed.bytes += nBytes + return nil +} + +// forceConsume is called when blocking is not acceptable and the limit can be violated +// for the sake of avoid deadlock. It merely records the increased memory consumption. +func (c *tableMemoryQuota) forceConsume(nBytes uint64) error { + c.consumed.Lock() + defer c.consumed.Unlock() + + if c.isAborted.Load() { + return cerrors.ErrFlowControllerAborted.GenWithStackByArgs() + } + + c.consumed.bytes += nBytes + return nil +} + +// release is called when a chuck of memory is done being used. +func (c *tableMemoryQuota) release(nBytes uint64) { + c.consumed.Lock() + + if c.consumed.bytes < nBytes { + c.consumed.Unlock() + log.Panic("tableMemoryQuota: releasing more than consumed, report a bug", + zap.Uint64("consumed", c.consumed.bytes), + zap.Uint64("released", nBytes)) + } + + c.consumed.bytes -= nBytes + if c.consumed.bytes < c.quota { + c.consumed.Unlock() + c.consumedCond.Signal() + return + } + + c.consumed.Unlock() +} + +// abort interrupts any ongoing consumeWithBlocking call +func (c *tableMemoryQuota) abort() { + c.isAborted.Store(true) + c.consumedCond.Signal() +} + +// getConsumption returns the current memory consumption +func (c *tableMemoryQuota) getConsumption() uint64 { + c.consumed.Lock() + defer c.consumed.Unlock() + + return c.consumed.bytes +}