Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vecindex: redistribute vectors across level during split #135506

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/sql/vecindex/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ go_test(
"//pkg/util/log",
"//pkg/util/num32",
"//pkg/util/stop",
"//pkg/util/timeutil",
"//pkg/util/vector",
"@com_github_cockroachdb_datadriven//:datadriven",
"@com_github_cockroachdb_errors//:errors",
Expand Down
231 changes: 213 additions & 18 deletions pkg/sql/vecindex/fixup_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/internal"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecstore"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/num32"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -334,18 +335,51 @@ func (fp *fixupProcessor) splitPartition(
tempLeftOffsets, tempRightOffsets := kmeans.Compute(&vectors, tempOffsets)

leftSplit, rightSplit := fp.splitPartitionData(
ctx, partition, &vectors, tempLeftOffsets, tempRightOffsets)
ctx, partition, vectors, tempLeftOffsets, tempRightOffsets)

if parentPartition != nil {
// De-link the splitting partition from its parent partition.
childKey := vecstore.ChildKey{PartitionKey: partitionKey}
_, err = fp.index.removeFromPartition(ctx, txn, parentPartitionKey, childKey)
count, err := fp.index.removeFromPartition(ctx, txn, parentPartitionKey, childKey)
if err != nil {
return errors.Wrapf(err, "removing splitting partition %d from its parent %d",
partitionKey, parentPartitionKey)
}

// TODO(andyk): Move vectors to/from split partition.
if count != 0 {
// Move any vectors to sibling partitions that have closer centroids.
// Lazily get parent vectors only if they're actually needed.
var parentVectors vector.Set
getParentVectors := func() (vector.Set, error) {
if parentVectors.Dims != 0 {
return parentVectors, nil
}
var err error
parentVectors, err = fp.getFullVectorsForPartition(
ctx, txn, parentPartitionKey, parentPartition)
return parentVectors, err
}

err = fp.moveVectorsToSiblings(
ctx, txn, parentPartitionKey, parentPartition, getParentVectors, partitionKey, &leftSplit)
if err != nil {
return err
}
err = fp.moveVectorsToSiblings(
ctx, txn, parentPartitionKey, parentPartition, getParentVectors, partitionKey, &rightSplit)
if err != nil {
return err
}

// Move any vectors at the same level that are closer to the new split
// centroids than they are to their own centroids.
if err = fp.linkNearbyVectors(ctx, txn, partitionKey, leftSplit.Partition); err != nil {
return err
}
if err = fp.linkNearbyVectors(ctx, txn, partitionKey, rightSplit.Partition); err != nil {
return err
}
}
}

// Insert the two new partitions into the index. This only adds their data
Expand Down Expand Up @@ -392,23 +426,19 @@ func (fp *fixupProcessor) splitPartition(
// Link the two new partitions into the K-means tree by inserting them
// into the parent level. This can trigger a further split, this time of
// the parent level.
fp.searchCtx = searchContext{
Ctx: ctx,
Workspace: fp.workspace,
Txn: txn,
Level: parentPartition.Level() + 1,
}
searchCtx := fp.reuseSearchContext(ctx, txn)
searchCtx.Level = parentPartition.Level() + 1

fp.searchCtx.Randomized = leftSplit.Partition.Centroid()
searchCtx.Randomized = leftSplit.Partition.Centroid()
childKey := vecstore.ChildKey{PartitionKey: leftPartitionKey}
err = fp.index.insertHelper(&fp.searchCtx, childKey, true /* allowRetry */)
err = fp.index.insertHelper(searchCtx, childKey, true /* allowRetry */)
if err != nil {
return errors.Wrapf(err, "inserting left partition for split of partition %d", partitionKey)
}

fp.searchCtx.Randomized = rightSplit.Partition.Centroid()
searchCtx.Randomized = rightSplit.Partition.Centroid()
childKey = vecstore.ChildKey{PartitionKey: rightPartitionKey}
err = fp.index.insertHelper(&fp.searchCtx, childKey, true /* allowRetry */)
err = fp.index.insertHelper(searchCtx, childKey, true /* allowRetry */)
if err != nil {
return errors.Wrapf(err, "inserting right partition for split of partition %d", partitionKey)
}
Expand All @@ -432,7 +462,7 @@ func (fp *fixupProcessor) splitPartition(
func (fp *fixupProcessor) splitPartitionData(
ctx context.Context,
splitPartition *vecstore.Partition,
vectors *vector.Set,
vectors vector.Set,
leftOffsets, rightOffsets []uint64,
) (leftSplit, rightSplit splitData) {
// Copy centroid distances and child keys so they can be split.
Expand Down Expand Up @@ -461,7 +491,8 @@ func (fp *fixupProcessor) splitPartitionData(

right := int(rightOffsets[ri])
if right >= len(leftOffsets) {
panic("expected equal number of left and right offsets that need to be swapped")
panic(errors.AssertionFailedf(
"expected equal number of left and right offsets that need to be swapped"))
}

// Swap vectors.
Expand All @@ -480,22 +511,172 @@ func (fp *fixupProcessor) splitPartitionData(
ri++
}

leftVectorSet := *vectors
leftVectorSet := vectors
rightVectorSet := leftVectorSet.SplitAt(len(leftOffsets))

leftCentroidDistances := centroidDistances[:len(leftOffsets):len(leftOffsets)]
leftChildKeys := childKeys[:len(leftOffsets):len(leftOffsets)]
leftSplit.Init(ctx, fp.index.quantizer, &leftVectorSet,
leftSplit.Init(ctx, fp.index.quantizer, leftVectorSet,
leftCentroidDistances, leftChildKeys, splitPartition.Level())

rightCentroidDistances := centroidDistances[len(leftOffsets):]
rightChildKeys := childKeys[len(leftOffsets):]
rightSplit.Init(ctx, fp.index.quantizer, &rightVectorSet,
rightSplit.Init(ctx, fp.index.quantizer, rightVectorSet,
rightCentroidDistances, rightChildKeys, splitPartition.Level())

return leftSplit, rightSplit
}

// moveVectorsToSiblings checks each vector in the new split partition to see if
// it's now closer to a sibling partition's centroid than it is to its own
// centroid. If that's true, then move the vector to the sibling partition. Pass
// function to lazily fetch parent vectors, as it's expensive and is only needed
// if vectors actually need to be moved.
func (fp *fixupProcessor) moveVectorsToSiblings(
ctx context.Context,
txn vecstore.Txn,
parentPartitionKey vecstore.PartitionKey,
parentPartition *vecstore.Partition,
getParentVectors func() (vector.Set, error),
oldPartitionKey vecstore.PartitionKey,
split *splitData,
) error {
for i := 0; i < split.Vectors.Count; i++ {
if split.Vectors.Count == 1 && split.Partition.Level() != vecstore.LeafLevel {
// Don't allow so many vectors to be moved that a non-leaf partition
// ends up empty. This would violate a key constraint that the K-means
// tree is always fully balanced.
break
}

vector := split.Vectors.At(i)

// If distance to new centroid is <= distance to old centroid, then skip.
newCentroidDistance := split.Partition.QuantizedSet().GetCentroidDistances()[i]
if newCentroidDistance <= split.OldCentroidDistances[i] {
continue
}

// Get the full vectors for the parent partition's children.
parentVectors, err := getParentVectors()
if err != nil {
return err
}

// Check whether the vector is closer to a sibling centroid than its own
// new centroid.
minDistanceOffset := -1
for parent := 0; parent < parentVectors.Count; parent++ {
squaredDistance := num32.L2Distance(parentVectors.At(parent), vector)
if squaredDistance < newCentroidDistance {
newCentroidDistance = squaredDistance
minDistanceOffset = parent
}
}
if minDistanceOffset == -1 {
continue
}

siblingPartitionKey := parentPartition.ChildKeys()[minDistanceOffset].PartitionKey
log.VEventf(ctx, 3, "moving vector from splitting partition %d to sibling partition %d",
oldPartitionKey, siblingPartitionKey)

// Found a sibling child partition that's closer, so insert the vector
// there instead.
childKey := split.Partition.ChildKeys()[i]
_, err = fp.index.addToPartition(ctx, txn, parentPartitionKey, siblingPartitionKey, vector, childKey)
if err != nil {
return errors.Wrapf(err, "moving vector to partition %d", siblingPartitionKey)
}

// Remove the vector's data from the new partition. The remove operation
// backfills data at the current index with data from the last index.
// Therefore, don't increment the iteration index, since the next item
// is in the same location as the last.
split.ReplaceWithLast(i)
i--
}

return nil
}

// linkNearbyVectors searches for vectors at the same level that are close to
// the given split partition's centroid. If they are closer than they are to
// their own centroid, then move them to the split partition.
func (fp *fixupProcessor) linkNearbyVectors(
ctx context.Context,
txn vecstore.Txn,
oldPartitionKey vecstore.PartitionKey,
partition *vecstore.Partition,
) error {
// TODO(andyk): Add way to filter search set in order to skip vectors deeper
// down in the search rather than afterwards.
searchCtx := fp.reuseSearchContext(ctx, txn)
searchCtx.Options = SearchOptions{ReturnVectors: true}
searchCtx.Level = partition.Level()
searchCtx.Randomized = partition.Centroid()

// Don't link more vectors than the number of remaining slots in the split
// partition, to avoid triggering another split.
maxResults := fp.index.options.MaxPartitionSize - partition.Count()
if maxResults < 1 {
return nil
}
searchSet := vecstore.SearchSet{MaxResults: maxResults}
err := fp.index.searchHelper(searchCtx, &searchSet, true /* allowRetry */)
if err != nil {
return err
}

tempVector := fp.workspace.AllocVector(fp.index.quantizer.GetRandomDims())
defer fp.workspace.FreeVector(tempVector)

// Filter the results.
results := searchSet.PopUnsortedResults()
for i := range results {
result := &results[i]

// Skip vectors that are closer to their own centroid than they are to
// the split partition's centroid.
if result.QuerySquaredDistance >= result.CentroidDistance*result.CentroidDistance {
continue
}

log.VEventf(ctx, 3, "linking vector from partition %d to splitting partition %d",
result.ChildKey.PartitionKey, oldPartitionKey)

// Leaf vectors from the primary index need to be randomized.
vector := result.Vector
if partition.Level() == vecstore.LeafLevel {
fp.index.quantizer.RandomizeVector(ctx, vector, tempVector, false /* invert */)
vector = tempVector
}

// Remove the vector from the other partition.
count, err := fp.index.removeFromPartition(ctx, txn, result.ParentPartitionKey, result.ChildKey)
if err != nil {
return err
}
if count == 0 && partition.Level() > vecstore.LeafLevel {
// Removing the vector will result in an empty non-leaf partition, which
// is not allowed, as the K-means tree would not be fully balanced. Add
// the vector back to the partition. This is a very rare case and that
// partition is likely to be merged away regardless.
_, err = fp.index.store.AddToPartition(
ctx, txn, result.ParentPartitionKey, vector, result.ChildKey)
if err != nil {
return err
}
continue
}

// Add the vector to the split partition.
partition.Add(ctx, vector, result.ChildKey)
}

return nil
}

// getFullVectorsForPartition fetches the full-size vectors (potentially
// randomized by the quantizer) that are quantized by the given partition.
func (fp *fixupProcessor) getFullVectorsForPartition(
Expand Down Expand Up @@ -543,3 +724,17 @@ func (fp *fixupProcessor) getFullVectorsForPartition(

return vectors, nil
}

// reuseSearchContext initializes the reusable search context, including reusing
// its temp slices.
func (fp *fixupProcessor) reuseSearchContext(ctx context.Context, txn vecstore.Txn) *searchContext {
fp.searchCtx = searchContext{
Ctx: ctx,
Workspace: fp.workspace,
Txn: txn,
tempKeys: fp.searchCtx.tempKeys,
tempCounts: fp.searchCtx.tempCounts,
tempVectorsWithKeys: fp.searchCtx.tempVectorsWithKeys,
}
return &fp.searchCtx
}
2 changes: 1 addition & 1 deletion pkg/sql/vecindex/fixup_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestSplitPartitionData(t *testing.T) {
tempVectors := vector.MakeSet(2)
tempVectors.AddSet(&vectors)
leftSplit, rightSplit := index.fixups.splitPartitionData(
ctx, splitPartition, &tempVectors, tc.leftOffsets, tc.rightOffsets)
ctx, splitPartition, tempVectors, tc.leftOffsets, tc.rightOffsets)

validate(&leftSplit, tc.expectedLeft)
validate(&rightSplit, tc.expectedRight)
Expand Down
13 changes: 6 additions & 7 deletions pkg/sql/vecindex/quantize/quantizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@ type Quantizer interface {
GetRandomDims() int

// RandomizeVector optionally performs a random orthogonal transformation
// (ROT) on the input vector and writes it to the output vector. If
// invert=false, the input vector is "original" and the caller is
// responsible for allocating the "randomized" output vector, with length
// equal to GetRandomDims(). If invert=true, the input vector is
// "randomized" and the caller is responsible for allocating the "original"
// output vector.
// (ROT) on the input vector and writes it to the output vector. The caller
// is responsible for allocating the output vector with length equal to
// GetRandomDims(). If invert is true, then a previous ROT is reversed in
// order to recover the original vector. The caller is responsible for
// allocating the output vector with length equal to GetOriginalDims().
//
// Randomizing vectors distributes skew more evenly across dimensions and
// across vectors in a set. Distance and angle between any two vectors
Expand All @@ -49,7 +48,7 @@ type Quantizer interface {
//
// NOTE: This step may be a no-op for some quantization algorithms, which
// may simply copy the original slice to the randomized slice, unchanged.
RandomizeVector(ctx context.Context, original vector.T, randomized vector.T, invert bool)
RandomizeVector(ctx context.Context, input vector.T, output vector.T, invert bool)

// Quantize quantizes a set of input vectors and returns their compressed
// form as a quantized vector set. Input vectors should already have been
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/vecindex/quantize/rabitq.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ func (q *raBitQuantizer) GetRandomDims() int {

// RandomizeVector implements the Quantizer interface.
func (q *raBitQuantizer) RandomizeVector(
ctx context.Context, original vector.T, randomized vector.T, invert bool,
ctx context.Context, input vector.T, output vector.T, invert bool,
) {
if !invert {
num32.MulMatrixByVector(&q.rot, original, randomized, num32.NoTranspose)
num32.MulMatrixByVector(&q.rot, input, output, num32.NoTranspose)
} else {
num32.MulMatrixByVector(&q.rot, randomized, original, num32.Transpose)
num32.MulMatrixByVector(&q.rot, input, output, num32.Transpose)
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/vecindex/quantize/rabitq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func TestRaBitRandomizeVector(t *testing.T) {

// Ensure that inverting RandomizeVector recovers original vector.
randomizedInv := make([]float32, dims)
quantizer.RandomizeVector(ctx, randomizedInv, randomized.At(i), true /* invert */)
quantizer.RandomizeVector(ctx, randomized.At(i), randomizedInv, true /* invert */)
for j, val := range original.At(i) {
require.InDelta(t, val, randomizedInv[j], 0.00001)
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/sql/vecindex/quantize/unquantizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ func (q *unQuantizer) GetRandomDims() int {

// RandomizeVector implements the Quantizer interface.
func (q *unQuantizer) RandomizeVector(
ctx context.Context, original vector.T, randomized vector.T, invert bool,
ctx context.Context, input vector.T, output vector.T, invert bool,
) {
if len(original) != q.dims {
if len(input) != q.dims {
panic(errors.AssertionFailedf(
"original dimensions %d do not match quantizer dims %d", len(original), q.dims))
"input dimensions %d do not match quantizer dims %d", len(input), q.dims))
}
if len(randomized) != q.dims {
if len(output) != q.dims {
panic(errors.AssertionFailedf(
"randomized dimensions %d do not match quantizer dims %d", len(original), q.dims))
"output dimensions %d do not match quantizer dims %d", len(output), q.dims))
}
copy(randomized, original)
copy(output, input)
}

// Quantize implements the Quantizer interface.
Expand Down
Loading
Loading