forked from cockroachdb/cockroach
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
133326: num32: add num32 package r=andy-kimball a=andy-kimball Move float32 numeric functions into new util/num32 package. This package will be used by built-in SQL functions as well as the vector indexing library. It deliberately uses simple float32 input/output types in order to stay as decoupled as possible from CRDB-specific types. Epic: CRDB-42943 Release note: None Co-authored-by: Andrew Kimball <[email protected]>
- Loading branch information
Showing
10 changed files
with
260 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") | ||
|
||
go_library( | ||
name = "num32", | ||
srcs = [ | ||
"doc.go", | ||
"vec.go", | ||
], | ||
importpath = "github.com/cockroachdb/cockroach/pkg/util/num32", | ||
visibility = ["//visibility:public"], | ||
deps = ["@com_github_cockroachdb_errors//:errors"], | ||
) | ||
|
||
go_test( | ||
name = "num32_test", | ||
srcs = ["vec_test.go"], | ||
embed = [":num32"], | ||
deps = ["@com_github_stretchr_testify//require"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Copyright 2024 The Cockroach Authors. | ||
// | ||
// Use of this software is governed by the CockroachDB Software License | ||
// included in the /LICENSE file. | ||
|
||
/* | ||
Package num32 contains basic numeric functions that operate on scalar, vector, | ||
and matrix float32 values. Inputs and outputs deliberately use simple float | ||
types so that they can be used in multiple contexts. It uses the gonum library | ||
when possible, since it offers assembly language implementations of various | ||
useful primitives. | ||
Using the same convention as gonum, when a slice is being modified in place, it | ||
has the name dst and the function does not return a value. | ||
Where possible, functions in this package are written with the assumption that | ||
the caller prevents bad input. They will panic with assertion errors if this is | ||
not the case, rather than returning error values. Callers should generally have | ||
panic recovery logic further up the stack to gracefully handle these assertions, | ||
as they indicate buggy code. | ||
*/ | ||
package num32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// Copyright 2024 The Cockroach Authors. | ||
// | ||
// Use of this software is governed by the CockroachDB Software License | ||
// included in the /LICENSE file. | ||
|
||
package num32 | ||
|
||
import ( | ||
"math" | ||
|
||
"github.com/cockroachdb/errors" | ||
) | ||
|
||
// L1Distance returns the L1 norm of s - t, which is the Manhattan distance | ||
// between the two vectors. | ||
func L1Distance(s []float32, t []float32) float32 { | ||
checkDims(s, t) | ||
var distance float32 | ||
for i := range s { | ||
diff := s[i] - t[i] | ||
distance += float32(math.Abs(float64(diff))) | ||
} | ||
return distance | ||
} | ||
|
||
// L2SquaredDistance returns the squared L2 norm of s - t, which is the squared | ||
// Euclidean distance between the two vectors. Comparing squared distance is | ||
// equivalent to comparing distance, but the squared distance avoids an | ||
// expensive square-root operation. | ||
func L2SquaredDistance(s, t []float32) float32 { | ||
checkDims(s, t) | ||
var distance float32 | ||
for i := range s { | ||
diff := s[i] - t[i] | ||
distance += diff * diff | ||
} | ||
return distance | ||
} | ||
|
||
// InnerProduct returns the inner product of t1 and t2, also called the dot | ||
// product. | ||
func InnerProduct(s []float32, t []float32) float32 { | ||
checkDims(s, t) | ||
var distance float32 | ||
for i := range s { | ||
distance += s[i] * t[i] | ||
} | ||
return distance | ||
} | ||
|
||
func checkDims(v []float32, v2 []float32) { | ||
if len(v) != len(v2) { | ||
panic(errors.AssertionFailedf("different vector dimensions %d and %d", len(v), len(v2))) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright 2024 The Cockroach Authors. | ||
// | ||
// Use of this software is governed by the CockroachDB Software License | ||
// included in the /LICENSE file. | ||
|
||
package num32 | ||
|
||
import ( | ||
"math" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
var NaN32 = float32(math.NaN()) | ||
var Inf32 = float32(math.Inf(1)) | ||
|
||
func TestDistances(t *testing.T) { | ||
// Test L1, L2, Cosine distance. | ||
testCases := []struct { | ||
v1 []float32 | ||
v2 []float32 | ||
l1 float32 | ||
l2s float32 | ||
panics bool | ||
}{ | ||
{v1: []float32{}, v2: []float32{}, l1: 0, l2s: 0}, | ||
{v1: []float32{1, 2, 3}, v2: []float32{4, 5, 6}, l1: 9, l2s: 27}, | ||
{v1: []float32{-1, -2, -3}, v2: []float32{-4, -5, -6}, l1: 9, l2s: 27}, | ||
{v1: []float32{1, 2, 3}, v2: []float32{1, 2, 3}, l1: 0, l2s: 0}, | ||
{v1: []float32{1, 2, 3}, v2: []float32{1, 2, 4}, l1: 1, l2s: 1}, | ||
{v1: []float32{NaN32}, v2: []float32{1}, l1: NaN32, l2s: NaN32}, | ||
{v1: []float32{Inf32}, v2: []float32{1}, l1: Inf32, l2s: Inf32}, | ||
{v1: []float32{1, 2}, v2: []float32{3, 4, 5}, panics: true}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
if !tc.panics { | ||
l1 := L1Distance(tc.v1, tc.v2) | ||
l2s := L2SquaredDistance(tc.v1, tc.v2) | ||
require.InDelta(t, tc.l1, l1, 0.000001) | ||
require.InDelta(t, tc.l2s, l2s, 0.000001) | ||
} else { | ||
require.Panics(t, func() { L1Distance(tc.v1, tc.v2) }) | ||
require.Panics(t, func() { L2SquaredDistance(tc.v1, tc.v2) }) | ||
} | ||
} | ||
} | ||
|
||
func TestInnerProduct(t *testing.T) { | ||
// Test inner product and negative inner product | ||
testCases := []struct { | ||
v1 []float32 | ||
v2 []float32 | ||
ip float32 | ||
panics bool | ||
}{ | ||
{v1: []float32{}, v2: []float32{}, ip: 0}, | ||
{v1: []float32{1, 2, 3}, v2: []float32{4, 5, 6}, ip: 32}, | ||
{v1: []float32{-1, -2, -3}, v2: []float32{-4, -5, -6}, ip: 32}, | ||
{v1: []float32{0, 0, 0}, v2: []float32{0, 0, 0}, ip: 0}, | ||
{v1: []float32{1, 2, 3}, v2: []float32{1, 2, 3}, ip: 14}, | ||
{v1: []float32{1, 2, 3}, v2: []float32{1, 2, 4}, ip: 17}, | ||
{v1: []float32{NaN32}, v2: []float32{1}, ip: NaN32}, | ||
{v1: []float32{Inf32}, v2: []float32{1}, ip: Inf32}, | ||
{v1: []float32{1, 2}, v2: []float32{3, 4, 5}, panics: true}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
if !tc.panics { | ||
ip := InnerProduct(tc.v1, tc.v2) | ||
require.InDelta(t, tc.ip, ip, 0.000001) | ||
} else { | ||
require.Panics(t, func() { InnerProduct(tc.v1, tc.v2) }) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.