Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
133326: num32: add num32 package r=andy-kimball a=andy-kimball

Move float32 numeric functions into new util/num32 package. This package will be used by built-in SQL functions as well as the vector indexing library. It deliberately uses simple float32 input/output types in order to stay as decoupled as possible from CRDB-specific types.

Epic: CRDB-42943

Release note: None

Co-authored-by: Andrew Kimball <[email protected]>
  • Loading branch information
craig[bot] and andy-kimball committed Oct 24, 2024
2 parents 25850c6 + 0ac9fb4 commit 3bab131
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 95 deletions.
3 changes: 3 additions & 0 deletions pkg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ ALL_TESTS = [
"//pkg/util/mon:mon_test",
"//pkg/util/netutil/addr:addr_test",
"//pkg/util/netutil:netutil_test",
"//pkg/util/num32:num32_test",
"//pkg/util/optional:optional_test",
"//pkg/util/parquet:parquet_test",
"//pkg/util/pprofutil:pprofutil_test",
Expand Down Expand Up @@ -2543,6 +2544,8 @@ GO_TARGETS = [
"//pkg/util/netutil/addr:addr_test",
"//pkg/util/netutil:netutil",
"//pkg/util/netutil:netutil_test",
"//pkg/util/num32:num32",
"//pkg/util/num32:num32_test",
"//pkg/util/optional:optional",
"//pkg/util/optional:optional_test",
"//pkg/util/parquet:parquet",
Expand Down
19 changes: 19 additions & 0 deletions pkg/util/num32/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "num32",
srcs = [
"doc.go",
"vec.go",
],
importpath = "github.com/cockroachdb/cockroach/pkg/util/num32",
visibility = ["//visibility:public"],
deps = ["@com_github_cockroachdb_errors//:errors"],
)

go_test(
name = "num32_test",
srcs = ["vec_test.go"],
embed = [":num32"],
deps = ["@com_github_stretchr_testify//require"],
)
22 changes: 22 additions & 0 deletions pkg/util/num32/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

/*
Package num32 contains basic numeric functions that operate on scalar, vector,
and matrix float32 values. Inputs and outputs deliberately use simple float
types so that they can be used in multiple contexts. It uses the gonum library
when possible, since it offers assembly language implementations of various
useful primitives.
Using the same convention as gonum, when a slice is being modified in place, it
has the name dst and the function does not return a value.
Where possible, functions in this package are written with the assumption that
the caller prevents bad input. They will panic with assertion errors if this is
not the case, rather than returning error values. Callers should generally have
panic recovery logic further up the stack to gracefully handle these assertions,
as they indicate buggy code.
*/
package num32
55 changes: 55 additions & 0 deletions pkg/util/num32/vec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package num32

import (
"math"

"github.com/cockroachdb/errors"
)

// L1Distance returns the L1 norm of s - t, which is the Manhattan distance
// between the two vectors.
func L1Distance(s []float32, t []float32) float32 {
checkDims(s, t)
var distance float32
for i := range s {
diff := s[i] - t[i]
distance += float32(math.Abs(float64(diff)))
}
return distance
}

// L2SquaredDistance returns the squared L2 norm of s - t, which is the squared
// Euclidean distance between the two vectors. Comparing squared distance is
// equivalent to comparing distance, but the squared distance avoids an
// expensive square-root operation.
func L2SquaredDistance(s, t []float32) float32 {
checkDims(s, t)
var distance float32
for i := range s {
diff := s[i] - t[i]
distance += diff * diff
}
return distance
}

// InnerProduct returns the inner product of t1 and t2, also called the dot
// product.
func InnerProduct(s []float32, t []float32) float32 {
checkDims(s, t)
var distance float32
for i := range s {
distance += s[i] * t[i]
}
return distance
}

func checkDims(v []float32, v2 []float32) {
if len(v) != len(v2) {
panic(errors.AssertionFailedf("different vector dimensions %d and %d", len(v), len(v2)))
}
}
77 changes: 77 additions & 0 deletions pkg/util/num32/vec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package num32

import (
"math"
"testing"

"github.com/stretchr/testify/require"
)

var NaN32 = float32(math.NaN())
var Inf32 = float32(math.Inf(1))

func TestDistances(t *testing.T) {
// Test L1, L2, Cosine distance.
testCases := []struct {
v1 []float32
v2 []float32
l1 float32
l2s float32
panics bool
}{
{v1: []float32{}, v2: []float32{}, l1: 0, l2s: 0},
{v1: []float32{1, 2, 3}, v2: []float32{4, 5, 6}, l1: 9, l2s: 27},
{v1: []float32{-1, -2, -3}, v2: []float32{-4, -5, -6}, l1: 9, l2s: 27},
{v1: []float32{1, 2, 3}, v2: []float32{1, 2, 3}, l1: 0, l2s: 0},
{v1: []float32{1, 2, 3}, v2: []float32{1, 2, 4}, l1: 1, l2s: 1},
{v1: []float32{NaN32}, v2: []float32{1}, l1: NaN32, l2s: NaN32},
{v1: []float32{Inf32}, v2: []float32{1}, l1: Inf32, l2s: Inf32},
{v1: []float32{1, 2}, v2: []float32{3, 4, 5}, panics: true},
}

for _, tc := range testCases {
if !tc.panics {
l1 := L1Distance(tc.v1, tc.v2)
l2s := L2SquaredDistance(tc.v1, tc.v2)
require.InDelta(t, tc.l1, l1, 0.000001)
require.InDelta(t, tc.l2s, l2s, 0.000001)
} else {
require.Panics(t, func() { L1Distance(tc.v1, tc.v2) })
require.Panics(t, func() { L2SquaredDistance(tc.v1, tc.v2) })
}
}
}

func TestInnerProduct(t *testing.T) {
// Test inner product and negative inner product
testCases := []struct {
v1 []float32
v2 []float32
ip float32
panics bool
}{
{v1: []float32{}, v2: []float32{}, ip: 0},
{v1: []float32{1, 2, 3}, v2: []float32{4, 5, 6}, ip: 32},
{v1: []float32{-1, -2, -3}, v2: []float32{-4, -5, -6}, ip: 32},
{v1: []float32{0, 0, 0}, v2: []float32{0, 0, 0}, ip: 0},
{v1: []float32{1, 2, 3}, v2: []float32{1, 2, 3}, ip: 14},
{v1: []float32{1, 2, 3}, v2: []float32{1, 2, 4}, ip: 17},
{v1: []float32{NaN32}, v2: []float32{1}, ip: NaN32},
{v1: []float32{Inf32}, v2: []float32{1}, ip: Inf32},
{v1: []float32{1, 2}, v2: []float32{3, 4, 5}, panics: true},
}

for _, tc := range testCases {
if !tc.panics {
ip := InnerProduct(tc.v1, tc.v2)
require.InDelta(t, tc.ip, ip, 0.000001)
} else {
require.Panics(t, func() { InnerProduct(tc.v1, tc.v2) })
}
}
}
1 change: 1 addition & 0 deletions pkg/util/vector/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ go_library(
"//pkg/sql/pgwire/pgcode",
"//pkg/sql/pgwire/pgerror",
"//pkg/util/encoding",
"//pkg/util/num32",
"@com_github_cockroachdb_errors//:errors",
],
)
Expand Down
61 changes: 30 additions & 31 deletions pkg/util/vector/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/util/encoding"
"github.com/cockroachdb/cockroach/pkg/util/num32"
)

// MaxDim is the maximum number of dimensions a vector can have.
Expand Down Expand Up @@ -137,54 +138,49 @@ func Decode(b []byte) (ret T, err error) {
return ret, nil
}

func checkDims(t T, t2 T) error {
if len(t) != len(t2) {
return pgerror.Newf(pgcode.DataException, "different vector dimensions %d and %d", len(t), len(t2))
}
return nil
}

// L1Distance returns the L1 (Manhattan) distance between t and t2.
func L1Distance(t T, t2 T) (float64, error) {
if err := checkDims(t, t2); err != nil {
return 0, err
}
var distance float32
for i := range len(t) {
diff := t[i] - t2[i]
distance += float32(math.Abs(float64(diff)))
}
return float64(distance), nil
return float64(num32.L1Distance(t, t2)), nil
}

// L2Distance returns the Euclidean distance between t and t2.
func L2Distance(t T, t2 T) (float64, error) {
if err := checkDims(t, t2); err != nil {
return 0, err
}
var distance float32
for i := range len(t) {
diff := t[i] - t2[i]
distance += diff * diff
}
// TODO(queries): check for overflow and validate intermediate result if needed.
return math.Sqrt(float64(distance)), nil
return math.Sqrt(float64(num32.L2SquaredDistance(t, t2))), nil
}

// CosDistance returns the cosine distance between t and t2.
// CosDistance returns the cosine distance between t and t2. This represents the
// similarity between the two vectors, ranging from 0 (most similar) to 2 (least
// similar). Only the angle between the vectors matters; the norms (magnitudes)
// are irrelevant.
func CosDistance(t T, t2 T) (float64, error) {
if err := checkDims(t, t2); err != nil {
return 0, err
}
var distance, normA, normB float32
for i := range len(t) {
distance += t[i] * t2[i]

// Compute the cosine of the angle between the two vectors as their dot
// product divided by the product of their norms:
// t·t2
// -----------
// ||t|| ||t2||
var dot, normA, normB float32
for i := range t {
dot += t[i] * t2[i]
normA += t[i] * t[i]
normB += t2[i] * t2[i]
}
// Use sqrt(a * b) over sqrt(a) * sqrt(b)
similarity := float64(distance) / math.Sqrt(float64(normA)*float64(normB))
/* Keep in range */

// Use sqrt(a * b) over sqrt(a) * sqrt(b) to compute norms.
similarity := float64(dot) / math.Sqrt(float64(normA)*float64(normB))

// Cosine distance = 1 - cosine similarity. Ensure that similarity always
// stays within [-1, 1] despite any floating point arithmetic error.
if similarity > 1 {
similarity = 1
} else if similarity < -1 {
Expand All @@ -198,11 +194,7 @@ func InnerProduct(t T, t2 T) (float64, error) {
if err := checkDims(t, t2); err != nil {
return 0, err
}
var distance float32
for i := range len(t) {
distance += t[i] * t2[i]
}
return float64(distance), nil
return float64(num32.InnerProduct(t, t2)), nil
}

// NegInnerProduct returns the negative inner product of t1 and t2.
Expand Down Expand Up @@ -290,3 +282,10 @@ func Random(rng *rand.Rand) T {
}
return v
}

func checkDims(t T, t2 T) error {
if len(t) != len(t2) {
return pgerror.Newf(pgcode.DataException, "different vector dimensions %d and %d", len(t), len(t2))
}
return nil
}
9 changes: 4 additions & 5 deletions pkg/util/vector/vector_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,10 @@ func (vs *Set) AddZero(count int) {
}
}

// Remove removes the vector at the given offset from the set. This is an O(1)
// operation.
// NB: This operation changes the ordering of vectors in the set, with the last
// vector moved to the removed vector's offset.
func (vs *Set) Remove(offset int) {
// ReplaceWithLast removes the vector at the given offset from the set,
// replacing it with the last vector in the set. The modified set has one less
// element and the last vector's position changes.
func (vs *Set) ReplaceWithLast(offset int) {
targetStart := offset * vs.Dims
sourceEnd := len(vs.Data)
copy(vs.Data[targetStart:targetStart+vs.Dims], vs.Data[sourceEnd-vs.Dims:sourceEnd])
Expand Down
12 changes: 6 additions & 6 deletions pkg/util/vector/vector_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ func TestVectorSet(t *testing.T) {
require.Equal(t, 8, vs.Count)
require.Equal(t, []float32{1, 2, 5, 3, 6, 6, 1, 2, 5, 3, 6, 6, 0, 0, 0, 0}, vs.Data)

// Remove.
vs.Remove(1)
vs.Remove(4)
vs.Remove(5)
// ReplaceWithLast.
vs.ReplaceWithLast(1)
vs.ReplaceWithLast(4)
vs.ReplaceWithLast(5)
require.Equal(t, 5, vs.Count)
require.Equal(t, []float32{1, 2, 0, 0, 6, 6, 1, 2, 0, 0}, vs.Data)

Expand Down Expand Up @@ -95,12 +95,12 @@ func TestVectorSet(t *testing.T) {
require.Panics(t, func() { vs11.SplitAt(-1) })
require.Panics(t, func() { vs11.AddZero(-1) })
require.Panics(t, func() { vs11.AddSet(nil) })
require.Panics(t, func() { vs11.Remove(-1) })
require.Panics(t, func() { vs11.ReplaceWithLast(-1) })

vs12 := MakeSet(2)
require.Panics(t, func() { vs12.At(0) })
require.Panics(t, func() { vs12.SplitAt(1) })
require.Panics(t, func() { vs12.Remove(0) })
require.Panics(t, func() { vs12.ReplaceWithLast(0) })

vs13 := MakeSet(-1)
require.Panics(t, func() { vs13.Add(v1) })
Expand Down
Loading

0 comments on commit 3bab131

Please sign in to comment.