From 5ce6bf1f099c9ea0c9f83020acf081ab67e0a5ca Mon Sep 17 00:00:00 2001 From: you06 Date: Mon, 25 Nov 2024 15:44:41 +0900 Subject: [PATCH] membuffer: fix memory leak in red-black tree (#1483) (#1498) Signed-off-by: you06 --- internal/unionstore/memdb.go | 36 +++++++++++++++++++++------ internal/unionstore/memdb_arena.go | 6 +++-- internal/unionstore/memdb_iterator.go | 14 +++++++++-- internal/unionstore/memdb_test.go | 22 ++++++++++++++++ 4 files changed, 66 insertions(+), 12 deletions(-) diff --git a/internal/unionstore/memdb.go b/internal/unionstore/memdb.go index 1560f51ce8..a3d1a78881 100644 --- a/internal/unionstore/memdb.go +++ b/internal/unionstore/memdb.go @@ -239,7 +239,7 @@ func (db *MemDB) SelectValueHistory(key []byte, predicate func(value []byte) boo // GetFlags returns the latest flags associated with key. func (db *MemDB) GetFlags(key []byte) (kv.KeyFlags, error) { x := db.traverse(key, false) - if x.isNull() { + if x.isNull() || x.isDeleted() { return 0, tikverr.ErrNotExist } return x.getKeyFlags(), nil @@ -337,17 +337,22 @@ func (db *MemDB) set(key []byte, value []byte, ops ...kv.FlagsOp) error { // the NeedConstraintCheckInPrewrite flag is temporary, // every write to the node removes the flag unless it's explicitly set. // This set must be in the latest stage so no special processing is needed. - var flags kv.KeyFlags + flags := x.getKeyFlags() + if flags == 0 && x.vptr.isNull() && x.isDeleted() { + x.unmarkDelete() + db.count++ + db.size += int(x.klen) + } if value != nil { - flags = kv.ApplyFlagsOps(x.getKeyFlags(), append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...) + flags = kv.ApplyFlagsOps(flags, append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...) } else { // an UpdateFlag operation, do not delete the NeedConstraintCheckInPrewrite flag. - flags = kv.ApplyFlagsOps(x.getKeyFlags(), ops...) + flags = kv.ApplyFlagsOps(flags, ops...) } if flags.AndPersistent() != 0 { db.dirty = true } - x.setKeyFlags(flags) + x.resetKeyFlags(flags) if value == nil { return nil @@ -848,18 +853,33 @@ func (n *memdbNode) getKey() []byte { const ( // bit 1 => red, bit 0 => black - nodeColorBit uint16 = 0x8000 - nodeFlagsMask = ^nodeColorBit + nodeColorBit uint16 = 0x8000 + // bit 1 => node is deleted, bit 0 => node is not deleted + // This flag is used to mark a node as deleted, so that we can reuse the node to avoid memory leak. + deleteFlag uint16 = 1 << 14 + nodeFlagsMask = ^(nodeColorBit | deleteFlag) ) func (n *memdbNode) getKeyFlags() kv.KeyFlags { return kv.KeyFlags(n.flags & nodeFlagsMask) } -func (n *memdbNode) setKeyFlags(f kv.KeyFlags) { +func (n *memdbNode) resetKeyFlags(f kv.KeyFlags) { n.flags = (^nodeFlagsMask & n.flags) | uint16(f) } +func (n *memdbNode) markDelete() { + n.flags = (nodeColorBit & n.flags) | deleteFlag +} + +func (n *memdbNode) unmarkDelete() { + n.flags &= ^deleteFlag +} + +func (n *memdbNode) isDeleted() bool { + return n.flags&deleteFlag != 0 +} + // RemoveFromBuffer removes a record from the mem buffer. It should be only used for test. func (db *MemDB) RemoveFromBuffer(key []byte) { x := db.traverse(key, false) diff --git a/internal/unionstore/memdb_arena.go b/internal/unionstore/memdb_arena.go index c1db745afa..1c38aa1edf 100644 --- a/internal/unionstore/memdb_arena.go +++ b/internal/unionstore/memdb_arena.go @@ -384,9 +384,11 @@ func (l *memdbVlog) revertToCheckpoint(db *MemDB, cp *MemDBCheckpoint) { // If there are no flags associated with this key, we need to delete this node. keptFlags := node.getKeyFlags().AndPersistent() if keptFlags == 0 { - db.deleteNode(node) + node.markDelete() + db.count-- + db.size -= int(node.klen) } else { - node.setKeyFlags(keptFlags) + node.resetKeyFlags(keptFlags) db.dirty = true } } else { diff --git a/internal/unionstore/memdb_iterator.go b/internal/unionstore/memdb_iterator.go index 3b4bdfd8f0..a17c5436de 100644 --- a/internal/unionstore/memdb_iterator.go +++ b/internal/unionstore/memdb_iterator.go @@ -118,7 +118,7 @@ func (i *MemdbIterator) init() { } } - if i.isFlagsOnly() && !i.includeFlags { + if (i.isFlagsOnly() && !i.includeFlags) || (!i.curr.isNull() && i.curr.isDeleted()) { err := i.Next() _ = err // memdbIterator will never fail } @@ -141,7 +141,7 @@ func (i *MemdbIterator) Flags() kv.KeyFlags { func (i *MemdbIterator) UpdateFlags(ops ...kv.FlagsOp) { origin := i.curr.getKeyFlags() n := kv.ApplyFlagsOps(origin, ops...) - i.curr.setKeyFlags(n) + i.curr.resetKeyFlags(n) } // HasValue returns false if it is flags only. @@ -176,6 +176,10 @@ func (i *MemdbIterator) Next() error { i.curr = i.db.successor(i.curr) } + if i.curr.isDeleted() { + continue + } + // We need to skip persistent flags only nodes. if i.includeFlags || !i.isFlagsOnly() { break @@ -197,6 +201,9 @@ func (i *MemdbIterator) seekToFirst() { } i.curr = y + for !i.curr.isNull() && i.curr.isDeleted() { + i.curr = i.db.successor(i.curr) + } } func (i *MemdbIterator) seekToLast() { @@ -209,6 +216,9 @@ func (i *MemdbIterator) seekToLast() { } i.curr = y + for !i.curr.isNull() && i.curr.isDeleted() { + i.curr = i.db.predecessor(i.curr) + } } func (i *MemdbIterator) seek(key []byte) { diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 913aebc430..a8280126c8 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -39,6 +39,8 @@ package unionstore import ( "encoding/binary" "fmt" + "strconv" + "strings" "testing" leveldb "github.com/pingcap/goleveldb/leveldb/memdb" @@ -895,3 +897,23 @@ func TestSnapshotGetIter(t *testing.T) { assert.Equal(iter.Value(), []byte{byte(50)}) } } + +func TestMemDBLeafFragmentation(t *testing.T) { + buffer := newMemDB() + assert := assert.New(t) + h := buffer.Staging() + mem := buffer.Mem() + for i := 0; i < 10; i++ { + for k := 0; k < 100; k++ { + buffer.Set([]byte(strings.Repeat(strconv.Itoa(k), 256)), []byte("value")) + } + cur := buffer.Mem() + if mem == 0 { + mem = cur + } else { + assert.LessOrEqual(cur, mem) + } + buffer.Cleanup(h) + h = buffer.Staging() + } +}