From 3e5a520bc6745332b32cdd69de0f5f25fe4b429e Mon Sep 17 00:00:00 2001 From: Jordan Lewis Date: Sat, 16 Mar 2024 13:26:07 -0600 Subject: [PATCH] sql: implement pgvector datatype and evaluation Release note (sql change): implement pgvector encoding, decoding, and operators, without index acceleration. --- docs/generated/sql/bnf/stmt_block.bnf | 12 +- docs/generated/sql/functions.md | 19 ++ docs/generated/sql/operators.md | 23 ++ pkg/BUILD.bazel | 5 + pkg/ccl/changefeedccl/avro_test.go | 3 + pkg/ccl/changefeedccl/encoder_test.go | 2 +- .../logic_test/mixed_version_pgvector | 63 ++++ .../logictestccl/testdata/logic_test/vector | 103 +++++++ .../tests/3node-tenant/generated_test.go | 7 + .../cockroach-go-testserver-23.2/BUILD.bazel | 30 ++ .../generated_test.go | 86 ++++++ .../tests/fakedist-disk/BUILD.bazel | 2 +- .../tests/fakedist-disk/generated_test.go | 7 + .../tests/fakedist-vec-off/BUILD.bazel | 2 +- .../tests/fakedist-vec-off/generated_test.go | 7 + .../logictestccl/tests/fakedist/BUILD.bazel | 2 +- .../tests/fakedist/generated_test.go | 7 + .../local-legacy-schema-changer/BUILD.bazel | 2 +- .../generated_test.go | 7 + .../tests/local-read-committed/BUILD.bazel | 2 +- .../local-read-committed/generated_test.go | 7 + .../tests/local-vec-off/BUILD.bazel | 2 +- .../tests/local-vec-off/generated_test.go | 7 + pkg/ccl/logictestccl/tests/local/BUILD.bazel | 2 +- .../tests/local/generated_test.go | 7 + pkg/sql/alter_column_type.go | 2 +- pkg/sql/catalog/colinfo/BUILD.bazel | 2 + pkg/sql/catalog/colinfo/col_type_info.go | 19 +- .../catalog/colinfo/column_type_properties.go | 1 + pkg/sql/catalog/tabledesc/table.go | 2 +- pkg/sql/exec_util.go | 1 + pkg/sql/logictest/BUILD.bazel | 1 + .../logictest/testdata/logic_test/grant_table | 30 ++ .../logictest/testdata/logic_test/pg_catalog | 10 + .../logictest/testdata/logic_test/vectoross | 5 + .../tests/fakedist-disk/generated_test.go | 7 + .../tests/fakedist-vec-off/generated_test.go | 7 + .../tests/fakedist/generated_test.go | 7 + .../generated_test.go | 7 + .../tests/local-vec-off/generated_test.go | 7 + .../logictest/tests/local/generated_test.go | 7 + pkg/sql/oidext/oidext.go | 4 + pkg/sql/opt/operator.go | 37 +-- pkg/sql/opt/ops/scalar.opt | 24 ++ pkg/sql/opt/optbuilder/scalar.go | 6 + pkg/sql/parser/BUILD.bazel | 1 + pkg/sql/parser/sql.y | 45 ++- pkg/sql/parser/testdata/create_table | 16 + pkg/sql/parser/testdata/select_exprs | 40 +++ pkg/sql/pg_catalog.go | 1 + pkg/sql/pgwire/pgwirebase/BUILD.bazel | 1 + pkg/sql/pgwire/pgwirebase/encoding.go | 7 + pkg/sql/pgwire/types.go | 4 + pkg/sql/randgen/BUILD.bazel | 1 + pkg/sql/randgen/datum.go | 3 + pkg/sql/randgen/type.go | 8 +- pkg/sql/rowenc/encoded_datum.go | 6 +- pkg/sql/rowenc/encoded_datum_test.go | 2 +- pkg/sql/rowenc/keyside/BUILD.bazel | 1 + pkg/sql/rowenc/keyside/keyside_test.go | 9 +- pkg/sql/rowenc/valueside/BUILD.bazel | 1 + pkg/sql/rowenc/valueside/array.go | 7 + pkg/sql/rowenc/valueside/decode.go | 11 + pkg/sql/rowenc/valueside/encode.go | 7 + pkg/sql/rowenc/valueside/legacy.go | 20 ++ pkg/sql/scanner/scan.go | 20 ++ .../comparator_generated_test.go | 5 + pkg/sql/sem/builtins/BUILD.bazel | 2 + .../builtins/builtinconstants/constants.go | 1 + pkg/sql/sem/builtins/fixed_oids.go | 17 ++ pkg/sql/sem/builtins/pgvector_builtins.go | 142 +++++++++ pkg/sql/sem/cast/cast.go | 14 + pkg/sql/sem/cast/cast_map.go | 62 ++-- pkg/sql/sem/eval/BUILD.bazel | 1 + pkg/sql/sem/eval/binary_op.go | 64 ++++ pkg/sql/sem/eval/cast.go | 44 +++ pkg/sql/sem/eval/testdata/eval/vector | 143 +++++++++ pkg/sql/sem/eval/unsupported_types.go | 26 +- pkg/sql/sem/tree/BUILD.bazel | 1 + pkg/sql/sem/tree/constant.go | 2 + pkg/sql/sem/tree/datum.go | 118 ++++++- pkg/sql/sem/tree/eval.go | 53 ++++ pkg/sql/sem/tree/eval_binary_ops.go | 15 + pkg/sql/sem/tree/eval_expr_generated.go | 5 + pkg/sql/sem/tree/eval_op_generated.go | 36 +++ pkg/sql/sem/tree/expr.go | 2 + pkg/sql/sem/tree/parse_string.go | 2 + pkg/sql/sem/tree/treebin/binary_operator.go | 6 + pkg/sql/sem/tree/type_check.go | 6 + pkg/sql/sem/tree/walk.go | 3 + pkg/sql/stats/stats_test.go | 2 +- pkg/sql/types/oid.go | 3 + pkg/sql/types/types.go | 39 ++- pkg/sql/types/types.proto | 4 + pkg/util/encoding/encoding.go | 27 +- pkg/util/encoding/type_string.go | 3 + pkg/util/parquet/writer_bench_test.go | 2 +- pkg/util/parquet/writer_test.go | 2 +- pkg/util/vector/BUILD.bazel | 23 ++ pkg/util/vector/vector.go | 288 ++++++++++++++++++ pkg/util/vector/vector_test.go | 208 +++++++++++++ pkg/workload/rand/rand.go | 2 + 102 files changed, 2100 insertions(+), 86 deletions(-) create mode 100644 pkg/ccl/logictestccl/testdata/logic_test/mixed_version_pgvector create mode 100644 pkg/ccl/logictestccl/testdata/logic_test/vector create mode 100644 pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/BUILD.bazel create mode 100644 pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/generated_test.go create mode 100644 pkg/sql/logictest/testdata/logic_test/vectoross create mode 100644 pkg/sql/sem/builtins/pgvector_builtins.go create mode 100644 pkg/sql/sem/eval/testdata/eval/vector create mode 100644 pkg/util/vector/BUILD.bazel create mode 100644 pkg/util/vector/vector.go create mode 100644 pkg/util/vector/vector_test.go diff --git a/docs/generated/sql/bnf/stmt_block.bnf b/docs/generated/sql/bnf/stmt_block.bnf index 82cecd617097..b2955c979faf 100644 --- a/docs/generated/sql/bnf/stmt_block.bnf +++ b/docs/generated/sql/bnf/stmt_block.bnf @@ -1568,6 +1568,7 @@ col_name_keyword ::= | 'VALUES' | 'VARBIT' | 'VARCHAR' + | 'VECTOR' | 'VIRTUAL' | 'WORK' @@ -1755,7 +1756,7 @@ backup_options_list ::= ( backup_options ) ( ( ',' backup_options ) )* a_expr ::= - ( c_expr | '+' a_expr | '-' a_expr | '~' a_expr | 'SQRT' a_expr | 'CBRT' a_expr | qual_op a_expr | 'NOT' a_expr | 'NOT' a_expr | row 'OVERLAPS' row | 'DEFAULT' ) ( ( 'TYPECAST' cast_target | 'TYPEANNOTATE' typename | 'COLLATE' collation_name | 'AT' 'TIME' 'ZONE' a_expr | '+' a_expr | '-' a_expr | '*' a_expr | '/' a_expr | 'FLOORDIV' a_expr | '%' a_expr | '^' a_expr | '#' a_expr | '&' a_expr | '|' a_expr | '<' a_expr | '>' a_expr | '?' a_expr | 'JSON_SOME_EXISTS' a_expr | 'JSON_ALL_EXISTS' a_expr | 'CONTAINS' a_expr | 'CONTAINED_BY' a_expr | '=' a_expr | 'CONCAT' a_expr | 'LSHIFT' a_expr | 'RSHIFT' a_expr | 'FETCHVAL' a_expr | 'FETCHTEXT' a_expr | 'FETCHVAL_PATH' a_expr | 'FETCHTEXT_PATH' a_expr | 'REMOVE_PATH' a_expr | 'INET_CONTAINED_BY_OR_EQUALS' a_expr | 'AND_AND' a_expr | 'AT_AT' a_expr | 'INET_CONTAINS_OR_EQUALS' a_expr | 'LESS_EQUALS' a_expr | 'GREATER_EQUALS' a_expr | 'NOT_EQUALS' a_expr | qual_op a_expr | 'AND' a_expr | 'OR' a_expr | 'LIKE' a_expr | 'LIKE' a_expr 'ESCAPE' a_expr | 'NOT' 'LIKE' a_expr | 'NOT' 'LIKE' a_expr 'ESCAPE' a_expr | 'ILIKE' a_expr | 'ILIKE' a_expr 'ESCAPE' a_expr | 'NOT' 'ILIKE' a_expr | 'NOT' 'ILIKE' a_expr 'ESCAPE' a_expr | 'SIMILAR' 'TO' a_expr | 'SIMILAR' 'TO' a_expr 'ESCAPE' a_expr | 'NOT' 'SIMILAR' 'TO' a_expr | 'NOT' 'SIMILAR' 'TO' a_expr 'ESCAPE' a_expr | '~' a_expr | 'NOT_REGMATCH' a_expr | 'REGIMATCH' a_expr | 'NOT_REGIMATCH' a_expr | 'IS' 'NAN' | 'IS' 'NOT' 'NAN' | 'IS' 'NULL' | 'ISNULL' | 'IS' 'NOT' 'NULL' | 'NOTNULL' | 'IS' 'TRUE' | 'IS' 'NOT' 'TRUE' | 'IS' 'FALSE' | 'IS' 'NOT' 'FALSE' | 'IS' 'UNKNOWN' | 'IS' 'NOT' 'UNKNOWN' | 'IS' 'DISTINCT' 'FROM' a_expr | 'IS' 'NOT' 'DISTINCT' 'FROM' a_expr | 'IS' 'OF' '(' type_list ')' | 'IS' 'NOT' 'OF' '(' type_list ')' | 'BETWEEN' opt_asymmetric b_expr 'AND' a_expr | 'NOT' 'BETWEEN' opt_asymmetric b_expr 'AND' a_expr | 'BETWEEN' 'SYMMETRIC' b_expr 'AND' a_expr | 'NOT' 'BETWEEN' 'SYMMETRIC' b_expr 'AND' a_expr | 'IN' in_expr | 'NOT' 'IN' in_expr | subquery_op sub_type a_expr ) )* + ( c_expr | '+' a_expr | '-' a_expr | '~' a_expr | 'SQRT' a_expr | 'CBRT' a_expr | qual_op a_expr | 'NOT' a_expr | 'NOT' a_expr | row 'OVERLAPS' row | 'DEFAULT' ) ( ( 'TYPECAST' cast_target | 'TYPEANNOTATE' typename | 'COLLATE' collation_name | 'AT' 'TIME' 'ZONE' a_expr | '+' a_expr | '-' a_expr | '*' a_expr | '/' a_expr | 'FLOORDIV' a_expr | '%' a_expr | '^' a_expr | '#' a_expr | '&' a_expr | '|' a_expr | '<' a_expr | '>' a_expr | '?' a_expr | 'JSON_SOME_EXISTS' a_expr | 'JSON_ALL_EXISTS' a_expr | 'CONTAINS' a_expr | 'CONTAINED_BY' a_expr | '=' a_expr | 'CONCAT' a_expr | 'LSHIFT' a_expr | 'RSHIFT' a_expr | 'FETCHVAL' a_expr | 'FETCHTEXT' a_expr | 'FETCHVAL_PATH' a_expr | 'FETCHTEXT_PATH' a_expr | 'REMOVE_PATH' a_expr | 'INET_CONTAINED_BY_OR_EQUALS' a_expr | 'AND_AND' a_expr | 'AT_AT' a_expr | 'DISTANCE' a_expr | 'COS_DISTANCE' a_expr | 'NEG_INNER_PRODUCT' a_expr | 'INET_CONTAINS_OR_EQUALS' a_expr | 'LESS_EQUALS' a_expr | 'GREATER_EQUALS' a_expr | 'NOT_EQUALS' a_expr | qual_op a_expr | 'AND' a_expr | 'OR' a_expr | 'LIKE' a_expr | 'LIKE' a_expr 'ESCAPE' a_expr | 'NOT' 'LIKE' a_expr | 'NOT' 'LIKE' a_expr 'ESCAPE' a_expr | 'ILIKE' a_expr | 'ILIKE' a_expr 'ESCAPE' a_expr | 'NOT' 'ILIKE' a_expr | 'NOT' 'ILIKE' a_expr 'ESCAPE' a_expr | 'SIMILAR' 'TO' a_expr | 'SIMILAR' 'TO' a_expr 'ESCAPE' a_expr | 'NOT' 'SIMILAR' 'TO' a_expr | 'NOT' 'SIMILAR' 'TO' a_expr 'ESCAPE' a_expr | '~' a_expr | 'NOT_REGMATCH' a_expr | 'REGIMATCH' a_expr | 'NOT_REGIMATCH' a_expr | 'IS' 'NAN' | 'IS' 'NOT' 'NAN' | 'IS' 'NULL' | 'ISNULL' | 'IS' 'NOT' 'NULL' | 'NOTNULL' | 'IS' 'TRUE' | 'IS' 'NOT' 'TRUE' | 'IS' 'FALSE' | 'IS' 'NOT' 'FALSE' | 'IS' 'UNKNOWN' | 'IS' 'NOT' 'UNKNOWN' | 'IS' 'DISTINCT' 'FROM' a_expr | 'IS' 'NOT' 'DISTINCT' 'FROM' a_expr | 'IS' 'OF' '(' type_list ')' | 'IS' 'NOT' 'OF' '(' type_list ')' | 'BETWEEN' opt_asymmetric b_expr 'AND' a_expr | 'NOT' 'BETWEEN' opt_asymmetric b_expr 'AND' a_expr | 'BETWEEN' 'SYMMETRIC' b_expr 'AND' a_expr | 'NOT' 'BETWEEN' 'SYMMETRIC' b_expr 'AND' a_expr | 'IN' in_expr | 'NOT' 'IN' in_expr | subquery_op sub_type a_expr ) )* for_schedules_clause ::= 'FOR' 'SCHEDULES' select_stmt @@ -3139,6 +3140,9 @@ all_op ::= | 'NOT_REGIMATCH' | 'AND_AND' | 'AT_AT' + | 'DISTANCE' + | 'COS_DISTANCE' + | 'NEG_INNER_PRODUCT' | '~' | 'SQRT' | 'CBRT' @@ -3397,6 +3401,7 @@ const_typename ::= | character_with_length | const_datetime | const_geo + | const_vector interval_type ::= 'INTERVAL' @@ -4160,6 +4165,7 @@ bare_label_keywords ::= | 'VARCHAR' | 'VARIABLES' | 'VARIADIC' + | 'VECTOR' | 'VERIFY_BACKUP_TABLE_DATA' | 'VIEW' | 'VIEWACTIVITY' @@ -4313,6 +4319,10 @@ const_geo ::= | 'GEOMETRY' '(' geo_shape_type ',' signed_iconst ')' | 'GEOGRAPHY' '(' geo_shape_type ',' signed_iconst ')' +const_vector ::= + 'VECTOR' + | 'VECTOR' '(' iconst32 ')' + interval_qualifier ::= 'YEAR' | 'MONTH' diff --git a/docs/generated/sql/functions.md b/docs/generated/sql/functions.md index df9e17317f33..763f1c063d10 100644 --- a/docs/generated/sql/functions.md +++ b/docs/generated/sql/functions.md @@ -1279,6 +1279,25 @@ the locality flag on node startup. Returns an error if no region is set.

Stable +### PGVector functions + + + + + + + + + + +
Function → ReturnsDescriptionVolatility
cosine_distance(v1: vector, v2: vector) → float

Returns the cosine distance between the two vectors.

+
Immutable
inner_product(v1: vector, v2: vector) → float

Returns the inner product between the two vectors.

+
Immutable
l1_distance(v1: vector, v2: vector) → float

Returns the Manhattan distance between the two vectors.

+
Immutable
l2_distance(v1: vector, v2: vector) → float

Returns the Euclidean distance between the two vectors.

+
Immutable
vector_dims(vector: vector) → int

Returns the number of the dimensions in the vector.

+
Immutable
vector_norm(vector: vector) → float

Returns the Euclidean norm of the vector.

+
Immutable
+ ### STRING[] functions diff --git a/docs/generated/sql/operators.md b/docs/generated/sql/operators.md index 5df24f706f65..431da4df6506 100644 --- a/docs/generated/sql/operators.md +++ b/docs/generated/sql/operators.md @@ -55,6 +55,7 @@ +
interval * decimalinterval
interval * floatinterval
interval * intinterval
vector * vectorvector
@@ -89,6 +90,7 @@ +
+Return
timestamptz + intervaltimestamptz
timetz + datetimestamptz
timetz + intervaltimetz
vector + vectorvector
@@ -123,6 +125,7 @@ +
-Return
timestamptz - timestampinterval
timestamptz - timestamptzinterval
timetz - intervaltimetz
vector - vectorvector
@@ -213,6 +216,17 @@ + +
->Return
uuid < uuidbool
uuid[] < uuid[]bool
varbit < varbitbool
vector < vectorbool
+ + + + +
<#>Return
vector <#> vectorfloat
+ + + +
<->Return
vector <-> vectorfloat
@@ -278,6 +292,12 @@ + +
<<Return
uuid <= uuidbool
uuid[] <= uuid[]bool
varbit <= varbitbool
vector <= vectorbool
+ + + +
<=>Return
vector <=> vectorfloat
@@ -344,6 +364,7 @@ +
<@Return
uuid = uuidbool
uuid[] = uuid[]bool
varbit = varbitbool
vector = vectorbool
@@ -412,6 +433,7 @@ +
>>Return
tuple IN tuplebool
uuid IN tuplebool
varbit IN tuplebool
vector IN tuplebool
@@ -475,6 +497,7 @@ +
IS NOT DISTINCT FROMReturn
uuid IS NOT DISTINCT FROM uuidbool
uuid[] IS NOT DISTINCT FROM uuid[]bool
varbit IS NOT DISTINCT FROM varbitbool
vector IS NOT DISTINCT FROM vectorbool
void IS NOT DISTINCT FROM unknownbool
diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 14d74fb8aef9..2430998eab01 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -54,6 +54,7 @@ ALL_TESTS = [ "//pkg/ccl/logictestccl/tests/3node-tenant-multiregion:3node-tenant-multiregion_test", "//pkg/ccl/logictestccl/tests/3node-tenant:3node-tenant_test", "//pkg/ccl/logictestccl/tests/5node:5node_test", + "//pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2:cockroach-go-testserver-23_2_test", "//pkg/ccl/logictestccl/tests/fakedist-disk:fakedist-disk_test", "//pkg/ccl/logictestccl/tests/fakedist-vec-off:fakedist-vec-off_test", "//pkg/ccl/logictestccl/tests/fakedist:fakedist_test", @@ -754,6 +755,7 @@ ALL_TESTS = [ "//pkg/util/ulid:ulid_test", "//pkg/util/unique:unique_test", "//pkg/util/uuid:uuid_test", + "//pkg/util/vector:vector_test", "//pkg/util/version:version_test", "//pkg/util:util_test", "//pkg/workload/bank:bank_test", @@ -899,6 +901,7 @@ GO_TARGETS = [ "//pkg/ccl/logictestccl/tests/3node-tenant-multiregion:3node-tenant-multiregion_test", "//pkg/ccl/logictestccl/tests/3node-tenant:3node-tenant_test", "//pkg/ccl/logictestccl/tests/5node:5node_test", + "//pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2:cockroach-go-testserver-23_2_test", "//pkg/ccl/logictestccl/tests/fakedist-disk:fakedist-disk_test", "//pkg/ccl/logictestccl/tests/fakedist-vec-off:fakedist-vec-off_test", "//pkg/ccl/logictestccl/tests/fakedist:fakedist_test", @@ -2604,6 +2607,8 @@ GO_TARGETS = [ "//pkg/util/unique:unique_test", "//pkg/util/uuid:uuid", "//pkg/util/uuid:uuid_test", + "//pkg/util/vector:vector", + "//pkg/util/vector:vector_test", "//pkg/util/version:version", "//pkg/util/version:version_test", "//pkg/util:util", diff --git a/pkg/ccl/changefeedccl/avro_test.go b/pkg/ccl/changefeedccl/avro_test.go index a7b87e0290cf..c852b653b97f 100644 --- a/pkg/ccl/changefeedccl/avro_test.go +++ b/pkg/ccl/changefeedccl/avro_test.go @@ -320,6 +320,9 @@ func TestAvroSchema(t *testing.T) { case types.AnyFamily, types.OidFamily, types.TupleFamily: // These aren't expected to be needed for changefeeds. return true + case types.PGVectorFamily: + // We don't support PGVector in Avro yet. + return true case types.ArrayFamily: if !randgen.IsAllowedForArray(typ.ArrayContents()) { return true diff --git a/pkg/ccl/changefeedccl/encoder_test.go b/pkg/ccl/changefeedccl/encoder_test.go index eddb0d177a78..e0a15735f5b4 100644 --- a/pkg/ccl/changefeedccl/encoder_test.go +++ b/pkg/ccl/changefeedccl/encoder_test.go @@ -1167,7 +1167,7 @@ func TestJsonRountrip(t *testing.T) { switch typ { case types.Jsonb: // Unsupported by sql/catalog/colinfo - case types.TSQuery, types.TSVector: + case types.TSQuery, types.TSVector, types.PGVector: // Unsupported by pkg/sql/parser default: if arrayTyp.InternalType.ArrayContents == typ { diff --git a/pkg/ccl/logictestccl/testdata/logic_test/mixed_version_pgvector b/pkg/ccl/logictestccl/testdata/logic_test/mixed_version_pgvector new file mode 100644 index 000000000000..30d74d9dc1dd --- /dev/null +++ b/pkg/ccl/logictestccl/testdata/logic_test/mixed_version_pgvector @@ -0,0 +1,63 @@ +# LogicTest: cockroach-go-testserver-23.2 + +# Verify that all nodes are running the previous version. + +query T nodeidx=0 +SELECT crdb_internal.node_executable_version() +---- +23.2 + +query T nodeidx=1 +SELECT crdb_internal.node_executable_version() +---- +23.2 + +query T nodeidx=2 +SELECT crdb_internal.node_executable_version() +---- +23.2 + +statement error syntax error +CREATE TABLE t (v VECTOR(1)) + +# Upgrade one node to 24.2 + +upgrade 0 + +# Verify that node index 0 is now running 24.2 binary. + +query T nodeidx=0 +SELECT crdb_internal.release_series(crdb_internal.node_executable_version()) +---- +24.2 + +statement error pg_vector not supported until version 24.2 +CREATE TABLE t (v VECTOR(1)) + +upgrade 1 + +upgrade 2 + +statement ok +SET CLUSTER SETTING version = crdb_internal.node_executable_version(); + +query T nodeidx=1 +SELECT crdb_internal.release_series(crdb_internal.node_executable_version()) +---- +24.2 + +query T nodeidx=2 +SELECT crdb_internal.release_series(crdb_internal.node_executable_version()) +---- +24.2 + +query B retry +SELECT crdb_internal.is_at_least_version('24.1-02') +---- +true + +# Note: the following statement would succeed if there cluster had an enterprise +# license, but the mixed version logic framework doesn't support adding one. +# This is tested normally in the vector ccl logic test. +statement error pgcode XXC02 use of vector datatype requires an enterprise license +CREATE TABLE t (v VECTOR(1)) diff --git a/pkg/ccl/logictestccl/testdata/logic_test/vector b/pkg/ccl/logictestccl/testdata/logic_test/vector new file mode 100644 index 000000000000..153e6eeb1ca7 --- /dev/null +++ b/pkg/ccl/logictestccl/testdata/logic_test/vector @@ -0,0 +1,103 @@ +# LogicTest: !local-mixed-23.2 + +query F +SELECT '[1,2,3]'::vector <-> '[4,5,6]'::vector +---- +5.196152422706632 + +statement error pgcode 42601 dimensions for type vector must be at least 1 +CREATE TABLE v (v vector(0)) + +statement error pgcode 42601 dimensions for type vector cannot exceed 16000 +CREATE TABLE v (v vector(16001)) + +statement error column v is of type vector and thus is not indexable +CREATE TABLE v (v vector(2) PRIMARY KEY) + +statement ok +CREATE TABLE v (v vector); +CREATE TABLE v2 (v vector(2)) + +statement ok +INSERT INTO v VALUES('[1]'), ('[2,3]') + +query T rowsort +SELECT * FROM v +---- +[1] +[2,3] + +query T +SELECT * FROM v WHERE v = '[1,2]' +---- + +query error pgcode 22000 different vector dimensions 2 and 1 +SELECT l2_distance('[1,2]', '[1]') + +statement error pgcode 22000 expected 2 dimensions, not 1 +INSERT INTO v2 VALUES('[1]'), ('[2,3]') + +statement ok +INSERT INTO v2 VALUES('[1,2]'), ('[3,4]') + +query T rowsort +SELECT * FROM v2 +---- +[1,2] +[3,4] + +query T +SELECT * FROM v2 WHERE v = '[1,2]' +---- +[1,2] + +query TT +SELECT '[1,2]'::text::vector, ARRAY[1,2]::vector +---- +[1,2] [1,2] + +query error pgcode 22004 array must not contain nulls +SELECT ARRAY[1,2,null]::vector + +query error pgcode 22000 expected 1 dimensions, not 2 +select '[3,1]'::vector(1) + +query error pgcode 22000 NaN not allowed in vector +select '[3,NaN]'::vector + +query error pgcode 22000 infinite value not allowed in vector +select '[3,Inf]'::vector + +query error pgcode 22000 infinite value not allowed in vector +select '[3,-Inf]'::vector + +statement ok +CREATE TABLE x (a float[], b real[]) + +# Test implicit cast from vector to array. +statement ok +INSERT INTO x VALUES('[1,2]'::vector, '[3,4]'::vector) + +statement ok +CREATE TABLE v3 (v1 vector(1), v2 vector(1)); +INSERT INTO v3 VALUES +('[1]', '[2]'), +('[1]', '[-2]'), +(NULL, '[1]'), +('[1]', NULL) + +query FFFTTT rowsort +SELECT v1<->v2, v1<#>v2, v1<=>v2, v1+v2, v1-v2, v1*v2 FROM v3 +---- +1 -2 0 [3] [-1] [2] +3 2 2 [-1] [3] [-2] +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +query FFFFFI rowsort +SELECT l1_distance(v1,v2), l2_distance(v1,v2), cosine_distance(v1,v2), inner_product(v1,v2), vector_norm(v1), vector_dims(v1) FROM v3 +---- +1 1 0 2 1 1 +3 3 2 -2 1 1 +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL 1 1 diff --git a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go index 0469f99fc6b3..11be5fd5b63a 100644 --- a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go +++ b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go @@ -2825,6 +2825,13 @@ func TestTenantLogicCCL_unique_read_committed( runCCLLogicTest(t, "unique_read_committed") } +func TestTenantLogicCCL_vector( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runCCLLogicTest(t, "vector") +} + func TestTenantLogicCCL_zone_config_secondary_tenants( t *testing.T, ) { diff --git a/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/BUILD.bazel b/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/BUILD.bazel new file mode 100644 index 000000000000..a7aef9e5a479 --- /dev/null +++ b/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/BUILD.bazel @@ -0,0 +1,30 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +go_test( + name = "cockroach-go-testserver-23_2_test", + size = "enormous", + srcs = ["generated_test.go"], + data = [ + "//c-deps:libgeos", # keep + "//pkg/ccl/logictestccl:testdata", # keep + "//pkg/cmd/cockroach-short", # keep + "//pkg/sql/logictest:cockroach_predecessor_version", # keep + ], + exec_properties = {"test.Pool": "large"}, + shard_count = 1, + tags = ["cpu:2"], + deps = [ + "//pkg/base", + "//pkg/build/bazel", + "//pkg/ccl", + "//pkg/security/securityassets", + "//pkg/security/securitytest", + "//pkg/server", + "//pkg/sql/logictest", + "//pkg/testutils/serverutils", + "//pkg/testutils/skip", + "//pkg/testutils/testcluster", + "//pkg/util/leaktest", + "//pkg/util/randutil", + ], +) diff --git a/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/generated_test.go b/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/generated_test.go new file mode 100644 index 000000000000..dda3b4feab45 --- /dev/null +++ b/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/generated_test.go @@ -0,0 +1,86 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +// Code generated by generate-logictest, DO NOT EDIT. + +package testcockroach_go_testserver_232 + +import ( + "os" + "path/filepath" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/build/bazel" + "github.com/cockroachdb/cockroach/pkg/ccl" + "github.com/cockroachdb/cockroach/pkg/security/securityassets" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/sql/logictest" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/randutil" +) + +const configIdx = 20 + +var cclLogicTestDir string + +func init() { + if bazel.BuiltWithBazel() { + var err error + cclLogicTestDir, err = bazel.Runfile("pkg/ccl/logictestccl/testdata/logic_test") + if err != nil { + panic(err) + } + } else { + cclLogicTestDir = "../../../../ccl/logictestccl/testdata/logic_test" + } +} + +func TestMain(m *testing.M) { + defer ccl.TestingEnableEnterprise()() + securityassets.SetLoader(securitytest.EmbeddedAssets) + randutil.SeedForTests() + serverutils.InitTestServerFactory(server.TestServerFactory) + serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + + defer serverutils.TestingSetDefaultTenantSelectionOverride( + base.TestIsForStuffThatShouldWorkWithSecondaryTenantsButDoesntYet(76378), + )() + + os.Exit(m.Run()) +} + +func runCCLLogicTest(t *testing.T, file string) { + skip.UnderDeadlock(t, "times out and/or hangs") + logictest.RunLogicTest(t, logictest.TestServerArgs{}, configIdx, filepath.Join(cclLogicTestDir, file)) +} + +// TestLogic_tmp runs any tests that are prefixed with "_", in which a dedicated +// test is not generated for. This allows developers to create and run temporary +// test files that are not checked into the repository, without repeatedly +// regenerating and reverting changes to this file, generated_test.go. +// +// TODO(mgartner): Add file filtering so that individual files can be run, +// instead of all files with the "_" prefix. +func TestLogic_tmp(t *testing.T) { + defer leaktest.AfterTest(t)() + var glob string + glob = filepath.Join(cclLogicTestDir, "_*") + logictest.RunLogicTests(t, logictest.TestServerArgs{}, configIdx, glob) +} + +func TestCCLLogic_mixed_version_pgvector( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runCCLLogicTest(t, "mixed_version_pgvector") +} diff --git a/pkg/ccl/logictestccl/tests/fakedist-disk/BUILD.bazel b/pkg/ccl/logictestccl/tests/fakedist-disk/BUILD.bazel index c4bb0b39dfba..5297e82de11f 100644 --- a/pkg/ccl/logictestccl/tests/fakedist-disk/BUILD.bazel +++ b/pkg/ccl/logictestccl/tests/fakedist-disk/BUILD.bazel @@ -12,7 +12,7 @@ go_test( "//build/toolchains:is_heavy": {"test.Pool": "heavy"}, "//conditions:default": {"test.Pool": "large"}, }), - shard_count = 28, + shard_count = 29, tags = ["cpu:2"], deps = [ "//pkg/base", diff --git a/pkg/ccl/logictestccl/tests/fakedist-disk/generated_test.go b/pkg/ccl/logictestccl/tests/fakedist-disk/generated_test.go index a2fdeec1cc9c..d61025d5a97f 100644 --- a/pkg/ccl/logictestccl/tests/fakedist-disk/generated_test.go +++ b/pkg/ccl/logictestccl/tests/fakedist-disk/generated_test.go @@ -273,3 +273,10 @@ func TestCCLLogic_unique_read_committed( defer leaktest.AfterTest(t)() runCCLLogicTest(t, "unique_read_committed") } + +func TestCCLLogic_vector( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runCCLLogicTest(t, "vector") +} diff --git a/pkg/ccl/logictestccl/tests/fakedist-vec-off/BUILD.bazel b/pkg/ccl/logictestccl/tests/fakedist-vec-off/BUILD.bazel index e7eb49f07487..d63444795f09 100644 --- a/pkg/ccl/logictestccl/tests/fakedist-vec-off/BUILD.bazel +++ b/pkg/ccl/logictestccl/tests/fakedist-vec-off/BUILD.bazel @@ -12,7 +12,7 @@ go_test( "//build/toolchains:is_heavy": {"test.Pool": "heavy"}, "//conditions:default": {"test.Pool": "large"}, }), - shard_count = 28, + shard_count = 29, tags = ["cpu:2"], deps = [ "//pkg/base", diff --git a/pkg/ccl/logictestccl/tests/fakedist-vec-off/generated_test.go b/pkg/ccl/logictestccl/tests/fakedist-vec-off/generated_test.go index 41b1bc49cbcc..c24a523488fc 100644 --- a/pkg/ccl/logictestccl/tests/fakedist-vec-off/generated_test.go +++ b/pkg/ccl/logictestccl/tests/fakedist-vec-off/generated_test.go @@ -273,3 +273,10 @@ func TestCCLLogic_unique_read_committed( defer leaktest.AfterTest(t)() runCCLLogicTest(t, "unique_read_committed") } + +func TestCCLLogic_vector( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runCCLLogicTest(t, "vector") +} diff --git a/pkg/ccl/logictestccl/tests/fakedist/BUILD.bazel b/pkg/ccl/logictestccl/tests/fakedist/BUILD.bazel index 9f97dfa4f620..4c8d25f92413 100644 --- a/pkg/ccl/logictestccl/tests/fakedist/BUILD.bazel +++ b/pkg/ccl/logictestccl/tests/fakedist/BUILD.bazel @@ -12,7 +12,7 @@ go_test( "//build/toolchains:is_heavy": {"test.Pool": "heavy"}, "//conditions:default": {"test.Pool": "large"}, }), - shard_count = 29, + shard_count = 30, tags = ["cpu:2"], deps = [ "//pkg/base", diff --git a/pkg/ccl/logictestccl/tests/fakedist/generated_test.go b/pkg/ccl/logictestccl/tests/fakedist/generated_test.go index b9e4f8e0d73b..bf73ec236736 100644 --- a/pkg/ccl/logictestccl/tests/fakedist/generated_test.go +++ b/pkg/ccl/logictestccl/tests/fakedist/generated_test.go @@ -280,3 +280,10 @@ func TestCCLLogic_unique_read_committed( defer leaktest.AfterTest(t)() runCCLLogicTest(t, "unique_read_committed") } + +func TestCCLLogic_vector( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runCCLLogicTest(t, "vector") +} diff --git a/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/BUILD.bazel b/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/BUILD.bazel index 638d7cc79915..2d52175c8146 100644 --- a/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/BUILD.bazel +++ b/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/BUILD.bazel @@ -9,7 +9,7 @@ go_test( "//pkg/ccl/logictestccl:testdata", # keep ], exec_properties = {"test.Pool": "large"}, - shard_count = 28, + shard_count = 29, tags = ["cpu:1"], deps = [ "//pkg/base", diff --git a/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/generated_test.go b/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/generated_test.go index 357b386f166d..5374bf1bf141 100644 --- a/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/generated_test.go +++ b/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/generated_test.go @@ -273,3 +273,10 @@ func TestCCLLogic_unique_read_committed( defer leaktest.AfterTest(t)() runCCLLogicTest(t, "unique_read_committed") } + +func TestCCLLogic_vector( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runCCLLogicTest(t, "vector") +} diff --git a/pkg/ccl/logictestccl/tests/local-read-committed/BUILD.bazel b/pkg/ccl/logictestccl/tests/local-read-committed/BUILD.bazel index 6ebd57d3fe8f..1ce636f6aa8d 100644 --- a/pkg/ccl/logictestccl/tests/local-read-committed/BUILD.bazel +++ b/pkg/ccl/logictestccl/tests/local-read-committed/BUILD.bazel @@ -10,7 +10,7 @@ go_test( "//pkg/sql/opt/exec/execbuilder:testdata", # keep ], exec_properties = {"test.Pool": "large"}, - shard_count = 35, + shard_count = 36, tags = ["cpu:1"], deps = [ "//pkg/base", diff --git a/pkg/ccl/logictestccl/tests/local-read-committed/generated_test.go b/pkg/ccl/logictestccl/tests/local-read-committed/generated_test.go index abca7f4e8348..bd066c1e9173 100644 --- a/pkg/ccl/logictestccl/tests/local-read-committed/generated_test.go +++ b/pkg/ccl/logictestccl/tests/local-read-committed/generated_test.go @@ -301,6 +301,13 @@ func TestReadCommittedLogicCCL_unique_read_committed( runCCLLogicTest(t, "unique_read_committed") } +func TestReadCommittedLogicCCL_vector( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runCCLLogicTest(t, "vector") +} + func TestReadCommittedExecBuild_explain_analyze_read_committed( t *testing.T, ) { diff --git a/pkg/ccl/logictestccl/tests/local-vec-off/BUILD.bazel b/pkg/ccl/logictestccl/tests/local-vec-off/BUILD.bazel index 92f0ed613a5c..8f49d20f505c 100644 --- a/pkg/ccl/logictestccl/tests/local-vec-off/BUILD.bazel +++ b/pkg/ccl/logictestccl/tests/local-vec-off/BUILD.bazel @@ -9,7 +9,7 @@ go_test( "//pkg/ccl/logictestccl:testdata", # keep ], exec_properties = {"test.Pool": "large"}, - shard_count = 28, + shard_count = 29, tags = ["cpu:1"], deps = [ "//pkg/base", diff --git a/pkg/ccl/logictestccl/tests/local-vec-off/generated_test.go b/pkg/ccl/logictestccl/tests/local-vec-off/generated_test.go index c9e9d7e13055..67c0c4b6610c 100644 --- a/pkg/ccl/logictestccl/tests/local-vec-off/generated_test.go +++ b/pkg/ccl/logictestccl/tests/local-vec-off/generated_test.go @@ -273,3 +273,10 @@ func TestCCLLogic_unique_read_committed( defer leaktest.AfterTest(t)() runCCLLogicTest(t, "unique_read_committed") } + +func TestCCLLogic_vector( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runCCLLogicTest(t, "vector") +} diff --git a/pkg/ccl/logictestccl/tests/local/BUILD.bazel b/pkg/ccl/logictestccl/tests/local/BUILD.bazel index ef7ea4aaef3b..096337816aaf 100644 --- a/pkg/ccl/logictestccl/tests/local/BUILD.bazel +++ b/pkg/ccl/logictestccl/tests/local/BUILD.bazel @@ -9,7 +9,7 @@ go_test( "//pkg/ccl/logictestccl:testdata", # keep ], exec_properties = {"test.Pool": "large"}, - shard_count = 44, + shard_count = 45, tags = ["cpu:1"], deps = [ "//pkg/base", diff --git a/pkg/ccl/logictestccl/tests/local/generated_test.go b/pkg/ccl/logictestccl/tests/local/generated_test.go index c67aea9b5aaf..837293b36ae2 100644 --- a/pkg/ccl/logictestccl/tests/local/generated_test.go +++ b/pkg/ccl/logictestccl/tests/local/generated_test.go @@ -385,3 +385,10 @@ func TestCCLLogic_unique_read_committed( defer leaktest.AfterTest(t)() runCCLLogicTest(t, "unique_read_committed") } + +func TestCCLLogic_vector( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runCCLLogicTest(t, "vector") +} diff --git a/pkg/sql/alter_column_type.go b/pkg/sql/alter_column_type.go index 11eb91a3909b..925667e2a7a8 100644 --- a/pkg/sql/alter_column_type.go +++ b/pkg/sql/alter_column_type.go @@ -107,7 +107,7 @@ func AlterColumnType( } } - err = colinfo.ValidateColumnDefType(ctx, params.EvalContext().Settings.Version, typ) + err = colinfo.ValidateColumnDefType(ctx, params.EvalContext().Settings, typ) if err != nil { return err } diff --git a/pkg/sql/catalog/colinfo/BUILD.bazel b/pkg/sql/catalog/colinfo/BUILD.bazel index 53e850531285..795ccb31fe0a 100644 --- a/pkg/sql/catalog/colinfo/BUILD.bazel +++ b/pkg/sql/catalog/colinfo/BUILD.bazel @@ -16,7 +16,9 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo", visibility = ["//visibility:public"], deps = [ + "//pkg/base", "//pkg/clusterversion", + "//pkg/settings/cluster", "//pkg/sql/catalog", "//pkg/sql/catalog/catpb", "//pkg/sql/catalog/descpb", diff --git a/pkg/sql/catalog/colinfo/col_type_info.go b/pkg/sql/catalog/colinfo/col_type_info.go index b5a6633986ae..9002aa6cb5dd 100644 --- a/pkg/sql/catalog/colinfo/col_type_info.go +++ b/pkg/sql/catalog/colinfo/col_type_info.go @@ -14,7 +14,9 @@ import ( "context" "fmt" + "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/clusterversion" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql/catalog" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" @@ -70,7 +72,7 @@ func (ti ColTypeInfo) Type(idx int) *types.T { // ValidateColumnDefType returns an error if the type of a column definition is // not valid. It is checked when a column is created or altered. -func ValidateColumnDefType(ctx context.Context, version clusterversion.Handle, t *types.T) error { +func ValidateColumnDefType(ctx context.Context, st *cluster.Settings, t *types.T) error { switch t.Family() { case types.StringFamily, types.CollatedStringFamily: if t.Family() == types.CollatedStringFamily { @@ -102,7 +104,7 @@ func ValidateColumnDefType(ctx context.Context, version clusterversion.Handle, t if err := types.CheckArrayElementType(t.ArrayContents()); err != nil { return err } - return ValidateColumnDefType(ctx, version, t.ArrayContents()) + return ValidateColumnDefType(ctx, st, t.ArrayContents()) case types.BitFamily, types.IntFamily, types.FloatFamily, types.BoolFamily, types.BytesFamily, types.DateFamily, types.INetFamily, types.IntervalFamily, types.JsonFamily, types.OidFamily, types.TimeFamily, @@ -119,6 +121,17 @@ func ValidateColumnDefType(ctx context.Context, version clusterversion.Handle, t return unimplemented.NewWithIssue(70099, "cannot use table record type as table column") } + case types.PGVectorFamily: + if !st.Version.IsActive(ctx, clusterversion.V24_2) { + return pgerror.Newf( + pgcode.FeatureNotSupported, + "pg_vector not supported until version 24.2", + ) + } + if err := base.CheckEnterpriseEnabled(st, "vector datatype"); err != nil { + return err + } + default: return pgerror.Newf(pgcode.InvalidTableDefinition, "value type %s cannot be used for table columns", t.String()) @@ -192,6 +205,8 @@ func MustBeValueEncoded(semanticType *types.T) bool { return true case types.TSVectorFamily, types.TSQueryFamily: return true + case types.PGVectorFamily: + return true } return false } diff --git a/pkg/sql/catalog/colinfo/column_type_properties.go b/pkg/sql/catalog/colinfo/column_type_properties.go index f474af606b49..1898ab46034f 100644 --- a/pkg/sql/catalog/colinfo/column_type_properties.go +++ b/pkg/sql/catalog/colinfo/column_type_properties.go @@ -83,6 +83,7 @@ func CanHaveCompositeKeyEncoding(typ *types.T) bool { types.EnumFamily, types.Box2DFamily, types.PGLSNFamily, + types.PGVectorFamily, types.RefCursorFamily, types.VoidFamily, types.EncodedKeyFamily, diff --git a/pkg/sql/catalog/tabledesc/table.go b/pkg/sql/catalog/tabledesc/table.go index 9b55075bfb3a..fb871851b11e 100644 --- a/pkg/sql/catalog/tabledesc/table.go +++ b/pkg/sql/catalog/tabledesc/table.go @@ -161,7 +161,7 @@ func MakeColumnDefDescs( if err != nil { return nil, err } - if err = colinfo.ValidateColumnDefType(ctx, evalCtx.Settings.Version, resType); err != nil { + if err = colinfo.ValidateColumnDefType(ctx, evalCtx.Settings, resType); err != nil { return nil, err } col.Type = resType diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index 77c71298db3b..0a449589876b 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -2051,6 +2051,7 @@ func checkResultType(typ *types.T, fmtCode pgwirebase.FormatCode) error { case types.INetFamily: case types.OidFamily: case types.PGLSNFamily: + case types.PGVectorFamily: case types.RefCursorFamily: case types.TupleFamily: case types.EnumFamily: diff --git a/pkg/sql/logictest/BUILD.bazel b/pkg/sql/logictest/BUILD.bazel index a54b6ee2b99b..59bb3a24bd43 100644 --- a/pkg/sql/logictest/BUILD.bazel +++ b/pkg/sql/logictest/BUILD.bazel @@ -38,6 +38,7 @@ filegroup( }), visibility = [ "//pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.1:__pkg__", + "//pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2:__pkg__", "//pkg/sql/logictest/tests/cockroach-go-testserver-23.1:__pkg__", "//pkg/sql/logictest/tests/cockroach-go-testserver-23.2:__pkg__", ], diff --git a/pkg/sql/logictest/testdata/logic_test/grant_table b/pkg/sql/logictest/testdata/logic_test/grant_table index f8384639e343..a47e1331bd56 100644 --- a/pkg/sql/logictest/testdata/logic_test/grant_table +++ b/pkg/sql/logictest/testdata/logic_test/grant_table @@ -603,6 +603,12 @@ test pg_catalog varchar test pg_catalog varchar[] type admin ALL false test pg_catalog varchar[] type public USAGE false test pg_catalog varchar[] type root ALL false +test pg_catalog vector type admin ALL false +test pg_catalog vector type public USAGE false +test pg_catalog vector type root ALL false +test pg_catalog vector[] type admin ALL false +test pg_catalog vector[] type public USAGE false +test pg_catalog vector[] type root ALL false test pg_catalog void type admin ALL false test pg_catalog void type public USAGE false test pg_catalog void type root ALL false @@ -791,6 +797,10 @@ test pg_catalog varchar type admin ALL test pg_catalog varchar type root ALL false test pg_catalog varchar[] type admin ALL false test pg_catalog varchar[] type root ALL false +test pg_catalog vector type admin ALL false +test pg_catalog vector type root ALL false +test pg_catalog vector[] type admin ALL false +test pg_catalog vector[] type root ALL false test pg_catalog void type admin ALL false test pg_catalog void type root ALL false test public NULL schema admin ALL true @@ -1403,6 +1413,10 @@ a pg_catalog varchar type admin a pg_catalog varchar type root ALL false a pg_catalog varchar[] type admin ALL false a pg_catalog varchar[] type root ALL false +a pg_catalog vector type admin ALL false +a pg_catalog vector type root ALL false +a pg_catalog vector[] type admin ALL false +a pg_catalog vector[] type root ALL false a pg_catalog void type admin ALL false a pg_catalog void type root ALL false a public NULL schema admin ALL true @@ -1579,6 +1593,10 @@ defaultdb pg_catalog varchar type admin defaultdb pg_catalog varchar type root ALL false defaultdb pg_catalog varchar[] type admin ALL false defaultdb pg_catalog varchar[] type root ALL false +defaultdb pg_catalog vector type admin ALL false +defaultdb pg_catalog vector type root ALL false +defaultdb pg_catalog vector[] type admin ALL false +defaultdb pg_catalog vector[] type root ALL false defaultdb pg_catalog void type admin ALL false defaultdb pg_catalog void type root ALL false defaultdb public NULL schema admin ALL true @@ -1755,6 +1773,10 @@ postgres pg_catalog varchar type admin postgres pg_catalog varchar type root ALL false postgres pg_catalog varchar[] type admin ALL false postgres pg_catalog varchar[] type root ALL false +postgres pg_catalog vector type admin ALL false +postgres pg_catalog vector type root ALL false +postgres pg_catalog vector[] type admin ALL false +postgres pg_catalog vector[] type root ALL false postgres pg_catalog void type admin ALL false postgres pg_catalog void type root ALL false postgres public NULL schema admin ALL true @@ -1931,6 +1953,10 @@ system pg_catalog varchar type admin system pg_catalog varchar type root ALL false system pg_catalog varchar[] type admin ALL false system pg_catalog varchar[] type root ALL false +system pg_catalog vector type admin ALL false +system pg_catalog vector type root ALL false +system pg_catalog vector[] type admin ALL false +system pg_catalog vector[] type root ALL false system pg_catalog void type admin ALL false system pg_catalog void type root ALL false system public NULL schema admin ALL true @@ -2487,6 +2513,10 @@ test pg_catalog varchar type admin test pg_catalog varchar type root ALL false test pg_catalog varchar[] type admin ALL false test pg_catalog varchar[] type root ALL false +test pg_catalog vector type admin ALL false +test pg_catalog vector type root ALL false +test pg_catalog vector[] type admin ALL false +test pg_catalog vector[] type root ALL false test pg_catalog void type admin ALL false test pg_catalog void type root ALL false test public NULL schema admin ALL true diff --git a/pkg/sql/logictest/testdata/logic_test/pg_catalog b/pkg/sql/logictest/testdata/logic_test/pg_catalog index 97a1da50aa8c..a77bef11eeba 100644 --- a/pkg/sql/logictest/testdata/logic_test/pg_catalog +++ b/pkg/sql/logictest/testdata/logic_test/pg_catalog @@ -1923,6 +1923,8 @@ oid typname typnamespace typowner typlen typbyval typty 90003 _geography 4294967103 NULL -1 false b 90004 box2d 4294967103 NULL 32 true b 90005 _box2d 4294967103 NULL -1 false b +90006 vector 4294967103 NULL -1 false b +90007 _vector 4294967103 NULL -1 false b 100110 t1 109 1546506610 -1 false c 100111 t1_m_seq 109 1546506610 -1 false c 100112 t1_n_seq 109 1546506610 -1 false c @@ -2038,6 +2040,8 @@ oid typname typcategory typispreferred typisdefined typdel 90003 _geography A false true : 0 90002 0 90004 box2d U false true , 0 0 90005 90005 _box2d A false true , 0 90004 0 +90006 vector U false true , 0 0 90007 +90007 _vector A false true , 0 90006 0 100110 t1 C false true , 110 0 0 100111 t1_m_seq C false true , 111 0 0 100112 t1_n_seq C false true , 112 0 0 @@ -2153,6 +2157,8 @@ oid typname typinput typoutput typreceive 90003 _geography array_in array_out array_recv array_send 0 0 0 90004 box2d box2d_in box2d_out box2d_recv box2d_send 0 0 0 90005 _box2d array_in array_out array_recv array_send 0 0 0 +90006 vector vectorin vectorout vectorrecv vectorsend 0 0 0 +90007 _vector array_in array_out array_recv array_send 0 0 0 100110 t1 record_in record_out record_recv record_send 0 0 0 100111 t1_m_seq record_in record_out record_recv record_send 0 0 0 100112 t1_n_seq record_in record_out record_recv record_send 0 0 0 @@ -2268,6 +2274,8 @@ oid typname typalign typstorage typnotnull typbasetype ty 90003 _geography NULL NULL false 0 -1 90004 box2d NULL NULL false 0 -1 90005 _box2d NULL NULL false 0 -1 +90006 vector NULL NULL false 0 -1 +90007 _vector NULL NULL false 0 -1 100110 t1 NULL NULL false 0 -1 100111 t1_m_seq NULL NULL false 0 -1 100112 t1_n_seq NULL NULL false 0 -1 @@ -2383,6 +2391,8 @@ oid typname typndims typcollation typdefaultbin typdefault 90003 _geography 0 0 NULL NULL NULL 90004 box2d 0 0 NULL NULL NULL 90005 _box2d 0 0 NULL NULL NULL +90006 vector 0 0 NULL NULL NULL +90007 _vector 0 0 NULL NULL NULL 100110 t1 0 0 NULL NULL NULL 100111 t1_m_seq 0 0 NULL NULL NULL 100112 t1_n_seq 0 0 NULL NULL NULL diff --git a/pkg/sql/logictest/testdata/logic_test/vectoross b/pkg/sql/logictest/testdata/logic_test/vectoross new file mode 100644 index 000000000000..12a220dbd673 --- /dev/null +++ b/pkg/sql/logictest/testdata/logic_test/vectoross @@ -0,0 +1,5 @@ +# LogicTest: !local-mixed-23.2 !3node-tenant + +statement error OSS binaries do not include enterprise features +CREATE TABLE v (v vector) + diff --git a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go index 02112869ec57..5074355e8844 100644 --- a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go @@ -2458,6 +2458,13 @@ func TestLogic_vectorize_window( runLogicTest(t, "vectorize_window") } +func TestLogic_vectoross( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "vectoross") +} + func TestLogic_views( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go index f82f27e38146..a88349f4d32d 100644 --- a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go @@ -2451,6 +2451,13 @@ func TestLogic_vectorize_unsupported( runLogicTest(t, "vectorize_unsupported") } +func TestLogic_vectoross( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "vectoross") +} + func TestLogic_views( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/fakedist/generated_test.go b/pkg/sql/logictest/tests/fakedist/generated_test.go index 1b066c6ef182..1818a68527d9 100644 --- a/pkg/sql/logictest/tests/fakedist/generated_test.go +++ b/pkg/sql/logictest/tests/fakedist/generated_test.go @@ -2479,6 +2479,13 @@ func TestLogic_vectorize_window( runLogicTest(t, "vectorize_window") } +func TestLogic_vectoross( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "vectoross") +} + func TestLogic_views( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go index cc508f62730a..7945dd8a194b 100644 --- a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go +++ b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go @@ -2451,6 +2451,13 @@ func TestLogic_vectorize_unsupported( runLogicTest(t, "vectorize_unsupported") } +func TestLogic_vectoross( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "vectoross") +} + func TestLogic_views( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local-vec-off/generated_test.go b/pkg/sql/logictest/tests/local-vec-off/generated_test.go index 28dea5367cd3..13ba0ea77d2b 100644 --- a/pkg/sql/logictest/tests/local-vec-off/generated_test.go +++ b/pkg/sql/logictest/tests/local-vec-off/generated_test.go @@ -2479,6 +2479,13 @@ func TestLogic_vectorize_unsupported( runLogicTest(t, "vectorize_unsupported") } +func TestLogic_vectoross( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "vectoross") +} + func TestLogic_views( t *testing.T, ) { diff --git a/pkg/sql/logictest/tests/local/generated_test.go b/pkg/sql/logictest/tests/local/generated_test.go index 9f1ac0809e00..12bfc36634c4 100644 --- a/pkg/sql/logictest/tests/local/generated_test.go +++ b/pkg/sql/logictest/tests/local/generated_test.go @@ -2710,6 +2710,13 @@ func TestLogic_vectorize_window( runLogicTest(t, "vectorize_window") } +func TestLogic_vectoross( + t *testing.T, +) { + defer leaktest.AfterTest(t)() + runLogicTest(t, "vectoross") +} + func TestLogic_views( t *testing.T, ) { diff --git a/pkg/sql/oidext/oidext.go b/pkg/sql/oidext/oidext.go index ac7da4507c7a..6a1d08461c78 100644 --- a/pkg/sql/oidext/oidext.go +++ b/pkg/sql/oidext/oidext.go @@ -34,6 +34,8 @@ const ( T__geography = oid.Oid(90003) T_box2d = oid.Oid(90004) T__box2d = oid.Oid(90005) + T_pgvector = oid.Oid(90006) + T__pgvector = oid.Oid(90007) ) // ExtensionTypeName returns a mapping from extension oids @@ -45,6 +47,8 @@ var ExtensionTypeName = map[oid.Oid]string{ T__geography: "_GEOGRAPHY", T_box2d: "BOX2D", T__box2d: "_BOX2D", + T_pgvector: "VECTOR", + T__pgvector: "_VECTOR", } // TypeName checks the name for a given type by first looking up oid.TypeName diff --git a/pkg/sql/opt/operator.go b/pkg/sql/opt/operator.go index 25cab29dc048..234fd30245e8 100644 --- a/pkg/sql/opt/operator.go +++ b/pkg/sql/opt/operator.go @@ -151,23 +151,26 @@ var ComparisonOpReverseMap = map[Operator]treecmp.ComparisonOperatorSymbol{ // BinaryOpReverseMap maps from an optimizer operator type to a semantic tree // binary operator type. var BinaryOpReverseMap = map[Operator]treebin.BinaryOperatorSymbol{ - BitandOp: treebin.Bitand, - BitorOp: treebin.Bitor, - BitxorOp: treebin.Bitxor, - PlusOp: treebin.Plus, - MinusOp: treebin.Minus, - MultOp: treebin.Mult, - DivOp: treebin.Div, - FloorDivOp: treebin.FloorDiv, - ModOp: treebin.Mod, - PowOp: treebin.Pow, - ConcatOp: treebin.Concat, - LShiftOp: treebin.LShift, - RShiftOp: treebin.RShift, - FetchValOp: treebin.JSONFetchVal, - FetchTextOp: treebin.JSONFetchText, - FetchValPathOp: treebin.JSONFetchValPath, - FetchTextPathOp: treebin.JSONFetchTextPath, + BitandOp: treebin.Bitand, + BitorOp: treebin.Bitor, + BitxorOp: treebin.Bitxor, + PlusOp: treebin.Plus, + MinusOp: treebin.Minus, + MultOp: treebin.Mult, + DivOp: treebin.Div, + FloorDivOp: treebin.FloorDiv, + ModOp: treebin.Mod, + PowOp: treebin.Pow, + ConcatOp: treebin.Concat, + LShiftOp: treebin.LShift, + RShiftOp: treebin.RShift, + FetchValOp: treebin.JSONFetchVal, + FetchTextOp: treebin.JSONFetchText, + FetchValPathOp: treebin.JSONFetchValPath, + FetchTextPathOp: treebin.JSONFetchTextPath, + VectorDistanceOp: treebin.Distance, + VectorCosDistanceOp: treebin.CosDistance, + VectorNegInnerProductOp: treebin.NegInnerProduct, } // UnaryOpReverseMap maps from an optimizer operator type to a semantic tree diff --git a/pkg/sql/opt/ops/scalar.opt b/pkg/sql/opt/ops/scalar.opt index 7572e006e9dd..dcc991d28d96 100644 --- a/pkg/sql/opt/ops/scalar.opt +++ b/pkg/sql/opt/ops/scalar.opt @@ -514,6 +514,30 @@ define TSMatches { Right ScalarExpr } +# VectorDistance is the <-> operator when used with vector operands. +# It maps to tree.Distance. +[Scalar, Binary] +define VectorDistance { + Left ScalarExpr + Right ScalarExpr +} + +# VectorCosDistance is the <=> operator when used with vector operands. +# It maps to tree.CosDistance. +[Scalar, Binary] +define VectorCosDistance { + Left ScalarExpr + Right ScalarExpr +} + +# VectorNegInnerProduct is the <#> operator when used with vector operands. +# It maps to tree.NegInnerProduct. +[Scalar, Binary] +define VectorNegInnerProduct { + Left ScalarExpr + Right ScalarExpr +} + # AnyScalar is the form of ANY which refers to an ANY operation on a # tuple or array, as opposed to Any which operates on a subquery. [Scalar, Bool] diff --git a/pkg/sql/opt/optbuilder/scalar.go b/pkg/sql/opt/optbuilder/scalar.go index 85441b076aa7..5755c1035106 100644 --- a/pkg/sql/opt/optbuilder/scalar.go +++ b/pkg/sql/opt/optbuilder/scalar.go @@ -853,6 +853,12 @@ func (b *Builder) constructBinary( return b.factory.ConstructFetchValPath(left, right) case treebin.JSONFetchTextPath: return b.factory.ConstructFetchTextPath(left, right) + case treebin.Distance: + return b.factory.ConstructVectorDistance(left, right) + case treebin.CosDistance: + return b.factory.ConstructVectorCosDistance(left, right) + case treebin.NegInnerProduct: + return b.factory.ConstructVectorNegInnerProduct(left, right) } panic(errors.AssertionFailedf("unhandled binary operator: %s", redact.Safe(bin))) } diff --git a/pkg/sql/parser/BUILD.bazel b/pkg/sql/parser/BUILD.bazel index b9bb467341db..a00beb6bc2b2 100644 --- a/pkg/sql/parser/BUILD.bazel +++ b/pkg/sql/parser/BUILD.bazel @@ -37,6 +37,7 @@ go_library( "//pkg/sql/sem/tree/treewindow", # keep "//pkg/sql/types", "//pkg/util/errorutil/unimplemented", + "//pkg/util/vector", # keep "@com_github_cockroachdb_errors//:errors", "@com_github_lib_pq//oid", # keep "@org_golang_x_text//cases", diff --git a/pkg/sql/parser/sql.y b/pkg/sql/parser/sql.y index 83b87c3d700b..b08f8e97b9ad 100644 --- a/pkg/sql/parser/sql.y +++ b/pkg/sql/parser/sql.y @@ -41,6 +41,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/tree/treecmp" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree/treewindow" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" ) @@ -927,14 +928,14 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption %token CLUSTER CLUSTERS COALESCE COLLATE COLLATION COLUMN COLUMNS COMMENT COMMENTS COMMIT %token COMMITTED COMPACT COMPLETE COMPLETIONS CONCAT CONCURRENTLY CONFIGURATION CONFIGURATIONS CONFIGURE %token CONFLICT CONNECTION CONNECTIONS CONSTRAINT CONSTRAINTS CONTAINS CONTROLCHANGEFEED CONTROLJOB -%token CONVERSION CONVERT COPY COST COVERING CREATE CREATEDB CREATELOGIN CREATEROLE +%token CONVERSION CONVERT COPY COS_DISTANCE COST COVERING CREATE CREATEDB CREATELOGIN CREATEROLE %token CROSS CSV CUBE CURRENT CURRENT_CATALOG CURRENT_DATE CURRENT_SCHEMA %token CURRENT_ROLE CURRENT_TIME CURRENT_TIMESTAMP %token CURRENT_USER CURSOR CYCLE %token DATA DATABASE DATABASES DATE DAY DEBUG_IDS DEC DEBUG_DUMP_METADATA_SST DECIMAL DEFAULT DEFAULTS DEFINER %token DEALLOCATE DECLARE DEFERRABLE DEFERRED DELETE DELIMITER DEPENDS DESC DESTINATION DETACHED DETAILS -%token DISCARD DISTINCT DO DOMAIN DOUBLE DROP +%token DISCARD DISTANCE DISTINCT DO DOMAIN DOUBLE DROP %token ELSE ENCODING ENCRYPTED ENCRYPTION_INFO_DIR ENCRYPTION_PASSPHRASE END ENUM ENUMS ESCAPE EXCEPT EXCLUDE EXCLUDING %token EXISTS EXECUTE EXECUTION EXPERIMENTAL @@ -977,7 +978,7 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption %token MULTIPOINT MULTIPOINTM MULTIPOINTZ MULTIPOINTZM %token MULTIPOLYGON MULTIPOLYGONM MULTIPOLYGONZ MULTIPOLYGONZM -%token NAN NAME NAMES NATURAL NEVER NEW_DB_NAME NEW_KMS NEXT NO NOCANCELQUERY NOCONTROLCHANGEFEED +%token NAN NAME NAMES NATURAL NEG_INNER_PRODUCT NEVER NEW_DB_NAME NEW_KMS NEXT NO NOCANCELQUERY NOCONTROLCHANGEFEED %token NOCONTROLJOB NOCREATEDB NOCREATELOGIN NOCREATEROLE NODE NOLOGIN NOMODIFYCLUSTERSETTING NOREPLICATION %token NOSQLLOGIN NO_INDEX_JOIN NO_ZIGZAG_JOIN NO_FULL_SCAN NONE NONVOTERS NORMAL NOT %token NOTHING NOTHING_AFTER_RETURNING @@ -1018,7 +1019,7 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption %token UNBOUNDED UNCOMMITTED UNION UNIQUE UNKNOWN UNLISTEN UNLOGGED UNSAFE_RESTORE_INCOMPATIBLE_VERSION UNSPLIT %token UPDATE UPDATES_CLUSTER_MONITORING_METRICS UPSERT UNSET UNTIL USE USER USERS USING UUID -%token VALID VALIDATE VALUE VALUES VARBIT VARCHAR VARIADIC VERIFY_BACKUP_TABLE_DATA VIEW VARIABLES VARYING VIEWACTIVITY VIEWACTIVITYREDACTED VIEWDEBUG +%token VALID VALIDATE VALUE VALUES VARBIT VARCHAR VARIADIC VECTOR VERIFY_BACKUP_TABLE_DATA VIEW VARIABLES VARYING VIEWACTIVITY VIEWACTIVITYREDACTED VIEWDEBUG %token VIEWCLUSTERMETADATA VIEWCLUSTERSETTING VIRTUAL VISIBLE INVISIBLE VISIBILITY VOLATILE VOTERS %token VIRTUAL_CLUSTER_NAME VIRTUAL_CLUSTER @@ -1603,6 +1604,7 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption %type <*types.T> character_base %type <*types.T> geo_shape_type %type <*types.T> const_geo +%type <*types.T> const_vector %type extract_arg %type opt_varying @@ -1760,7 +1762,7 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption // funny behavior of UNBOUNDED on the SQL standard, though. %nonassoc UNBOUNDED // ideally should have same precedence as IDENT %nonassoc IDENT NULL PARTITION RANGE ROWS GROUPS PRECEDING FOLLOWING CUBE ROLLUP -%left CONCAT FETCHVAL FETCHTEXT FETCHVAL_PATH FETCHTEXT_PATH REMOVE_PATH AT_AT // multi-character ops +%left CONCAT FETCHVAL FETCHTEXT FETCHVAL_PATH FETCHTEXT_PATH REMOVE_PATH AT_AT DISTANCE COS_DISTANCE NEG_INNER_PRODUCT // multi-character ops %left '|' %left '#' %left '&' @@ -14414,6 +14416,21 @@ const_geo: $$.val = types.MakeGeography($3.geoShapeType(), geopb.SRID(val)) } +const_vector: + VECTOR { $$.val = types.PGVector } +| VECTOR '(' iconst32 ')' + { + dims := $3.int32() + if dims <= 0 { + sqllex.Error("dimensions for type vector must be at least 1") + return 1 + } else if dims > vector.MaxDim { + sqllex.Error(fmt.Sprintf("dimensions for type vector cannot exceed %d", vector.MaxDim)) + return 1 + } + $$.val = types.MakePGVector(dims) + } + // We have a separate const_typename to allow defaulting fixed-length types such // as CHAR() and BIT() to an unspecified length. SQL9x requires that these // default to a length of one, but this makes no sense for constructs like CHAR @@ -14431,6 +14448,7 @@ const_typename: | character_with_length | const_datetime | const_geo +| const_vector opt_numeric_modifiers: '(' iconst32 ')' @@ -15016,6 +15034,18 @@ a_expr: { $$.val = &tree.ComparisonExpr{Operator: treecmp.MakeComparisonOperator(treecmp.TSMatches), Left: $1.expr(), Right: $3.expr()} } +| a_expr DISTANCE a_expr + { + $$.val = &tree.BinaryExpr{Operator: treebin.MakeBinaryOperator(treebin.Distance), Left: $1.expr(), Right: $3.expr()} + } +| a_expr COS_DISTANCE a_expr + { + $$.val = &tree.BinaryExpr{Operator: treebin.MakeBinaryOperator(treebin.CosDistance), Left: $1.expr(), Right: $3.expr()} + } +| a_expr NEG_INNER_PRODUCT a_expr + { + $$.val = &tree.BinaryExpr{Operator: treebin.MakeBinaryOperator(treebin.NegInnerProduct), Left: $1.expr(), Right: $3.expr()} + } | a_expr INET_CONTAINS_OR_EQUALS a_expr { $$.val = &tree.FuncExpr{Func: tree.WrapFunction("inet_contains_or_equals"), Exprs: tree.Exprs{$1.expr(), $3.expr()}} @@ -16228,6 +16258,9 @@ all_op: | NOT_REGIMATCH { $$.val = treecmp.MakeComparisonOperator(treecmp.NotRegIMatch) } | AND_AND { $$.val = treecmp.MakeComparisonOperator(treecmp.Overlaps) } | AT_AT { $$.val = treecmp.MakeComparisonOperator(treecmp.TSMatches) } +| DISTANCE { $$.val = treebin.MakeBinaryOperator(treebin.Distance) } +| COS_DISTANCE { $$.val = treebin.MakeBinaryOperator(treebin.CosDistance) } +| NEG_INNER_PRODUCT { $$.val = treebin.MakeBinaryOperator(treebin.NegInnerProduct) } | '~' { $$.val = tree.MakeUnaryOperator(tree.UnaryComplement) } | SQRT { $$.val = tree.MakeUnaryOperator(tree.UnarySqrt) } | CBRT { $$.val = tree.MakeUnaryOperator(tree.UnaryCbrt) } @@ -18203,6 +18236,7 @@ bare_label_keywords: | VARCHAR | VARIABLES | VARIADIC +| VECTOR | VERIFY_BACKUP_TABLE_DATA | VIEW | VIEWACTIVITY @@ -18288,6 +18322,7 @@ col_name_keyword: | VALUES | VARBIT | VARCHAR +| VECTOR | VIRTUAL | WORK diff --git a/pkg/sql/parser/testdata/create_table b/pkg/sql/parser/testdata/create_table index 4d7834e1d667..0238fe007724 100644 --- a/pkg/sql/parser/testdata/create_table +++ b/pkg/sql/parser/testdata/create_table @@ -2542,3 +2542,19 @@ ALTER TABLE a PARTITION ALL BY LIST ("a b", "c.d") (PARTITION "e.f" VALUES IN (1 ALTER TABLE a PARTITION ALL BY LIST ("a b", "c.d") (PARTITION "e.f" VALUES IN ((1))) -- fully parenthesized ALTER TABLE a PARTITION ALL BY LIST ("a b", "c.d") (PARTITION "e.f" VALUES IN (_)) -- literals removed ALTER TABLE _ PARTITION ALL BY LIST (_, _) (PARTITION _ VALUES IN (1)) -- identifiers removed + +parse +CREATE TABLE a (a VECTOR(3)) +---- +CREATE TABLE a (a VECTOR(3)) +CREATE TABLE a (a VECTOR(3)) -- fully parenthesized +CREATE TABLE a (a VECTOR(3)) -- literals removed +CREATE TABLE _ (_ VECTOR(3)) -- identifiers removed + +parse +CREATE TABLE a (a VECTOR) +---- +CREATE TABLE a (a VECTOR) +CREATE TABLE a (a VECTOR) -- fully parenthesized +CREATE TABLE a (a VECTOR) -- literals removed +CREATE TABLE _ (_ VECTOR) -- identifiers removed diff --git a/pkg/sql/parser/testdata/select_exprs b/pkg/sql/parser/testdata/select_exprs index 2ce2e9902d0e..3525f0a2298f 100644 --- a/pkg/sql/parser/testdata/select_exprs +++ b/pkg/sql/parser/testdata/select_exprs @@ -2058,3 +2058,43 @@ SELECT my_func('a', 1, true) SELECT (my_func(('a'), (1), (true))) -- fully parenthesized SELECT my_func('_', _, _) -- literals removed SELECT _('a', 1, true) -- identifiers removed + +parse +SELECT "[1,2]" <-> "[3,4]" +---- +SELECT "[1,2]" <-> "[3,4]" +SELECT (("[1,2]") <-> ("[3,4]")) -- fully parenthesized +SELECT "[1,2]" <-> "[3,4]" -- literals removed +SELECT _ <-> _ -- identifiers removed + +parse +SELECT "[2,2]" <=> "[3,4]" +---- +SELECT "[2,2]" <=> "[3,4]" +SELECT (("[2,2]") <=> ("[3,4]")) -- fully parenthesized +SELECT "[2,2]" <=> "[3,4]" -- literals removed +SELECT _ <=> _ -- identifiers removed + +parse +SELECT "[2,2]" <#> "[3,4]" +---- +SELECT "[2,2]" <#> "[3,4]" +SELECT (("[2,2]") <#> ("[3,4]")) -- fully parenthesized +SELECT "[2,2]" <#> "[3,4]" -- literals removed +SELECT _ <#> _ -- identifiers removed + +error +SELECT "[2,2]" <# "[3,4]" +---- +at or near "#": syntax error +DETAIL: source SQL: +SELECT "[2,2]" <# "[3,4]" + ^ + +parse +SELECT "[2,2]" <- "[3,4]" +---- +SELECT "[2,2]" < (-"[3,4]") -- normalized! +SELECT (("[2,2]") < ((-("[3,4]")))) -- fully parenthesized +SELECT "[2,2]" < (-"[3,4]") -- literals removed +SELECT _ < (-_) -- identifiers removed diff --git a/pkg/sql/pg_catalog.go b/pkg/sql/pg_catalog.go index e305e2fb89b7..51d6678c8a16 100644 --- a/pkg/sql/pg_catalog.go +++ b/pkg/sql/pg_catalog.go @@ -4775,6 +4775,7 @@ var datumToTypeCategory = map[types.Family]*tree.DString{ types.TupleFamily: typCategoryPseudo, types.OidFamily: typCategoryNumeric, types.PGLSNFamily: typCategoryUserDefined, + types.PGVectorFamily: typCategoryUserDefined, types.RefCursorFamily: typCategoryUserDefined, types.UuidFamily: typCategoryUserDefined, types.INetFamily: typCategoryNetworkAddr, diff --git a/pkg/sql/pgwire/pgwirebase/BUILD.bazel b/pkg/sql/pgwire/pgwirebase/BUILD.bazel index bb2b1d401148..66f7bd826927 100644 --- a/pkg/sql/pgwire/pgwirebase/BUILD.bazel +++ b/pkg/sql/pgwire/pgwirebase/BUILD.bazel @@ -42,6 +42,7 @@ go_library( "//pkg/util/tsearch", "//pkg/util/uint128", "//pkg/util/uuid", + "//pkg/util/vector", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_redact//:redact", "@com_github_dustin_go_humanize//:go-humanize", diff --git a/pkg/sql/pgwire/pgwirebase/encoding.go b/pkg/sql/pgwire/pgwirebase/encoding.go index eca9700403fd..9138d56e998f 100644 --- a/pkg/sql/pgwire/pgwirebase/encoding.go +++ b/pkg/sql/pgwire/pgwirebase/encoding.go @@ -44,6 +44,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/tsearch" "github.com/cockroachdb/cockroach/pkg/util/uint128" "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" "github.com/dustin/go-humanize" @@ -477,6 +478,12 @@ func DecodeDatum( return nil, err } return &tree.DTSVector{TSVector: ret}, nil + case oidext.T_pgvector: + ret, err := vector.ParseVector(bs) + if err != nil { + return nil, err + } + return &tree.DPGVector{T: ret}, nil } switch typ.Family() { case types.ArrayFamily, types.TupleFamily: diff --git a/pkg/sql/pgwire/types.go b/pkg/sql/pgwire/types.go index dbe3b6347540..76763862787a 100644 --- a/pkg/sql/pgwire/types.go +++ b/pkg/sql/pgwire/types.go @@ -258,6 +258,10 @@ func writeTextDatumNotNull( b.textFormatter.FormatNode(v) b.writeFromFmtCtx(b.textFormatter) + case *tree.DPGVector: + b.textFormatter.FormatNode(v) + b.writeFromFmtCtx(b.textFormatter) + case *tree.DArray: // Arrays have custom formatting depending on their OID. b.textFormatter.FormatNode(d) diff --git a/pkg/sql/randgen/BUILD.bazel b/pkg/sql/randgen/BUILD.bazel index 32594b4c5e4f..fbf3afb9198e 100644 --- a/pkg/sql/randgen/BUILD.bazel +++ b/pkg/sql/randgen/BUILD.bazel @@ -51,6 +51,7 @@ go_library( "//pkg/util/tsearch", "//pkg/util/uint128", "//pkg/util/uuid", + "//pkg/util/vector", "@com_github_cockroachdb_apd_v3//:apd", "@com_github_cockroachdb_errors//:errors", "@com_github_lib_pq//oid", diff --git a/pkg/sql/randgen/datum.go b/pkg/sql/randgen/datum.go index 298f3d4d0608..54148ce0a96e 100644 --- a/pkg/sql/randgen/datum.go +++ b/pkg/sql/randgen/datum.go @@ -35,6 +35,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/tsearch" "github.com/cockroachdb/cockroach/pkg/util/uint128" "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" "github.com/twpayne/go-geom" @@ -328,6 +329,8 @@ func RandDatumWithNullChance( return tree.NewDTSVector(tsearch.RandomTSVector(rng)) case types.TSQueryFamily: return tree.NewDTSQuery(tsearch.RandomTSQuery(rng)) + case types.PGVectorFamily: + return tree.NewDPGVector(vector.Random(rng)) default: panic(errors.AssertionFailedf("invalid type %v", typ.DebugString())) } diff --git a/pkg/sql/randgen/type.go b/pkg/sql/randgen/type.go index 22c93c7b2df5..30baaad79cfa 100644 --- a/pkg/sql/randgen/type.go +++ b/pkg/sql/randgen/type.go @@ -173,18 +173,18 @@ func IsLegalColumnType(typ *types.T) bool { return false } ctx := context.Background() - version := clustersettings.MakeTestingClusterSettings().Version - return colinfo.ValidateColumnDefType(ctx, version, typ) == nil + st := clustersettings.MakeTestingClusterSettings() + return colinfo.ValidateColumnDefType(ctx, st, typ) == nil } // RandArrayType generates a random array type. func RandArrayType(rng *rand.Rand) *types.T { ctx := context.Background() - version := clustersettings.MakeTestingClusterSettings().Version + st := clustersettings.MakeTestingClusterSettings() for { typ := RandColumnType(rng) resTyp := types.MakeArray(typ) - if err := colinfo.ValidateColumnDefType(ctx, version, resTyp); err == nil { + if err := colinfo.ValidateColumnDefType(ctx, st, resTyp); err == nil { return resTyp } } diff --git a/pkg/sql/rowenc/encoded_datum.go b/pkg/sql/rowenc/encoded_datum.go index de64669dc563..9ff1ad3fab7d 100644 --- a/pkg/sql/rowenc/encoded_datum.go +++ b/pkg/sql/rowenc/encoded_datum.go @@ -325,12 +325,12 @@ func mustUseValueEncodingForFingerprinting(t *types.T) bool { // available, but for historical reasons we will keep on using the // value-encoding (Fingerprint is used by hash routers, so changing its // behavior can result in incorrect results in mixed version clusters). - case types.JsonFamily, types.TSQueryFamily, types.TSVectorFamily: + case types.JsonFamily, types.TSQueryFamily, types.TSVectorFamily, types.PGVectorFamily: return true case types.ArrayFamily: // Note that at time of this writing we don't support arrays of JSON - // (tracked via #23468) nor of TSQuery / TSVector types (tracked by - // #90886), so technically we don't need to do a recursive call here, + // (tracked via #23468) nor of TSQuery / TSVector / PGVector types (tracked by + // #90886, #121432), so technically we don't need to do a recursive call here, // but we choose to be on the safe side, so we do it anyway. return mustUseValueEncodingForFingerprinting(t.ArrayContents()) case types.TupleFamily: diff --git a/pkg/sql/rowenc/encoded_datum_test.go b/pkg/sql/rowenc/encoded_datum_test.go index 10d9c499a800..dadb8b0278a5 100644 --- a/pkg/sql/rowenc/encoded_datum_test.go +++ b/pkg/sql/rowenc/encoded_datum_test.go @@ -221,7 +221,7 @@ func TestEncDatumCompare(t *testing.T) { for _, typ := range types.OidToType { switch typ.Family() { case types.AnyFamily, types.UnknownFamily, types.ArrayFamily, types.JsonFamily, types.TupleFamily, types.VoidFamily, - types.TSQueryFamily, types.TSVectorFamily: + types.TSQueryFamily, types.TSVectorFamily, types.PGVectorFamily: continue case types.CollatedStringFamily: typ = types.MakeCollatedString(types.String, *randgen.RandCollationLocale(rng)) diff --git a/pkg/sql/rowenc/keyside/BUILD.bazel b/pkg/sql/rowenc/keyside/BUILD.bazel index 9990b599ed03..4ce546884cf9 100644 --- a/pkg/sql/rowenc/keyside/BUILD.bazel +++ b/pkg/sql/rowenc/keyside/BUILD.bazel @@ -38,6 +38,7 @@ go_test( deps = [ ":keyside", "//pkg/settings/cluster", + "//pkg/sql/catalog/colinfo", "//pkg/sql/randgen", "//pkg/sql/sem/eval", "//pkg/sql/sem/tree", diff --git a/pkg/sql/rowenc/keyside/keyside_test.go b/pkg/sql/rowenc/keyside/keyside_test.go index 4c7895437df5..1b8e3837a9d4 100644 --- a/pkg/sql/rowenc/keyside/keyside_test.go +++ b/pkg/sql/rowenc/keyside/keyside_test.go @@ -19,6 +19,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo" "github.com/cockroachdb/cockroach/pkg/sql/randgen" "github.com/cockroachdb/cockroach/pkg/sql/rowenc/keyside" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" @@ -243,13 +244,13 @@ func genEncodingDirection() gopter.Gen { } func hasKeyEncoding(typ *types.T) bool { - // Only some types are round-trip key encodable. switch typ.Family() { - case types.CollatedStringFamily, types.TupleFamily, types.DecimalFamily, - types.GeographyFamily, types.GeometryFamily, types.TSVectorFamily, types.TSQueryFamily: + // Special case needed for CollatedStringFamily and DecimalFamily which do have + // a key encoding but do not roundtrip. + case types.CollatedStringFamily, types.DecimalFamily: return false case types.ArrayFamily: return hasKeyEncoding(typ.ArrayContents()) } - return true + return !colinfo.MustBeValueEncoded(typ) } diff --git a/pkg/sql/rowenc/valueside/BUILD.bazel b/pkg/sql/rowenc/valueside/BUILD.bazel index de2a8e28bbd2..22c783c72a5a 100644 --- a/pkg/sql/rowenc/valueside/BUILD.bazel +++ b/pkg/sql/rowenc/valueside/BUILD.bazel @@ -29,6 +29,7 @@ go_library( "//pkg/util/timeutil/pgdate", "//pkg/util/tsearch", "//pkg/util/uuid", + "//pkg/util/vector", "@com_github_cockroachdb_errors//:errors", "@com_github_lib_pq//oid", ], diff --git a/pkg/sql/rowenc/valueside/array.go b/pkg/sql/rowenc/valueside/array.go index a96ddd62799b..b9c1b9ac2578 100644 --- a/pkg/sql/rowenc/valueside/array.go +++ b/pkg/sql/rowenc/valueside/array.go @@ -17,6 +17,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/json" "github.com/cockroachdb/cockroach/pkg/util/tsearch" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" ) @@ -329,6 +330,12 @@ func encodeArrayElement(b []byte, d tree.Datum) ([]byte, error) { return nil, err } return encoding.EncodeUntaggedBytesValue(b, encoded), nil + case *tree.DPGVector: + encoded, err := vector.Encode(nil, t.T) + if err != nil { + return nil, err + } + return encoding.EncodeUntaggedBytesValue(b, encoded), nil default: return nil, errors.Errorf("don't know how to encode %s (%T)", d, d) } diff --git a/pkg/sql/rowenc/valueside/decode.go b/pkg/sql/rowenc/valueside/decode.go index a02745612002..ca6fa44a7485 100644 --- a/pkg/sql/rowenc/valueside/decode.go +++ b/pkg/sql/rowenc/valueside/decode.go @@ -22,6 +22,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/json" "github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate" "github.com/cockroachdb/cockroach/pkg/util/tsearch" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" ) @@ -223,6 +224,16 @@ func DecodeUntaggedDatum( return nil, b, err } return tree.NewDTSVector(v), b, nil + case types.PGVectorFamily: + b, data, err := encoding.DecodeUntaggedBytesValue(buf) + if err != nil { + return nil, b, err + } + vec, err := vector.Decode(data) + if err != nil { + return nil, b, err + } + return tree.NewDPGVector(vec), b, nil case types.OidFamily: // TODO: This possibly should decode to uint32 (with corresponding changes // to encoding) to ensure that the value fits in a DOid without any loss of diff --git a/pkg/sql/rowenc/valueside/encode.go b/pkg/sql/rowenc/valueside/encode.go index b999b957b476..afae3a13cc89 100644 --- a/pkg/sql/rowenc/valueside/encode.go +++ b/pkg/sql/rowenc/valueside/encode.go @@ -17,6 +17,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/encoding" "github.com/cockroachdb/cockroach/pkg/util/json" "github.com/cockroachdb/cockroach/pkg/util/tsearch" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" ) @@ -96,6 +97,12 @@ func Encode(appendTo []byte, colID ColumnIDDelta, val tree.Datum, scratch []byte return nil, err } return encoding.EncodeTSVectorValue(appendTo, uint32(colID), encoded), nil + case *tree.DPGVector: + encoded, err := vector.Encode(scratch, t.T) + if err != nil { + return nil, err + } + return encoding.EncodePGVectorValue(appendTo, uint32(colID), encoded), nil case *tree.DArray: a, err := encodeArray(t, scratch) if err != nil { diff --git a/pkg/sql/rowenc/valueside/legacy.go b/pkg/sql/rowenc/valueside/legacy.go index 6162a72144ce..db37921d4846 100644 --- a/pkg/sql/rowenc/valueside/legacy.go +++ b/pkg/sql/rowenc/valueside/legacy.go @@ -22,6 +22,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate" "github.com/cockroachdb/cockroach/pkg/util/tsearch" "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" ) @@ -168,6 +169,15 @@ func MarshalLegacy(colType *types.T, val tree.Datum) (roachpb.Value, error) { r.SetBytes(data) return r, nil } + case types.PGVectorFamily: + if v, ok := val.(*tree.DPGVector); ok { + data, err := vector.Encode(nil, v.T) + if err != nil { + return r, err + } + r.SetBytes(data) + return r, nil + } case types.ArrayFamily: if v, ok := val.(*tree.DArray); ok { if err := checkElementType(v.ParamTyp, colType.ArrayContents()); err != nil { @@ -422,6 +432,16 @@ func UnmarshalLegacy(a *tree.DatumAlloc, typ *types.T, value roachpb.Value) (tre return nil, err } return tree.NewDTSVector(vec), nil + case types.PGVectorFamily: + v, err := value.GetBytes() + if err != nil { + return nil, err + } + vec, err := vector.Decode(v) + if err != nil { + return nil, err + } + return tree.NewDPGVector(vec), nil case types.EnumFamily: v, err := value.GetBytes() if err != nil { diff --git a/pkg/sql/scanner/scan.go b/pkg/sql/scanner/scan.go index 890e21c2919b..a4cb630e8b88 100644 --- a/pkg/sql/scanner/scan.go +++ b/pkg/sql/scanner/scan.go @@ -315,12 +315,32 @@ func (s *SQLScanner) Scan(lval ScanSymType) { return case '=': // <= s.pos++ + switch s.peek() { + case '>': // <=> + s.pos++ + lval.SetID(lexbase.COS_DISTANCE) + return + } lval.SetID(lexbase.LESS_EQUALS) return case '@': // <@ s.pos++ lval.SetID(lexbase.CONTAINED_BY) return + case '-': // <- + switch s.peekN(1) { + case '>': // <-> + s.pos += 2 + lval.SetID(lexbase.DISTANCE) + return + } + case '#': // <# + switch s.peekN(1) { + case '>': // <#> + s.pos += 2 + lval.SetID(lexbase.NEG_INNER_PRODUCT) + return + } } return diff --git a/pkg/sql/schemachanger/comparator_generated_test.go b/pkg/sql/schemachanger/comparator_generated_test.go index bcff16a947d5..e060cf7aedcb 100644 --- a/pkg/sql/schemachanger/comparator_generated_test.go +++ b/pkg/sql/schemachanger/comparator_generated_test.go @@ -2088,6 +2088,11 @@ func TestSchemaChangeComparator_vectorize_window(t *testing.T) { var logicTestFile = "pkg/sql/logictest/testdata/logic_test/vectorize_window" runSchemaChangeComparatorTest(t, logicTestFile) } +func TestSchemaChangeComparator_vectoross(t *testing.T) { + defer leaktest.AfterTest(t)() + var logicTestFile = "pkg/sql/logictest/testdata/logic_test/vectoross" + runSchemaChangeComparatorTest(t, logicTestFile) +} func TestSchemaChangeComparator_views(t *testing.T) { defer leaktest.AfterTest(t)() var logicTestFile = "pkg/sql/logictest/testdata/logic_test/views" diff --git a/pkg/sql/sem/builtins/BUILD.bazel b/pkg/sql/sem/builtins/BUILD.bazel index aaed197b09f1..8bfab120f80a 100644 --- a/pkg/sql/sem/builtins/BUILD.bazel +++ b/pkg/sql/sem/builtins/BUILD.bazel @@ -19,6 +19,7 @@ go_library( "parse_ident_builtin.go", "pg_builtins.go", "pgcrypto_builtins.go", + "pgvector_builtins.go", "replication_builtins.go", "show_create_all_schemas_builtin.go", "show_create_all_tables_builtin.go", @@ -134,6 +135,7 @@ go_library( "//pkg/util/ulid", "//pkg/util/unaccent", "//pkg/util/uuid", + "//pkg/util/vector", "@com_github_cockroachdb_apd_v3//:apd", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_redact//:redact", diff --git a/pkg/sql/sem/builtins/builtinconstants/constants.go b/pkg/sql/sem/builtins/builtinconstants/constants.go index 699797403da1..c71a01d0b8b8 100644 --- a/pkg/sql/sem/builtins/builtinconstants/constants.go +++ b/pkg/sql/sem/builtins/builtinconstants/constants.go @@ -51,6 +51,7 @@ const ( CategoryJSON = "JSONB" CategoryMultiRegion = "Multi-region" CategoryMultiTenancy = "Multi-tenancy" + CategoryPGVector = "PGVector" CategorySequences = "Sequence" CategorySpatial = "Spatial" CategoryString = "String and byte" diff --git a/pkg/sql/sem/builtins/fixed_oids.go b/pkg/sql/sem/builtins/fixed_oids.go index f7f9c142baa6..d882c601337a 100644 --- a/pkg/sql/sem/builtins/fixed_oids.go +++ b/pkg/sql/sem/builtins/fixed_oids.go @@ -2584,6 +2584,23 @@ var builtinOidsArray = []string{ 2616: `crdb_internal.plan_logical_replication(req: bytes) -> bytes`, 2617: `crdb_internal.start_replication_stream_for_tables(req: bytes) -> bytes`, 2618: `crdb_internal.logical_replication_inject_failures(stream: int, proc: int, percent: int) -> void`, + 2619: `vectorsend(vector: vector) -> bytes`, + 2620: `vectorrecv(input: anyelement) -> vector`, + 2621: `vectorout(vector: vector) -> bytes`, + 2622: `vectorin(input: anyelement) -> vector`, + 2623: `char(vector: vector) -> "char"`, + 2624: `name(vector: vector) -> name`, + 2625: `text(vector: vector) -> string`, + 2626: `varchar(vector: vector) -> varchar`, + 2627: `bpchar(vector: vector) -> char`, + 2628: `vector(string: string) -> vector`, + 2629: `vector(vector: vector) -> vector`, + 2630: `cosine_distance(v1: vector, v2: vector) -> float`, + 2631: `l1_distance(v1: vector, v2: vector) -> float`, + 2632: `l2_distance(v1: vector, v2: vector) -> float`, + 2633: `inner_product(v1: vector, v2: vector) -> float`, + 2634: `vector_dims(vector: vector) -> int`, + 2635: `vector_norm(vector: vector) -> float`, } var builtinOidsBySignature map[string]oid.Oid diff --git a/pkg/sql/sem/builtins/pgvector_builtins.go b/pkg/sql/sem/builtins/pgvector_builtins.go new file mode 100644 index 000000000000..1bcbc03bdccb --- /dev/null +++ b/pkg/sql/sem/builtins/pgvector_builtins.go @@ -0,0 +1,142 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package builtins + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinconstants" + "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sem/volatility" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/vector" +) + +func init() { + for k, v := range pgvectorBuiltins { + v.props.Category = builtinconstants.CategoryPGVector + v.props.AvailableOnPublicSchema = true + const enforceClass = true + registerBuiltin(k, v, tree.NormalClass, enforceClass) + } +} + +var pgvectorBuiltins = map[string]builtinDefinition{ + "cosine_distance": makeBuiltin(defProps(), + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "v1", Typ: types.PGVector}, + {Name: "v2", Typ: types.PGVector}, + }, + ReturnType: tree.FixedReturnType(types.Float), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + v1 := tree.MustBeDPGVector(args[0]) + v2 := tree.MustBeDPGVector(args[1]) + distance, err := vector.CosDistance(v1.T, v2.T) + if err != nil { + return nil, err + } + return tree.NewDFloat(tree.DFloat(distance)), nil + }, + Info: "Returns the cosine distance between the two vectors.", + Volatility: volatility.Immutable, + }, + ), + "inner_product": makeBuiltin(defProps(), + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "v1", Typ: types.PGVector}, + {Name: "v2", Typ: types.PGVector}, + }, + ReturnType: tree.FixedReturnType(types.Float), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + v1 := tree.MustBeDPGVector(args[0]) + v2 := tree.MustBeDPGVector(args[1]) + distance, err := vector.InnerProduct(v1.T, v2.T) + if err != nil { + return nil, err + } + return tree.NewDFloat(tree.DFloat(distance)), nil + }, + Info: "Returns the inner product between the two vectors.", + Volatility: volatility.Immutable, + }, + ), + "l1_distance": makeBuiltin(defProps(), + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "v1", Typ: types.PGVector}, + {Name: "v2", Typ: types.PGVector}, + }, + ReturnType: tree.FixedReturnType(types.Float), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + v1 := tree.MustBeDPGVector(args[0]) + v2 := tree.MustBeDPGVector(args[1]) + distance, err := vector.L1Distance(v1.T, v2.T) + if err != nil { + return nil, err + } + return tree.NewDFloat(tree.DFloat(distance)), nil + }, + Info: "Returns the Manhattan distance between the two vectors.", + Volatility: volatility.Immutable, + }, + ), + "l2_distance": makeBuiltin(defProps(), + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "v1", Typ: types.PGVector}, + {Name: "v2", Typ: types.PGVector}, + }, + ReturnType: tree.FixedReturnType(types.Float), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + v1 := tree.MustBeDPGVector(args[0]) + v2 := tree.MustBeDPGVector(args[1]) + distance, err := vector.L2Distance(v1.T, v2.T) + if err != nil { + return nil, err + } + return tree.NewDFloat(tree.DFloat(distance)), nil + }, + Info: "Returns the Euclidean distance between the two vectors.", + Volatility: volatility.Immutable, + }, + ), + "vector_dims": makeBuiltin(defProps(), + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "vector", Typ: types.PGVector}, + }, + ReturnType: tree.FixedReturnType(types.Int), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + v1 := tree.MustBeDPGVector(args[0]) + return tree.NewDInt(tree.DInt(len(v1.T))), nil + }, + Info: "Returns the number of the dimensions in the vector.", + Volatility: volatility.Immutable, + }, + ), + "vector_norm": makeBuiltin(defProps(), + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "vector", Typ: types.PGVector}, + }, + ReturnType: tree.FixedReturnType(types.Float), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + v1 := tree.MustBeDPGVector(args[0]) + return tree.NewDFloat(tree.DFloat(vector.Norm(v1.T))), nil + }, + Info: "Returns the Euclidean norm of the vector.", + Volatility: volatility.Immutable, + }, + ), +} diff --git a/pkg/sql/sem/cast/cast.go b/pkg/sql/sem/cast/cast.go index ad0b1cd6c220..cefdca085d36 100644 --- a/pkg/sql/sem/cast/cast.go +++ b/pkg/sql/sem/cast/cast.go @@ -251,6 +251,20 @@ func LookupCast(src, tgt *types.T) (Cast, bool) { }, true } + if srcFamily == types.ArrayFamily && tgtFamily == types.PGVectorFamily { + return Cast{ + MaxContext: ContextAssignment, + Volatility: volatility.Stable, + }, true + } + + if srcFamily == types.PGVectorFamily && tgtFamily == types.ArrayFamily { + return Cast{ + MaxContext: ContextAssignment, + Volatility: volatility.Stable, + }, true + } + // Casts from array and tuple types to string types are immutable and // allowed in assignment contexts. // TODO(mgartner): Tuple to string casts should be stable. They are diff --git a/pkg/sql/sem/cast/cast_map.go b/pkg/sql/sem/cast/cast_map.go index 24be3ecd3693..039a01d77b39 100644 --- a/pkg/sql/sem/cast/cast_map.go +++ b/pkg/sql/sem/cast/cast_map.go @@ -72,6 +72,13 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{ oid.T_varchar: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, oid.T_text: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, }, + oidext.T_pgvector: { + oid.T_bpchar: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_char: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_name: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_varchar: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_text: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + }, oid.T_bpchar: { oid.T_bpchar: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable}, oid.T_char: {MaxContext: ContextAssignment, origin: ContextOriginPgCast, Volatility: volatility.Immutable}, @@ -79,11 +86,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{ oid.T_text: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable}, oid.T_varchar: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable}, // Automatic I/O conversions from bpchar to other types. - oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, oid.T_date: { MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, @@ -166,11 +174,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{ // Automatic I/O conversions to string types. oid.T_name: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, // Automatic I/O conversions from "char" to other types. - oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, oid.T_date: { MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, @@ -486,11 +495,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{ // Automatic I/O conversions to string types. oid.T_char: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, // Automatic I/O conversions from NAME to other types. - oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, oid.T_date: { MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, @@ -733,11 +743,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{ oid.T_text: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable}, oid.T_varchar: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable}, // Automatic I/O conversions from TEXT to other types. - oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, oid.T_date: { MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, @@ -966,11 +977,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{ oid.T_text: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable}, oid.T_varchar: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable}, // Automatic I/O conversions from VARCHAR to other types. - oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, - oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, + oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable}, oid.T_date: { MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, diff --git a/pkg/sql/sem/eval/BUILD.bazel b/pkg/sql/sem/eval/BUILD.bazel index a1fe5927185d..05c92e17f7db 100644 --- a/pkg/sql/sem/eval/BUILD.bazel +++ b/pkg/sql/sem/eval/BUILD.bazel @@ -97,6 +97,7 @@ go_library( "//pkg/util/tsearch", "//pkg/util/ulid", "//pkg/util/uuid", + "//pkg/util/vector", "@com_github_cockroachdb_apd_v3//:apd", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_redact//:redact", diff --git a/pkg/sql/sem/eval/binary_op.go b/pkg/sql/sem/eval/binary_op.go index 9ca3d9bf12f7..fd48722a72de 100644 --- a/pkg/sql/sem/eval/binary_op.go +++ b/pkg/sql/sem/eval/binary_op.go @@ -32,6 +32,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/timeofday" "github.com/cockroachdb/cockroach/pkg/util/trigram" "github.com/cockroachdb/cockroach/pkg/util/tsearch" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" ) @@ -1229,6 +1230,33 @@ func (e *evaluator) EvalTSMatchesVectorQueryOp( return tree.MakeDBool(tree.DBool(ret)), err } +func (e *evaluator) EvalDistanceVectorOp( + ctx context.Context, _ *tree.DistanceVectorOp, left, right tree.Datum, +) (tree.Datum, error) { + v := tree.MustBeDPGVector(left) + q := tree.MustBeDPGVector(right) + ret, err := vector.L2Distance(q.T, v.T) + return tree.NewDFloat(tree.DFloat(ret)), err +} + +func (e *evaluator) EvalCosDistanceVectorOp( + ctx context.Context, _ *tree.CosDistanceVectorOp, left, right tree.Datum, +) (tree.Datum, error) { + v := tree.MustBeDPGVector(left) + q := tree.MustBeDPGVector(right) + ret, err := vector.CosDistance(q.T, v.T) + return tree.NewDFloat(tree.DFloat(ret)), err +} + +func (e *evaluator) EvalNegInnerProductVectorOp( + ctx context.Context, _ *tree.NegInnerProductVectorOp, left, right tree.Datum, +) (tree.Datum, error) { + v := tree.MustBeDPGVector(left) + q := tree.MustBeDPGVector(right) + ret, err := vector.NegInnerProduct(q.T, v.T) + return tree.NewDFloat(tree.DFloat(ret)), err +} + func (e *evaluator) EvalPlusDateIntOp( ctx context.Context, _ *tree.PlusDateIntOp, left, right tree.Datum, ) (tree.Datum, error) { @@ -1648,3 +1676,39 @@ func decimalPGLSNEval( } return tree.NewDPGLSN(resultLSN), nil } + +func (e *evaluator) EvalPlusPGVectorOp( + ctx context.Context, _ *tree.PlusPGVectorOp, left, right tree.Datum, +) (tree.Datum, error) { + t1 := tree.MustBeDPGVector(left) + t2 := tree.MustBeDPGVector(right) + ret, err := vector.Add(t1.T, t2.T) + if err != nil { + return nil, err + } + return tree.NewDPGVector(ret), nil +} + +func (e *evaluator) EvalMinusPGVectorOp( + ctx context.Context, _ *tree.MinusPGVectorOp, left, right tree.Datum, +) (tree.Datum, error) { + t1 := tree.MustBeDPGVector(left) + t2 := tree.MustBeDPGVector(right) + ret, err := vector.Minus(t1.T, t2.T) + if err != nil { + return nil, err + } + return tree.NewDPGVector(ret), nil +} + +func (e *evaluator) EvalMultPGVectorOp( + ctx context.Context, _ *tree.MultPGVectorOp, left, right tree.Datum, +) (tree.Datum, error) { + t1 := tree.MustBeDPGVector(left) + t2 := tree.MustBeDPGVector(right) + ret, err := vector.Mult(t1.T, t2.T) + if err != nil { + return nil, err + } + return tree.NewDPGVector(ret), nil +} diff --git a/pkg/sql/sem/eval/cast.go b/pkg/sql/sem/eval/cast.go index b5cc8aa39d5a..41d64cbdfb9c 100644 --- a/pkg/sql/sem/eval/cast.go +++ b/pkg/sql/sem/eval/cast.go @@ -18,6 +18,7 @@ import ( "time" "github.com/cockroachdb/apd/v3" + "github.com/cockroachdb/cockroach/pkg/clusterversion" "github.com/cockroachdb/cockroach/pkg/geo" "github.com/cockroachdb/cockroach/pkg/geo/geopb" "github.com/cockroachdb/cockroach/pkg/sql/lex" @@ -34,6 +35,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate" "github.com/cockroachdb/cockroach/pkg/util/tsearch" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" ) @@ -502,6 +504,8 @@ func performCastWithoutPrecisionTruncation( s = t.TSQuery.String() case *tree.DTSVector: s = t.TSVector.String() + case *tree.DPGVector: + s = t.T.String() case *tree.DEnum: s = t.LogicalRep case *tree.DVoid: @@ -601,6 +605,38 @@ func performCastWithoutPrecisionTruncation( return d, nil } + case types.PGVectorFamily: + if !evalCtx.Settings.Version.IsActive(ctx, clusterversion.V24_2) { + return nil, pgerror.Newf(pgcode.FeatureNotSupported, + "version %v must be finalized to use vector", + clusterversion.V24_2.Version()) + } + switch d := d.(type) { + case *tree.DString: + return tree.ParseDPGVector(string(*d)) + case *tree.DCollatedString: + return tree.ParseDPGVector(d.Contents) + case *tree.DArray: + switch d.ParamTyp.Family() { + case types.FloatFamily, types.IntFamily, types.DecimalFamily: + v := make(vector.T, len(d.Array)) + for i, elem := range d.Array { + if elem == tree.DNull { + return nil, pgerror.Newf(pgcode.NullValueNotAllowed, + "array must not contain nulls") + } + datum, err := performCast(ctx, evalCtx, elem, types.Float4, false) + if err != nil { + return nil, err + } + v[i] = float32(*datum.(*tree.DFloat)) + } + return tree.NewDPGVector(v), nil + } + case *tree.DPGVector: + return d, nil + } + case types.RefCursorFamily: switch d := d.(type) { case *tree.DString: @@ -929,6 +965,14 @@ func performCastWithoutPrecisionTruncation( } } return dcast, nil + case *tree.DPGVector: + dcast := tree.NewDArray(t.ArrayContents()) + for i := range v.T { + if err := dcast.Append(tree.NewDFloat(tree.DFloat(v.T[i]))); err != nil { + return nil, err + } + } + return dcast, nil } case types.OidFamily: switch v := d.(type) { diff --git a/pkg/sql/sem/eval/testdata/eval/vector b/pkg/sql/sem/eval/testdata/eval/vector new file mode 100644 index 000000000000..53755ffff368 --- /dev/null +++ b/pkg/sql/sem/eval/testdata/eval/vector @@ -0,0 +1,143 @@ +# Basic smoke tests for pgvector evaluation. + +eval +'[1,2]'::vector <-> '[1,2]'::vector +---- +0.0 + +eval +'[1,2]'::vector <#> '[1,2]'::vector +---- +-5.0 + +eval +'[1,2]'::vector <=> '[1,2]'::vector +---- +0.0 + +eval +'[1,2,3]'::vector <-> '[3,1,2]'::vector +---- +2.449489742783178 + +eval +'[1,2,3]'::vector <#> '[3,1,2]'::vector +---- +-11.0 + +eval +'[1,2,3]'::vector <=> '[3,1,2]'::vector +---- +0.2142857142857143 + +eval +'[1,2,3]'::vector - '[3,1,2]'::vector +---- +[-2,1,1] + +eval +'[1,2,3]'::vector + '[3,1,2]'::vector +---- +[4,3,5] + +eval +'[1,2,3]'::vector * '[3,1,2]'::vector +---- +[3,2,6] + +eval +cosine_distance('[1,2,3]'::vector, '[3,1,2]'::vector) +---- +0.2142857142857143 + +eval +inner_product('[1,2,3]'::vector, '[3,1,2]'::vector) +---- +11.0 + +eval +l1_distance('[1,2,3]'::vector, '[3,1,2]'::vector) +---- +4.0 + +eval +l2_distance('[1,2,3]'::vector, '[3,1,2]'::vector) +---- +2.449489742783178 + +eval +vector_dims('[1,2,3]'::vector) +---- +3 + +eval +vector_norm('[1,2,3]'::vector) +---- +3.7416573867739413 + +eval +'[1,2]'::vector < '[1,2]'::vector +---- +false + +eval +'[1,2]'::vector <= '[1,2]'::vector +---- +true + +eval +'[1,2]'::vector > '[1,2]'::vector +---- +false + +eval +'[1,2]'::vector >= '[1,2]'::vector +---- +true + +eval +'[1,2]'::vector < '[1,3]'::vector +---- +true + +eval +'[2,2]'::vector < '[1,3]'::vector +---- +false + +eval +'[2,2]'::vector > '[1,3]'::vector +---- +true + +# Mixed dimension comparisons + +eval +'[1]'::vector < '[1,1]'::vector +---- +true + +eval +'[1,1]'::vector < '[1]'::vector +---- +false + +eval +'[1]'::vector > '[1,1]'::vector +---- +false + +eval +'[1,1]'::vector > '[1]'::vector +---- +true + +eval +'[100000000000000000000000]'::vector * '[100000000000000000000]'::vector +---- +value out of range: overflow + +eval +'[.000000000000000000000001]'::vector * '[.0000000000000000000001]'::vector +---- +value out of range: underflow diff --git a/pkg/sql/sem/eval/unsupported_types.go b/pkg/sql/sem/eval/unsupported_types.go index 89fce60c6cb2..6ac35c46303f 100644 --- a/pkg/sql/sem/eval/unsupported_types.go +++ b/pkg/sql/sem/eval/unsupported_types.go @@ -14,22 +14,23 @@ import ( "context" "github.com/cockroachdb/cockroach/pkg/clusterversion" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" ) type unsupportedTypeChecker struct { - // Uncomment this when a new type is introduced. - // version clusterversion.Handle + // Uncomment this when a new type is introduced, or comment it out if there + // are no types in the checker. + version clusterversion.Handle } // NewUnsupportedTypeChecker returns a new tree.UnsupportedTypeChecker that can // be used to check whether a type is allowed by the current cluster version. -func NewUnsupportedTypeChecker(clusterversion.Handle) tree.UnsupportedTypeChecker { - // Right now we don't have any unsupported types, so there is no benefit in - // returning a type checker that never errors, so we just return nil. (The - // infrastructure is already set up to handle such case gracefully.) - return nil +func NewUnsupportedTypeChecker(handle clusterversion.Handle) tree.UnsupportedTypeChecker { + // If there are no types in the checker, change this code to return nil. + return &unsupportedTypeChecker{version: handle} } var _ tree.UnsupportedTypeChecker = &unsupportedTypeChecker{} @@ -38,5 +39,16 @@ var _ tree.UnsupportedTypeChecker = &unsupportedTypeChecker{} func (tc *unsupportedTypeChecker) CheckType(ctx context.Context, typ *types.T) error { // NB: when adding an unsupported type here, change the constructor to not // return nil. + var errorTypeString string + switch typ.Family() { + case types.PGVectorFamily: + errorTypeString = "vector" + } + if errorTypeString != "" && !tc.version.IsActive(ctx, clusterversion.V24_2) { + return pgerror.Newf(pgcode.FeatureNotSupported, + "%s not supported until version 24.2", errorTypeString, + ) + } + return nil } diff --git a/pkg/sql/sem/tree/BUILD.bazel b/pkg/sql/sem/tree/BUILD.bazel index 530f89114d94..9ff37ef0738f 100644 --- a/pkg/sql/sem/tree/BUILD.bazel +++ b/pkg/sql/sem/tree/BUILD.bazel @@ -168,6 +168,7 @@ go_library( "//pkg/util/tsearch", "//pkg/util/uint128", "//pkg/util/uuid", + "//pkg/util/vector", "@com_github_cockroachdb_apd_v3//:apd", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_redact//:redact", diff --git a/pkg/sql/sem/tree/constant.go b/pkg/sql/sem/tree/constant.go index 85f727ed416f..a0226851c446 100644 --- a/pkg/sql/sem/tree/constant.go +++ b/pkg/sql/sem/tree/constant.go @@ -544,6 +544,8 @@ var ( types.Jsonb, types.PGLSN, types.PGLSNArray, + types.PGVector, + types.PGVectorArray, types.RefCursor, types.RefCursorArray, types.TSQuery, diff --git a/pkg/sql/sem/tree/datum.go b/pkg/sql/sem/tree/datum.go index 1b264982debd..2ad79ddaf623 100644 --- a/pkg/sql/sem/tree/datum.go +++ b/pkg/sql/sem/tree/datum.go @@ -49,6 +49,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/tsearch" "github.com/cockroachdb/cockroach/pkg/util/uint128" "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" "github.com/lib/pq/oid" @@ -3632,6 +3633,112 @@ func (d *DPGLSN) Size() uintptr { return unsafe.Sizeof(*d) } +// DPGVector is the Datum representation of the PGVector type. +type DPGVector struct { + vector.T +} + +// NewDPGVector returns a new PGVector Datum. +func NewDPGVector(vector vector.T) *DPGVector { return &DPGVector{vector} } + +// AsDPGVector attempts to retrieve a DPGVector from an Expr, returning a +// DPGVector and a flag signifying whether the assertion was successful. The +// function should be used instead of direct type assertions wherever a +// *DPGVector wrapped by a *DOidWrapper is possible. +func AsDPGVector(e Expr) (*DPGVector, bool) { + switch t := e.(type) { + case *DPGVector: + return t, true + case *DOidWrapper: + return AsDPGVector(t.Wrapped) + } + return nil, false +} + +// MustBeDPGVector attempts to retrieve a DPGVector from an Expr, panicking if the +// assertion fails. +func MustBeDPGVector(e Expr) *DPGVector { + v, ok := AsDPGVector(e) + if !ok { + panic(errors.AssertionFailedf("expected *DPGVector, found %T", e)) + } + return v +} + +// ParseDPGVector takes a string of PGVector and returns a DPGVector value. +func ParseDPGVector(s string) (Datum, error) { + v, err := vector.ParseVector(s) + if err != nil { + return nil, pgerror.Wrapf(err, pgcode.Syntax, "could not parse vector") + } + return NewDPGVector(v), nil +} + +// Format implements the NodeFormatter interface. +func (d *DPGVector) Format(ctx *FmtCtx) { + bareStrings := ctx.HasFlags(FmtFlags(lexbase.EncBareStrings)) + if !bareStrings { + ctx.WriteByte('\'') + } + ctx.WriteString(d.String()) + if !bareStrings { + ctx.WriteByte('\'') + } +} + +// ResolvedType implements the TypedExpr interface. +func (d *DPGVector) ResolvedType() *types.T { return types.PGVector } + +// AmbiguousFormat implements the Datum interface. +func (d *DPGVector) AmbiguousFormat() bool { + return true +} + +func (d *DPGVector) Compare(ctx context.Context, cmpCtx CompareContext, other Datum) (int, error) { + if other == DNull { + // NULL is less than any non-NULL value. + return 1, nil + } + v, ok := cmpCtx.UnwrapDatum(ctx, other).(*DPGVector) + if !ok { + return 0, makeUnsupportedComparisonMessage(d, other) + } + return d.T.Compare(v.T) +} + +// Prev implements the Datum interface. +func (d *DPGVector) Prev(ctx context.Context, cmpCtx CompareContext) (Datum, bool) { + return nil, false +} + +// Next implements the Datum interface. +func (d *DPGVector) Next(ctx context.Context, cmpCtx CompareContext) (Datum, bool) { + return nil, false +} + +// IsMax implements the Datum interface. +func (d *DPGVector) IsMax(ctx context.Context, cmpCtx CompareContext) bool { + return false +} + +// IsMin implements the Datum interface. +func (d *DPGVector) IsMin(ctx context.Context, cmpCtx CompareContext) bool { + return false +} + +// Max implements the Datum interface. +func (d *DPGVector) Max(ctx context.Context, cmpCtx CompareContext) (Datum, bool) { + return nil, false +} + +// Min implements the Datum interface. +func (d *DPGVector) Min(ctx context.Context, cmpCtx CompareContext) (Datum, bool) { return nil, false } + +// Size implements the Datum interface. +func (d *DPGVector) Size() uintptr { + return unsafe.Sizeof(*d) + d.T.Size() +} + // DBox2D is the Datum representation of the Box2D type. type DBox2D struct { geo.CartesianBoundingBox @@ -3906,7 +4013,7 @@ func AsJSON( // This is RFC3339Nano, but without the TZ fields. return json.FromString(formatTime(t.UTC(), "2006-01-02T15:04:05.999999999")), nil case *DDate, *DUuid, *DOid, *DInterval, *DBytes, *DIPAddr, *DTime, *DTimeTZ, *DBitArray, *DBox2D, - *DTSVector, *DTSQuery, *DPGLSN: + *DTSVector, *DTSQuery, *DPGLSN, *DPGVector: return json.FromString( AsStringWithFlags(t, FmtBareStrings, FmtDataConversionConfig(dcc), FmtLocation(loc)), ), nil @@ -5963,6 +6070,7 @@ var baseDatumTypeSizes = map[types.Family]struct { types.GeographyFamily: {unsafe.Sizeof(DGeography{}), variableSize}, types.GeometryFamily: {unsafe.Sizeof(DGeometry{}), variableSize}, types.PGLSNFamily: {unsafe.Sizeof(DPGLSN{}), fixedSize}, + types.PGVectorFamily: {unsafe.Sizeof(DPGVector{}), variableSize}, types.RefCursorFamily: {unsafe.Sizeof(DString("")), variableSize}, types.TimeFamily: {unsafe.Sizeof(DTime(0)), fixedSize}, types.TimeTZFamily: {unsafe.Sizeof(DTimeTZ{}), fixedSize}, @@ -6311,6 +6419,14 @@ func AdjustValueToType(typ *types.T, inVal Datum) (outVal Datum, err error) { return nil, err } } + case types.PGVectorFamily: + if in, ok := inVal.(*DPGVector); ok { + width := int(typ.Width()) + if width > 0 && len(in.T) != width { + return nil, pgerror.Newf(pgcode.DataException, + "expected %d dimensions, not %d", typ.Width(), len(in.T)) + } + } } return inVal, nil } diff --git a/pkg/sql/sem/tree/eval.go b/pkg/sql/sem/tree/eval.go index c597129de5a0..d72c47a0a66e 100644 --- a/pkg/sql/sem/tree/eval.go +++ b/pkg/sql/sem/tree/eval.go @@ -801,6 +801,13 @@ var BinOps = map[treebin.BinaryOperatorSymbol]*BinOpOverloads{ EvalOp: &PlusPGLSNDecimalOp{}, Volatility: volatility.Immutable, }, + { + LeftType: types.PGVector, + RightType: types.PGVector, + ReturnType: types.PGVector, + EvalOp: &PlusPGVectorOp{}, + Volatility: volatility.Immutable, + }, }}, treebin.Minus: {overloads: []*BinOp{ @@ -987,6 +994,13 @@ var BinOps = map[treebin.BinaryOperatorSymbol]*BinOpOverloads{ EvalOp: &MinusPGLSNOp{}, Volatility: volatility.Immutable, }, + { + LeftType: types.PGVector, + RightType: types.PGVector, + ReturnType: types.PGVector, + EvalOp: &MinusPGVectorOp{}, + Volatility: volatility.Immutable, + }, }}, treebin.Mult: {overloads: []*BinOp{ @@ -1070,6 +1084,13 @@ var BinOps = map[treebin.BinaryOperatorSymbol]*BinOpOverloads{ EvalOp: &MultIntervalDecimalOp{}, Volatility: volatility.Immutable, }, + { + LeftType: types.PGVector, + RightType: types.PGVector, + ReturnType: types.PGVector, + EvalOp: &MultPGVectorOp{}, + Volatility: volatility.Immutable, + }, }}, treebin.Div: {overloads: []*BinOp{ @@ -1382,6 +1403,33 @@ var BinOps = map[treebin.BinaryOperatorSymbol]*BinOpOverloads{ Volatility: volatility.Immutable, }, }}, + treebin.Distance: {overloads: []*BinOp{ + { + LeftType: types.PGVector, + RightType: types.PGVector, + ReturnType: types.Float, + EvalOp: &DistanceVectorOp{}, + Volatility: volatility.Immutable, + }, + }}, + treebin.CosDistance: {overloads: []*BinOp{ + { + LeftType: types.PGVector, + RightType: types.PGVector, + ReturnType: types.Float, + EvalOp: &CosDistanceVectorOp{}, + Volatility: volatility.Immutable, + }, + }}, + treebin.NegInnerProduct: {overloads: []*BinOp{ + { + LeftType: types.PGVector, + RightType: types.PGVector, + ReturnType: types.Float, + EvalOp: &NegInnerProductVectorOp{}, + Volatility: volatility.Immutable, + }, + }}, } // CmpOp is a comparison operator. @@ -1575,6 +1623,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{ makeEqFn(types.Jsonb, types.Jsonb, volatility.Immutable), makeEqFn(types.Oid, types.Oid, volatility.Leakproof), makeEqFn(types.PGLSN, types.PGLSN, volatility.Leakproof), + makeEqFn(types.PGVector, types.PGVector, volatility.Leakproof), makeEqFn(types.RefCursor, types.RefCursor, volatility.Leakproof), makeEqFn(types.String, types.String, volatility.Leakproof), makeEqFn(types.Time, types.Time, volatility.Leakproof), @@ -1635,6 +1684,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{ makeLtFn(types.Interval, types.Interval, volatility.Leakproof), makeLtFn(types.Oid, types.Oid, volatility.Leakproof), makeLtFn(types.PGLSN, types.PGLSN, volatility.Leakproof), + makeLtFn(types.PGVector, types.PGVector, volatility.Leakproof), makeLtFn(types.RefCursor, types.RefCursor, volatility.Leakproof), makeLtFn(types.String, types.String, volatility.Leakproof), makeLtFn(types.Time, types.Time, volatility.Leakproof), @@ -1694,6 +1744,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{ makeLeFn(types.Interval, types.Interval, volatility.Leakproof), makeLeFn(types.Oid, types.Oid, volatility.Leakproof), makeLeFn(types.PGLSN, types.PGLSN, volatility.Leakproof), + makeLeFn(types.PGVector, types.PGVector, volatility.Leakproof), makeLeFn(types.RefCursor, types.RefCursor, volatility.Leakproof), makeLeFn(types.String, types.String, volatility.Leakproof), makeLeFn(types.Time, types.Time, volatility.Leakproof), @@ -1774,6 +1825,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{ makeIsFn(types.Jsonb, types.Jsonb, volatility.Immutable), makeIsFn(types.Oid, types.Oid, volatility.Leakproof), makeIsFn(types.PGLSN, types.PGLSN, volatility.Leakproof), + makeIsFn(types.PGVector, types.PGVector, volatility.Leakproof), makeIsFn(types.RefCursor, types.RefCursor, volatility.Leakproof), makeIsFn(types.String, types.String, volatility.Leakproof), makeIsFn(types.Time, types.Time, volatility.Leakproof), @@ -1842,6 +1894,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{ makeEvalTupleIn(types.Jsonb, volatility.Leakproof), makeEvalTupleIn(types.Oid, volatility.Leakproof), makeEvalTupleIn(types.PGLSN, volatility.Leakproof), + makeEvalTupleIn(types.PGVector, volatility.Leakproof), makeEvalTupleIn(types.RefCursor, volatility.Leakproof), makeEvalTupleIn(types.String, volatility.Leakproof), makeEvalTupleIn(types.Time, volatility.Leakproof), diff --git a/pkg/sql/sem/tree/eval_binary_ops.go b/pkg/sql/sem/tree/eval_binary_ops.go index 0c96d326d24d..e68a034328a0 100644 --- a/pkg/sql/sem/tree/eval_binary_ops.go +++ b/pkg/sql/sem/tree/eval_binary_ops.go @@ -81,6 +81,15 @@ type TSMatchesVectorQueryOp struct{} // TSMatchesQueryVectorOp is a BinaryEvalOp. type TSMatchesQueryVectorOp struct{} +type ( + // DistanceVectorOp is a BinaryEvalOp. + DistanceVectorOp struct{} + // CosDistanceVectorOp is a BinaryEvalOp. + CosDistanceVectorOp struct{} + // NegInnerProductVectorOp is a BinaryEvalOp. + NegInnerProductVectorOp struct{} +) + // AppendToMaybeNullArrayOp is a BinaryEvalOp. type AppendToMaybeNullArrayOp struct { Typ *types.T @@ -180,6 +189,8 @@ type ( PlusDecimalPGLSNOp struct{} // PlusPGLSNDecimalOp is a BinaryEvalOp. PlusPGLSNDecimalOp struct{} + // PlusPGVectorOp is a BinaryEvalOp. + PlusPGVectorOp struct{} ) type ( @@ -235,6 +246,8 @@ type ( MinusPGLSNDecimalOp struct{} // MinusPGLSNOp is a BinaryEvalOp. MinusPGLSNOp struct{} + // MinusPGVectorOp is a BinaryEvalOp. + MinusPGVectorOp struct{} ) type ( // MultDecimalIntOp is a BinaryEvalOp. @@ -259,6 +272,8 @@ type ( MultIntervalFloatOp struct{} // MultIntervalIntOp is a BinaryEvalOp. MultIntervalIntOp struct{} + // MultPGVectorOp is a BinaryEvalOp. + MultPGVectorOp struct{} ) type ( diff --git a/pkg/sql/sem/tree/eval_expr_generated.go b/pkg/sql/sem/tree/eval_expr_generated.go index 3726d3c19802..e3d8f59439fa 100644 --- a/pkg/sql/sem/tree/eval_expr_generated.go +++ b/pkg/sql/sem/tree/eval_expr_generated.go @@ -222,6 +222,11 @@ func (node *DPGLSN) Eval(ctx context.Context, v ExprEvaluator) (Datum, error) { return node, nil } +// Eval is part of the TypedExpr interface. +func (node *DPGVector) Eval(ctx context.Context, v ExprEvaluator) (Datum, error) { + return node, nil +} + // Eval is part of the TypedExpr interface. func (node *DString) Eval(ctx context.Context, v ExprEvaluator) (Datum, error) { return node, nil diff --git a/pkg/sql/sem/tree/eval_op_generated.go b/pkg/sql/sem/tree/eval_op_generated.go index a99cd26ef4f1..f53bf8f4e845 100644 --- a/pkg/sql/sem/tree/eval_op_generated.go +++ b/pkg/sql/sem/tree/eval_op_generated.go @@ -77,6 +77,8 @@ type BinaryOpEvaluator interface { EvalContainedByJsonbOp(context.Context, *ContainedByJsonbOp, Datum, Datum) (Datum, error) EvalContainsArrayOp(context.Context, *ContainsArrayOp, Datum, Datum) (Datum, error) EvalContainsJsonbOp(context.Context, *ContainsJsonbOp, Datum, Datum) (Datum, error) + EvalCosDistanceVectorOp(context.Context, *CosDistanceVectorOp, Datum, Datum) (Datum, error) + EvalDistanceVectorOp(context.Context, *DistanceVectorOp, Datum, Datum) (Datum, error) EvalDivDecimalIntOp(context.Context, *DivDecimalIntOp, Datum, Datum) (Datum, error) EvalDivDecimalOp(context.Context, *DivDecimalOp, Datum, Datum) (Datum, error) EvalDivFloatOp(context.Context, *DivFloatOp, Datum, Datum) (Datum, error) @@ -121,6 +123,7 @@ type BinaryOpEvaluator interface { EvalMinusJsonbStringOp(context.Context, *MinusJsonbStringOp, Datum, Datum) (Datum, error) EvalMinusPGLSNDecimalOp(context.Context, *MinusPGLSNDecimalOp, Datum, Datum) (Datum, error) EvalMinusPGLSNOp(context.Context, *MinusPGLSNOp, Datum, Datum) (Datum, error) + EvalMinusPGVectorOp(context.Context, *MinusPGVectorOp, Datum, Datum) (Datum, error) EvalMinusTimeIntervalOp(context.Context, *MinusTimeIntervalOp, Datum, Datum) (Datum, error) EvalMinusTimeOp(context.Context, *MinusTimeOp, Datum, Datum) (Datum, error) EvalMinusTimeTZIntervalOp(context.Context, *MinusTimeTZIntervalOp, Datum, Datum) (Datum, error) @@ -147,6 +150,8 @@ type BinaryOpEvaluator interface { EvalMultIntervalDecimalOp(context.Context, *MultIntervalDecimalOp, Datum, Datum) (Datum, error) EvalMultIntervalFloatOp(context.Context, *MultIntervalFloatOp, Datum, Datum) (Datum, error) EvalMultIntervalIntOp(context.Context, *MultIntervalIntOp, Datum, Datum) (Datum, error) + EvalMultPGVectorOp(context.Context, *MultPGVectorOp, Datum, Datum) (Datum, error) + EvalNegInnerProductVectorOp(context.Context, *NegInnerProductVectorOp, Datum, Datum) (Datum, error) EvalOverlapsArrayOp(context.Context, *OverlapsArrayOp, Datum, Datum) (Datum, error) EvalOverlapsINetOp(context.Context, *OverlapsINetOp, Datum, Datum) (Datum, error) EvalPlusDateIntOp(context.Context, *PlusDateIntOp, Datum, Datum) (Datum, error) @@ -169,6 +174,7 @@ type BinaryOpEvaluator interface { EvalPlusIntervalTimestampOp(context.Context, *PlusIntervalTimestampOp, Datum, Datum) (Datum, error) EvalPlusIntervalTimestampTZOp(context.Context, *PlusIntervalTimestampTZOp, Datum, Datum) (Datum, error) EvalPlusPGLSNDecimalOp(context.Context, *PlusPGLSNDecimalOp, Datum, Datum) (Datum, error) + EvalPlusPGVectorOp(context.Context, *PlusPGVectorOp, Datum, Datum) (Datum, error) EvalPlusTimeDateOp(context.Context, *PlusTimeDateOp, Datum, Datum) (Datum, error) EvalPlusTimeIntervalOp(context.Context, *PlusTimeIntervalOp, Datum, Datum) (Datum, error) EvalPlusTimeTZDateOp(context.Context, *PlusTimeTZDateOp, Datum, Datum) (Datum, error) @@ -355,6 +361,16 @@ func (op *ContainsJsonbOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) return e.EvalContainsJsonbOp(ctx, op, a, b) } +// Eval is part of the BinaryEvalOp interface. +func (op *CosDistanceVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { + return e.EvalCosDistanceVectorOp(ctx, op, a, b) +} + +// Eval is part of the BinaryEvalOp interface. +func (op *DistanceVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { + return e.EvalDistanceVectorOp(ctx, op, a, b) +} + // Eval is part of the BinaryEvalOp interface. func (op *DivDecimalIntOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { return e.EvalDivDecimalIntOp(ctx, op, a, b) @@ -575,6 +591,11 @@ func (op *MinusPGLSNOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Da return e.EvalMinusPGLSNOp(ctx, op, a, b) } +// Eval is part of the BinaryEvalOp interface. +func (op *MinusPGVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { + return e.EvalMinusPGVectorOp(ctx, op, a, b) +} + // Eval is part of the BinaryEvalOp interface. func (op *MinusTimeIntervalOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { return e.EvalMinusTimeIntervalOp(ctx, op, a, b) @@ -705,6 +726,16 @@ func (op *MultIntervalIntOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum return e.EvalMultIntervalIntOp(ctx, op, a, b) } +// Eval is part of the BinaryEvalOp interface. +func (op *MultPGVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { + return e.EvalMultPGVectorOp(ctx, op, a, b) +} + +// Eval is part of the BinaryEvalOp interface. +func (op *NegInnerProductVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { + return e.EvalNegInnerProductVectorOp(ctx, op, a, b) +} + // Eval is part of the BinaryEvalOp interface. func (op *OverlapsArrayOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { return e.EvalOverlapsArrayOp(ctx, op, a, b) @@ -815,6 +846,11 @@ func (op *PlusPGLSNDecimalOp) Eval(ctx context.Context, e OpEvaluator, a, b Datu return e.EvalPlusPGLSNDecimalOp(ctx, op, a, b) } +// Eval is part of the BinaryEvalOp interface. +func (op *PlusPGVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { + return e.EvalPlusPGVectorOp(ctx, op, a, b) +} + // Eval is part of the BinaryEvalOp interface. func (op *PlusTimeDateOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) { return e.EvalPlusTimeDateOp(ctx, op, a, b) diff --git a/pkg/sql/sem/tree/expr.go b/pkg/sql/sem/tree/expr.go index 4540cbea894b..11c8d1cac6b1 100644 --- a/pkg/sql/sem/tree/expr.go +++ b/pkg/sql/sem/tree/expr.go @@ -1050,6 +1050,7 @@ var binaryOpPrio = [...]int{ treebin.Bitxor: 6, treebin.Bitor: 7, treebin.Concat: 8, treebin.JSONFetchVal: 8, treebin.JSONFetchText: 8, treebin.JSONFetchValPath: 8, treebin.JSONFetchTextPath: 8, + treebin.Distance: 8, treebin.CosDistance: 8, treebin.NegInnerProduct: 8, } // binaryOpFullyAssoc indicates whether an operator is fully associative. @@ -1063,6 +1064,7 @@ var binaryOpFullyAssoc = [...]bool{ treebin.Bitxor: true, treebin.Bitor: true, treebin.Concat: true, treebin.JSONFetchVal: false, treebin.JSONFetchText: false, treebin.JSONFetchValPath: false, treebin.JSONFetchTextPath: false, + treebin.Distance: false, treebin.CosDistance: false, treebin.NegInnerProduct: false, } // BinaryExpr represents a binary value expression. diff --git a/pkg/sql/sem/tree/parse_string.go b/pkg/sql/sem/tree/parse_string.go index d745336f3e9e..68e9c5b31d56 100644 --- a/pkg/sql/sem/tree/parse_string.go +++ b/pkg/sql/sem/tree/parse_string.go @@ -69,6 +69,8 @@ func ParseAndRequireString( d, err = ParseDIntervalWithTypeMetadata(intervalStyle(ctx), s, itm) case types.PGLSNFamily: d, err = ParseDPGLSN(s) + case types.PGVectorFamily: + d, err = ParseDPGVector(s) case types.RefCursorFamily: d = NewDRefCursor(s) case types.Box2DFamily: diff --git a/pkg/sql/sem/tree/treebin/binary_operator.go b/pkg/sql/sem/tree/treebin/binary_operator.go index 332d4954d238..f9203e0a29fb 100644 --- a/pkg/sql/sem/tree/treebin/binary_operator.go +++ b/pkg/sql/sem/tree/treebin/binary_operator.go @@ -61,6 +61,9 @@ const ( JSONFetchValPath JSONFetchTextPath TSMatch + Distance + CosDistance + NegInnerProduct NumBinaryOperatorSymbols ) @@ -86,6 +89,9 @@ var binaryOpName = [...]string{ JSONFetchValPath: "#>", JSONFetchTextPath: "#>>", TSMatch: "@@", + Distance: "<->", + CosDistance: "<=>", + NegInnerProduct: "<#>", } // IsPadded returns whether the binary operator needs to be padded. diff --git a/pkg/sql/sem/tree/type_check.go b/pkg/sql/sem/tree/type_check.go index e5987341aece..cfe94fc640cb 100644 --- a/pkg/sql/sem/tree/type_check.go +++ b/pkg/sql/sem/tree/type_check.go @@ -2116,6 +2116,12 @@ func (d *DPGLSN) TypeCheck(_ context.Context, _ *SemaContext, _ *types.T) (Typed return d, nil } +// TypeCheck implements the Expr interface. It is implemented as an idempotent +// identity function for Datum. +func (d *DPGVector) TypeCheck(_ context.Context, _ *SemaContext, _ *types.T) (TypedExpr, error) { + return d, nil +} + // TypeCheck implements the Expr interface. It is implemented as an idempotent // identity function for Datum. func (d *DGeography) TypeCheck(_ context.Context, _ *SemaContext, _ *types.T) (TypedExpr, error) { diff --git a/pkg/sql/sem/tree/walk.go b/pkg/sql/sem/tree/walk.go index 4775d4926498..c7eed182d673 100644 --- a/pkg/sql/sem/tree/walk.go +++ b/pkg/sql/sem/tree/walk.go @@ -762,6 +762,9 @@ func (expr *DBox2D) Walk(_ Visitor) Expr { return expr } // Walk implements the Expr interface. func (expr *DPGLSN) Walk(_ Visitor) Expr { return expr } +// Walk implements the Expr interface. +func (expr *DPGVector) Walk(_ Visitor) Expr { return expr } + // Walk implements the Expr interface. func (expr *DGeography) Walk(_ Visitor) Expr { return expr } diff --git a/pkg/sql/stats/stats_test.go b/pkg/sql/stats/stats_test.go index cd290fc56f4e..cb383dd048d2 100644 --- a/pkg/sql/stats/stats_test.go +++ b/pkg/sql/stats/stats_test.go @@ -54,7 +54,7 @@ func TestStatsAnyType(t *testing.T) { // Casting random integers to REGTYPE might fail. continue loop } - if err := colinfo.ValidateColumnDefType(ctx, st.Version, typ); err == nil { + if err := colinfo.ValidateColumnDefType(ctx, st, typ); err == nil { break } } diff --git a/pkg/sql/types/oid.go b/pkg/sql/types/oid.go index d6320a967b2b..12a83a87872c 100644 --- a/pkg/sql/types/oid.go +++ b/pkg/sql/types/oid.go @@ -108,6 +108,7 @@ var OidToType = map[oid.Oid]*T{ oidext.T_geometry: Geometry, oidext.T_geography: Geography, oidext.T_box2d: Box2D, + oidext.T_pgvector: PGVector, } // oidToArrayOid maps scalar type Oids to their corresponding array type Oid. @@ -155,6 +156,7 @@ var oidToArrayOid = map[oid.Oid]oid.Oid{ oidext.T_geometry: oidext.T__geometry, oidext.T_geography: oidext.T__geography, oidext.T_box2d: oidext.T__box2d, + oidext.T_pgvector: oidext.T__pgvector, } // familyToOid maps each type family to a default OID value that is used when @@ -191,6 +193,7 @@ var familyToOid = map[Family]oid.Oid{ GeometryFamily: oidext.T_geometry, GeographyFamily: oidext.T_geography, Box2DFamily: oidext.T_box2d, + PGVectorFamily: oidext.T_pgvector, } // ArrayOids is a set of all oids which correspond to an array type. diff --git a/pkg/sql/types/types.go b/pkg/sql/types/types.go index da055d3416d6..4f5d21a4b68e 100644 --- a/pkg/sql/types/types.go +++ b/pkg/sql/types/types.go @@ -487,6 +487,15 @@ var ( }, } + // PGVector is the type representing a PGVector object. + PGVector = &T{ + InternalType: InternalType{ + Family: PGVectorFamily, + Oid: oidext.T_pgvector, + Locale: &emptyLocale, + }, + } + // Void is the type representing void. Void = &T{ InternalType: InternalType{ @@ -647,6 +656,10 @@ var ( PGLSNArray = &T{InternalType: InternalType{ Family: ArrayFamily, ArrayContents: PGLSN, Oid: oid.T__pg_lsn, Locale: &emptyLocale}} + // PGVectorArray is the type of an array value having PGVector-typed elements. + PGVectorArray = &T{InternalType: InternalType{ + Family: ArrayFamily, ArrayContents: PGVector, Oid: oidext.T__pgvector, Locale: &emptyLocale}} + // RefCursorArray is the type of an array value having REFCURSOR-typed elements. RefCursorArray = &T{InternalType: InternalType{ Family: ArrayFamily, ArrayContents: RefCursor, Oid: oid.T__refcursor, Locale: &emptyLocale}} @@ -1218,6 +1231,17 @@ func MakeLabeledTuple(contents []*T, labels []string) *T { }} } +// MakePGVector constructs a new instance of a VECTOR type (pg_vector) that has +// the given number of dimensions. +func MakePGVector(dims int32) *T { + return &T{InternalType: InternalType{ + Family: PGVectorFamily, + Oid: oidext.T_pgvector, + Width: dims, + Locale: &emptyLocale, + }} +} + // NewCompositeType constructs a new instance of a TupleFamily type with the // given field types and labels, and the given user-defined type OIDs. func NewCompositeType(typeOID, arrayTypeOID oid.Oid, contents []*T, labels []string) *T { @@ -1294,6 +1318,7 @@ func (t *T) Locale() string { // STRING : max # of characters // COLLATEDSTRING: max # of characters // BIT : max # of bits +// VECTOR : # of dimensions // // Width is always 0 for other types. func (t *T) Width() int32 { @@ -1344,7 +1369,7 @@ func (t *T) TypeModifier() int32 { // var header size. return width + 4 } - case BitFamily: + case BitFamily, PGVectorFamily: if width := t.Width(); width != 0 { return width } @@ -1507,6 +1532,7 @@ var familyNames = map[Family]redact.SafeString{ JsonFamily: "jsonb", OidFamily: "oid", PGLSNFamily: "pg_lsn", + PGVectorFamily: "vector", RefCursorFamily: "refcursor", StringFamily: "string", TimeFamily: "time", @@ -1788,6 +1814,8 @@ func (t *T) SQLStandardNameWithTypmod(haveTypmod bool, typmod int) string { } case PGLSNFamily: return "pg_lsn" + case PGVectorFamily: + return "vector" case RefCursorFamily: return "refcursor" case StringFamily, CollatedStringFamily: @@ -1999,6 +2027,11 @@ func (t *T) SQLString() string { // databases when this function is called to produce DDL like in SHOW // CREATE. return t.TypeMeta.Name.FQName(false /* explicitCatalog */) + case PGVectorFamily: + if t.Width() == 0 { + return "VECTOR" + } + return fmt.Sprintf("VECTOR(%d)", t.Width()) } return strings.ToUpper(t.Name()) } @@ -2046,7 +2079,7 @@ func (t *T) SQLStringForError() redact.RedactableString { IntervalFamily, StringFamily, BytesFamily, TimestampTZFamily, CollatedStringFamily, OidFamily, UnknownFamily, UuidFamily, INetFamily, TimeFamily, JsonFamily, TimeTZFamily, BitFamily, GeometryFamily, GeographyFamily, Box2DFamily, VoidFamily, EncodedKeyFamily, TSQueryFamily, - TSVectorFamily, AnyFamily, PGLSNFamily, RefCursorFamily: + TSVectorFamily, AnyFamily, PGLSNFamily, PGVectorFamily, RefCursorFamily: // These types do not contain other types, and do not require redaction. return redact.Sprint(redact.SafeString(t.SQLString())) } @@ -2816,6 +2849,8 @@ func IsValidArrayElementType(t *T) (valid bool, issueNum int) { return false, 90886 case TSVectorFamily: return false, 90886 + case PGVectorFamily: + return false, 121432 default: return true, 0 } diff --git a/pkg/sql/types/types.proto b/pkg/sql/types/types.proto index 84ab3a03fa9d..539dbc55f8cc 100644 --- a/pkg/sql/types/types.proto +++ b/pkg/sql/types/types.proto @@ -402,6 +402,10 @@ enum Family { // Oid : T_refcursor RefCursorFamily = 31; + // PGVectorFamily is a type family for the vector type, which is the + // type representing pgvector vectors. + PGVectorFamily = 32; + // AnyFamily is a special type family used during static analysis as a // wildcard type that matches any other type, including scalar, array, and // tuple types. Execution-time values should never have this type. As an diff --git a/pkg/util/encoding/encoding.go b/pkg/util/encoding/encoding.go index f0bae698e4d1..c3f62460a452 100644 --- a/pkg/util/encoding/encoding.go +++ b/pkg/util/encoding/encoding.go @@ -1780,6 +1780,7 @@ const ( // Special case JsonEmptyArray Type = 42 JsonEmptyArrayDesc Type = 43 + PGVector Type = 44 ) // typMap maps an encoded type byte to a decoded Type. It's got 256 slots, one @@ -2622,6 +2623,12 @@ func EncodeUntaggedFloatValue(appendTo []byte, f float64) []byte { return EncodeUint64Ascending(appendTo, math.Float64bits(f)) } +// EncodeUntaggedFloat32Value encodes a float32 value, appends it to the supplied buffer, +// and returns the final buffer. +func EncodeUntaggedFloat32Value(appendTo []byte, f float32) []byte { + return EncodeUint32Ascending(appendTo, math.Float32bits(f)) +} + // EncodeBytesValue encodes a byte array value with its value tag, appends it to // the supplied buffer, and returns the final buffer. func EncodeBytesValue(appendTo []byte, colID uint32, data []byte) []byte { @@ -2820,6 +2827,14 @@ func EncodeTSVectorValue(appendTo []byte, colID uint32, data []byte) []byte { return EncodeUntaggedBytesValue(appendTo, data) } +// EncodePGVectorValue encodes an already-byte-encoded PGVector value with no +// value tag but with a length prefix, appends it to the supplied buffer, and +// returns the final buffer. +func EncodePGVectorValue(appendTo []byte, colID uint32, data []byte) []byte { + appendTo = EncodeValueTag(appendTo, colID, PGVector) + return EncodeUntaggedBytesValue(appendTo, data) +} + // DecodeValueTag decodes a value encoded by EncodeValueTag, used as a prefix in // each of the other EncodeFooValue methods. // @@ -2922,6 +2937,16 @@ func DecodeUntaggedFloatValue(b []byte) (remaining []byte, f float64, err error) return b, math.Float64frombits(i), err } +// DecodeUntaggedFloat32Value decodes a value encoded by EncodeUntaggedFloat32Value. +func DecodeUntaggedFloat32Value(b []byte) (remaining []byte, f float32, err error) { + if len(b) < 4 { + return b, 0, fmt.Errorf("float32 value should be exactly 4 bytes: %d", len(b)) + } + var i uint32 + b, i, err = DecodeUint32Ascending(b) + return b, math.Float32frombits(i), err +} + // DecodeBytesValue decodes a value encoded by EncodeBytesValue. func DecodeBytesValue(b []byte) (remaining []byte, data []byte, err error) { b, err = decodeValueTypeAssert(b, Bytes) @@ -3208,7 +3233,7 @@ func PeekValueLengthWithOffsetsAndType(b []byte, dataOffset int, typ Type) (leng return dataOffset + n, err case Float: return dataOffset + floatValueEncodedLength, nil - case Bytes, Array, JSON, Geo, TSVector, TSQuery: + case Bytes, Array, JSON, Geo, TSVector, TSQuery, PGVector: _, n, i, err := DecodeNonsortingUvarint(b) return dataOffset + n + int(i), err case Box2D: diff --git a/pkg/util/encoding/type_string.go b/pkg/util/encoding/type_string.go index a6eaee2c3d5e..19174c451844 100644 --- a/pkg/util/encoding/type_string.go +++ b/pkg/util/encoding/type_string.go @@ -53,6 +53,7 @@ func _() { _ = x[JSONObjectDesc-41] _ = x[JsonEmptyArray-42] _ = x[JsonEmptyArrayDesc-43] + _ = x[PGVector-44] } func (i Type) String() string { @@ -145,6 +146,8 @@ func (i Type) String() string { return "JsonEmptyArray" case JsonEmptyArrayDesc: return "JsonEmptyArrayDesc" + case PGVector: + return "PGVector" default: return "Type(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/util/parquet/writer_bench_test.go b/pkg/util/parquet/writer_bench_test.go index 3e550681a341..6525da667b14 100644 --- a/pkg/util/parquet/writer_bench_test.go +++ b/pkg/util/parquet/writer_bench_test.go @@ -69,7 +69,7 @@ func getBenchmarkTypes() []*types.T { for _, typ := range randgen.SeedTypes { switch typ.Family() { case types.AnyFamily, types.TSQueryFamily, types.TSVectorFamily, - types.VoidFamily: + types.VoidFamily, types.PGVectorFamily: case types.TupleFamily: // Replace Any Tuple with Tuple of Ints with size 5. typs = append(typs, types.MakeTuple([]*types.T{ diff --git a/pkg/util/parquet/writer_test.go b/pkg/util/parquet/writer_test.go index 094667490e89..934a5e59e3bc 100644 --- a/pkg/util/parquet/writer_test.go +++ b/pkg/util/parquet/writer_test.go @@ -51,7 +51,7 @@ func newColSchema(numCols int) *colSchema { // that are not supported by the writer. func typSupported(typ *types.T) bool { switch typ.Family() { - case types.AnyFamily, types.TSQueryFamily, types.TSVectorFamily, types.VoidFamily: + case types.AnyFamily, types.TSQueryFamily, types.TSVectorFamily, types.PGVectorFamily, types.VoidFamily: return false case types.ArrayFamily: if typ.ArrayContents().Family() == types.ArrayFamily || typ.ArrayContents().Family() == types.TupleFamily { diff --git a/pkg/util/vector/BUILD.bazel b/pkg/util/vector/BUILD.bazel new file mode 100644 index 000000000000..f4aeca51ac43 --- /dev/null +++ b/pkg/util/vector/BUILD.bazel @@ -0,0 +1,23 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "vector", + srcs = ["vector.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/util/vector", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sql/pgwire/pgcode", + "//pkg/sql/pgwire/pgerror", + "//pkg/util/encoding", + ], +) + +go_test( + name = "vector_test", + srcs = ["vector_test.go"], + embed = [":vector"], + deps = [ + "//pkg/util/randutil", + "@com_github_stretchr_testify//assert", + ], +) diff --git a/pkg/util/vector/vector.go b/pkg/util/vector/vector.go new file mode 100644 index 000000000000..c5ed3085d57d --- /dev/null +++ b/pkg/util/vector/vector.go @@ -0,0 +1,288 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package vector + +import ( + "math" + "math/rand" + "strconv" + "strings" + + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/util/encoding" +) + +// MaxDim is the maximum number of dimensions a vector can have. +const MaxDim = 16000 + +// T is the type of a PGVector-like vector. +type T []float32 + +// ParseVector parses the Postgres string representation of a vector. +func ParseVector(input string) (T, error) { + input = strings.TrimSpace(input) + if !strings.HasPrefix(input, "[") || !strings.HasSuffix(input, "]") { + return T{}, pgerror.Newf(pgcode.InvalidTextRepresentation, + "malformed vector literal: Vector contents must start with \"[\" and"+ + " end with \"]\"") + } + + input = strings.TrimPrefix(input, "[") + input = strings.TrimSuffix(input, "]") + parts := strings.Split(input, ",") + + if len(parts) > MaxDim { + return T{}, pgerror.Newf(pgcode.ProgramLimitExceeded, "vector cannot have more than %d dimensions", MaxDim) + } + + vector := make([]float32, len(parts)) + for i, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + return T{}, pgerror.New(pgcode.InvalidTextRepresentation, "invalid input syntax for type vector: empty string") + } + + val, err := strconv.ParseFloat(part, 32) + if err != nil { + return T{}, pgerror.Newf(pgcode.InvalidTextRepresentation, "invalid input syntax for type vector: %s", part) + } + + if math.IsInf(val, 0) { + return T{}, pgerror.New(pgcode.DataException, "infinite value not allowed in vector") + } + if math.IsNaN(val) { + return T{}, pgerror.New(pgcode.DataException, "NaN not allowed in vector") + } + vector[i] = float32(val) + } + + return vector, nil +} + +// String implements the fmt.Stringer interface. +func (v T) String() string { + var sb strings.Builder + sb.WriteString("[") + // Pre-grow by a reasonable amount to avoid multiple allocations. + sb.Grow(len(v) * 8) + for i, v := range v { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatFloat(float64(v), 'g', -1, 32)) + } + sb.WriteString("]") + return sb.String() +} + +// Size returns the size of the vector in bytes. +func (v T) Size() uintptr { + return 24 + uintptr(cap(v))*4 +} + +// Compare returns -1 if v < v2, 1 if v > v2, and 0 if v == v2. +func (v T) Compare(v2 T) (int, error) { + n := min(len(v), len(v2)) + for i := 0; i < n; i++ { + if v[i] < v2[i] { + return -1, nil + } else if v[i] > v2[i] { + return 1, nil + } + } + if len(v) < len(v2) { + return -1, nil + } else if len(v) > len(v2) { + return 1, nil + } + return 0, nil +} + +// Encode encodes the vector as a byte array suitable for storing in KV. +func Encode(appendTo []byte, t T) ([]byte, error) { + appendTo = encoding.EncodeUint32Ascending(appendTo, uint32(len(t))) + for i := range t { + appendTo = encoding.EncodeUntaggedFloat32Value(appendTo, t[i]) + } + return appendTo, nil +} + +// Decode decodes the byte array into a vector. +func Decode(b []byte) (ret T, err error) { + var n uint32 + b, n, err = encoding.DecodeUint32Ascending(b) + if err != nil { + return nil, err + } + ret = make(T, n) + for i := range ret { + b, ret[i], err = encoding.DecodeUntaggedFloat32Value(b) + if err != nil { + return nil, err + } + } + return ret, nil +} + +func checkDims(t T, t2 T) error { + if len(t) != len(t2) { + return pgerror.Newf(pgcode.DataException, "different vector dimensions %d and %d", len(t), len(t2)) + } + return nil +} + +// L1Distance returns the L1 (Manhattan) distance between t and t2. +func L1Distance(t T, t2 T) (float64, error) { + if err := checkDims(t, t2); err != nil { + return 0, err + } + var distance float32 + for i := range len(t) { + diff := t[i] - t2[i] + distance += float32(math.Abs(float64(diff))) + } + return float64(distance), nil +} + +// L2Distance returns the Euclidean distance between t and t2. +func L2Distance(t T, t2 T) (float64, error) { + if err := checkDims(t, t2); err != nil { + return 0, err + } + var distance float32 + for i := range len(t) { + diff := t[i] - t2[i] + distance += diff * diff + } + // TODO(queries): check for overflow and validate intermediate result if needed. + return math.Sqrt(float64(distance)), nil +} + +// CosDistance returns the cosine distance between t and t2. +func CosDistance(t T, t2 T) (float64, error) { + if err := checkDims(t, t2); err != nil { + return 0, err + } + var distance, normA, normB float32 + for i := range len(t) { + distance += t[i] * t2[i] + normA += t[i] * t[i] + normB += t2[i] * t2[i] + } + // Use sqrt(a * b) over sqrt(a) * sqrt(b) + similarity := float64(distance) / math.Sqrt(float64(normA)*float64(normB)) + /* Keep in range */ + if similarity > 1 { + similarity = 1 + } else if similarity < -1 { + similarity = -1 + } + return 1 - similarity, nil +} + +// InnerProduct returns the negative inner product of t1 and t2. +func InnerProduct(t T, t2 T) (float64, error) { + if err := checkDims(t, t2); err != nil { + return 0, err + } + var distance float32 + for i := range len(t) { + distance += t[i] * t2[i] + } + return float64(distance), nil +} + +// NegInnerProduct returns the negative inner product of t1 and t2. +func NegInnerProduct(t T, t2 T) (float64, error) { + p, err := InnerProduct(t, t2) + return p * -1, err +} + +// Norm returns the L2 norm of t. +func Norm(t T) float64 { + var norm float64 + for i := range t { + norm += float64(t[i]) * float64(t[i]) + } + // TODO(queries): check for overflow and validate intermediate result if needed. + return math.Sqrt(norm) +} + +// Add returns t+t2, pointwise. +func Add(t T, t2 T) (T, error) { + if err := checkDims(t, t2); err != nil { + return nil, err + } + ret := make(T, len(t)) + for i := range t { + ret[i] = t[i] + t2[i] + } + for i := range ret { + if math.IsInf(float64(ret[i]), 0) { + return nil, pgerror.New(pgcode.NumericValueOutOfRange, "value out of range: overflow") + } + } + return ret, nil +} + +// Minus returns t-t2, pointwise. +func Minus(t T, t2 T) (T, error) { + if err := checkDims(t, t2); err != nil { + return nil, err + } + ret := make(T, len(t)) + for i := range t { + ret[i] = t[i] - t2[i] + } + for i := range ret { + if math.IsInf(float64(ret[i]), 0) { + return nil, pgerror.New(pgcode.NumericValueOutOfRange, "value out of range: overflow") + } + } + return ret, nil +} + +// Mult returns t*t2, pointwise. +func Mult(t T, t2 T) (T, error) { + if err := checkDims(t, t2); err != nil { + return nil, err + } + ret := make(T, len(t)) + for i := range t { + ret[i] = t[i] * t2[i] + } + for i := range ret { + if math.IsInf(float64(ret[i]), 0) { + return nil, pgerror.New(pgcode.NumericValueOutOfRange, "value out of range: overflow") + } + if ret[i] == 0 && !(t[i] == 0 || t2[i] == 0) { + return nil, pgerror.New(pgcode.NumericValueOutOfRange, "value out of range: underflow") + } + } + return ret, nil +} + +// Random returns a random vector. +func Random(rng *rand.Rand) T { + n := 1 + rng.Intn(1000) + v := make(T, n) + for i := range v { + for { + v[i] = math.Float32frombits(rng.Uint32()) + if math.IsNaN(float64(v[i])) || math.IsInf(float64(v[i]), 0) { + continue + } + break + } + } + return v +} diff --git a/pkg/util/vector/vector_test.go b/pkg/util/vector/vector_test.go new file mode 100644 index 000000000000..28fbf998b36b --- /dev/null +++ b/pkg/util/vector/vector_test.go @@ -0,0 +1,208 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package vector + +import ( + "math" + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/util/randutil" + "github.com/stretchr/testify/assert" +) + +func TestParseVector(t *testing.T) { + testCases := []struct { + input string + expected T + hasError bool + }{ + {input: "[1,2,3]", expected: T{1, 2, 3}, hasError: false}, + {input: "[1.0, 2.0, 3.0]", expected: T{1.0, 2.0, 3.0}, hasError: false}, + {input: "[1.0, 2.0, 3.0", expected: T{}, hasError: true}, + {input: "1.0, 2.0, 3.0]", expected: T{}, hasError: true}, + {input: "[1.0, 2.0, [3.0]]", expected: T{}, hasError: true}, + {input: "1.0, 2.0, 3.0]", expected: T{}, hasError: true}, + {input: "1.0, , 3.0]", expected: T{}, hasError: true}, + {input: "", expected: T{}, hasError: true}, + {input: "[]", expected: T{}, hasError: true}, + {input: "1.0, 2.0, 3.0", expected: T{}, hasError: true}, + } + + for _, tc := range testCases { + result, err := ParseVector(tc.input) + + if tc.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + // Test roundtripping through String(). + s := result.String() + result, err = ParseVector(s) + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + } + + // Test the maxdims error case. + var sb strings.Builder + sb.WriteString("[") + for i := 0; i < MaxDim; i++ { + sb.WriteString("1,") + } + sb.WriteString("1]") + _, err := ParseVector(sb.String()) + assert.Errorf(t, err, "vector cannot have more than %d dimensions", MaxDim) +} + +func TestRoundtripRandomPGVector(t *testing.T) { + rng, _ := randutil.NewTestRand() + for i := 0; i < 1000; i++ { + v := Random(rng) + encoded, err := Encode(nil, v) + assert.NoError(t, err) + roundtripped, err := Decode(encoded) + assert.NoError(t, err) + assert.Equal(t, v.String(), roundtripped.String()) + reEncoded, err := Encode(nil, roundtripped) + assert.NoError(t, err) + assert.Equal(t, encoded, reEncoded) + } +} + +func TestDistances(t *testing.T) { + // Test L1, L2, Cosine distance. + testCases := []struct { + v1 T + v2 T + l1 float64 + l2 float64 + cos float64 + err bool + }{ + {v1: T{1, 2, 3}, v2: T{4, 5, 6}, l1: 9, l2: 5.196152422, cos: 0.02536815, err: false}, + {v1: T{-1, -2, -3}, v2: T{-4, -5, -6}, l1: 9, l2: 5.196152422, cos: 0.02536815, err: false}, + {v1: T{0, 0, 0}, v2: T{0, 0, 0}, l1: 0, l2: 0, cos: math.NaN(), err: false}, + {v1: T{1, 2, 3}, v2: T{1, 2, 3}, l1: 0, l2: 0, cos: 0, err: false}, + {v1: T{1, 2, 3}, v2: T{1, 2, 4}, l1: 1, l2: 1, cos: 0.008539, err: false}, + // Different vector sizes errors. + {v1: T{1, 2, 3}, v2: T{4, 5}, err: true}, + } + + for _, tc := range testCases { + l1, l1Err := L1Distance(tc.v1, tc.v2) + l2, l2Err := L2Distance(tc.v1, tc.v2) + cos, cosErr := CosDistance(tc.v1, tc.v2) + + if tc.err { + assert.Error(t, l1Err) + assert.Error(t, l2Err) + assert.Error(t, cosErr) + } else { + assert.NoError(t, l1Err) + assert.NoError(t, l2Err) + assert.NoError(t, cosErr) + assert.InDelta(t, tc.l1, l1, 0.000001) + assert.InDelta(t, tc.l2, l2, 0.000001) + assert.InDelta(t, tc.cos, cos, 0.000001) + } + } +} + +func TestProducts(t *testing.T) { + // Test inner product and negative inner product + testCases := []struct { + v1 T + v2 T + ip float64 + negIp float64 + err bool + }{ + {v1: T{1, 2, 3}, v2: T{4, 5, 6}, ip: 32, negIp: -32, err: false}, + {v1: T{-1, -2, -3}, v2: T{-4, -5, -6}, ip: 32, negIp: -32, err: false}, + {v1: T{0, 0, 0}, v2: T{0, 0, 0}, ip: 0, negIp: 0, err: false}, + {v1: T{1, 2, 3}, v2: T{1, 2, 3}, ip: 14, negIp: -14, err: false}, + {v1: T{1, 2, 3}, v2: T{1, 2, 4}, ip: 17, negIp: -17, err: false}, + // Different vector sizes errors. + {v1: T{1, 2, 3}, v2: T{4, 5}, err: true}, + } + + for _, tc := range testCases { + ip, ipErr := InnerProduct(tc.v1, tc.v2) + negIp, negIpErr := NegInnerProduct(tc.v1, tc.v2) + + if tc.err { + assert.Error(t, ipErr) + assert.Error(t, negIpErr) + } else { + assert.NoError(t, ipErr) + assert.NoError(t, negIpErr) + assert.InDelta(t, tc.ip, ip, 0.000001) + assert.InDelta(t, tc.negIp, negIp, 0.000001) + } + } +} + +func TestNorm(t *testing.T) { + testCases := []struct { + v T + norm float64 + }{ + {v: T{1, 2, 3}, norm: 3.7416573867739413}, + {v: T{0, 0, 0}, norm: 0}, + {v: T{-1, -2, -3}, norm: 3.7416573867739413}, + } + + for _, tc := range testCases { + norm := Norm(tc.v) + assert.InDelta(t, tc.norm, norm, 0.000001) + } +} + +func TestPointwiseOps(t *testing.T) { + // Test L1, L2, Cosine distance. + testCases := []struct { + v1 T + v2 T + add T + minus T + mult T + err bool + }{ + {v1: T{1, 2, 3}, v2: T{4, 5, 6}, add: T{5, 7, 9}, minus: T{-3, -3, -3}, mult: T{4, 10, 18}, err: false}, + {v1: T{-1, -2, -3}, v2: T{-4, -5, -6}, add: T{-5, -7, -9}, minus: T{3, 3, 3}, mult: T{4, 10, 18}, err: false}, + {v1: T{0, 0, 0}, v2: T{0, 0, 0}, add: T{0, 0, 0}, minus: T{0, 0, 0}, mult: T{0, 0, 0}, err: false}, + {v1: T{1, 2, 3}, v2: T{1, 2, 3}, add: T{2, 4, 6}, minus: T{0, 0, 0}, mult: T{1, 4, 9}, err: false}, + {v1: T{1, 2, 3}, v2: T{1, 2, 4}, add: T{2, 4, 7}, minus: T{0, 0, -1}, mult: T{1, 4, 12}, err: false}, + // Different vector sizes errors. + {v1: T{1, 2, 3}, v2: T{4, 5}, err: true}, + } + + for _, tc := range testCases { + add, addErr := Add(tc.v1, tc.v2) + minus, minusErr := Minus(tc.v1, tc.v2) + mult, multErr := Mult(tc.v1, tc.v2) + + if tc.err { + assert.Error(t, addErr) + assert.Error(t, minusErr) + assert.Error(t, multErr) + } else { + assert.NoError(t, addErr) + assert.NoError(t, minusErr) + assert.NoError(t, multErr) + assert.Equal(t, tc.add, add) + assert.Equal(t, tc.minus, minus) + assert.Equal(t, tc.mult, mult) + } + } +} diff --git a/pkg/workload/rand/rand.go b/pkg/workload/rand/rand.go index f98b0391410c..f583865a82b6 100644 --- a/pkg/workload/rand/rand.go +++ b/pkg/workload/rand/rand.go @@ -436,6 +436,8 @@ func DatumToGoSQL(d tree.Datum) (interface{}, error) { return d.String(), nil case *tree.DTSVector: return d.String(), nil + case *tree.DPGVector: + return d.String(), nil } return nil, errors.Errorf("unhandled datum type: %s", reflect.TypeOf(d)) }