Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

storage/pebbleiter: mangle unsafe buffers during positioning #96685

Merged
merged 3 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions pkg/kv/kvserver/spanset/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,37 @@ func (i *MVCCIterator) SeekLT(key storage.MVCCKey) {
// Next is part of the storage.MVCCIterator interface.
func (i *MVCCIterator) Next() {
i.i.Next()
i.checkAllowed(roachpb.Span{Key: i.UnsafeKey().Key}, false)
i.checkAllowedCurrPosForward(false)
}

// Prev is part of the storage.MVCCIterator interface.
func (i *MVCCIterator) Prev() {
i.i.Prev()
i.checkAllowed(roachpb.Span{Key: i.UnsafeKey().Key}, false)
i.checkAllowedCurrPosForward(false)
}

// NextKey is part of the storage.MVCCIterator interface.
func (i *MVCCIterator) NextKey() {
i.i.NextKey()
i.checkAllowed(roachpb.Span{Key: i.UnsafeKey().Key}, false)
i.checkAllowedCurrPosForward(false)
}

// checkAllowedCurrPosForward checks the span starting at the current iterator
// position, if the current iterator position is valid.
func (i *MVCCIterator) checkAllowedCurrPosForward(errIfDisallowed bool) {
i.invalid = false
i.err = nil
if ok, _ := i.i.Valid(); !ok {
// If the iterator is invalid after the operation, there's nothing to
// check. We allow uses of iterators to exceed the declared span bounds
// as long as the iterator itself is configured with proper boundaries.
return
}
i.checkAllowedValidPos(roachpb.Span{Key: i.UnsafeKey().Key}, errIfDisallowed)
}

// checkAllowed checks the provided span if the current iterator position is
// valid.
func (i *MVCCIterator) checkAllowed(span roachpb.Span, errIfDisallowed bool) {
i.invalid = false
i.err = nil
Expand All @@ -126,6 +142,10 @@ func (i *MVCCIterator) checkAllowed(span roachpb.Span, errIfDisallowed bool) {
// as long as the iterator itself is configured with proper boundaries.
return
}
i.checkAllowedValidPos(span, errIfDisallowed)
}

func (i *MVCCIterator) checkAllowedValidPos(span roachpb.Span, errIfDisallowed bool) {
var err error
if i.spansOnly {
err = i.spans.CheckAllowed(SpanReadOnly, span)
Expand Down
5 changes: 4 additions & 1 deletion pkg/storage/pebbleiter/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ go_library(
}),
importpath = "github.com/cockroachdb/cockroach/pkg/storage/pebbleiter",
visibility = ["//visibility:public"],
deps = ["@com_github_cockroachdb_pebble//:pebble"],
deps = [
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_pebble//:pebble",
],
)

REMOVE_GO_BUILD_CONSTRAINTS = "cat $< | grep -v '//go:build' | grep -v '// +build' > $@"
Expand Down
129 changes: 127 additions & 2 deletions pkg/storage/pebbleiter/crdb_test_on.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

package pebbleiter

import "github.com/cockroachdb/pebble"
import (
"math/rand"

"github.com/cockroachdb/errors"
"github.com/cockroachdb/pebble"
)

// Iterator wraps the *pebble.Iterator in crdb_test builds with an assertionIter
// that detects when Close is called on the iterator twice. Double closes are
Expand All @@ -31,6 +36,23 @@ func MaybeWrap(iter *pebble.Iterator) Iterator {
type assertionIter struct {
*pebble.Iterator
closed bool
// unsafeBufs hold buffers used for returning values with short lifetimes to
// the caller. To assert that the client is respecting the lifetimes,
// assertionIter mangles the buffers as soon as the associated lifetime
// expires. This is the same technique applied by the unsafeMVCCIterator in
// pkg/storage, but this time applied at the API boundary between
// pkg/storage and Pebble.
//
// unsafeBufs holds two buffers per-key type and an index indicating which
// are currently in use. This is used to randomly switch to a different
// buffer, ensuring that the buffer(s) returned to the caller for the
// previous iterator position are garbage (as opposed to just state
// corresponding to the current iterator position).
unsafeBufs struct {
idx int
key [2][]byte
val [2][]byte
}
}

func (i *assertionIter) Clone(cloneOpts pebble.CloneOptions) (Iterator, error) {
Expand All @@ -43,8 +65,111 @@ func (i *assertionIter) Clone(cloneOpts pebble.CloneOptions) (Iterator, error) {

func (i *assertionIter) Close() error {
if i.closed {
panic("pebble.Iterator already closed")
panic(errors.AssertionFailedf("pebble.Iterator already closed"))
}
i.closed = true
return i.Iterator.Close()
}

func (i *assertionIter) Key() []byte {
if !i.Valid() {
panic(errors.AssertionFailedf("Key() called on !Valid() pebble.Iterator"))
}
idx := i.unsafeBufs.idx
i.unsafeBufs.key[idx] = append(i.unsafeBufs.key[idx][:0], i.Iterator.Key()...)
return i.unsafeBufs.key[idx]
}

func (i *assertionIter) Value() []byte {
if !i.Valid() {
panic(errors.AssertionFailedf("Value() called on !Valid() pebble.Iterator"))
}
idx := i.unsafeBufs.idx
i.unsafeBufs.val[idx] = append(i.unsafeBufs.val[idx][:0], i.Iterator.Value()...)
return i.unsafeBufs.val[idx]
}

func (i *assertionIter) LazyValue() pebble.LazyValue {
if !i.Valid() {
panic(errors.AssertionFailedf("LazyValue() called on !Valid() pebble.Iterator"))
}
return i.Iterator.LazyValue()
}

func (i *assertionIter) First() bool {
i.maybeMangleBufs()
return i.Iterator.First()
}

func (i *assertionIter) SeekGE(key []byte) bool {
i.maybeMangleBufs()
return i.Iterator.SeekGE(key)
}

func (i *assertionIter) SeekGEWithLimit(key []byte, limit []byte) pebble.IterValidityState {
i.maybeMangleBufs()
return i.Iterator.SeekGEWithLimit(key, limit)
}

func (i *assertionIter) SeekPrefixGE(key []byte) bool {
i.maybeMangleBufs()
return i.Iterator.SeekPrefixGE(key)
}

func (i *assertionIter) Next() bool {
i.maybeMangleBufs()
return i.Iterator.Next()
}

func (i *assertionIter) NextWithLimit(limit []byte) pebble.IterValidityState {
i.maybeMangleBufs()
return i.Iterator.NextWithLimit(limit)
}

func (i *assertionIter) NextPrefix() bool {
i.maybeMangleBufs()
return i.Iterator.NextPrefix()
}

func (i *assertionIter) Last() bool {
i.maybeMangleBufs()
return i.Iterator.Last()
}

func (i *assertionIter) SeekLT(key []byte) bool {
i.maybeMangleBufs()
return i.Iterator.SeekLT(key)
}

func (i *assertionIter) SeekLTWithLimit(key []byte, limit []byte) pebble.IterValidityState {
i.maybeMangleBufs()
return i.Iterator.SeekLTWithLimit(key, limit)
}

func (i *assertionIter) Prev() bool {
i.maybeMangleBufs()
return i.Iterator.Prev()
}

func (i *assertionIter) PrevWithLimit(limit []byte) pebble.IterValidityState {
i.maybeMangleBufs()
return i.Iterator.PrevWithLimit(limit)
}

// maybeMangleBufs trashes the contents of buffers used to return unsafe values
// to the caller. This is used to ensure that the client respects the Pebble
// iterator interface and the lifetimes of buffers it returns.
func (i *assertionIter) maybeMangleBufs() {
if rand.Intn(2) == 0 {
idx := i.unsafeBufs.idx
for _, b := range [...][]byte{i.unsafeBufs.key[idx], i.unsafeBufs.val[idx]} {
for i := range b {
b[i] = 0
}
}
if rand.Intn(2) == 0 {
// Switch to a new buffer for the next iterator position.
i.unsafeBufs.idx = (i.unsafeBufs.idx + 1) % 2
}
}
}
2 changes: 1 addition & 1 deletion pkg/storage/sst.go
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ func CheckSSTConflicts(
// 2) the ext iterator became invalid
// 3) both iterators changed keys and the sst iterator's key is further
// ahead.
if extChangedKeys && (!sstChangedKeys || (!extOK && sstOK) || extIter.UnsafeKey().Key.Compare(sstIter.UnsafeKey().Key) < 0) {
if sstOK && extChangedKeys && (!sstChangedKeys || !extOK || extIter.UnsafeKey().Key.Compare(sstIter.UnsafeKey().Key) < 0) {
extIter.SeekGE(MVCCKey{Key: sstIter.UnsafeKey().Key})
extOK, extErr = extIter.Valid()
}
Expand Down