Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

membuffer: fix memory leak in red-black tree (#1483) #1500

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ jobs:
uses: golangci/[email protected]
with:
version: v1.55.2
skip-go-installation: true

36 changes: 28 additions & 8 deletions internal/unionstore/memdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (db *MemDB) SelectValueHistory(key []byte, predicate func(value []byte) boo
// GetFlags returns the latest flags associated with key.
func (db *MemDB) GetFlags(key []byte) (kv.KeyFlags, error) {
x := db.traverse(key, false)
if x.isNull() {
if x.isNull() || x.isDeleted() {
return 0, tikverr.ErrNotExist
}
return x.getKeyFlags(), nil
Expand Down Expand Up @@ -337,17 +337,22 @@ func (db *MemDB) set(key []byte, value []byte, ops ...kv.FlagsOp) error {
// the NeedConstraintCheckInPrewrite flag is temporary,
// every write to the node removes the flag unless it's explicitly set.
// This set must be in the latest stage so no special processing is needed.
var flags kv.KeyFlags
flags := x.getKeyFlags()
if flags == 0 && x.vptr.isNull() && x.isDeleted() {
x.unmarkDelete()
db.count++
db.size += int(x.klen)
}
if value != nil {
flags = kv.ApplyFlagsOps(x.getKeyFlags(), append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...)
flags = kv.ApplyFlagsOps(flags, append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...)
} else {
// an UpdateFlag operation, do not delete the NeedConstraintCheckInPrewrite flag.
flags = kv.ApplyFlagsOps(x.getKeyFlags(), ops...)
flags = kv.ApplyFlagsOps(flags, ops...)
}
if flags.AndPersistent() != 0 {
db.dirty = true
}
x.setKeyFlags(flags)
x.resetKeyFlags(flags)

if value == nil {
return nil
Expand Down Expand Up @@ -847,18 +852,33 @@ func (n *memdbNode) getKey() []byte {

const (
// bit 1 => red, bit 0 => black
nodeColorBit uint16 = 0x8000
nodeFlagsMask = ^nodeColorBit
nodeColorBit uint16 = 0x8000
// bit 1 => node is deleted, bit 0 => node is not deleted
// This flag is used to mark a node as deleted, so that we can reuse the node to avoid memory leak.
deleteFlag uint16 = 1 << 14
nodeFlagsMask = ^(nodeColorBit | deleteFlag)
)

func (n *memdbNode) getKeyFlags() kv.KeyFlags {
return kv.KeyFlags(n.flags & nodeFlagsMask)
}

func (n *memdbNode) setKeyFlags(f kv.KeyFlags) {
func (n *memdbNode) resetKeyFlags(f kv.KeyFlags) {
n.flags = (^nodeFlagsMask & n.flags) | uint16(f)
}

func (n *memdbNode) markDelete() {
n.flags = (nodeColorBit & n.flags) | deleteFlag
}

func (n *memdbNode) unmarkDelete() {
n.flags &= ^deleteFlag
}

func (n *memdbNode) isDeleted() bool {
return n.flags&deleteFlag != 0
}

// RemoveFromBuffer removes a record from the mem buffer. It should be only used for test.
func (db *MemDB) RemoveFromBuffer(key []byte) {
x := db.traverse(key, false)
Expand Down
6 changes: 4 additions & 2 deletions internal/unionstore/memdb_arena.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,11 @@ func (l *memdbVlog) revertToCheckpoint(db *MemDB, cp *MemDBCheckpoint) {
// If there are no flags associated with this key, we need to delete this node.
keptFlags := node.getKeyFlags().AndPersistent()
if keptFlags == 0 {
db.deleteNode(node)
node.markDelete()
db.count--
db.size -= int(node.klen)
} else {
node.setKeyFlags(keptFlags)
node.resetKeyFlags(keptFlags)
db.dirty = true
}
} else {
Expand Down
14 changes: 12 additions & 2 deletions internal/unionstore/memdb_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (i *MemdbIterator) init() {
}
}

if i.isFlagsOnly() && !i.includeFlags {
if (i.isFlagsOnly() && !i.includeFlags) || (!i.curr.isNull() && i.curr.isDeleted()) {
err := i.Next()
_ = err // memdbIterator will never fail
}
Expand All @@ -140,7 +140,7 @@ func (i *MemdbIterator) Flags() kv.KeyFlags {
func (i *MemdbIterator) UpdateFlags(ops ...kv.FlagsOp) {
origin := i.curr.getKeyFlags()
n := kv.ApplyFlagsOps(origin, ops...)
i.curr.setKeyFlags(n)
i.curr.resetKeyFlags(n)
}

// HasValue returns false if it is flags only.
Expand Down Expand Up @@ -175,6 +175,10 @@ func (i *MemdbIterator) Next() error {
i.curr = i.db.successor(i.curr)
}

if i.curr.isDeleted() {
continue
}

// We need to skip persistent flags only nodes.
if i.includeFlags || !i.isFlagsOnly() {
break
Expand All @@ -196,6 +200,9 @@ func (i *MemdbIterator) seekToFirst() {
}

i.curr = y
for !i.curr.isNull() && i.curr.isDeleted() {
i.curr = i.db.successor(i.curr)
}
}

func (i *MemdbIterator) seekToLast() {
Expand All @@ -208,6 +215,9 @@ func (i *MemdbIterator) seekToLast() {
}

i.curr = y
for !i.curr.isNull() && i.curr.isDeleted() {
i.curr = i.db.predecessor(i.curr)
}
}

func (i *MemdbIterator) seek(key []byte) {
Expand Down
22 changes: 22 additions & 0 deletions internal/unionstore/memdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ package unionstore
import (
"encoding/binary"
"fmt"
"strconv"
"strings"
"testing"

leveldb "github.com/pingcap/goleveldb/leveldb/memdb"
Expand Down Expand Up @@ -879,3 +881,23 @@ func TestSnapshotGetIter(t *testing.T) {
assert.Equal(iter.Value(), []byte{byte(50)})
}
}

func TestMemDBLeafFragmentation(t *testing.T) {
buffer := newMemDB()
assert := assert.New(t)
h := buffer.Staging()
mem := buffer.Mem()
for i := 0; i < 10; i++ {
for k := 0; k < 100; k++ {
buffer.Set([]byte(strings.Repeat(strconv.Itoa(k), 256)), []byte("value"))
}
cur := buffer.Mem()
if mem == 0 {
mem = cur
} else {
assert.LessOrEqual(cur, mem)
}
buffer.Cleanup(h)
h = buffer.Staging()
}
}
Loading