From 8f1a5cb7ad44ae90e005422a4dd151bcda1a3c92 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 27 Feb 2024 16:01:14 +0800 Subject: [PATCH 1/5] util: add generics lfu Signed-off-by: Weizhen Wang --- pkg/util/lfu/key_set.go | 73 ++++++++ pkg/util/lfu/key_set_shard.go | 69 ++++++++ pkg/util/lfu/lfu_cache.go | 266 +++++++++++++++++++++++++++++ pkg/util/lfu/lfu_cache_test.go | 297 +++++++++++++++++++++++++++++++++ 4 files changed, 705 insertions(+) create mode 100644 pkg/util/lfu/key_set.go create mode 100644 pkg/util/lfu/key_set_shard.go create mode 100644 pkg/util/lfu/lfu_cache.go create mode 100644 pkg/util/lfu/lfu_cache_test.go diff --git a/pkg/util/lfu/key_set.go b/pkg/util/lfu/key_set.go new file mode 100644 index 0000000000000..4467dab30c592 --- /dev/null +++ b/pkg/util/lfu/key_set.go @@ -0,0 +1,73 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lfu + +import ( + "sync" + + "github.com/pingcap/tidb/pkg/statistics" + "golang.org/x/exp/maps" +) + +type keySet struct { + set map[int64]*statistics.Table + mu sync.RWMutex +} + +func (ks *keySet) Remove(key int64) int64 { + var cost int64 + ks.mu.Lock() + if table, ok := ks.set[key]; ok { + if table != nil { + cost = table.MemoryUsage().TotalTrackingMemUsage() + } + delete(ks.set, key) + } + ks.mu.Unlock() + return cost +} + +func (ks *keySet) Keys() []int64 { + ks.mu.RLock() + result := maps.Keys(ks.set) + ks.mu.RUnlock() + return result +} + +func (ks *keySet) Len() int { + ks.mu.RLock() + result := len(ks.set) + ks.mu.RUnlock() + return result +} + +func (ks *keySet) AddKeyValue(key int64, value *statistics.Table) { + ks.mu.Lock() + ks.set[key] = value + ks.mu.Unlock() +} + +func (ks *keySet) Get(key int64) (*statistics.Table, bool) { + ks.mu.RLock() + value, ok := ks.set[key] + ks.mu.RUnlock() + return value, ok +} + +func (ks *keySet) Clear() { + ks.mu.Lock() + ks.set = make(map[int64]*statistics.Table) + ks.mu.Unlock() +} diff --git a/pkg/util/lfu/key_set_shard.go b/pkg/util/lfu/key_set_shard.go new file mode 100644 index 0000000000000..f07990400001e --- /dev/null +++ b/pkg/util/lfu/key_set_shard.go @@ -0,0 +1,69 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lfu + +import ( + "github.com/pingcap/tidb/pkg/statistics" +) + +const keySetCnt = 256 + +type keySetShard struct { + resultKeySet [keySetCnt]keySet +} + +func newKeySetShard() *keySetShard { + result := keySetShard{} + for i := 0; i < keySetCnt; i++ { + result.resultKeySet[i] = keySet{ + set: make(map[int64]*statistics.Table), + } + } + return &result +} + +func (kss *keySetShard) Get(key int64) (*statistics.Table, bool) { + return kss.resultKeySet[key%keySetCnt].Get(key) +} + +func (kss *keySetShard) AddKeyValue(key int64, table *statistics.Table) { + kss.resultKeySet[key%keySetCnt].AddKeyValue(key, table) +} + +func (kss *keySetShard) Remove(key int64) { + kss.resultKeySet[key%keySetCnt].Remove(key) +} + +func (kss *keySetShard) Keys() []int64 { + result := make([]int64, 0, len(kss.resultKeySet)) + for idx := range kss.resultKeySet { + result = append(result, kss.resultKeySet[idx].Keys()...) + } + return result +} + +func (kss *keySetShard) Len() int { + result := 0 + for idx := range kss.resultKeySet { + result += kss.resultKeySet[idx].Len() + } + return result +} + +func (kss *keySetShard) Clear() { + for idx := range kss.resultKeySet { + kss.resultKeySet[idx].Clear() + } +} diff --git a/pkg/util/lfu/lfu_cache.go b/pkg/util/lfu/lfu_cache.go new file mode 100644 index 0000000000000..d4299d19b45be --- /dev/null +++ b/pkg/util/lfu/lfu_cache.go @@ -0,0 +1,266 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lfu + +import ( + "sync" + "sync/atomic" + + "github.com/dgraph-io/ristretto" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal" + "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal/metrics" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" + "golang.org/x/exp/rand" +) + +// LFU is a LFU based on the ristretto.Cache +type LFU struct { + cache *ristretto.Cache + // This is a secondary cache layer used to store all tables, + // including those that have been evicted from the primary cache. + resultKeySet *keySetShard + cost atomic.Int64 + closed atomic.Bool + closeOnce sync.Once +} + +// NewLFU creates a new LFU cache. +func NewLFU(totalMemCost int64) (*LFU, error) { + cost, err := adjustMemCost(totalMemCost) + if err != nil { + return nil, err + } + if intest.InTest && totalMemCost == 0 { + // In test, we set the cost to 5MB to avoid using too many memory in the LFU's CM sketch. + cost = 5000000 + } + metrics.CapacityGauge.Set(float64(cost)) + result := &LFU{} + bufferItems := int64(64) + + cache, err := ristretto.NewCache( + &ristretto.Config{ + NumCounters: max(min(cost/128, 1_000_000), 10), // assume the cost per table stats is 128 + MaxCost: cost, + BufferItems: bufferItems, + OnEvict: result.onEvict, + OnExit: result.onExit, + OnReject: result.onReject, + IgnoreInternalCost: intest.InTest, + Metrics: intest.InTest, + }, + ) + if err != nil { + return nil, err + } + result.cache = cache + result.resultKeySet = newKeySetShard() + return result, err +} + +// adjustMemCost adjusts the memory cost according to the total memory cost. +// When the total memory cost is 0, the memory cost is set to half of the total memory. +func adjustMemCost(totalMemCost int64) (result int64, err error) { + if totalMemCost == 0 { + memTotal, err := memory.MemTotal() + if err != nil { + return 0, err + } + return int64(memTotal / 2), nil + } + return totalMemCost, nil +} + +// Get implements statsCacheInner +func (s *LFU) Get(tid int64) (*statistics.Table, bool) { + result, ok := s.cache.Get(tid) + if !ok { + return s.resultKeySet.Get(tid) + } + return result.(*statistics.Table), ok +} + +// Put implements statsCacheInner +func (s *LFU) Put(tblID int64, tbl *statistics.Table) bool { + cost := tbl.MemoryUsage().TotalTrackingMemUsage() + s.resultKeySet.AddKeyValue(tblID, tbl) + s.addCost(cost) + return s.cache.Set(tblID, tbl, cost) +} + +// Del implements statsCacheInner +func (s *LFU) Del(tblID int64) { + s.cache.Del(tblID) + s.resultKeySet.Remove(tblID) +} + +// Cost implements statsCacheInner +func (s *LFU) Cost() int64 { + return s.cost.Load() +} + +// Values implements statsCacheInner +func (s *LFU) Values() []*statistics.Table { + result := make([]*statistics.Table, 0, 512) + for _, k := range s.resultKeySet.Keys() { + if value, ok := s.resultKeySet.Get(k); ok { + result = append(result, value) + } + } + return result +} + +// DropEvicted drop stats for table column/index +func DropEvicted(item statistics.TableCacheItem) { + if !item.IsStatsInitialized() || + item.GetEvictedStatus() == statistics.AllEvicted { + return + } + item.DropUnnecessaryData() +} + +func (s *LFU) onReject(item *ristretto.Item) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Warn("panic in onReject", zap.Any("error", r), zap.Stack("stack")) + } + }() + s.dropMemory(item) + metrics.RejectCounter.Inc() +} + +func (s *LFU) onEvict(item *ristretto.Item) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Warn("panic in onEvict", zap.Any("error", r), zap.Stack("stack")) + } + }() + s.dropMemory(item) + metrics.EvictCounter.Inc() +} + +func (s *LFU) dropMemory(item *ristretto.Item) { + if item.Value == nil { + // Sometimes the same key may be passed to the "onEvict/onExit" + // function twice, and in the second invocation, the value is empty, + // so it should not be processed. + return + } + if s.closed.Load() { + return + } + // We do not need to calculate the cost during onEvict, + // because the onexit function is also called when the evict event occurs. + // TODO(hawkingrei): not copy the useless part. + table := item.Value.(*statistics.Table).Copy() + for _, column := range table.Columns { + DropEvicted(column) + } + for _, indix := range table.Indices { + DropEvicted(indix) + } + s.resultKeySet.AddKeyValue(int64(item.Key), table) + after := table.MemoryUsage().TotalTrackingMemUsage() + // why add before again? because the cost will be subtracted in onExit. + // in fact, it is after - before + s.addCost(after) + s.triggerEvict() +} + +func (s *LFU) triggerEvict() { + // When the memory usage of the cache exceeds the maximum value, Many item need to evict. But + // ristretto'c cache execute the evict operation when to write the cache. for we can evict as soon as possible, + // we will write some fake item to the cache. fake item have a negative key, and the value is nil. + if s.Cost() > s.cache.MaxCost() { + //nolint: gosec + s.cache.Set(-rand.Int(), nil, 0) + } +} + +func (s *LFU) onExit(val any) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Warn("panic in onExit", zap.Any("error", r), zap.Stack("stack")) + } + }() + if val == nil { + // Sometimes the same key may be passed to the "onEvict/onExit" function twice, + // and in the second invocation, the value is empty, so it should not be processed. + return + } + if s.closed.Load() { + return + } + // Subtract the memory usage of the table from the total memory usage. + s.addCost(-val.(*statistics.Table).MemoryUsage().TotalTrackingMemUsage()) +} + +// Len implements statsCacheInner +func (s *LFU) Len() int { + return s.resultKeySet.Len() +} + +// Copy implements statsCacheInner +func (s *LFU) Copy() internal.StatsCacheInner { + return s +} + +// SetCapacity implements statsCacheInner +func (s *LFU) SetCapacity(maxCost int64) { + cost, err := adjustMemCost(maxCost) + if err != nil { + logutil.BgLogger().Warn("adjustMemCost failed", zap.Error(err)) + return + } + s.cache.UpdateMaxCost(cost) + s.triggerEvict() + metrics.CapacityGauge.Set(float64(cost)) + metrics.CostGauge.Set(float64(s.Cost())) +} + +// wait blocks until all buffered writes have been applied. This ensures a call to Set() +// will be visible to future calls to Get(). it is only used for test. +func (s *LFU) wait() { + s.cache.Wait() +} + +func (s *LFU) metrics() *ristretto.Metrics { + return s.cache.Metrics +} + +// Close implements statsCacheInner +func (s *LFU) Close() { + s.closeOnce.Do(func() { + s.closed.Store(true) + s.Clear() + s.cache.Close() + s.cache.Wait() + }) +} + +// Clear implements statsCacheInner +func (s *LFU) Clear() { + s.cache.Clear() + s.resultKeySet.Clear() +} + +func (s *LFU) addCost(v int64) { + newv := s.cost.Add(v) + metrics.CostGauge.Set(float64(newv)) +} diff --git a/pkg/util/lfu/lfu_cache_test.go b/pkg/util/lfu/lfu_cache_test.go new file mode 100644 index 0000000000000..de690b8e06ca9 --- /dev/null +++ b/pkg/util/lfu/lfu_cache_test.go @@ -0,0 +1,297 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lfu + +import ( + "math/rand" + "sync" + "testing" + "time" + + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal/testutil" + "github.com/stretchr/testify/require" +) + +var ( + mockCMSMemoryUsage = int64(4) +) + +func TestLFUPutGetDel(t *testing.T) { + capacity := int64(100) + lfu, err := NewLFU(capacity) + require.NoError(t, err) + mockTable := testutil.NewMockStatisticsTable(1, 1, true, false, false) + mockTableID := int64(1) + lfu.Put(mockTableID, mockTable) + lfu.wait() + lfu.Del(mockTableID) + v, ok := lfu.Get(mockTableID) + require.False(t, ok) + require.Nil(t, v) + lfu.wait() + require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) + require.Equal(t, 0, len(lfu.Values())) +} + +func TestLFUFreshMemUsage(t *testing.T) { + lfu, err := NewLFU(10000) + require.NoError(t, err) + t1 := testutil.NewMockStatisticsTable(1, 1, true, false, false) + require.Equal(t, mockCMSMemoryUsage+mockCMSMemoryUsage, t1.MemoryUsage().TotalMemUsage) + t2 := testutil.NewMockStatisticsTable(2, 2, true, false, false) + require.Equal(t, 2*mockCMSMemoryUsage+2*mockCMSMemoryUsage, t2.MemoryUsage().TotalMemUsage) + t3 := testutil.NewMockStatisticsTable(3, 3, true, false, false) + require.Equal(t, 3*mockCMSMemoryUsage+3*mockCMSMemoryUsage, t3.MemoryUsage().TotalMemUsage) + lfu.Put(int64(1), t1) + lfu.Put(int64(2), t2) + lfu.Put(int64(3), t3) + lfu.wait() + require.Equal(t, lfu.Cost(), 6*mockCMSMemoryUsage+6*mockCMSMemoryUsage) + t4 := testutil.NewMockStatisticsTable(2, 1, true, false, false) + lfu.Put(int64(1), t4) + lfu.wait() + require.Equal(t, lfu.Cost(), 7*mockCMSMemoryUsage+6*mockCMSMemoryUsage) + t5 := testutil.NewMockStatisticsTable(2, 2, true, false, false) + lfu.Put(int64(1), t5) + lfu.wait() + require.Equal(t, lfu.Cost(), 7*mockCMSMemoryUsage+7*mockCMSMemoryUsage) + + t6 := testutil.NewMockStatisticsTable(1, 2, true, false, false) + lfu.Put(int64(1), t6) + require.Equal(t, lfu.Cost(), 7*mockCMSMemoryUsage+6*mockCMSMemoryUsage) + + t7 := testutil.NewMockStatisticsTable(1, 1, true, false, false) + lfu.Put(int64(1), t7) + require.Equal(t, lfu.Cost(), 6*mockCMSMemoryUsage+6*mockCMSMemoryUsage) + lfu.wait() + require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) +} + +func TestLFUPutTooBig(t *testing.T) { + lfu, err := NewLFU(1) + require.NoError(t, err) + mockTable := testutil.NewMockStatisticsTable(1, 1, true, false, false) + // put mockTable, the index should be evicted but the table still exists in the list. + lfu.Put(int64(1), mockTable) + _, ok := lfu.Get(int64(1)) + require.True(t, ok) + lfu.wait() + require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) +} + +func TestCacheLen(t *testing.T) { + capacity := int64(12) + lfu, err := NewLFU(capacity) + require.NoError(t, err) + t1 := testutil.NewMockStatisticsTable(2, 1, true, false, false) + require.Equal(t, int64(12), t1.MemoryUsage().TotalTrackingMemUsage()) + lfu.Put(int64(1), t1) + t2 := testutil.NewMockStatisticsTable(1, 1, true, false, false) + // put t2, t1 should be evicted 2 items and still exists in the list + lfu.Put(int64(2), t2) + lfu.wait() + require.Equal(t, lfu.Len(), 2) + require.Equal(t, uint64(8), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) + + // put t3, t1/t2 should be evicted all items. but t1/t2 still exists in the list + t3 := testutil.NewMockStatisticsTable(2, 1, true, false, false) + lfu.Put(int64(3), t3) + lfu.wait() + require.Equal(t, lfu.Len(), 3) + require.Equal(t, uint64(12), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) +} + +func TestLFUCachePutGetWithManyConcurrency(t *testing.T) { + // to test DATA RACE + capacity := int64(100000000000) + lfu, err := NewLFU(capacity) + require.NoError(t, err) + var wg sync.WaitGroup + wg.Add(2000) + for i := 0; i < 1000; i++ { + go func(i int) { + defer wg.Done() + t1 := testutil.NewMockStatisticsTable(1, 1, true, false, false) + lfu.Put(int64(i), t1) + }(i) + go func(i int) { + defer wg.Done() + lfu.Get(int64(i)) + }(i) + } + wg.Wait() + lfu.wait() + require.Equal(t, lfu.Len(), 1000) + require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) + require.Equal(t, 1000, len(lfu.Values())) +} + +func TestLFUCachePutGetWithManyConcurrency2(t *testing.T) { + // to test DATA RACE + capacity := int64(100000000000) + lfu, err := NewLFU(capacity) + require.NoError(t, err) + var wg sync.WaitGroup + wg.Add(10) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + for n := 0; n < 1000; n++ { + t1 := testutil.NewMockStatisticsTable(1, 1, true, false, false) + lfu.Put(int64(n), t1) + } + }() + } + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + for n := 0; n < 1000; n++ { + lfu.Get(int64(n)) + } + }() + } + wg.Wait() + lfu.wait() + require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) + require.Equal(t, 1000, len(lfu.Values())) +} + +func TestLFUCachePutGetWithManyConcurrencyAndSmallConcurrency(t *testing.T) { + // to test DATA RACE + + capacity := int64(100) + lfu, err := NewLFU(capacity) + require.NoError(t, err) + var wg sync.WaitGroup + wg.Add(10) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + for c := 0; c < 1000; c++ { + for n := 0; n < 50; n++ { + t1 := testutil.NewMockStatisticsTable(1, 1, true, true, true) + lfu.Put(int64(n), t1) + } + } + }() + } + time.Sleep(1 * time.Second) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + for c := 0; c < 1000; c++ { + for n := 0; n < 50; n++ { + tbl, ok := lfu.Get(int64(n)) + require.True(t, ok) + checkTable(t, tbl) + } + } + }() + } + wg.Wait() + lfu.wait() + v, ok := lfu.Get(rand.Int63n(50)) + require.True(t, ok) + for _, c := range v.Columns { + require.Equal(t, c.GetEvictedStatus(), statistics.AllEvicted) + } + for _, i := range v.Indices { + require.Equal(t, i.GetEvictedStatus(), statistics.AllEvicted) + } +} + +func checkTable(t *testing.T, tbl *statistics.Table) { + for _, column := range tbl.Columns { + if column.GetEvictedStatus() == statistics.AllEvicted { + require.Nil(t, column.TopN) + require.Equal(t, 0, cap(column.Histogram.Buckets)) + } else { + require.NotNil(t, column.TopN) + require.Greater(t, cap(column.Histogram.Buckets), 0) + } + } + for _, idx := range tbl.Indices { + if idx.GetEvictedStatus() == statistics.AllEvicted { + require.Nil(t, idx.TopN) + require.Equal(t, 0, cap(idx.Histogram.Buckets)) + } else { + require.NotNil(t, idx.TopN) + require.Greater(t, cap(idx.Histogram.Buckets), 0) + } + } +} + +func TestLFUReject(t *testing.T) { + capacity := int64(100000000000) + lfu, err := NewLFU(capacity) + require.NoError(t, err) + t1 := testutil.NewMockStatisticsTable(2, 1, true, false, false) + require.Equal(t, 2*mockCMSMemoryUsage+mockCMSMemoryUsage, t1.MemoryUsage().TotalTrackingMemUsage()) + lfu.Put(1, t1) + lfu.wait() + require.Equal(t, lfu.Cost(), 2*mockCMSMemoryUsage+mockCMSMemoryUsage) + + lfu.SetCapacity(2*mockCMSMemoryUsage + mockCMSMemoryUsage - 1) + + t2 := testutil.NewMockStatisticsTable(2, 1, true, false, false) + require.True(t, lfu.Put(2, t2)) + lfu.wait() + time.Sleep(3 * time.Second) + require.Equal(t, int64(0), lfu.Cost()) + require.Len(t, lfu.Values(), 2) + v, ok := lfu.Get(2) + require.True(t, ok) + for _, c := range v.Columns { + require.Equal(t, statistics.AllEvicted, c.GetEvictedStatus()) + } + for _, i := range v.Indices { + require.Equal(t, statistics.AllEvicted, i.GetEvictedStatus()) + } +} + +func TestMemoryControl(t *testing.T) { + capacity := int64(100000000000) + lfu, err := NewLFU(capacity) + require.NoError(t, err) + t1 := testutil.NewMockStatisticsTable(2, 1, true, false, false) + require.Equal(t, 2*mockCMSMemoryUsage+mockCMSMemoryUsage, t1.MemoryUsage().TotalTrackingMemUsage()) + lfu.Put(1, t1) + lfu.wait() + + for i := 2; i <= 1000; i++ { + t1 := testutil.NewMockStatisticsTable(2, 1, true, false, false) + require.Equal(t, 2*mockCMSMemoryUsage+mockCMSMemoryUsage, t1.MemoryUsage().TotalTrackingMemUsage()) + lfu.Put(int64(i), t1) + } + require.Equal(t, 1000*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) + + for i := 1000; i > 990; i-- { + lfu.SetCapacity(int64(i-1) * (2*mockCMSMemoryUsage + mockCMSMemoryUsage)) + lfu.wait() + require.Equal(t, int64(i-1)*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) + } + for i := 990; i > 100; i = i - 100 { + lfu.SetCapacity(int64(i-1) * (2*mockCMSMemoryUsage + mockCMSMemoryUsage)) + lfu.wait() + require.Equal(t, int64(i-1)*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) + } + lfu.SetCapacity(int64(10) * (2*mockCMSMemoryUsage + mockCMSMemoryUsage)) + lfu.wait() + require.Equal(t, int64(10)*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) + lfu.SetCapacity(0) + lfu.wait() + require.Equal(t, int64(10)*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) +} From 758a3f48219fb622b70dc4998ee96599a9267eb6 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 27 Feb 2024 23:02:56 +0800 Subject: [PATCH 2/5] test Signed-off-by: Weizhen Wang --- .../handle/cache/internal/lfu/key_set.go | 73 ----- .../cache/internal/lfu/key_set_shard.go | 69 ---- pkg/util/lfu/BUILD.bazel | 24 ++ pkg/util/lfu/key.go | 39 +++ pkg/util/lfu/key_set.go | 30 +- pkg/util/lfu/key_set_shard.go | 36 +-- pkg/util/lfu/lfu_cache.go | 149 +++++---- pkg/util/lfu/lfu_cache_test.go | 297 ------------------ 8 files changed, 193 insertions(+), 524 deletions(-) delete mode 100644 pkg/statistics/handle/cache/internal/lfu/key_set.go delete mode 100644 pkg/statistics/handle/cache/internal/lfu/key_set_shard.go create mode 100644 pkg/util/lfu/BUILD.bazel create mode 100644 pkg/util/lfu/key.go delete mode 100644 pkg/util/lfu/lfu_cache_test.go diff --git a/pkg/statistics/handle/cache/internal/lfu/key_set.go b/pkg/statistics/handle/cache/internal/lfu/key_set.go deleted file mode 100644 index 4467dab30c592..0000000000000 --- a/pkg/statistics/handle/cache/internal/lfu/key_set.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lfu - -import ( - "sync" - - "github.com/pingcap/tidb/pkg/statistics" - "golang.org/x/exp/maps" -) - -type keySet struct { - set map[int64]*statistics.Table - mu sync.RWMutex -} - -func (ks *keySet) Remove(key int64) int64 { - var cost int64 - ks.mu.Lock() - if table, ok := ks.set[key]; ok { - if table != nil { - cost = table.MemoryUsage().TotalTrackingMemUsage() - } - delete(ks.set, key) - } - ks.mu.Unlock() - return cost -} - -func (ks *keySet) Keys() []int64 { - ks.mu.RLock() - result := maps.Keys(ks.set) - ks.mu.RUnlock() - return result -} - -func (ks *keySet) Len() int { - ks.mu.RLock() - result := len(ks.set) - ks.mu.RUnlock() - return result -} - -func (ks *keySet) AddKeyValue(key int64, value *statistics.Table) { - ks.mu.Lock() - ks.set[key] = value - ks.mu.Unlock() -} - -func (ks *keySet) Get(key int64) (*statistics.Table, bool) { - ks.mu.RLock() - value, ok := ks.set[key] - ks.mu.RUnlock() - return value, ok -} - -func (ks *keySet) Clear() { - ks.mu.Lock() - ks.set = make(map[int64]*statistics.Table) - ks.mu.Unlock() -} diff --git a/pkg/statistics/handle/cache/internal/lfu/key_set_shard.go b/pkg/statistics/handle/cache/internal/lfu/key_set_shard.go deleted file mode 100644 index f07990400001e..0000000000000 --- a/pkg/statistics/handle/cache/internal/lfu/key_set_shard.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lfu - -import ( - "github.com/pingcap/tidb/pkg/statistics" -) - -const keySetCnt = 256 - -type keySetShard struct { - resultKeySet [keySetCnt]keySet -} - -func newKeySetShard() *keySetShard { - result := keySetShard{} - for i := 0; i < keySetCnt; i++ { - result.resultKeySet[i] = keySet{ - set: make(map[int64]*statistics.Table), - } - } - return &result -} - -func (kss *keySetShard) Get(key int64) (*statistics.Table, bool) { - return kss.resultKeySet[key%keySetCnt].Get(key) -} - -func (kss *keySetShard) AddKeyValue(key int64, table *statistics.Table) { - kss.resultKeySet[key%keySetCnt].AddKeyValue(key, table) -} - -func (kss *keySetShard) Remove(key int64) { - kss.resultKeySet[key%keySetCnt].Remove(key) -} - -func (kss *keySetShard) Keys() []int64 { - result := make([]int64, 0, len(kss.resultKeySet)) - for idx := range kss.resultKeySet { - result = append(result, kss.resultKeySet[idx].Keys()...) - } - return result -} - -func (kss *keySetShard) Len() int { - result := 0 - for idx := range kss.resultKeySet { - result += kss.resultKeySet[idx].Len() - } - return result -} - -func (kss *keySetShard) Clear() { - for idx := range kss.resultKeySet { - kss.resultKeySet[idx].Clear() - } -} diff --git a/pkg/util/lfu/BUILD.bazel b/pkg/util/lfu/BUILD.bazel new file mode 100644 index 0000000000000..8631728600738 --- /dev/null +++ b/pkg/util/lfu/BUILD.bazel @@ -0,0 +1,24 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "lfu", + srcs = [ + "key.go", + "key_set.go", + "key_set_shard.go", + "lfu_cache.go", + ], + importpath = "github.com/pingcap/tidb/pkg/util/lfu", + visibility = ["//visibility:public"], + deps = [ + "//pkg/util/intest", + "//pkg/util/logutil", + "//pkg/util/memory", + "@com_github_cespare_xxhash_v2//:xxhash", + "@com_github_dgraph_io_ristretto//:ristretto", + "@com_github_prometheus_client_golang//prometheus", + "@org_golang_x_exp//maps", + "@org_golang_x_exp//rand", + "@org_uber_go_zap//:zap", + ], +) diff --git a/pkg/util/lfu/key.go b/pkg/util/lfu/key.go new file mode 100644 index 0000000000000..f2617c8494929 --- /dev/null +++ b/pkg/util/lfu/key.go @@ -0,0 +1,39 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lfu + +import "github.com/cespare/xxhash/v2" + +func KeyToHash(key interface{}) uint64 { + if key == nil { + return 0 + } + switch k := key.(type) { + case uint64: + return k + case string: + return xxhash.Sum64String(k) + case int: + return uint64(k) + case int32: + return uint64(k) + case uint32: + return uint64(k) + case int64: + return uint64(k) + default: + panic("Key type not supported") + } +} diff --git a/pkg/util/lfu/key_set.go b/pkg/util/lfu/key_set.go index 4467dab30c592..85c5454734462 100644 --- a/pkg/util/lfu/key_set.go +++ b/pkg/util/lfu/key_set.go @@ -17,21 +17,29 @@ package lfu import ( "sync" - "github.com/pingcap/tidb/pkg/statistics" "golang.org/x/exp/maps" ) -type keySet struct { - set map[int64]*statistics.Table +type K interface { + ~uint64 | ~string | ~int | ~int32 | ~uint32 | ~int64 +} +type V interface { + comparable + Copy() any + TotalTrackingMemUsage() int64 +} + +type keySet[k K, v V] struct { + set map[k]v mu sync.RWMutex } -func (ks *keySet) Remove(key int64) int64 { +func (ks *keySet[K, V]) Remove(key K) int64 { var cost int64 ks.mu.Lock() if table, ok := ks.set[key]; ok { if table != nil { - cost = table.MemoryUsage().TotalTrackingMemUsage() + cost = table.TotalTrackingMemUsage() } delete(ks.set, key) } @@ -39,35 +47,35 @@ func (ks *keySet) Remove(key int64) int64 { return cost } -func (ks *keySet) Keys() []int64 { +func (ks *keySet[K, V]) Keys() []K { ks.mu.RLock() result := maps.Keys(ks.set) ks.mu.RUnlock() return result } -func (ks *keySet) Len() int { +func (ks *keySet[K, V]) Len() int { ks.mu.RLock() result := len(ks.set) ks.mu.RUnlock() return result } -func (ks *keySet) AddKeyValue(key int64, value *statistics.Table) { +func (ks *keySet[K, V]) AddKeyValue(key K, value V) { ks.mu.Lock() ks.set[key] = value ks.mu.Unlock() } -func (ks *keySet) Get(key int64) (*statistics.Table, bool) { +func (ks *keySet[K, V]) Get(key K) (V, bool) { ks.mu.RLock() value, ok := ks.set[key] ks.mu.RUnlock() return value, ok } -func (ks *keySet) Clear() { +func (ks *keySet[K, V]) Clear() { ks.mu.Lock() - ks.set = make(map[int64]*statistics.Table) + ks.set = make(map[K]V) ks.mu.Unlock() } diff --git a/pkg/util/lfu/key_set_shard.go b/pkg/util/lfu/key_set_shard.go index f07990400001e..e6341e26da9f3 100644 --- a/pkg/util/lfu/key_set_shard.go +++ b/pkg/util/lfu/key_set_shard.go @@ -14,47 +14,43 @@ package lfu -import ( - "github.com/pingcap/tidb/pkg/statistics" -) - const keySetCnt = 256 -type keySetShard struct { - resultKeySet [keySetCnt]keySet +type keySetShard[k K, v V] struct { + resultKeySet [keySetCnt]keySet[k, v] } -func newKeySetShard() *keySetShard { - result := keySetShard{} +func newKeySetShard[k K, v V]() *keySetShard[k, v] { + result := keySetShard[k, v]{} for i := 0; i < keySetCnt; i++ { - result.resultKeySet[i] = keySet{ - set: make(map[int64]*statistics.Table), + result.resultKeySet[i] = keySet[k, v]{ + set: make(map[k]v), } } return &result } -func (kss *keySetShard) Get(key int64) (*statistics.Table, bool) { - return kss.resultKeySet[key%keySetCnt].Get(key) +func (kss *keySetShard[K, V]) Get(key K) (V, bool) { + return kss.resultKeySet[KeyToHash(key)%keySetCnt].Get(key) } -func (kss *keySetShard) AddKeyValue(key int64, table *statistics.Table) { - kss.resultKeySet[key%keySetCnt].AddKeyValue(key, table) +func (kss *keySetShard[K, V]) AddKeyValue(key K, table V) { + kss.resultKeySet[KeyToHash(key)%keySetCnt].AddKeyValue(key, table) } -func (kss *keySetShard) Remove(key int64) { - kss.resultKeySet[key%keySetCnt].Remove(key) +func (kss *keySetShard[K, V]) Remove(key K) { + kss.resultKeySet[KeyToHash(key)%keySetCnt].Remove(key) } -func (kss *keySetShard) Keys() []int64 { - result := make([]int64, 0, len(kss.resultKeySet)) +func (kss *keySetShard[K, V]) Keys() []K { + result := make([]K, 0, len(kss.resultKeySet)) for idx := range kss.resultKeySet { result = append(result, kss.resultKeySet[idx].Keys()...) } return result } -func (kss *keySetShard) Len() int { +func (kss *keySetShard[K, V]) Len() int { result := 0 for idx := range kss.resultKeySet { result += kss.resultKeySet[idx].Len() @@ -62,7 +58,7 @@ func (kss *keySetShard) Len() int { return result } -func (kss *keySetShard) Clear() { +func (kss *keySetShard[K, V]) Clear() { for idx := range kss.resultKeySet { kss.resultKeySet[idx].Clear() } diff --git a/pkg/util/lfu/lfu_cache.go b/pkg/util/lfu/lfu_cache.go index d4299d19b45be..a38a06038e19d 100644 --- a/pkg/util/lfu/lfu_cache.go +++ b/pkg/util/lfu/lfu_cache.go @@ -19,29 +19,47 @@ import ( "sync/atomic" "github.com/dgraph-io/ristretto" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal" - "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal/metrics" "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/memory" + "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" "golang.org/x/exp/rand" ) // LFU is a LFU based on the ristretto.Cache -type LFU struct { +type LFU[k K, v V] struct { cache *ristretto.Cache // This is a secondary cache layer used to store all tables, // including those that have been evicted from the primary cache. - resultKeySet *keySetShard + resultKeySet *keySetShard[k, v] cost atomic.Int64 closed atomic.Bool closeOnce sync.Once + + // dropEvicted is to evict useless part of the table when to evict. + dropEvicted func(any) + + // missCounter is the counter of missing cache. + missCounter prometheus.Counter + // hitCounter is the counter of hitting cache. + hitCounter prometheus.Counter + // updateCounter is the counter of updating cache. + updateCounter prometheus.Counter + // delCounter is the counter of deleting cache. + delCounter prometheus.Counter + // evictCounter is the counter of evicting cache. + evictCounter prometheus.Counter + // rejectCounter is the counter of reject cache. + rejectCounter prometheus.Counter + // costGauge is the gauge of cost time. + costGauge prometheus.Gauge + // capacityGauge is the gauge of capacity. + capacityGauge prometheus.Gauge } // NewLFU creates a new LFU cache. -func NewLFU(totalMemCost int64) (*LFU, error) { +func NewLFU[k K, v V](totalMemCost int64, dropEvicted func(any), capacityGauge prometheus.Gauge) (*LFU[k, v], error) { cost, err := adjustMemCost(totalMemCost) if err != nil { return nil, err @@ -50,8 +68,8 @@ func NewLFU(totalMemCost int64) (*LFU, error) { // In test, we set the cost to 5MB to avoid using too many memory in the LFU's CM sketch. cost = 5000000 } - metrics.CapacityGauge.Set(float64(cost)) - result := &LFU{} + capacityGauge.Set(float64(cost)) + result := &LFU[k, v]{} bufferItems := int64(64) cache, err := ristretto.NewCache( @@ -70,7 +88,9 @@ func NewLFU(totalMemCost int64) (*LFU, error) { return nil, err } result.cache = cache - result.resultKeySet = newKeySetShard() + result.dropEvicted = dropEvicted + result.capacityGauge = capacityGauge + result.resultKeySet = newKeySetShard[k, v]() return result, err } @@ -88,36 +108,36 @@ func adjustMemCost(totalMemCost int64) (result int64, err error) { } // Get implements statsCacheInner -func (s *LFU) Get(tid int64) (*statistics.Table, bool) { +func (s *LFU[K, V]) Get(tid K) (V, bool) { result, ok := s.cache.Get(tid) if !ok { return s.resultKeySet.Get(tid) } - return result.(*statistics.Table), ok + return result.(V), ok } // Put implements statsCacheInner -func (s *LFU) Put(tblID int64, tbl *statistics.Table) bool { - cost := tbl.MemoryUsage().TotalTrackingMemUsage() +func (s *LFU[K, V]) Put(tblID K, tbl V) bool { + cost := tbl.TotalTrackingMemUsage() s.resultKeySet.AddKeyValue(tblID, tbl) s.addCost(cost) return s.cache.Set(tblID, tbl, cost) } // Del implements statsCacheInner -func (s *LFU) Del(tblID int64) { +func (s *LFU[K, V]) Del(tblID K) { s.cache.Del(tblID) s.resultKeySet.Remove(tblID) } // Cost implements statsCacheInner -func (s *LFU) Cost() int64 { +func (s *LFU[K, V]) Cost() int64 { return s.cost.Load() } // Values implements statsCacheInner -func (s *LFU) Values() []*statistics.Table { - result := make([]*statistics.Table, 0, 512) +func (s *LFU[K, V]) Values() []V { + result := make([]V, 0, 512) for _, k := range s.resultKeySet.Keys() { if value, ok := s.resultKeySet.Get(k); ok { result = append(result, value) @@ -126,36 +146,27 @@ func (s *LFU) Values() []*statistics.Table { return result } -// DropEvicted drop stats for table column/index -func DropEvicted(item statistics.TableCacheItem) { - if !item.IsStatsInitialized() || - item.GetEvictedStatus() == statistics.AllEvicted { - return - } - item.DropUnnecessaryData() -} - -func (s *LFU) onReject(item *ristretto.Item) { +func (s *LFU[K, V]) onReject(item *ristretto.Item) { defer func() { if r := recover(); r != nil { logutil.BgLogger().Warn("panic in onReject", zap.Any("error", r), zap.Stack("stack")) } }() s.dropMemory(item) - metrics.RejectCounter.Inc() + s.rejectCounter.Inc() } -func (s *LFU) onEvict(item *ristretto.Item) { +func (s *LFU[K, V]) onEvict(item *ristretto.Item) { defer func() { if r := recover(); r != nil { logutil.BgLogger().Warn("panic in onEvict", zap.Any("error", r), zap.Stack("stack")) } }() s.dropMemory(item) - metrics.EvictCounter.Inc() + s.evictCounter.Inc() } -func (s *LFU) dropMemory(item *ristretto.Item) { +func (s *LFU[K, V]) dropMemory(item *ristretto.Item) { if item.Value == nil { // Sometimes the same key may be passed to the "onEvict/onExit" // function twice, and in the second invocation, the value is empty, @@ -168,22 +179,17 @@ func (s *LFU) dropMemory(item *ristretto.Item) { // We do not need to calculate the cost during onEvict, // because the onexit function is also called when the evict event occurs. // TODO(hawkingrei): not copy the useless part. - table := item.Value.(*statistics.Table).Copy() - for _, column := range table.Columns { - DropEvicted(column) - } - for _, indix := range table.Indices { - DropEvicted(indix) - } - s.resultKeySet.AddKeyValue(int64(item.Key), table) - after := table.MemoryUsage().TotalTrackingMemUsage() + table := item.Value.(V).Copy().(V) + s.dropEvicted(table) + s.resultKeySet.AddKeyValue(K(item.Key), table) + after := table.TotalTrackingMemUsage() // why add before again? because the cost will be subtracted in onExit. // in fact, it is after - before s.addCost(after) s.triggerEvict() } -func (s *LFU) triggerEvict() { +func (s *LFU[K, V]) triggerEvict() { // When the memory usage of the cache exceeds the maximum value, Many item need to evict. But // ristretto'c cache execute the evict operation when to write the cache. for we can evict as soon as possible, // we will write some fake item to the cache. fake item have a negative key, and the value is nil. @@ -193,7 +199,7 @@ func (s *LFU) triggerEvict() { } } -func (s *LFU) onExit(val any) { +func (s *LFU[K, V]) onExit(val any) { defer func() { if r := recover(); r != nil { logutil.BgLogger().Warn("panic in onExit", zap.Any("error", r), zap.Stack("stack")) @@ -208,21 +214,21 @@ func (s *LFU) onExit(val any) { return } // Subtract the memory usage of the table from the total memory usage. - s.addCost(-val.(*statistics.Table).MemoryUsage().TotalTrackingMemUsage()) + s.addCost(-val.(V).TotalTrackingMemUsage()) } // Len implements statsCacheInner -func (s *LFU) Len() int { +func (s *LFU[K, V]) Len() int { return s.resultKeySet.Len() } // Copy implements statsCacheInner -func (s *LFU) Copy() internal.StatsCacheInner { +func (s *LFU[K, V]) Copy() *LFU[K, V] { return s } // SetCapacity implements statsCacheInner -func (s *LFU) SetCapacity(maxCost int64) { +func (s *LFU[K, V]) SetCapacity(maxCost int64) { cost, err := adjustMemCost(maxCost) if err != nil { logutil.BgLogger().Warn("adjustMemCost failed", zap.Error(err)) @@ -230,22 +236,22 @@ func (s *LFU) SetCapacity(maxCost int64) { } s.cache.UpdateMaxCost(cost) s.triggerEvict() - metrics.CapacityGauge.Set(float64(cost)) - metrics.CostGauge.Set(float64(s.Cost())) + s.capacityGauge.Set(float64(cost)) + s.costGauge.Set(float64(s.Cost())) } // wait blocks until all buffered writes have been applied. This ensures a call to Set() // will be visible to future calls to Get(). it is only used for test. -func (s *LFU) wait() { +func (s *LFU[K, V]) wait() { s.cache.Wait() } -func (s *LFU) metrics() *ristretto.Metrics { +func (s *LFU[K, V]) metrics() *ristretto.Metrics { return s.cache.Metrics } // Close implements statsCacheInner -func (s *LFU) Close() { +func (s *LFU[K, V]) Close() { s.closeOnce.Do(func() { s.closed.Store(true) s.Clear() @@ -255,12 +261,47 @@ func (s *LFU) Close() { } // Clear implements statsCacheInner -func (s *LFU) Clear() { +func (s *LFU[K, V]) Clear() { s.cache.Clear() s.resultKeySet.Clear() } -func (s *LFU) addCost(v int64) { +func (s *LFU[K, V]) addCost(v int64) { newv := s.cost.Add(v) - metrics.CostGauge.Set(float64(newv)) + s.costGauge.Set(float64(newv)) +} + +// RegisterMissCounter register MissCounter +func (s *LFU[K, V]) RegisterMissCounter(c prometheus.Counter) { + s.missCounter = c +} + +// RegisterHitCounter register HitCounter +func (s *LFU[K, V]) RegisterHitCounter(c prometheus.Counter) { + s.hitCounter = c +} + +// RegisterUpdateCounter register UpdateCounter +func (s *LFU[K, V]) RegisterUpdateCounter(c prometheus.Counter) { + s.updateCounter = c +} + +// RegisterDelCounter register DelCounter +func (s *LFU[K, V]) RegisterDelCounter(c prometheus.Counter) { + s.delCounter = c +} + +// RegisterEvictCounter register EvictCounter +func (s *LFU[K, V]) RegisterEvictCounter(c prometheus.Counter) { + s.evictCounter = c +} + +// RegisterRejectCounter register RejectCounter +func (s *LFU[K, V]) RegisterRejectCounter(c prometheus.Counter) { + s.rejectCounter = c +} + +// RegisterCostGauge register CostGauge +func (s *LFU[K, V]) RegisterCostGauge(g prometheus.Gauge) { + s.costGauge = g } diff --git a/pkg/util/lfu/lfu_cache_test.go b/pkg/util/lfu/lfu_cache_test.go deleted file mode 100644 index de690b8e06ca9..0000000000000 --- a/pkg/util/lfu/lfu_cache_test.go +++ /dev/null @@ -1,297 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package lfu - -import ( - "math/rand" - "sync" - "testing" - "time" - - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal/testutil" - "github.com/stretchr/testify/require" -) - -var ( - mockCMSMemoryUsage = int64(4) -) - -func TestLFUPutGetDel(t *testing.T) { - capacity := int64(100) - lfu, err := NewLFU(capacity) - require.NoError(t, err) - mockTable := testutil.NewMockStatisticsTable(1, 1, true, false, false) - mockTableID := int64(1) - lfu.Put(mockTableID, mockTable) - lfu.wait() - lfu.Del(mockTableID) - v, ok := lfu.Get(mockTableID) - require.False(t, ok) - require.Nil(t, v) - lfu.wait() - require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) - require.Equal(t, 0, len(lfu.Values())) -} - -func TestLFUFreshMemUsage(t *testing.T) { - lfu, err := NewLFU(10000) - require.NoError(t, err) - t1 := testutil.NewMockStatisticsTable(1, 1, true, false, false) - require.Equal(t, mockCMSMemoryUsage+mockCMSMemoryUsage, t1.MemoryUsage().TotalMemUsage) - t2 := testutil.NewMockStatisticsTable(2, 2, true, false, false) - require.Equal(t, 2*mockCMSMemoryUsage+2*mockCMSMemoryUsage, t2.MemoryUsage().TotalMemUsage) - t3 := testutil.NewMockStatisticsTable(3, 3, true, false, false) - require.Equal(t, 3*mockCMSMemoryUsage+3*mockCMSMemoryUsage, t3.MemoryUsage().TotalMemUsage) - lfu.Put(int64(1), t1) - lfu.Put(int64(2), t2) - lfu.Put(int64(3), t3) - lfu.wait() - require.Equal(t, lfu.Cost(), 6*mockCMSMemoryUsage+6*mockCMSMemoryUsage) - t4 := testutil.NewMockStatisticsTable(2, 1, true, false, false) - lfu.Put(int64(1), t4) - lfu.wait() - require.Equal(t, lfu.Cost(), 7*mockCMSMemoryUsage+6*mockCMSMemoryUsage) - t5 := testutil.NewMockStatisticsTable(2, 2, true, false, false) - lfu.Put(int64(1), t5) - lfu.wait() - require.Equal(t, lfu.Cost(), 7*mockCMSMemoryUsage+7*mockCMSMemoryUsage) - - t6 := testutil.NewMockStatisticsTable(1, 2, true, false, false) - lfu.Put(int64(1), t6) - require.Equal(t, lfu.Cost(), 7*mockCMSMemoryUsage+6*mockCMSMemoryUsage) - - t7 := testutil.NewMockStatisticsTable(1, 1, true, false, false) - lfu.Put(int64(1), t7) - require.Equal(t, lfu.Cost(), 6*mockCMSMemoryUsage+6*mockCMSMemoryUsage) - lfu.wait() - require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) -} - -func TestLFUPutTooBig(t *testing.T) { - lfu, err := NewLFU(1) - require.NoError(t, err) - mockTable := testutil.NewMockStatisticsTable(1, 1, true, false, false) - // put mockTable, the index should be evicted but the table still exists in the list. - lfu.Put(int64(1), mockTable) - _, ok := lfu.Get(int64(1)) - require.True(t, ok) - lfu.wait() - require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) -} - -func TestCacheLen(t *testing.T) { - capacity := int64(12) - lfu, err := NewLFU(capacity) - require.NoError(t, err) - t1 := testutil.NewMockStatisticsTable(2, 1, true, false, false) - require.Equal(t, int64(12), t1.MemoryUsage().TotalTrackingMemUsage()) - lfu.Put(int64(1), t1) - t2 := testutil.NewMockStatisticsTable(1, 1, true, false, false) - // put t2, t1 should be evicted 2 items and still exists in the list - lfu.Put(int64(2), t2) - lfu.wait() - require.Equal(t, lfu.Len(), 2) - require.Equal(t, uint64(8), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) - - // put t3, t1/t2 should be evicted all items. but t1/t2 still exists in the list - t3 := testutil.NewMockStatisticsTable(2, 1, true, false, false) - lfu.Put(int64(3), t3) - lfu.wait() - require.Equal(t, lfu.Len(), 3) - require.Equal(t, uint64(12), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) -} - -func TestLFUCachePutGetWithManyConcurrency(t *testing.T) { - // to test DATA RACE - capacity := int64(100000000000) - lfu, err := NewLFU(capacity) - require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(2000) - for i := 0; i < 1000; i++ { - go func(i int) { - defer wg.Done() - t1 := testutil.NewMockStatisticsTable(1, 1, true, false, false) - lfu.Put(int64(i), t1) - }(i) - go func(i int) { - defer wg.Done() - lfu.Get(int64(i)) - }(i) - } - wg.Wait() - lfu.wait() - require.Equal(t, lfu.Len(), 1000) - require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) - require.Equal(t, 1000, len(lfu.Values())) -} - -func TestLFUCachePutGetWithManyConcurrency2(t *testing.T) { - // to test DATA RACE - capacity := int64(100000000000) - lfu, err := NewLFU(capacity) - require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(10) - for i := 0; i < 5; i++ { - go func() { - defer wg.Done() - for n := 0; n < 1000; n++ { - t1 := testutil.NewMockStatisticsTable(1, 1, true, false, false) - lfu.Put(int64(n), t1) - } - }() - } - for i := 0; i < 5; i++ { - go func() { - defer wg.Done() - for n := 0; n < 1000; n++ { - lfu.Get(int64(n)) - } - }() - } - wg.Wait() - lfu.wait() - require.Equal(t, uint64(lfu.Cost()), lfu.metrics().CostAdded()-lfu.metrics().CostEvicted()) - require.Equal(t, 1000, len(lfu.Values())) -} - -func TestLFUCachePutGetWithManyConcurrencyAndSmallConcurrency(t *testing.T) { - // to test DATA RACE - - capacity := int64(100) - lfu, err := NewLFU(capacity) - require.NoError(t, err) - var wg sync.WaitGroup - wg.Add(10) - for i := 0; i < 5; i++ { - go func() { - defer wg.Done() - for c := 0; c < 1000; c++ { - for n := 0; n < 50; n++ { - t1 := testutil.NewMockStatisticsTable(1, 1, true, true, true) - lfu.Put(int64(n), t1) - } - } - }() - } - time.Sleep(1 * time.Second) - for i := 0; i < 5; i++ { - go func() { - defer wg.Done() - for c := 0; c < 1000; c++ { - for n := 0; n < 50; n++ { - tbl, ok := lfu.Get(int64(n)) - require.True(t, ok) - checkTable(t, tbl) - } - } - }() - } - wg.Wait() - lfu.wait() - v, ok := lfu.Get(rand.Int63n(50)) - require.True(t, ok) - for _, c := range v.Columns { - require.Equal(t, c.GetEvictedStatus(), statistics.AllEvicted) - } - for _, i := range v.Indices { - require.Equal(t, i.GetEvictedStatus(), statistics.AllEvicted) - } -} - -func checkTable(t *testing.T, tbl *statistics.Table) { - for _, column := range tbl.Columns { - if column.GetEvictedStatus() == statistics.AllEvicted { - require.Nil(t, column.TopN) - require.Equal(t, 0, cap(column.Histogram.Buckets)) - } else { - require.NotNil(t, column.TopN) - require.Greater(t, cap(column.Histogram.Buckets), 0) - } - } - for _, idx := range tbl.Indices { - if idx.GetEvictedStatus() == statistics.AllEvicted { - require.Nil(t, idx.TopN) - require.Equal(t, 0, cap(idx.Histogram.Buckets)) - } else { - require.NotNil(t, idx.TopN) - require.Greater(t, cap(idx.Histogram.Buckets), 0) - } - } -} - -func TestLFUReject(t *testing.T) { - capacity := int64(100000000000) - lfu, err := NewLFU(capacity) - require.NoError(t, err) - t1 := testutil.NewMockStatisticsTable(2, 1, true, false, false) - require.Equal(t, 2*mockCMSMemoryUsage+mockCMSMemoryUsage, t1.MemoryUsage().TotalTrackingMemUsage()) - lfu.Put(1, t1) - lfu.wait() - require.Equal(t, lfu.Cost(), 2*mockCMSMemoryUsage+mockCMSMemoryUsage) - - lfu.SetCapacity(2*mockCMSMemoryUsage + mockCMSMemoryUsage - 1) - - t2 := testutil.NewMockStatisticsTable(2, 1, true, false, false) - require.True(t, lfu.Put(2, t2)) - lfu.wait() - time.Sleep(3 * time.Second) - require.Equal(t, int64(0), lfu.Cost()) - require.Len(t, lfu.Values(), 2) - v, ok := lfu.Get(2) - require.True(t, ok) - for _, c := range v.Columns { - require.Equal(t, statistics.AllEvicted, c.GetEvictedStatus()) - } - for _, i := range v.Indices { - require.Equal(t, statistics.AllEvicted, i.GetEvictedStatus()) - } -} - -func TestMemoryControl(t *testing.T) { - capacity := int64(100000000000) - lfu, err := NewLFU(capacity) - require.NoError(t, err) - t1 := testutil.NewMockStatisticsTable(2, 1, true, false, false) - require.Equal(t, 2*mockCMSMemoryUsage+mockCMSMemoryUsage, t1.MemoryUsage().TotalTrackingMemUsage()) - lfu.Put(1, t1) - lfu.wait() - - for i := 2; i <= 1000; i++ { - t1 := testutil.NewMockStatisticsTable(2, 1, true, false, false) - require.Equal(t, 2*mockCMSMemoryUsage+mockCMSMemoryUsage, t1.MemoryUsage().TotalTrackingMemUsage()) - lfu.Put(int64(i), t1) - } - require.Equal(t, 1000*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) - - for i := 1000; i > 990; i-- { - lfu.SetCapacity(int64(i-1) * (2*mockCMSMemoryUsage + mockCMSMemoryUsage)) - lfu.wait() - require.Equal(t, int64(i-1)*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) - } - for i := 990; i > 100; i = i - 100 { - lfu.SetCapacity(int64(i-1) * (2*mockCMSMemoryUsage + mockCMSMemoryUsage)) - lfu.wait() - require.Equal(t, int64(i-1)*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) - } - lfu.SetCapacity(int64(10) * (2*mockCMSMemoryUsage + mockCMSMemoryUsage)) - lfu.wait() - require.Equal(t, int64(10)*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) - lfu.SetCapacity(0) - lfu.wait() - require.Equal(t, int64(10)*(2*mockCMSMemoryUsage+mockCMSMemoryUsage), lfu.Cost()) -} From 6f5dee90f3cc7877367238a64938182a71b424a7 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 27 Feb 2024 23:35:37 +0800 Subject: [PATCH 3/5] test Signed-off-by: Weizhen Wang --- .../handle/cache/internal/lfu/BUILD.bazel | 13 +- .../handle/cache/internal/lfu/lfu_cache.go | 206 +++--------------- pkg/statistics/table.go | 13 ++ pkg/util/lfu/key_set.go | 8 +- pkg/util/lfu/lfu_cache.go | 6 +- 5 files changed, 55 insertions(+), 191 deletions(-) diff --git a/pkg/statistics/handle/cache/internal/lfu/BUILD.bazel b/pkg/statistics/handle/cache/internal/lfu/BUILD.bazel index 1e732998623ab..3786fa8a87779 100644 --- a/pkg/statistics/handle/cache/internal/lfu/BUILD.bazel +++ b/pkg/statistics/handle/cache/internal/lfu/BUILD.bazel @@ -2,24 +2,15 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "lfu", - srcs = [ - "key_set.go", - "key_set_shard.go", - "lfu_cache.go", - ], + srcs = ["lfu_cache.go"], importpath = "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal/lfu", visibility = ["//pkg/statistics/handle/cache:__subpackages__"], deps = [ "//pkg/statistics", "//pkg/statistics/handle/cache/internal", "//pkg/statistics/handle/cache/internal/metrics", - "//pkg/util/intest", - "//pkg/util/logutil", - "//pkg/util/memory", + "//pkg/util/lfu", "@com_github_dgraph_io_ristretto//:ristretto", - "@org_golang_x_exp//maps", - "@org_golang_x_exp//rand", - "@org_uber_go_zap//:zap", ], ) diff --git a/pkg/statistics/handle/cache/internal/lfu/lfu_cache.go b/pkg/statistics/handle/cache/internal/lfu/lfu_cache.go index d4299d19b45be..202b91c2d5e67 100644 --- a/pkg/statistics/handle/cache/internal/lfu/lfu_cache.go +++ b/pkg/statistics/handle/cache/internal/lfu/lfu_cache.go @@ -15,223 +15,96 @@ package lfu import ( - "sync" - "sync/atomic" - "github.com/dgraph-io/ristretto" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal" "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal/metrics" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" - "golang.org/x/exp/rand" + "github.com/pingcap/tidb/pkg/util/lfu" ) // LFU is a LFU based on the ristretto.Cache type LFU struct { - cache *ristretto.Cache - // This is a secondary cache layer used to store all tables, - // including those that have been evicted from the primary cache. - resultKeySet *keySetShard - cost atomic.Int64 - closed atomic.Bool - closeOnce sync.Once + cache *lfu.LFU[int64, *statistics.Table] } // NewLFU creates a new LFU cache. func NewLFU(totalMemCost int64) (*LFU, error) { - cost, err := adjustMemCost(totalMemCost) - if err != nil { - return nil, err - } - if intest.InTest && totalMemCost == 0 { - // In test, we set the cost to 5MB to avoid using too many memory in the LFU's CM sketch. - cost = 5000000 - } - metrics.CapacityGauge.Set(float64(cost)) - result := &LFU{} - bufferItems := int64(64) - - cache, err := ristretto.NewCache( - &ristretto.Config{ - NumCounters: max(min(cost/128, 1_000_000), 10), // assume the cost per table stats is 128 - MaxCost: cost, - BufferItems: bufferItems, - OnEvict: result.onEvict, - OnExit: result.onExit, - OnReject: result.onReject, - IgnoreInternalCost: intest.InTest, - Metrics: intest.InTest, - }, - ) + cache, err := lfu.NewLFU[int64, *statistics.Table](totalMemCost, DropEvicted, metrics.CapacityGauge) if err != nil { return nil, err } - result.cache = cache - result.resultKeySet = newKeySetShard() - return result, err -} - -// adjustMemCost adjusts the memory cost according to the total memory cost. -// When the total memory cost is 0, the memory cost is set to half of the total memory. -func adjustMemCost(totalMemCost int64) (result int64, err error) { - if totalMemCost == 0 { - memTotal, err := memory.MemTotal() - if err != nil { - return 0, err - } - return int64(memTotal / 2), nil - } - return totalMemCost, nil + cache.RegisterMissCounter(metrics.MissCounter) + cache.RegisterHitCounter(metrics.HitCounter) + cache.RegisterUpdateCounter(metrics.UpdateCounter) + cache.RegisterDelCounter(metrics.DelCounter) + cache.RegisterEvictCounter(metrics.EvictCounter) + cache.RegisterRejectCounter(metrics.RejectCounter) + cache.RegisterCostGauge(metrics.CostGauge) + return &LFU{ + cache: cache, + }, nil } // Get implements statsCacheInner func (s *LFU) Get(tid int64) (*statistics.Table, bool) { - result, ok := s.cache.Get(tid) - if !ok { - return s.resultKeySet.Get(tid) - } - return result.(*statistics.Table), ok + return s.cache.Get(tid) } // Put implements statsCacheInner func (s *LFU) Put(tblID int64, tbl *statistics.Table) bool { - cost := tbl.MemoryUsage().TotalTrackingMemUsage() - s.resultKeySet.AddKeyValue(tblID, tbl) - s.addCost(cost) - return s.cache.Set(tblID, tbl, cost) + return s.cache.Put(tblID, tbl) } // Del implements statsCacheInner func (s *LFU) Del(tblID int64) { s.cache.Del(tblID) - s.resultKeySet.Remove(tblID) } // Cost implements statsCacheInner func (s *LFU) Cost() int64 { - return s.cost.Load() + return s.cache.Cost() } // Values implements statsCacheInner func (s *LFU) Values() []*statistics.Table { - result := make([]*statistics.Table, 0, 512) - for _, k := range s.resultKeySet.Keys() { - if value, ok := s.resultKeySet.Get(k); ok { - result = append(result, value) - } - } - return result + return s.Values() } // DropEvicted drop stats for table column/index -func DropEvicted(item statistics.TableCacheItem) { - if !item.IsStatsInitialized() || - item.GetEvictedStatus() == statistics.AllEvicted { - return - } - item.DropUnnecessaryData() -} - -func (s *LFU) onReject(item *ristretto.Item) { - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Warn("panic in onReject", zap.Any("error", r), zap.Stack("stack")) - } - }() - s.dropMemory(item) - metrics.RejectCounter.Inc() -} - -func (s *LFU) onEvict(item *ristretto.Item) { - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Warn("panic in onEvict", zap.Any("error", r), zap.Stack("stack")) - } - }() - s.dropMemory(item) - metrics.EvictCounter.Inc() -} - -func (s *LFU) dropMemory(item *ristretto.Item) { - if item.Value == nil { - // Sometimes the same key may be passed to the "onEvict/onExit" - // function twice, and in the second invocation, the value is empty, - // so it should not be processed. - return - } - if s.closed.Load() { - return - } - // We do not need to calculate the cost during onEvict, - // because the onexit function is also called when the evict event occurs. - // TODO(hawkingrei): not copy the useless part. - table := item.Value.(*statistics.Table).Copy() - for _, column := range table.Columns { - DropEvicted(column) +func DropEvicted(table any) { + t := table.(*statistics.Table) + for _, column := range t.Columns { + dropEvicted(column) } - for _, indix := range table.Indices { - DropEvicted(indix) + for _, indix := range t.Indices { + dropEvicted(indix) } - s.resultKeySet.AddKeyValue(int64(item.Key), table) - after := table.MemoryUsage().TotalTrackingMemUsage() - // why add before again? because the cost will be subtracted in onExit. - // in fact, it is after - before - s.addCost(after) - s.triggerEvict() -} -func (s *LFU) triggerEvict() { - // When the memory usage of the cache exceeds the maximum value, Many item need to evict. But - // ristretto'c cache execute the evict operation when to write the cache. for we can evict as soon as possible, - // we will write some fake item to the cache. fake item have a negative key, and the value is nil. - if s.Cost() > s.cache.MaxCost() { - //nolint: gosec - s.cache.Set(-rand.Int(), nil, 0) - } } -func (s *LFU) onExit(val any) { - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Warn("panic in onExit", zap.Any("error", r), zap.Stack("stack")) - } - }() - if val == nil { - // Sometimes the same key may be passed to the "onEvict/onExit" function twice, - // and in the second invocation, the value is empty, so it should not be processed. - return - } - if s.closed.Load() { +// dropEvicted drop stats for table column/index +func dropEvicted(item statistics.TableCacheItem) { + if !item.IsStatsInitialized() || + item.GetEvictedStatus() == statistics.AllEvicted { return } - // Subtract the memory usage of the table from the total memory usage. - s.addCost(-val.(*statistics.Table).MemoryUsage().TotalTrackingMemUsage()) + item.DropUnnecessaryData() } // Len implements statsCacheInner func (s *LFU) Len() int { - return s.resultKeySet.Len() + return s.cache.Len() } // Copy implements statsCacheInner func (s *LFU) Copy() internal.StatsCacheInner { - return s + cache := s.cache.Copy() + return &LFU{cache: cache} } // SetCapacity implements statsCacheInner func (s *LFU) SetCapacity(maxCost int64) { - cost, err := adjustMemCost(maxCost) - if err != nil { - logutil.BgLogger().Warn("adjustMemCost failed", zap.Error(err)) - return - } - s.cache.UpdateMaxCost(cost) - s.triggerEvict() - metrics.CapacityGauge.Set(float64(cost)) - metrics.CostGauge.Set(float64(s.Cost())) + s.cache.SetCapacity(maxCost) } // wait blocks until all buffered writes have been applied. This ensures a call to Set() @@ -241,26 +114,15 @@ func (s *LFU) wait() { } func (s *LFU) metrics() *ristretto.Metrics { - return s.cache.Metrics + return s.cache.Metrics() } // Close implements statsCacheInner func (s *LFU) Close() { - s.closeOnce.Do(func() { - s.closed.Store(true) - s.Clear() - s.cache.Close() - s.cache.Wait() - }) + s.cache.Close() } // Clear implements statsCacheInner func (s *LFU) Clear() { s.cache.Clear() - s.resultKeySet.Clear() -} - -func (s *LFU) addCost(v int64) { - newv := s.cost.Add(v) - metrics.CostGauge.Set(float64(newv)) } diff --git a/pkg/statistics/table.go b/pkg/statistics/table.go index b1e62cbfa10ce..4864bc53810a7 100644 --- a/pkg/statistics/table.go +++ b/pkg/statistics/table.go @@ -70,6 +70,11 @@ type Table struct { TblInfoUpdateTS uint64 } +// DeepCopy implements the interface of LFU' key. +func (t *Table) DeepCopy() any { + return t.Copy() +} + // ExtendedStatsItem is the cached item of a mysql.stats_extended record. type ExtendedStatsItem struct { StringVals string @@ -277,6 +282,14 @@ func (t *Table) MemoryUsage() *TableMemoryUsage { return tMemUsage } +// TotalTrackingMemUsage return Total Tracking Mem Usage +func (t *Table) TotalTrackingMemUsage() int64 { + if t == nil { + return 0 + } + return t.MemoryUsage().TotalTrackingMemUsage() +} + // Copy copies the current table. func (t *Table) Copy() *Table { newHistColl := HistColl{ diff --git a/pkg/util/lfu/key_set.go b/pkg/util/lfu/key_set.go index 85c5454734462..35bc3ce1569c6 100644 --- a/pkg/util/lfu/key_set.go +++ b/pkg/util/lfu/key_set.go @@ -24,8 +24,7 @@ type K interface { ~uint64 | ~string | ~int | ~int32 | ~uint32 | ~int64 } type V interface { - comparable - Copy() any + DeepCopy() any TotalTrackingMemUsage() int64 } @@ -38,9 +37,8 @@ func (ks *keySet[K, V]) Remove(key K) int64 { var cost int64 ks.mu.Lock() if table, ok := ks.set[key]; ok { - if table != nil { - cost = table.TotalTrackingMemUsage() - } + // if table is nil, it still return 0. + cost = table.TotalTrackingMemUsage() delete(ks.set, key) } ks.mu.Unlock() diff --git a/pkg/util/lfu/lfu_cache.go b/pkg/util/lfu/lfu_cache.go index a38a06038e19d..de3c30ff4baa4 100644 --- a/pkg/util/lfu/lfu_cache.go +++ b/pkg/util/lfu/lfu_cache.go @@ -179,7 +179,7 @@ func (s *LFU[K, V]) dropMemory(item *ristretto.Item) { // We do not need to calculate the cost during onEvict, // because the onexit function is also called when the evict event occurs. // TODO(hawkingrei): not copy the useless part. - table := item.Value.(V).Copy().(V) + table := item.Value.(V).DeepCopy().(V) s.dropEvicted(table) s.resultKeySet.AddKeyValue(K(item.Key), table) after := table.TotalTrackingMemUsage() @@ -242,11 +242,11 @@ func (s *LFU[K, V]) SetCapacity(maxCost int64) { // wait blocks until all buffered writes have been applied. This ensures a call to Set() // will be visible to future calls to Get(). it is only used for test. -func (s *LFU[K, V]) wait() { +func (s *LFU[K, V]) Wait() { s.cache.Wait() } -func (s *LFU[K, V]) metrics() *ristretto.Metrics { +func (s *LFU[K, V]) Metrics() *ristretto.Metrics { return s.cache.Metrics } From 1bc5c6da8f9bbf02e6b22d8e90affc985788b8d5 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 27 Feb 2024 23:44:50 +0800 Subject: [PATCH 4/5] test Signed-off-by: Weizhen Wang --- pkg/util/lfu/key.go | 2 +- pkg/util/lfu/key_set.go | 9 ++++++--- pkg/util/lfu/key_set_shard.go | 10 +++++----- pkg/util/lfu/lfu_cache.go | 7 ++++--- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/pkg/util/lfu/key.go b/pkg/util/lfu/key.go index f2617c8494929..aee420b083e6d 100644 --- a/pkg/util/lfu/key.go +++ b/pkg/util/lfu/key.go @@ -16,7 +16,7 @@ package lfu import "github.com/cespare/xxhash/v2" -func KeyToHash(key interface{}) uint64 { +func keyToHash(key interface{}) uint64 { if key == nil { return 0 } diff --git a/pkg/util/lfu/key_set.go b/pkg/util/lfu/key_set.go index 35bc3ce1569c6..c8a7def59c8d6 100644 --- a/pkg/util/lfu/key_set.go +++ b/pkg/util/lfu/key_set.go @@ -20,15 +20,18 @@ import ( "golang.org/x/exp/maps" ) -type K interface { +// Key is the key of cache +type Key interface { ~uint64 | ~string | ~int | ~int32 | ~uint32 | ~int64 } -type V interface { + +// Value is the value of cache +type Value interface { DeepCopy() any TotalTrackingMemUsage() int64 } -type keySet[k K, v V] struct { +type keySet[k Key, v Value] struct { set map[k]v mu sync.RWMutex } diff --git a/pkg/util/lfu/key_set_shard.go b/pkg/util/lfu/key_set_shard.go index e6341e26da9f3..85e26ceaec889 100644 --- a/pkg/util/lfu/key_set_shard.go +++ b/pkg/util/lfu/key_set_shard.go @@ -16,11 +16,11 @@ package lfu const keySetCnt = 256 -type keySetShard[k K, v V] struct { +type keySetShard[k Key, v Value] struct { resultKeySet [keySetCnt]keySet[k, v] } -func newKeySetShard[k K, v V]() *keySetShard[k, v] { +func newKeySetShard[k Key, v Value]() *keySetShard[k, v] { result := keySetShard[k, v]{} for i := 0; i < keySetCnt; i++ { result.resultKeySet[i] = keySet[k, v]{ @@ -31,15 +31,15 @@ func newKeySetShard[k K, v V]() *keySetShard[k, v] { } func (kss *keySetShard[K, V]) Get(key K) (V, bool) { - return kss.resultKeySet[KeyToHash(key)%keySetCnt].Get(key) + return kss.resultKeySet[keyToHash(key)%keySetCnt].Get(key) } func (kss *keySetShard[K, V]) AddKeyValue(key K, table V) { - kss.resultKeySet[KeyToHash(key)%keySetCnt].AddKeyValue(key, table) + kss.resultKeySet[keyToHash(key)%keySetCnt].AddKeyValue(key, table) } func (kss *keySetShard[K, V]) Remove(key K) { - kss.resultKeySet[KeyToHash(key)%keySetCnt].Remove(key) + kss.resultKeySet[keyToHash(key)%keySetCnt].Remove(key) } func (kss *keySetShard[K, V]) Keys() []K { diff --git a/pkg/util/lfu/lfu_cache.go b/pkg/util/lfu/lfu_cache.go index de3c30ff4baa4..44a57928cd1be 100644 --- a/pkg/util/lfu/lfu_cache.go +++ b/pkg/util/lfu/lfu_cache.go @@ -28,7 +28,7 @@ import ( ) // LFU is a LFU based on the ristretto.Cache -type LFU[k K, v V] struct { +type LFU[k Key, v Value] struct { cache *ristretto.Cache // This is a secondary cache layer used to store all tables, // including those that have been evicted from the primary cache. @@ -59,7 +59,7 @@ type LFU[k K, v V] struct { } // NewLFU creates a new LFU cache. -func NewLFU[k K, v V](totalMemCost int64, dropEvicted func(any), capacityGauge prometheus.Gauge) (*LFU[k, v], error) { +func NewLFU[k Key, v Value](totalMemCost int64, dropEvicted func(any), capacityGauge prometheus.Gauge) (*LFU[k, v], error) { cost, err := adjustMemCost(totalMemCost) if err != nil { return nil, err @@ -240,12 +240,13 @@ func (s *LFU[K, V]) SetCapacity(maxCost int64) { s.costGauge.Set(float64(s.Cost())) } -// wait blocks until all buffered writes have been applied. This ensures a call to Set() +// Wait blocks until all buffered writes have been applied. This ensures a call to Set() // will be visible to future calls to Get(). it is only used for test. func (s *LFU[K, V]) Wait() { s.cache.Wait() } +// Metrics is to get metrics. It is only used for test. func (s *LFU[K, V]) Metrics() *ristretto.Metrics { return s.cache.Metrics } From 23a165936ac3c54c4b9162eeb9f155d90bb7b0d1 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Wed, 28 Feb 2024 10:33:11 +0800 Subject: [PATCH 5/5] test Signed-off-by: Weizhen Wang --- pkg/util/lfu/key.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/util/lfu/key.go b/pkg/util/lfu/key.go index aee420b083e6d..bf2dc0971a805 100644 --- a/pkg/util/lfu/key.go +++ b/pkg/util/lfu/key.go @@ -16,7 +16,7 @@ package lfu import "github.com/cespare/xxhash/v2" -func keyToHash(key interface{}) uint64 { +func keyToHash(key any) uint64 { if key == nil { return 0 }