Skip to content

Commit

Permalink
executor: add some memory tracker in HashJoin (#33918)
Browse files Browse the repository at this point in the history
ref #33877
  • Loading branch information
wshwsh12 authored Apr 19, 2022
1 parent 4d06631 commit d63a5fd
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 29 deletions.
17 changes: 5 additions & 12 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ type baseHashAggWorker struct {
BInMap int // indicate there are 2^BInMap buckets in Golang Map.
}

const (
// ref https://github.com/golang/go/blob/go1.15.6/src/reflect/type.go#L2162.
// defBucketMemoryUsage = bucketSize*(1+unsafe.Sizeof(string) + unsafe.Sizeof(slice))+2*ptrSize
// The bucket size may be changed by golang implement in the future.
defBucketMemoryUsage = 8*(1+16+24) + 16
)

func newBaseHashAggWorker(ctx sessionctx.Context, finishCh <-chan struct{}, aggFuncs []aggfuncs.AggFunc,
maxChunkSize int, memTrack *memory.Tracker) baseHashAggWorker {
baseWorker := baseHashAggWorker{
Expand Down Expand Up @@ -332,7 +325,7 @@ func (e *HashAggExec) initForUnparallelExec() {
e.partialResultMap = make(aggPartialResultMapper)
e.bInMap = 0
failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(defBucketMemoryUsage*(1<<e.bInMap) + setSize)
e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice*(1<<e.bInMap) + setSize)
e.groupKeyBuffer = make([][]byte, 0, 8)
e.childResult = newFirstChunk(e.children[0])
e.memTracker.Consume(e.childResult.MemoryUsage())
Expand Down Expand Up @@ -395,7 +388,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
}
// There is a bucket in the empty partialResultsMap.
failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(defBucketMemoryUsage * (1 << w.BInMap))
e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice * (1 << w.BInMap))
if e.stats != nil {
w.stats = &AggWorkerStat{}
e.stats.PartialStats = append(e.stats.PartialStats, w.stats)
Expand Down Expand Up @@ -425,7 +418,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
groupKeys: make([][]byte, 0, 8),
}
// There is a bucket in the empty partialResultsMap.
e.memTracker.Consume(defBucketMemoryUsage*(1<<w.BInMap) + setSize)
e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice*(1<<w.BInMap) + setSize)
if e.stats != nil {
w.stats = &AggWorkerStat{}
e.stats.FinalStats = append(e.stats.FinalStats, w.stats)
Expand Down Expand Up @@ -615,7 +608,7 @@ func (w *baseHashAggWorker) getPartialResult(sc *stmtctx.StatementContext, group
allMemDelta += int64(len(groupKey[i]))
// Map will expand when count > bucketNum * loadFactor. The memory usage will doubled.
if len(mapper) > (1<<w.BInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
w.memTracker.Consume(defBucketMemoryUsage * (1 << w.BInMap))
w.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice * (1 << w.BInMap))
w.BInMap++
}
}
Expand Down Expand Up @@ -1084,7 +1077,7 @@ func (e *HashAggExec) getPartialResults(groupKey string) []aggfuncs.PartialResul
allMemDelta += int64(len(groupKey))
// Map will expand when count > bucketNum * loadFactor. The memory usage will doubled.
if len(e.partialResultMap) > (1<<e.bInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
e.memTracker.Consume(defBucketMemoryUsage * (1 << e.bInMap))
e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice * (1 << e.bInMap))
e.bInMap++
}
}
Expand Down
22 changes: 13 additions & 9 deletions executor/concurrent_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package executor

import (
"sync"

"github.com/pingcap/tidb/util/hack"
)

// ShardCount controls the shard maps within the concurrent map
Expand All @@ -28,14 +30,15 @@ type concurrentMap []*concurrentMapShared
// A "thread" safe string to anything map.
type concurrentMapShared struct {
items map[uint64]*entry
sync.RWMutex // Read Write mutex, guards access to internal map.
sync.RWMutex // Read Write mutex, guards access to internal map.
bInMap int64 // indicate there are 2^bInMap buckets in items
}

// newConcurrentMap creates a new concurrent map.
func newConcurrentMap() concurrentMap {
m := make(concurrentMap, ShardCount)
for i := 0; i < ShardCount; i++ {
m[i] = &concurrentMapShared{items: make(map[uint64]*entry)}
m[i] = &concurrentMapShared{items: make(map[uint64]*entry), bInMap: 0}
}
return m
}
Expand All @@ -46,17 +49,18 @@ func (m concurrentMap) getShard(hashKey uint64) *concurrentMapShared {
}

// Insert inserts a value in a shard safely
func (m concurrentMap) Insert(key uint64, value *entry) {
func (m concurrentMap) Insert(key uint64, value *entry) (memDelta int64) {
shard := m.getShard(key)
shard.Lock()
v, ok := shard.items[key]
if !ok {
shard.items[key] = value
} else {
value.next = v
shard.items[key] = value
oldValue := shard.items[key]
value.next = oldValue
shard.items[key] = value
if len(shard.items) > (1<<shard.bInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
memDelta = hack.DefBucketMemoryUsageForMapIntToPtr * (1 << shard.bInMap)
shard.bInMap++
}
shard.Unlock()
return memDelta
}

// UpsertCb : Callback to return new element to be inserted into the map
Expand Down
35 changes: 35 additions & 0 deletions executor/concurrent_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ package executor

import (
"sync"
"sync/atomic"
"testing"

"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -65,3 +67,36 @@ func TestConcurrentMap(t *testing.T) {
_, ok = m.Get(uint64(mod + 1))
require.False(t, ok)
}

func TestConcurrentMapMemoryUsage(t *testing.T) {
m := newConcurrentMap()
const iterations = 1024 * hack.LoadFactorNum / hack.LoadFactorDen
var memUsage int64
wg := &sync.WaitGroup{}
wg.Add(2)
// Using go routines insert 1000 entries into the map.
go func() {
defer wg.Done()
var memDelta int64
for i := 0; i < iterations/2; i++ {
// Add entry to map.
memDelta += m.Insert(uint64(i*ShardCount), &entry{chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(i)}, nil})
}
atomic.AddInt64(&memUsage, memDelta)
}()

go func() {
defer wg.Done()
var memDelta int64
for i := iterations / 2; i < iterations; i++ {
// Add entry to map.
memDelta += m.Insert(uint64(i*ShardCount), &entry{chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(i)}, nil})
}
atomic.AddInt64(&memUsage, memDelta)
}()
wg.Wait()

// The first bucket memory usage will be recorded in concurrentMapHashTable, here only test the memory delta.
require.Equal(t, int64(1023)*hack.DefBucketMemoryUsageForMapIntToPtr, memUsage)
require.Equal(t, int64(10), m.getShard(0).bInMap)
}
52 changes: 47 additions & 5 deletions executor/hash_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"hash/fnv"
"sync/atomic"
"time"
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/sessionctx"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/disk"
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/memory"
)

Expand Down Expand Up @@ -83,6 +85,7 @@ type hashRowContainer struct {
hashTable baseHashTable

rowContainer *chunk.RowContainer
memTracker *memory.Tracker
}

func newHashRowContainer(sCtx sessionctx.Context, estCount int, hCtx *hashContext, allTypes []*types.FieldType) *hashRowContainer {
Expand All @@ -94,7 +97,9 @@ func newHashRowContainer(sCtx sessionctx.Context, estCount int, hCtx *hashContex
stat: new(hashStatistic),
hashTable: newConcurrentMapHashTable(),
rowContainer: rc,
memTracker: memory.NewTracker(memory.LabelForRowContainer, -1),
}
rc.GetMemTracker().AttachTo(c.GetMemTracker())
return c
}

Expand Down Expand Up @@ -186,6 +191,7 @@ func (c *hashRowContainer) PutChunkSelected(chk *chunk.Chunk, selected, ignoreNu
rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(i)}
c.hashTable.Put(key, rowPtr)
}
c.GetMemTracker().Consume(c.hashTable.GetAndCleanMemoryDelta())
return nil
}

Expand Down Expand Up @@ -219,7 +225,7 @@ func (c *hashRowContainer) Close() error {
}

// GetMemTracker returns the underlying memory usage tracker in hashRowContainer.
func (c *hashRowContainer) GetMemTracker() *memory.Tracker { return c.rowContainer.GetMemTracker() }
func (c *hashRowContainer) GetMemTracker() *memory.Tracker { return c.memTracker }

// GetDiskTracker returns the underlying disk usage tracker in hashRowContainer.
func (c *hashRowContainer) GetDiskTracker() *disk.Tracker { return c.rowContainer.GetDiskTracker() }
Expand Down Expand Up @@ -251,7 +257,7 @@ func newEntryStore() *entryStore {
return es
}

func (es *entryStore) GetStore() (e *entry) {
func (es *entryStore) GetStore() (e *entry, memDelta int64) {
sliceIdx := uint32(len(es.slices) - 1)
slice := es.slices[sliceIdx]
if es.cursor >= cap(slice) {
Expand All @@ -263,6 +269,7 @@ func (es *entryStore) GetStore() (e *entry) {
es.slices = append(es.slices, slice)
sliceIdx++
es.cursor = 0
memDelta = int64(unsafe.Sizeof(entry{})) * int64(size)
}
e = &es.slices[sliceIdx][es.cursor]
es.cursor++
Expand All @@ -273,6 +280,9 @@ type baseHashTable interface {
Put(hashKey uint64, rowPtr chunk.RowPtr)
Get(hashKey uint64) (rowPtrs []chunk.RowPtr)
Len() uint64
// GetAndCleanMemoryDelta gets and cleans the memDelta of the baseHashTable. Memory delta will be cleared after each fetch.
// It indicates the memory delta of the baseHashTable since the last calling GetAndCleanMemoryDelta().
GetAndCleanMemoryDelta() int64
}

// TODO (fangzhuhe) remove unsafeHashTable later if it not used anymore
Expand All @@ -283,6 +293,9 @@ type unsafeHashTable struct {
hashMap map[uint64]*entry
entryStore *entryStore
length uint64

bInMap int64 // indicate there are 2^bInMap buckets in hashMap
memDelta int64 // the memory delta of the unsafeHashTable since the last calling GetAndCleanMemoryDelta()
}

// newUnsafeHashTable creates a new unsafeHashTable. estCount means the estimated size of the hashMap.
Expand All @@ -297,11 +310,16 @@ func newUnsafeHashTable(estCount int) *unsafeHashTable {
// Put puts the key/rowPtr pairs to the unsafeHashTable, multiple rowPtrs are stored in a list.
func (ht *unsafeHashTable) Put(hashKey uint64, rowPtr chunk.RowPtr) {
oldEntry := ht.hashMap[hashKey]
newEntry := ht.entryStore.GetStore()
newEntry, memDelta := ht.entryStore.GetStore()
newEntry.ptr = rowPtr
newEntry.next = oldEntry
ht.hashMap[hashKey] = newEntry
if len(ht.hashMap) > (1<<ht.bInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
memDelta += hack.DefBucketMemoryUsageForMapIntToPtr * (1 << ht.bInMap)
ht.bInMap++
}
ht.length++
ht.memDelta += memDelta
}

// Get gets the values of the "key" and appends them to "values".
Expand All @@ -318,11 +336,19 @@ func (ht *unsafeHashTable) Get(hashKey uint64) (rowPtrs []chunk.RowPtr) {
// if the same key is put more than once.
func (ht *unsafeHashTable) Len() uint64 { return ht.length }

// GetAndCleanMemoryDelta gets and cleans the memDelta of the unsafeHashTable.
func (ht *unsafeHashTable) GetAndCleanMemoryDelta() int64 {
memDelta := ht.memDelta
ht.memDelta = 0
return memDelta
}

// concurrentMapHashTable is a concurrent hash table built on concurrentMap
type concurrentMapHashTable struct {
hashMap concurrentMap
entryStore *entryStore
length uint64
memDelta int64 // the memory delta of the concurrentMapHashTable since the last calling GetAndCleanMemoryDelta()
}

// newConcurrentMapHashTable creates a concurrentMapHashTable
Expand All @@ -331,6 +357,7 @@ func newConcurrentMapHashTable() *concurrentMapHashTable {
ht.hashMap = newConcurrentMap()
ht.entryStore = newEntryStore()
ht.length = 0
ht.memDelta = hack.DefBucketMemoryUsageForMapIntToPtr + int64(unsafe.Sizeof(entry{}))*initialEntrySliceLen
return ht
}

Expand All @@ -341,10 +368,13 @@ func (ht *concurrentMapHashTable) Len() uint64 {

// Put puts the key/rowPtr pairs to the concurrentMapHashTable, multiple rowPtrs are stored in a list.
func (ht *concurrentMapHashTable) Put(hashKey uint64, rowPtr chunk.RowPtr) {
newEntry := ht.entryStore.GetStore()
newEntry, memDelta := ht.entryStore.GetStore()
newEntry.ptr = rowPtr
newEntry.next = nil
ht.hashMap.Insert(hashKey, newEntry)
memDelta += ht.hashMap.Insert(hashKey, newEntry)
if memDelta != 0 {
atomic.AddInt64(&ht.memDelta, memDelta)
}
atomic.AddUint64(&ht.length, 1)
}

Expand All @@ -357,3 +387,15 @@ func (ht *concurrentMapHashTable) Get(hashKey uint64) (rowPtrs []chunk.RowPtr) {
}
return
}

// GetAndCleanMemoryDelta gets and cleans the memDelta of the concurrentMapHashTable. Memory delta will be cleared after each fetch.
func (ht *concurrentMapHashTable) GetAndCleanMemoryDelta() int64 {
var memDelta int64
for {
memDelta = atomic.LoadInt64(&ht.memDelta)
if atomic.CompareAndSwapInt64(&ht.memDelta, memDelta, 0) {
break
}
}
return memDelta
}
23 changes: 21 additions & 2 deletions executor/hash_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ import (
"fmt"
"hash"
"hash/fnv"
"sync"
"testing"

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -140,8 +142,9 @@ func testHashRowContainer(t *testing.T, hashFunc func() hash.Hash64, spill bool)
require.NoError(t, err)
rowContainer.ActionSpill().(*chunk.SpillDiskAction).WaitForTest()
require.Equal(t, spill, rowContainer.alreadySpilledSafeForTest())
require.Equal(t, spill, rowContainer.GetMemTracker().BytesConsumed() == 0)
require.Equal(t, !spill, rowContainer.GetMemTracker().BytesConsumed() > 0)
require.Equal(t, spill, rowContainer.rowContainer.GetMemTracker().BytesConsumed() == 0)
require.Equal(t, !spill, rowContainer.rowContainer.GetMemTracker().BytesConsumed() > 0)
require.True(t, rowContainer.GetMemTracker().BytesConsumed() > 0) // hashtable need memory
if rowContainer.alreadySpilledSafeForTest() {
require.NotNil(t, rowContainer.GetDiskTracker())
require.True(t, rowContainer.GetDiskTracker().BytesConsumed() > 0)
Expand All @@ -162,3 +165,19 @@ func testHashRowContainer(t *testing.T, hashFunc func() hash.Hash64, spill bool)
require.Equal(t, chk1.GetRow(1).GetDatumRow(colTypes), matched[1].GetDatumRow(colTypes))
return rowContainer, copiedRC
}

func TestConcurrentMapHashTableMemoryUsage(t *testing.T) {
m := newConcurrentMapHashTable()
const iterations = 1024 * hack.LoadFactorNum / hack.LoadFactorDen // 6656
wg := &sync.WaitGroup{}
wg.Add(2)
// Note: Now concurrentMapHashTable doesn't support inserting in parallel.
for i := 0; i < iterations; i++ {
// Add entry to map.
m.Put(uint64(i*ShardCount), chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(i)})
}
mapMemoryExpected := int64(1024) * hack.DefBucketMemoryUsageForMapIntToPtr
entryMemoryExpected := 16 * int64(64+128+256+512+1024+2048+4096)
require.Equal(t, mapMemoryExpected+entryMemoryExpected, m.GetAndCleanMemoryDelta())
require.Equal(t, int64(0), m.GetAndCleanMemoryDelta())
}
3 changes: 2 additions & 1 deletion util/chunk/row_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ func NewRowContainer(fieldType []*types.FieldType, chunkSize int) *RowContainer
},
fieldType: fieldType,
chunkSize: chunkSize,
memTracker: li.memTracker,
memTracker: memory.NewTracker(memory.LabelForRowContainer, -1),
diskTracker: disk.NewTracker(memory.LabelForRowContainer, -1),
}
li.GetMemTracker().AttachTo(rc.GetMemTracker())
return rc
}

Expand Down
9 changes: 9 additions & 0 deletions util/hack/hack.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ const (
// LoadFactorDen is the denominator of load factor
LoadFactorDen = 2
)

const (
// DefBucketMemoryUsageForMapStrToSlice = bucketSize*(1+unsafe.Sizeof(string) + unsafe.Sizeof(slice))+2*ptrSize
// ref https://github.com/golang/go/blob/go1.15.6/src/reflect/type.go#L2162.
// The bucket size may be changed by golang implement in the future.
DefBucketMemoryUsageForMapStrToSlice = 8*(1+16+24) + 16
// DefBucketMemoryUsageForMapIntToPtr = bucketSize*(1+unsafe.Sizeof(uint64) + unsafe.Sizeof(pointer))+2*ptrSize
DefBucketMemoryUsageForMapIntToPtr = 8*(1+8+8) + 16
)

0 comments on commit d63a5fd

Please sign in to comment.