From 3e5a520bc6745332b32cdd69de0f5f25fe4b429e Mon Sep 17 00:00:00 2001
From: Jordan Lewis
Date: Sat, 16 Mar 2024 13:26:07 -0600
Subject: [PATCH] sql: implement pgvector datatype and evaluation
Release note (sql change): implement pgvector encoding, decoding, and
operators, without index acceleration.
---
docs/generated/sql/bnf/stmt_block.bnf | 12 +-
docs/generated/sql/functions.md | 19 ++
docs/generated/sql/operators.md | 23 ++
pkg/BUILD.bazel | 5 +
pkg/ccl/changefeedccl/avro_test.go | 3 +
pkg/ccl/changefeedccl/encoder_test.go | 2 +-
.../logic_test/mixed_version_pgvector | 63 ++++
.../logictestccl/testdata/logic_test/vector | 103 +++++++
.../tests/3node-tenant/generated_test.go | 7 +
.../cockroach-go-testserver-23.2/BUILD.bazel | 30 ++
.../generated_test.go | 86 ++++++
.../tests/fakedist-disk/BUILD.bazel | 2 +-
.../tests/fakedist-disk/generated_test.go | 7 +
.../tests/fakedist-vec-off/BUILD.bazel | 2 +-
.../tests/fakedist-vec-off/generated_test.go | 7 +
.../logictestccl/tests/fakedist/BUILD.bazel | 2 +-
.../tests/fakedist/generated_test.go | 7 +
.../local-legacy-schema-changer/BUILD.bazel | 2 +-
.../generated_test.go | 7 +
.../tests/local-read-committed/BUILD.bazel | 2 +-
.../local-read-committed/generated_test.go | 7 +
.../tests/local-vec-off/BUILD.bazel | 2 +-
.../tests/local-vec-off/generated_test.go | 7 +
pkg/ccl/logictestccl/tests/local/BUILD.bazel | 2 +-
.../tests/local/generated_test.go | 7 +
pkg/sql/alter_column_type.go | 2 +-
pkg/sql/catalog/colinfo/BUILD.bazel | 2 +
pkg/sql/catalog/colinfo/col_type_info.go | 19 +-
.../catalog/colinfo/column_type_properties.go | 1 +
pkg/sql/catalog/tabledesc/table.go | 2 +-
pkg/sql/exec_util.go | 1 +
pkg/sql/logictest/BUILD.bazel | 1 +
.../logictest/testdata/logic_test/grant_table | 30 ++
.../logictest/testdata/logic_test/pg_catalog | 10 +
.../logictest/testdata/logic_test/vectoross | 5 +
.../tests/fakedist-disk/generated_test.go | 7 +
.../tests/fakedist-vec-off/generated_test.go | 7 +
.../tests/fakedist/generated_test.go | 7 +
.../generated_test.go | 7 +
.../tests/local-vec-off/generated_test.go | 7 +
.../logictest/tests/local/generated_test.go | 7 +
pkg/sql/oidext/oidext.go | 4 +
pkg/sql/opt/operator.go | 37 +--
pkg/sql/opt/ops/scalar.opt | 24 ++
pkg/sql/opt/optbuilder/scalar.go | 6 +
pkg/sql/parser/BUILD.bazel | 1 +
pkg/sql/parser/sql.y | 45 ++-
pkg/sql/parser/testdata/create_table | 16 +
pkg/sql/parser/testdata/select_exprs | 40 +++
pkg/sql/pg_catalog.go | 1 +
pkg/sql/pgwire/pgwirebase/BUILD.bazel | 1 +
pkg/sql/pgwire/pgwirebase/encoding.go | 7 +
pkg/sql/pgwire/types.go | 4 +
pkg/sql/randgen/BUILD.bazel | 1 +
pkg/sql/randgen/datum.go | 3 +
pkg/sql/randgen/type.go | 8 +-
pkg/sql/rowenc/encoded_datum.go | 6 +-
pkg/sql/rowenc/encoded_datum_test.go | 2 +-
pkg/sql/rowenc/keyside/BUILD.bazel | 1 +
pkg/sql/rowenc/keyside/keyside_test.go | 9 +-
pkg/sql/rowenc/valueside/BUILD.bazel | 1 +
pkg/sql/rowenc/valueside/array.go | 7 +
pkg/sql/rowenc/valueside/decode.go | 11 +
pkg/sql/rowenc/valueside/encode.go | 7 +
pkg/sql/rowenc/valueside/legacy.go | 20 ++
pkg/sql/scanner/scan.go | 20 ++
.../comparator_generated_test.go | 5 +
pkg/sql/sem/builtins/BUILD.bazel | 2 +
.../builtins/builtinconstants/constants.go | 1 +
pkg/sql/sem/builtins/fixed_oids.go | 17 ++
pkg/sql/sem/builtins/pgvector_builtins.go | 142 +++++++++
pkg/sql/sem/cast/cast.go | 14 +
pkg/sql/sem/cast/cast_map.go | 62 ++--
pkg/sql/sem/eval/BUILD.bazel | 1 +
pkg/sql/sem/eval/binary_op.go | 64 ++++
pkg/sql/sem/eval/cast.go | 44 +++
pkg/sql/sem/eval/testdata/eval/vector | 143 +++++++++
pkg/sql/sem/eval/unsupported_types.go | 26 +-
pkg/sql/sem/tree/BUILD.bazel | 1 +
pkg/sql/sem/tree/constant.go | 2 +
pkg/sql/sem/tree/datum.go | 118 ++++++-
pkg/sql/sem/tree/eval.go | 53 ++++
pkg/sql/sem/tree/eval_binary_ops.go | 15 +
pkg/sql/sem/tree/eval_expr_generated.go | 5 +
pkg/sql/sem/tree/eval_op_generated.go | 36 +++
pkg/sql/sem/tree/expr.go | 2 +
pkg/sql/sem/tree/parse_string.go | 2 +
pkg/sql/sem/tree/treebin/binary_operator.go | 6 +
pkg/sql/sem/tree/type_check.go | 6 +
pkg/sql/sem/tree/walk.go | 3 +
pkg/sql/stats/stats_test.go | 2 +-
pkg/sql/types/oid.go | 3 +
pkg/sql/types/types.go | 39 ++-
pkg/sql/types/types.proto | 4 +
pkg/util/encoding/encoding.go | 27 +-
pkg/util/encoding/type_string.go | 3 +
pkg/util/parquet/writer_bench_test.go | 2 +-
pkg/util/parquet/writer_test.go | 2 +-
pkg/util/vector/BUILD.bazel | 23 ++
pkg/util/vector/vector.go | 288 ++++++++++++++++++
pkg/util/vector/vector_test.go | 208 +++++++++++++
pkg/workload/rand/rand.go | 2 +
102 files changed, 2100 insertions(+), 86 deletions(-)
create mode 100644 pkg/ccl/logictestccl/testdata/logic_test/mixed_version_pgvector
create mode 100644 pkg/ccl/logictestccl/testdata/logic_test/vector
create mode 100644 pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/BUILD.bazel
create mode 100644 pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/generated_test.go
create mode 100644 pkg/sql/logictest/testdata/logic_test/vectoross
create mode 100644 pkg/sql/sem/builtins/pgvector_builtins.go
create mode 100644 pkg/sql/sem/eval/testdata/eval/vector
create mode 100644 pkg/util/vector/BUILD.bazel
create mode 100644 pkg/util/vector/vector.go
create mode 100644 pkg/util/vector/vector_test.go
diff --git a/docs/generated/sql/bnf/stmt_block.bnf b/docs/generated/sql/bnf/stmt_block.bnf
index 82cecd617097..b2955c979faf 100644
--- a/docs/generated/sql/bnf/stmt_block.bnf
+++ b/docs/generated/sql/bnf/stmt_block.bnf
@@ -1568,6 +1568,7 @@ col_name_keyword ::=
| 'VALUES'
| 'VARBIT'
| 'VARCHAR'
+ | 'VECTOR'
| 'VIRTUAL'
| 'WORK'
@@ -1755,7 +1756,7 @@ backup_options_list ::=
( backup_options ) ( ( ',' backup_options ) )*
a_expr ::=
- ( c_expr | '+' a_expr | '-' a_expr | '~' a_expr | 'SQRT' a_expr | 'CBRT' a_expr | qual_op a_expr | 'NOT' a_expr | 'NOT' a_expr | row 'OVERLAPS' row | 'DEFAULT' ) ( ( 'TYPECAST' cast_target | 'TYPEANNOTATE' typename | 'COLLATE' collation_name | 'AT' 'TIME' 'ZONE' a_expr | '+' a_expr | '-' a_expr | '*' a_expr | '/' a_expr | 'FLOORDIV' a_expr | '%' a_expr | '^' a_expr | '#' a_expr | '&' a_expr | '|' a_expr | '<' a_expr | '>' a_expr | '?' a_expr | 'JSON_SOME_EXISTS' a_expr | 'JSON_ALL_EXISTS' a_expr | 'CONTAINS' a_expr | 'CONTAINED_BY' a_expr | '=' a_expr | 'CONCAT' a_expr | 'LSHIFT' a_expr | 'RSHIFT' a_expr | 'FETCHVAL' a_expr | 'FETCHTEXT' a_expr | 'FETCHVAL_PATH' a_expr | 'FETCHTEXT_PATH' a_expr | 'REMOVE_PATH' a_expr | 'INET_CONTAINED_BY_OR_EQUALS' a_expr | 'AND_AND' a_expr | 'AT_AT' a_expr | 'INET_CONTAINS_OR_EQUALS' a_expr | 'LESS_EQUALS' a_expr | 'GREATER_EQUALS' a_expr | 'NOT_EQUALS' a_expr | qual_op a_expr | 'AND' a_expr | 'OR' a_expr | 'LIKE' a_expr | 'LIKE' a_expr 'ESCAPE' a_expr | 'NOT' 'LIKE' a_expr | 'NOT' 'LIKE' a_expr 'ESCAPE' a_expr | 'ILIKE' a_expr | 'ILIKE' a_expr 'ESCAPE' a_expr | 'NOT' 'ILIKE' a_expr | 'NOT' 'ILIKE' a_expr 'ESCAPE' a_expr | 'SIMILAR' 'TO' a_expr | 'SIMILAR' 'TO' a_expr 'ESCAPE' a_expr | 'NOT' 'SIMILAR' 'TO' a_expr | 'NOT' 'SIMILAR' 'TO' a_expr 'ESCAPE' a_expr | '~' a_expr | 'NOT_REGMATCH' a_expr | 'REGIMATCH' a_expr | 'NOT_REGIMATCH' a_expr | 'IS' 'NAN' | 'IS' 'NOT' 'NAN' | 'IS' 'NULL' | 'ISNULL' | 'IS' 'NOT' 'NULL' | 'NOTNULL' | 'IS' 'TRUE' | 'IS' 'NOT' 'TRUE' | 'IS' 'FALSE' | 'IS' 'NOT' 'FALSE' | 'IS' 'UNKNOWN' | 'IS' 'NOT' 'UNKNOWN' | 'IS' 'DISTINCT' 'FROM' a_expr | 'IS' 'NOT' 'DISTINCT' 'FROM' a_expr | 'IS' 'OF' '(' type_list ')' | 'IS' 'NOT' 'OF' '(' type_list ')' | 'BETWEEN' opt_asymmetric b_expr 'AND' a_expr | 'NOT' 'BETWEEN' opt_asymmetric b_expr 'AND' a_expr | 'BETWEEN' 'SYMMETRIC' b_expr 'AND' a_expr | 'NOT' 'BETWEEN' 'SYMMETRIC' b_expr 'AND' a_expr | 'IN' in_expr | 'NOT' 'IN' in_expr | subquery_op sub_type a_expr ) )*
+ ( c_expr | '+' a_expr | '-' a_expr | '~' a_expr | 'SQRT' a_expr | 'CBRT' a_expr | qual_op a_expr | 'NOT' a_expr | 'NOT' a_expr | row 'OVERLAPS' row | 'DEFAULT' ) ( ( 'TYPECAST' cast_target | 'TYPEANNOTATE' typename | 'COLLATE' collation_name | 'AT' 'TIME' 'ZONE' a_expr | '+' a_expr | '-' a_expr | '*' a_expr | '/' a_expr | 'FLOORDIV' a_expr | '%' a_expr | '^' a_expr | '#' a_expr | '&' a_expr | '|' a_expr | '<' a_expr | '>' a_expr | '?' a_expr | 'JSON_SOME_EXISTS' a_expr | 'JSON_ALL_EXISTS' a_expr | 'CONTAINS' a_expr | 'CONTAINED_BY' a_expr | '=' a_expr | 'CONCAT' a_expr | 'LSHIFT' a_expr | 'RSHIFT' a_expr | 'FETCHVAL' a_expr | 'FETCHTEXT' a_expr | 'FETCHVAL_PATH' a_expr | 'FETCHTEXT_PATH' a_expr | 'REMOVE_PATH' a_expr | 'INET_CONTAINED_BY_OR_EQUALS' a_expr | 'AND_AND' a_expr | 'AT_AT' a_expr | 'DISTANCE' a_expr | 'COS_DISTANCE' a_expr | 'NEG_INNER_PRODUCT' a_expr | 'INET_CONTAINS_OR_EQUALS' a_expr | 'LESS_EQUALS' a_expr | 'GREATER_EQUALS' a_expr | 'NOT_EQUALS' a_expr | qual_op a_expr | 'AND' a_expr | 'OR' a_expr | 'LIKE' a_expr | 'LIKE' a_expr 'ESCAPE' a_expr | 'NOT' 'LIKE' a_expr | 'NOT' 'LIKE' a_expr 'ESCAPE' a_expr | 'ILIKE' a_expr | 'ILIKE' a_expr 'ESCAPE' a_expr | 'NOT' 'ILIKE' a_expr | 'NOT' 'ILIKE' a_expr 'ESCAPE' a_expr | 'SIMILAR' 'TO' a_expr | 'SIMILAR' 'TO' a_expr 'ESCAPE' a_expr | 'NOT' 'SIMILAR' 'TO' a_expr | 'NOT' 'SIMILAR' 'TO' a_expr 'ESCAPE' a_expr | '~' a_expr | 'NOT_REGMATCH' a_expr | 'REGIMATCH' a_expr | 'NOT_REGIMATCH' a_expr | 'IS' 'NAN' | 'IS' 'NOT' 'NAN' | 'IS' 'NULL' | 'ISNULL' | 'IS' 'NOT' 'NULL' | 'NOTNULL' | 'IS' 'TRUE' | 'IS' 'NOT' 'TRUE' | 'IS' 'FALSE' | 'IS' 'NOT' 'FALSE' | 'IS' 'UNKNOWN' | 'IS' 'NOT' 'UNKNOWN' | 'IS' 'DISTINCT' 'FROM' a_expr | 'IS' 'NOT' 'DISTINCT' 'FROM' a_expr | 'IS' 'OF' '(' type_list ')' | 'IS' 'NOT' 'OF' '(' type_list ')' | 'BETWEEN' opt_asymmetric b_expr 'AND' a_expr | 'NOT' 'BETWEEN' opt_asymmetric b_expr 'AND' a_expr | 'BETWEEN' 'SYMMETRIC' b_expr 'AND' a_expr | 'NOT' 'BETWEEN' 'SYMMETRIC' b_expr 'AND' a_expr | 'IN' in_expr | 'NOT' 'IN' in_expr | subquery_op sub_type a_expr ) )*
for_schedules_clause ::=
'FOR' 'SCHEDULES' select_stmt
@@ -3139,6 +3140,9 @@ all_op ::=
| 'NOT_REGIMATCH'
| 'AND_AND'
| 'AT_AT'
+ | 'DISTANCE'
+ | 'COS_DISTANCE'
+ | 'NEG_INNER_PRODUCT'
| '~'
| 'SQRT'
| 'CBRT'
@@ -3397,6 +3401,7 @@ const_typename ::=
| character_with_length
| const_datetime
| const_geo
+ | const_vector
interval_type ::=
'INTERVAL'
@@ -4160,6 +4165,7 @@ bare_label_keywords ::=
| 'VARCHAR'
| 'VARIABLES'
| 'VARIADIC'
+ | 'VECTOR'
| 'VERIFY_BACKUP_TABLE_DATA'
| 'VIEW'
| 'VIEWACTIVITY'
@@ -4313,6 +4319,10 @@ const_geo ::=
| 'GEOMETRY' '(' geo_shape_type ',' signed_iconst ')'
| 'GEOGRAPHY' '(' geo_shape_type ',' signed_iconst ')'
+const_vector ::=
+ 'VECTOR'
+ | 'VECTOR' '(' iconst32 ')'
+
interval_qualifier ::=
'YEAR'
| 'MONTH'
diff --git a/docs/generated/sql/functions.md b/docs/generated/sql/functions.md
index df9e17317f33..763f1c063d10 100644
--- a/docs/generated/sql/functions.md
+++ b/docs/generated/sql/functions.md
@@ -1279,6 +1279,25 @@ the locality flag on node startup. Returns an error if no region is set.
Stable |
+### PGVector functions
+
+
+Function → Returns | Description | Volatility |
+
+cosine_distance(v1: vector, v2: vector) → float | Returns the cosine distance between the two vectors.
+ | Immutable |
+inner_product(v1: vector, v2: vector) → float | Returns the inner product between the two vectors.
+ | Immutable |
+l1_distance(v1: vector, v2: vector) → float | Returns the Manhattan distance between the two vectors.
+ | Immutable |
+l2_distance(v1: vector, v2: vector) → float | Returns the Euclidean distance between the two vectors.
+ | Immutable |
+vector_dims(vector: vector) → int | Returns the number of the dimensions in the vector.
+ | Immutable |
+vector_norm(vector: vector) → float | Returns the Euclidean norm of the vector.
+ | Immutable |
+
+
### STRING[] functions
+
+<#> | Return |
+
+vector <#> vector | float |
+
+
+<-> | Return |
+
+vector <-> vector | float |
+
+<=> | Return |
+
+vector <=> vector | float |
>> | Return |
@@ -412,6 +433,7 @@
tuple IN tuple | bool |
uuid IN tuple | bool |
varbit IN tuple | bool |
+vector IN tuple | bool |
IS NOT DISTINCT FROM | Return |
@@ -475,6 +497,7 @@
uuid IS NOT DISTINCT FROM uuid | bool |
uuid[] IS NOT DISTINCT FROM uuid[] | bool |
varbit IS NOT DISTINCT FROM varbit | bool |
+vector IS NOT DISTINCT FROM vector | bool |
void IS NOT DISTINCT FROM unknown | bool |
diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel
index 14d74fb8aef9..2430998eab01 100644
--- a/pkg/BUILD.bazel
+++ b/pkg/BUILD.bazel
@@ -54,6 +54,7 @@ ALL_TESTS = [
"//pkg/ccl/logictestccl/tests/3node-tenant-multiregion:3node-tenant-multiregion_test",
"//pkg/ccl/logictestccl/tests/3node-tenant:3node-tenant_test",
"//pkg/ccl/logictestccl/tests/5node:5node_test",
+ "//pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2:cockroach-go-testserver-23_2_test",
"//pkg/ccl/logictestccl/tests/fakedist-disk:fakedist-disk_test",
"//pkg/ccl/logictestccl/tests/fakedist-vec-off:fakedist-vec-off_test",
"//pkg/ccl/logictestccl/tests/fakedist:fakedist_test",
@@ -754,6 +755,7 @@ ALL_TESTS = [
"//pkg/util/ulid:ulid_test",
"//pkg/util/unique:unique_test",
"//pkg/util/uuid:uuid_test",
+ "//pkg/util/vector:vector_test",
"//pkg/util/version:version_test",
"//pkg/util:util_test",
"//pkg/workload/bank:bank_test",
@@ -899,6 +901,7 @@ GO_TARGETS = [
"//pkg/ccl/logictestccl/tests/3node-tenant-multiregion:3node-tenant-multiregion_test",
"//pkg/ccl/logictestccl/tests/3node-tenant:3node-tenant_test",
"//pkg/ccl/logictestccl/tests/5node:5node_test",
+ "//pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2:cockroach-go-testserver-23_2_test",
"//pkg/ccl/logictestccl/tests/fakedist-disk:fakedist-disk_test",
"//pkg/ccl/logictestccl/tests/fakedist-vec-off:fakedist-vec-off_test",
"//pkg/ccl/logictestccl/tests/fakedist:fakedist_test",
@@ -2604,6 +2607,8 @@ GO_TARGETS = [
"//pkg/util/unique:unique_test",
"//pkg/util/uuid:uuid",
"//pkg/util/uuid:uuid_test",
+ "//pkg/util/vector:vector",
+ "//pkg/util/vector:vector_test",
"//pkg/util/version:version",
"//pkg/util/version:version_test",
"//pkg/util:util",
diff --git a/pkg/ccl/changefeedccl/avro_test.go b/pkg/ccl/changefeedccl/avro_test.go
index a7b87e0290cf..c852b653b97f 100644
--- a/pkg/ccl/changefeedccl/avro_test.go
+++ b/pkg/ccl/changefeedccl/avro_test.go
@@ -320,6 +320,9 @@ func TestAvroSchema(t *testing.T) {
case types.AnyFamily, types.OidFamily, types.TupleFamily:
// These aren't expected to be needed for changefeeds.
return true
+ case types.PGVectorFamily:
+ // We don't support PGVector in Avro yet.
+ return true
case types.ArrayFamily:
if !randgen.IsAllowedForArray(typ.ArrayContents()) {
return true
diff --git a/pkg/ccl/changefeedccl/encoder_test.go b/pkg/ccl/changefeedccl/encoder_test.go
index eddb0d177a78..e0a15735f5b4 100644
--- a/pkg/ccl/changefeedccl/encoder_test.go
+++ b/pkg/ccl/changefeedccl/encoder_test.go
@@ -1167,7 +1167,7 @@ func TestJsonRountrip(t *testing.T) {
switch typ {
case types.Jsonb:
// Unsupported by sql/catalog/colinfo
- case types.TSQuery, types.TSVector:
+ case types.TSQuery, types.TSVector, types.PGVector:
// Unsupported by pkg/sql/parser
default:
if arrayTyp.InternalType.ArrayContents == typ {
diff --git a/pkg/ccl/logictestccl/testdata/logic_test/mixed_version_pgvector b/pkg/ccl/logictestccl/testdata/logic_test/mixed_version_pgvector
new file mode 100644
index 000000000000..30d74d9dc1dd
--- /dev/null
+++ b/pkg/ccl/logictestccl/testdata/logic_test/mixed_version_pgvector
@@ -0,0 +1,63 @@
+# LogicTest: cockroach-go-testserver-23.2
+
+# Verify that all nodes are running the previous version.
+
+query T nodeidx=0
+SELECT crdb_internal.node_executable_version()
+----
+23.2
+
+query T nodeidx=1
+SELECT crdb_internal.node_executable_version()
+----
+23.2
+
+query T nodeidx=2
+SELECT crdb_internal.node_executable_version()
+----
+23.2
+
+statement error syntax error
+CREATE TABLE t (v VECTOR(1))
+
+# Upgrade one node to 24.2
+
+upgrade 0
+
+# Verify that node index 0 is now running 24.2 binary.
+
+query T nodeidx=0
+SELECT crdb_internal.release_series(crdb_internal.node_executable_version())
+----
+24.2
+
+statement error pg_vector not supported until version 24.2
+CREATE TABLE t (v VECTOR(1))
+
+upgrade 1
+
+upgrade 2
+
+statement ok
+SET CLUSTER SETTING version = crdb_internal.node_executable_version();
+
+query T nodeidx=1
+SELECT crdb_internal.release_series(crdb_internal.node_executable_version())
+----
+24.2
+
+query T nodeidx=2
+SELECT crdb_internal.release_series(crdb_internal.node_executable_version())
+----
+24.2
+
+query B retry
+SELECT crdb_internal.is_at_least_version('24.1-02')
+----
+true
+
+# Note: the following statement would succeed if there cluster had an enterprise
+# license, but the mixed version logic framework doesn't support adding one.
+# This is tested normally in the vector ccl logic test.
+statement error pgcode XXC02 use of vector datatype requires an enterprise license
+CREATE TABLE t (v VECTOR(1))
diff --git a/pkg/ccl/logictestccl/testdata/logic_test/vector b/pkg/ccl/logictestccl/testdata/logic_test/vector
new file mode 100644
index 000000000000..153e6eeb1ca7
--- /dev/null
+++ b/pkg/ccl/logictestccl/testdata/logic_test/vector
@@ -0,0 +1,103 @@
+# LogicTest: !local-mixed-23.2
+
+query F
+SELECT '[1,2,3]'::vector <-> '[4,5,6]'::vector
+----
+5.196152422706632
+
+statement error pgcode 42601 dimensions for type vector must be at least 1
+CREATE TABLE v (v vector(0))
+
+statement error pgcode 42601 dimensions for type vector cannot exceed 16000
+CREATE TABLE v (v vector(16001))
+
+statement error column v is of type vector and thus is not indexable
+CREATE TABLE v (v vector(2) PRIMARY KEY)
+
+statement ok
+CREATE TABLE v (v vector);
+CREATE TABLE v2 (v vector(2))
+
+statement ok
+INSERT INTO v VALUES('[1]'), ('[2,3]')
+
+query T rowsort
+SELECT * FROM v
+----
+[1]
+[2,3]
+
+query T
+SELECT * FROM v WHERE v = '[1,2]'
+----
+
+query error pgcode 22000 different vector dimensions 2 and 1
+SELECT l2_distance('[1,2]', '[1]')
+
+statement error pgcode 22000 expected 2 dimensions, not 1
+INSERT INTO v2 VALUES('[1]'), ('[2,3]')
+
+statement ok
+INSERT INTO v2 VALUES('[1,2]'), ('[3,4]')
+
+query T rowsort
+SELECT * FROM v2
+----
+[1,2]
+[3,4]
+
+query T
+SELECT * FROM v2 WHERE v = '[1,2]'
+----
+[1,2]
+
+query TT
+SELECT '[1,2]'::text::vector, ARRAY[1,2]::vector
+----
+[1,2] [1,2]
+
+query error pgcode 22004 array must not contain nulls
+SELECT ARRAY[1,2,null]::vector
+
+query error pgcode 22000 expected 1 dimensions, not 2
+select '[3,1]'::vector(1)
+
+query error pgcode 22000 NaN not allowed in vector
+select '[3,NaN]'::vector
+
+query error pgcode 22000 infinite value not allowed in vector
+select '[3,Inf]'::vector
+
+query error pgcode 22000 infinite value not allowed in vector
+select '[3,-Inf]'::vector
+
+statement ok
+CREATE TABLE x (a float[], b real[])
+
+# Test implicit cast from vector to array.
+statement ok
+INSERT INTO x VALUES('[1,2]'::vector, '[3,4]'::vector)
+
+statement ok
+CREATE TABLE v3 (v1 vector(1), v2 vector(1));
+INSERT INTO v3 VALUES
+('[1]', '[2]'),
+('[1]', '[-2]'),
+(NULL, '[1]'),
+('[1]', NULL)
+
+query FFFTTT rowsort
+SELECT v1<->v2, v1<#>v2, v1<=>v2, v1+v2, v1-v2, v1*v2 FROM v3
+----
+1 -2 0 [3] [-1] [2]
+3 2 2 [-1] [3] [-2]
+NULL NULL NULL NULL NULL NULL
+NULL NULL NULL NULL NULL NULL
+
+query FFFFFI rowsort
+SELECT l1_distance(v1,v2), l2_distance(v1,v2), cosine_distance(v1,v2), inner_product(v1,v2), vector_norm(v1), vector_dims(v1) FROM v3
+----
+1 1 0 2 1 1
+3 3 2 -2 1 1
+NULL NULL NULL NULL NULL NULL
+NULL NULL NULL NULL 1 1
diff --git a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go
index 0469f99fc6b3..11be5fd5b63a 100644
--- a/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go
+++ b/pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go
@@ -2825,6 +2825,13 @@ func TestTenantLogicCCL_unique_read_committed(
runCCLLogicTest(t, "unique_read_committed")
}
+func TestTenantLogicCCL_vector(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runCCLLogicTest(t, "vector")
+}
+
func TestTenantLogicCCL_zone_config_secondary_tenants(
t *testing.T,
) {
diff --git a/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/BUILD.bazel b/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/BUILD.bazel
new file mode 100644
index 000000000000..a7aef9e5a479
--- /dev/null
+++ b/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/BUILD.bazel
@@ -0,0 +1,30 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+go_test(
+ name = "cockroach-go-testserver-23_2_test",
+ size = "enormous",
+ srcs = ["generated_test.go"],
+ data = [
+ "//c-deps:libgeos", # keep
+ "//pkg/ccl/logictestccl:testdata", # keep
+ "//pkg/cmd/cockroach-short", # keep
+ "//pkg/sql/logictest:cockroach_predecessor_version", # keep
+ ],
+ exec_properties = {"test.Pool": "large"},
+ shard_count = 1,
+ tags = ["cpu:2"],
+ deps = [
+ "//pkg/base",
+ "//pkg/build/bazel",
+ "//pkg/ccl",
+ "//pkg/security/securityassets",
+ "//pkg/security/securitytest",
+ "//pkg/server",
+ "//pkg/sql/logictest",
+ "//pkg/testutils/serverutils",
+ "//pkg/testutils/skip",
+ "//pkg/testutils/testcluster",
+ "//pkg/util/leaktest",
+ "//pkg/util/randutil",
+ ],
+)
diff --git a/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/generated_test.go b/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/generated_test.go
new file mode 100644
index 000000000000..dda3b4feab45
--- /dev/null
+++ b/pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2/generated_test.go
@@ -0,0 +1,86 @@
+// Copyright 2022 The Cockroach Authors.
+//
+// Licensed as a CockroachDB Enterprise file under the Cockroach Community
+// License (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
+
+// Code generated by generate-logictest, DO NOT EDIT.
+
+package testcockroach_go_testserver_232
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/cockroachdb/cockroach/pkg/base"
+ "github.com/cockroachdb/cockroach/pkg/build/bazel"
+ "github.com/cockroachdb/cockroach/pkg/ccl"
+ "github.com/cockroachdb/cockroach/pkg/security/securityassets"
+ "github.com/cockroachdb/cockroach/pkg/security/securitytest"
+ "github.com/cockroachdb/cockroach/pkg/server"
+ "github.com/cockroachdb/cockroach/pkg/sql/logictest"
+ "github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
+ "github.com/cockroachdb/cockroach/pkg/testutils/skip"
+ "github.com/cockroachdb/cockroach/pkg/testutils/testcluster"
+ "github.com/cockroachdb/cockroach/pkg/util/leaktest"
+ "github.com/cockroachdb/cockroach/pkg/util/randutil"
+)
+
+const configIdx = 20
+
+var cclLogicTestDir string
+
+func init() {
+ if bazel.BuiltWithBazel() {
+ var err error
+ cclLogicTestDir, err = bazel.Runfile("pkg/ccl/logictestccl/testdata/logic_test")
+ if err != nil {
+ panic(err)
+ }
+ } else {
+ cclLogicTestDir = "../../../../ccl/logictestccl/testdata/logic_test"
+ }
+}
+
+func TestMain(m *testing.M) {
+ defer ccl.TestingEnableEnterprise()()
+ securityassets.SetLoader(securitytest.EmbeddedAssets)
+ randutil.SeedForTests()
+ serverutils.InitTestServerFactory(server.TestServerFactory)
+ serverutils.InitTestClusterFactory(testcluster.TestClusterFactory)
+
+ defer serverutils.TestingSetDefaultTenantSelectionOverride(
+ base.TestIsForStuffThatShouldWorkWithSecondaryTenantsButDoesntYet(76378),
+ )()
+
+ os.Exit(m.Run())
+}
+
+func runCCLLogicTest(t *testing.T, file string) {
+ skip.UnderDeadlock(t, "times out and/or hangs")
+ logictest.RunLogicTest(t, logictest.TestServerArgs{}, configIdx, filepath.Join(cclLogicTestDir, file))
+}
+
+// TestLogic_tmp runs any tests that are prefixed with "_", in which a dedicated
+// test is not generated for. This allows developers to create and run temporary
+// test files that are not checked into the repository, without repeatedly
+// regenerating and reverting changes to this file, generated_test.go.
+//
+// TODO(mgartner): Add file filtering so that individual files can be run,
+// instead of all files with the "_" prefix.
+func TestLogic_tmp(t *testing.T) {
+ defer leaktest.AfterTest(t)()
+ var glob string
+ glob = filepath.Join(cclLogicTestDir, "_*")
+ logictest.RunLogicTests(t, logictest.TestServerArgs{}, configIdx, glob)
+}
+
+func TestCCLLogic_mixed_version_pgvector(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runCCLLogicTest(t, "mixed_version_pgvector")
+}
diff --git a/pkg/ccl/logictestccl/tests/fakedist-disk/BUILD.bazel b/pkg/ccl/logictestccl/tests/fakedist-disk/BUILD.bazel
index c4bb0b39dfba..5297e82de11f 100644
--- a/pkg/ccl/logictestccl/tests/fakedist-disk/BUILD.bazel
+++ b/pkg/ccl/logictestccl/tests/fakedist-disk/BUILD.bazel
@@ -12,7 +12,7 @@ go_test(
"//build/toolchains:is_heavy": {"test.Pool": "heavy"},
"//conditions:default": {"test.Pool": "large"},
}),
- shard_count = 28,
+ shard_count = 29,
tags = ["cpu:2"],
deps = [
"//pkg/base",
diff --git a/pkg/ccl/logictestccl/tests/fakedist-disk/generated_test.go b/pkg/ccl/logictestccl/tests/fakedist-disk/generated_test.go
index a2fdeec1cc9c..d61025d5a97f 100644
--- a/pkg/ccl/logictestccl/tests/fakedist-disk/generated_test.go
+++ b/pkg/ccl/logictestccl/tests/fakedist-disk/generated_test.go
@@ -273,3 +273,10 @@ func TestCCLLogic_unique_read_committed(
defer leaktest.AfterTest(t)()
runCCLLogicTest(t, "unique_read_committed")
}
+
+func TestCCLLogic_vector(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runCCLLogicTest(t, "vector")
+}
diff --git a/pkg/ccl/logictestccl/tests/fakedist-vec-off/BUILD.bazel b/pkg/ccl/logictestccl/tests/fakedist-vec-off/BUILD.bazel
index e7eb49f07487..d63444795f09 100644
--- a/pkg/ccl/logictestccl/tests/fakedist-vec-off/BUILD.bazel
+++ b/pkg/ccl/logictestccl/tests/fakedist-vec-off/BUILD.bazel
@@ -12,7 +12,7 @@ go_test(
"//build/toolchains:is_heavy": {"test.Pool": "heavy"},
"//conditions:default": {"test.Pool": "large"},
}),
- shard_count = 28,
+ shard_count = 29,
tags = ["cpu:2"],
deps = [
"//pkg/base",
diff --git a/pkg/ccl/logictestccl/tests/fakedist-vec-off/generated_test.go b/pkg/ccl/logictestccl/tests/fakedist-vec-off/generated_test.go
index 41b1bc49cbcc..c24a523488fc 100644
--- a/pkg/ccl/logictestccl/tests/fakedist-vec-off/generated_test.go
+++ b/pkg/ccl/logictestccl/tests/fakedist-vec-off/generated_test.go
@@ -273,3 +273,10 @@ func TestCCLLogic_unique_read_committed(
defer leaktest.AfterTest(t)()
runCCLLogicTest(t, "unique_read_committed")
}
+
+func TestCCLLogic_vector(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runCCLLogicTest(t, "vector")
+}
diff --git a/pkg/ccl/logictestccl/tests/fakedist/BUILD.bazel b/pkg/ccl/logictestccl/tests/fakedist/BUILD.bazel
index 9f97dfa4f620..4c8d25f92413 100644
--- a/pkg/ccl/logictestccl/tests/fakedist/BUILD.bazel
+++ b/pkg/ccl/logictestccl/tests/fakedist/BUILD.bazel
@@ -12,7 +12,7 @@ go_test(
"//build/toolchains:is_heavy": {"test.Pool": "heavy"},
"//conditions:default": {"test.Pool": "large"},
}),
- shard_count = 29,
+ shard_count = 30,
tags = ["cpu:2"],
deps = [
"//pkg/base",
diff --git a/pkg/ccl/logictestccl/tests/fakedist/generated_test.go b/pkg/ccl/logictestccl/tests/fakedist/generated_test.go
index b9e4f8e0d73b..bf73ec236736 100644
--- a/pkg/ccl/logictestccl/tests/fakedist/generated_test.go
+++ b/pkg/ccl/logictestccl/tests/fakedist/generated_test.go
@@ -280,3 +280,10 @@ func TestCCLLogic_unique_read_committed(
defer leaktest.AfterTest(t)()
runCCLLogicTest(t, "unique_read_committed")
}
+
+func TestCCLLogic_vector(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runCCLLogicTest(t, "vector")
+}
diff --git a/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/BUILD.bazel b/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/BUILD.bazel
index 638d7cc79915..2d52175c8146 100644
--- a/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/BUILD.bazel
+++ b/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/BUILD.bazel
@@ -9,7 +9,7 @@ go_test(
"//pkg/ccl/logictestccl:testdata", # keep
],
exec_properties = {"test.Pool": "large"},
- shard_count = 28,
+ shard_count = 29,
tags = ["cpu:1"],
deps = [
"//pkg/base",
diff --git a/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/generated_test.go b/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/generated_test.go
index 357b386f166d..5374bf1bf141 100644
--- a/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/generated_test.go
+++ b/pkg/ccl/logictestccl/tests/local-legacy-schema-changer/generated_test.go
@@ -273,3 +273,10 @@ func TestCCLLogic_unique_read_committed(
defer leaktest.AfterTest(t)()
runCCLLogicTest(t, "unique_read_committed")
}
+
+func TestCCLLogic_vector(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runCCLLogicTest(t, "vector")
+}
diff --git a/pkg/ccl/logictestccl/tests/local-read-committed/BUILD.bazel b/pkg/ccl/logictestccl/tests/local-read-committed/BUILD.bazel
index 6ebd57d3fe8f..1ce636f6aa8d 100644
--- a/pkg/ccl/logictestccl/tests/local-read-committed/BUILD.bazel
+++ b/pkg/ccl/logictestccl/tests/local-read-committed/BUILD.bazel
@@ -10,7 +10,7 @@ go_test(
"//pkg/sql/opt/exec/execbuilder:testdata", # keep
],
exec_properties = {"test.Pool": "large"},
- shard_count = 35,
+ shard_count = 36,
tags = ["cpu:1"],
deps = [
"//pkg/base",
diff --git a/pkg/ccl/logictestccl/tests/local-read-committed/generated_test.go b/pkg/ccl/logictestccl/tests/local-read-committed/generated_test.go
index abca7f4e8348..bd066c1e9173 100644
--- a/pkg/ccl/logictestccl/tests/local-read-committed/generated_test.go
+++ b/pkg/ccl/logictestccl/tests/local-read-committed/generated_test.go
@@ -301,6 +301,13 @@ func TestReadCommittedLogicCCL_unique_read_committed(
runCCLLogicTest(t, "unique_read_committed")
}
+func TestReadCommittedLogicCCL_vector(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runCCLLogicTest(t, "vector")
+}
+
func TestReadCommittedExecBuild_explain_analyze_read_committed(
t *testing.T,
) {
diff --git a/pkg/ccl/logictestccl/tests/local-vec-off/BUILD.bazel b/pkg/ccl/logictestccl/tests/local-vec-off/BUILD.bazel
index 92f0ed613a5c..8f49d20f505c 100644
--- a/pkg/ccl/logictestccl/tests/local-vec-off/BUILD.bazel
+++ b/pkg/ccl/logictestccl/tests/local-vec-off/BUILD.bazel
@@ -9,7 +9,7 @@ go_test(
"//pkg/ccl/logictestccl:testdata", # keep
],
exec_properties = {"test.Pool": "large"},
- shard_count = 28,
+ shard_count = 29,
tags = ["cpu:1"],
deps = [
"//pkg/base",
diff --git a/pkg/ccl/logictestccl/tests/local-vec-off/generated_test.go b/pkg/ccl/logictestccl/tests/local-vec-off/generated_test.go
index c9e9d7e13055..67c0c4b6610c 100644
--- a/pkg/ccl/logictestccl/tests/local-vec-off/generated_test.go
+++ b/pkg/ccl/logictestccl/tests/local-vec-off/generated_test.go
@@ -273,3 +273,10 @@ func TestCCLLogic_unique_read_committed(
defer leaktest.AfterTest(t)()
runCCLLogicTest(t, "unique_read_committed")
}
+
+func TestCCLLogic_vector(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runCCLLogicTest(t, "vector")
+}
diff --git a/pkg/ccl/logictestccl/tests/local/BUILD.bazel b/pkg/ccl/logictestccl/tests/local/BUILD.bazel
index ef7ea4aaef3b..096337816aaf 100644
--- a/pkg/ccl/logictestccl/tests/local/BUILD.bazel
+++ b/pkg/ccl/logictestccl/tests/local/BUILD.bazel
@@ -9,7 +9,7 @@ go_test(
"//pkg/ccl/logictestccl:testdata", # keep
],
exec_properties = {"test.Pool": "large"},
- shard_count = 44,
+ shard_count = 45,
tags = ["cpu:1"],
deps = [
"//pkg/base",
diff --git a/pkg/ccl/logictestccl/tests/local/generated_test.go b/pkg/ccl/logictestccl/tests/local/generated_test.go
index c67aea9b5aaf..837293b36ae2 100644
--- a/pkg/ccl/logictestccl/tests/local/generated_test.go
+++ b/pkg/ccl/logictestccl/tests/local/generated_test.go
@@ -385,3 +385,10 @@ func TestCCLLogic_unique_read_committed(
defer leaktest.AfterTest(t)()
runCCLLogicTest(t, "unique_read_committed")
}
+
+func TestCCLLogic_vector(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runCCLLogicTest(t, "vector")
+}
diff --git a/pkg/sql/alter_column_type.go b/pkg/sql/alter_column_type.go
index 11eb91a3909b..925667e2a7a8 100644
--- a/pkg/sql/alter_column_type.go
+++ b/pkg/sql/alter_column_type.go
@@ -107,7 +107,7 @@ func AlterColumnType(
}
}
- err = colinfo.ValidateColumnDefType(ctx, params.EvalContext().Settings.Version, typ)
+ err = colinfo.ValidateColumnDefType(ctx, params.EvalContext().Settings, typ)
if err != nil {
return err
}
diff --git a/pkg/sql/catalog/colinfo/BUILD.bazel b/pkg/sql/catalog/colinfo/BUILD.bazel
index 53e850531285..795ccb31fe0a 100644
--- a/pkg/sql/catalog/colinfo/BUILD.bazel
+++ b/pkg/sql/catalog/colinfo/BUILD.bazel
@@ -16,7 +16,9 @@ go_library(
importpath = "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/base",
"//pkg/clusterversion",
+ "//pkg/settings/cluster",
"//pkg/sql/catalog",
"//pkg/sql/catalog/catpb",
"//pkg/sql/catalog/descpb",
diff --git a/pkg/sql/catalog/colinfo/col_type_info.go b/pkg/sql/catalog/colinfo/col_type_info.go
index b5a6633986ae..9002aa6cb5dd 100644
--- a/pkg/sql/catalog/colinfo/col_type_info.go
+++ b/pkg/sql/catalog/colinfo/col_type_info.go
@@ -14,7 +14,9 @@ import (
"context"
"fmt"
+ "github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/clusterversion"
+ "github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql/catalog"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
@@ -70,7 +72,7 @@ func (ti ColTypeInfo) Type(idx int) *types.T {
// ValidateColumnDefType returns an error if the type of a column definition is
// not valid. It is checked when a column is created or altered.
-func ValidateColumnDefType(ctx context.Context, version clusterversion.Handle, t *types.T) error {
+func ValidateColumnDefType(ctx context.Context, st *cluster.Settings, t *types.T) error {
switch t.Family() {
case types.StringFamily, types.CollatedStringFamily:
if t.Family() == types.CollatedStringFamily {
@@ -102,7 +104,7 @@ func ValidateColumnDefType(ctx context.Context, version clusterversion.Handle, t
if err := types.CheckArrayElementType(t.ArrayContents()); err != nil {
return err
}
- return ValidateColumnDefType(ctx, version, t.ArrayContents())
+ return ValidateColumnDefType(ctx, st, t.ArrayContents())
case types.BitFamily, types.IntFamily, types.FloatFamily, types.BoolFamily, types.BytesFamily, types.DateFamily,
types.INetFamily, types.IntervalFamily, types.JsonFamily, types.OidFamily, types.TimeFamily,
@@ -119,6 +121,17 @@ func ValidateColumnDefType(ctx context.Context, version clusterversion.Handle, t
return unimplemented.NewWithIssue(70099, "cannot use table record type as table column")
}
+ case types.PGVectorFamily:
+ if !st.Version.IsActive(ctx, clusterversion.V24_2) {
+ return pgerror.Newf(
+ pgcode.FeatureNotSupported,
+ "pg_vector not supported until version 24.2",
+ )
+ }
+ if err := base.CheckEnterpriseEnabled(st, "vector datatype"); err != nil {
+ return err
+ }
+
default:
return pgerror.Newf(pgcode.InvalidTableDefinition,
"value type %s cannot be used for table columns", t.String())
@@ -192,6 +205,8 @@ func MustBeValueEncoded(semanticType *types.T) bool {
return true
case types.TSVectorFamily, types.TSQueryFamily:
return true
+ case types.PGVectorFamily:
+ return true
}
return false
}
diff --git a/pkg/sql/catalog/colinfo/column_type_properties.go b/pkg/sql/catalog/colinfo/column_type_properties.go
index f474af606b49..1898ab46034f 100644
--- a/pkg/sql/catalog/colinfo/column_type_properties.go
+++ b/pkg/sql/catalog/colinfo/column_type_properties.go
@@ -83,6 +83,7 @@ func CanHaveCompositeKeyEncoding(typ *types.T) bool {
types.EnumFamily,
types.Box2DFamily,
types.PGLSNFamily,
+ types.PGVectorFamily,
types.RefCursorFamily,
types.VoidFamily,
types.EncodedKeyFamily,
diff --git a/pkg/sql/catalog/tabledesc/table.go b/pkg/sql/catalog/tabledesc/table.go
index 9b55075bfb3a..fb871851b11e 100644
--- a/pkg/sql/catalog/tabledesc/table.go
+++ b/pkg/sql/catalog/tabledesc/table.go
@@ -161,7 +161,7 @@ func MakeColumnDefDescs(
if err != nil {
return nil, err
}
- if err = colinfo.ValidateColumnDefType(ctx, evalCtx.Settings.Version, resType); err != nil {
+ if err = colinfo.ValidateColumnDefType(ctx, evalCtx.Settings, resType); err != nil {
return nil, err
}
col.Type = resType
diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go
index 77c71298db3b..0a449589876b 100644
--- a/pkg/sql/exec_util.go
+++ b/pkg/sql/exec_util.go
@@ -2051,6 +2051,7 @@ func checkResultType(typ *types.T, fmtCode pgwirebase.FormatCode) error {
case types.INetFamily:
case types.OidFamily:
case types.PGLSNFamily:
+ case types.PGVectorFamily:
case types.RefCursorFamily:
case types.TupleFamily:
case types.EnumFamily:
diff --git a/pkg/sql/logictest/BUILD.bazel b/pkg/sql/logictest/BUILD.bazel
index a54b6ee2b99b..59bb3a24bd43 100644
--- a/pkg/sql/logictest/BUILD.bazel
+++ b/pkg/sql/logictest/BUILD.bazel
@@ -38,6 +38,7 @@ filegroup(
}),
visibility = [
"//pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.1:__pkg__",
+ "//pkg/ccl/logictestccl/tests/cockroach-go-testserver-23.2:__pkg__",
"//pkg/sql/logictest/tests/cockroach-go-testserver-23.1:__pkg__",
"//pkg/sql/logictest/tests/cockroach-go-testserver-23.2:__pkg__",
],
diff --git a/pkg/sql/logictest/testdata/logic_test/grant_table b/pkg/sql/logictest/testdata/logic_test/grant_table
index f8384639e343..a47e1331bd56 100644
--- a/pkg/sql/logictest/testdata/logic_test/grant_table
+++ b/pkg/sql/logictest/testdata/logic_test/grant_table
@@ -603,6 +603,12 @@ test pg_catalog varchar
test pg_catalog varchar[] type admin ALL false
test pg_catalog varchar[] type public USAGE false
test pg_catalog varchar[] type root ALL false
+test pg_catalog vector type admin ALL false
+test pg_catalog vector type public USAGE false
+test pg_catalog vector type root ALL false
+test pg_catalog vector[] type admin ALL false
+test pg_catalog vector[] type public USAGE false
+test pg_catalog vector[] type root ALL false
test pg_catalog void type admin ALL false
test pg_catalog void type public USAGE false
test pg_catalog void type root ALL false
@@ -791,6 +797,10 @@ test pg_catalog varchar type admin ALL
test pg_catalog varchar type root ALL false
test pg_catalog varchar[] type admin ALL false
test pg_catalog varchar[] type root ALL false
+test pg_catalog vector type admin ALL false
+test pg_catalog vector type root ALL false
+test pg_catalog vector[] type admin ALL false
+test pg_catalog vector[] type root ALL false
test pg_catalog void type admin ALL false
test pg_catalog void type root ALL false
test public NULL schema admin ALL true
@@ -1403,6 +1413,10 @@ a pg_catalog varchar type admin
a pg_catalog varchar type root ALL false
a pg_catalog varchar[] type admin ALL false
a pg_catalog varchar[] type root ALL false
+a pg_catalog vector type admin ALL false
+a pg_catalog vector type root ALL false
+a pg_catalog vector[] type admin ALL false
+a pg_catalog vector[] type root ALL false
a pg_catalog void type admin ALL false
a pg_catalog void type root ALL false
a public NULL schema admin ALL true
@@ -1579,6 +1593,10 @@ defaultdb pg_catalog varchar type admin
defaultdb pg_catalog varchar type root ALL false
defaultdb pg_catalog varchar[] type admin ALL false
defaultdb pg_catalog varchar[] type root ALL false
+defaultdb pg_catalog vector type admin ALL false
+defaultdb pg_catalog vector type root ALL false
+defaultdb pg_catalog vector[] type admin ALL false
+defaultdb pg_catalog vector[] type root ALL false
defaultdb pg_catalog void type admin ALL false
defaultdb pg_catalog void type root ALL false
defaultdb public NULL schema admin ALL true
@@ -1755,6 +1773,10 @@ postgres pg_catalog varchar type admin
postgres pg_catalog varchar type root ALL false
postgres pg_catalog varchar[] type admin ALL false
postgres pg_catalog varchar[] type root ALL false
+postgres pg_catalog vector type admin ALL false
+postgres pg_catalog vector type root ALL false
+postgres pg_catalog vector[] type admin ALL false
+postgres pg_catalog vector[] type root ALL false
postgres pg_catalog void type admin ALL false
postgres pg_catalog void type root ALL false
postgres public NULL schema admin ALL true
@@ -1931,6 +1953,10 @@ system pg_catalog varchar type admin
system pg_catalog varchar type root ALL false
system pg_catalog varchar[] type admin ALL false
system pg_catalog varchar[] type root ALL false
+system pg_catalog vector type admin ALL false
+system pg_catalog vector type root ALL false
+system pg_catalog vector[] type admin ALL false
+system pg_catalog vector[] type root ALL false
system pg_catalog void type admin ALL false
system pg_catalog void type root ALL false
system public NULL schema admin ALL true
@@ -2487,6 +2513,10 @@ test pg_catalog varchar type admin
test pg_catalog varchar type root ALL false
test pg_catalog varchar[] type admin ALL false
test pg_catalog varchar[] type root ALL false
+test pg_catalog vector type admin ALL false
+test pg_catalog vector type root ALL false
+test pg_catalog vector[] type admin ALL false
+test pg_catalog vector[] type root ALL false
test pg_catalog void type admin ALL false
test pg_catalog void type root ALL false
test public NULL schema admin ALL true
diff --git a/pkg/sql/logictest/testdata/logic_test/pg_catalog b/pkg/sql/logictest/testdata/logic_test/pg_catalog
index 97a1da50aa8c..a77bef11eeba 100644
--- a/pkg/sql/logictest/testdata/logic_test/pg_catalog
+++ b/pkg/sql/logictest/testdata/logic_test/pg_catalog
@@ -1923,6 +1923,8 @@ oid typname typnamespace typowner typlen typbyval typty
90003 _geography 4294967103 NULL -1 false b
90004 box2d 4294967103 NULL 32 true b
90005 _box2d 4294967103 NULL -1 false b
+90006 vector 4294967103 NULL -1 false b
+90007 _vector 4294967103 NULL -1 false b
100110 t1 109 1546506610 -1 false c
100111 t1_m_seq 109 1546506610 -1 false c
100112 t1_n_seq 109 1546506610 -1 false c
@@ -2038,6 +2040,8 @@ oid typname typcategory typispreferred typisdefined typdel
90003 _geography A false true : 0 90002 0
90004 box2d U false true , 0 0 90005
90005 _box2d A false true , 0 90004 0
+90006 vector U false true , 0 0 90007
+90007 _vector A false true , 0 90006 0
100110 t1 C false true , 110 0 0
100111 t1_m_seq C false true , 111 0 0
100112 t1_n_seq C false true , 112 0 0
@@ -2153,6 +2157,8 @@ oid typname typinput typoutput typreceive
90003 _geography array_in array_out array_recv array_send 0 0 0
90004 box2d box2d_in box2d_out box2d_recv box2d_send 0 0 0
90005 _box2d array_in array_out array_recv array_send 0 0 0
+90006 vector vectorin vectorout vectorrecv vectorsend 0 0 0
+90007 _vector array_in array_out array_recv array_send 0 0 0
100110 t1 record_in record_out record_recv record_send 0 0 0
100111 t1_m_seq record_in record_out record_recv record_send 0 0 0
100112 t1_n_seq record_in record_out record_recv record_send 0 0 0
@@ -2268,6 +2274,8 @@ oid typname typalign typstorage typnotnull typbasetype ty
90003 _geography NULL NULL false 0 -1
90004 box2d NULL NULL false 0 -1
90005 _box2d NULL NULL false 0 -1
+90006 vector NULL NULL false 0 -1
+90007 _vector NULL NULL false 0 -1
100110 t1 NULL NULL false 0 -1
100111 t1_m_seq NULL NULL false 0 -1
100112 t1_n_seq NULL NULL false 0 -1
@@ -2383,6 +2391,8 @@ oid typname typndims typcollation typdefaultbin typdefault
90003 _geography 0 0 NULL NULL NULL
90004 box2d 0 0 NULL NULL NULL
90005 _box2d 0 0 NULL NULL NULL
+90006 vector 0 0 NULL NULL NULL
+90007 _vector 0 0 NULL NULL NULL
100110 t1 0 0 NULL NULL NULL
100111 t1_m_seq 0 0 NULL NULL NULL
100112 t1_n_seq 0 0 NULL NULL NULL
diff --git a/pkg/sql/logictest/testdata/logic_test/vectoross b/pkg/sql/logictest/testdata/logic_test/vectoross
new file mode 100644
index 000000000000..12a220dbd673
--- /dev/null
+++ b/pkg/sql/logictest/testdata/logic_test/vectoross
@@ -0,0 +1,5 @@
+# LogicTest: !local-mixed-23.2 !3node-tenant
+
+statement error OSS binaries do not include enterprise features
+CREATE TABLE v (v vector)
+
diff --git a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go
index 02112869ec57..5074355e8844 100644
--- a/pkg/sql/logictest/tests/fakedist-disk/generated_test.go
+++ b/pkg/sql/logictest/tests/fakedist-disk/generated_test.go
@@ -2458,6 +2458,13 @@ func TestLogic_vectorize_window(
runLogicTest(t, "vectorize_window")
}
+func TestLogic_vectoross(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runLogicTest(t, "vectoross")
+}
+
func TestLogic_views(
t *testing.T,
) {
diff --git a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go
index f82f27e38146..a88349f4d32d 100644
--- a/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go
+++ b/pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go
@@ -2451,6 +2451,13 @@ func TestLogic_vectorize_unsupported(
runLogicTest(t, "vectorize_unsupported")
}
+func TestLogic_vectoross(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runLogicTest(t, "vectoross")
+}
+
func TestLogic_views(
t *testing.T,
) {
diff --git a/pkg/sql/logictest/tests/fakedist/generated_test.go b/pkg/sql/logictest/tests/fakedist/generated_test.go
index 1b066c6ef182..1818a68527d9 100644
--- a/pkg/sql/logictest/tests/fakedist/generated_test.go
+++ b/pkg/sql/logictest/tests/fakedist/generated_test.go
@@ -2479,6 +2479,13 @@ func TestLogic_vectorize_window(
runLogicTest(t, "vectorize_window")
}
+func TestLogic_vectoross(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runLogicTest(t, "vectoross")
+}
+
func TestLogic_views(
t *testing.T,
) {
diff --git a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go
index cc508f62730a..7945dd8a194b 100644
--- a/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go
+++ b/pkg/sql/logictest/tests/local-legacy-schema-changer/generated_test.go
@@ -2451,6 +2451,13 @@ func TestLogic_vectorize_unsupported(
runLogicTest(t, "vectorize_unsupported")
}
+func TestLogic_vectoross(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runLogicTest(t, "vectoross")
+}
+
func TestLogic_views(
t *testing.T,
) {
diff --git a/pkg/sql/logictest/tests/local-vec-off/generated_test.go b/pkg/sql/logictest/tests/local-vec-off/generated_test.go
index 28dea5367cd3..13ba0ea77d2b 100644
--- a/pkg/sql/logictest/tests/local-vec-off/generated_test.go
+++ b/pkg/sql/logictest/tests/local-vec-off/generated_test.go
@@ -2479,6 +2479,13 @@ func TestLogic_vectorize_unsupported(
runLogicTest(t, "vectorize_unsupported")
}
+func TestLogic_vectoross(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runLogicTest(t, "vectoross")
+}
+
func TestLogic_views(
t *testing.T,
) {
diff --git a/pkg/sql/logictest/tests/local/generated_test.go b/pkg/sql/logictest/tests/local/generated_test.go
index 9f1ac0809e00..12bfc36634c4 100644
--- a/pkg/sql/logictest/tests/local/generated_test.go
+++ b/pkg/sql/logictest/tests/local/generated_test.go
@@ -2710,6 +2710,13 @@ func TestLogic_vectorize_window(
runLogicTest(t, "vectorize_window")
}
+func TestLogic_vectoross(
+ t *testing.T,
+) {
+ defer leaktest.AfterTest(t)()
+ runLogicTest(t, "vectoross")
+}
+
func TestLogic_views(
t *testing.T,
) {
diff --git a/pkg/sql/oidext/oidext.go b/pkg/sql/oidext/oidext.go
index ac7da4507c7a..6a1d08461c78 100644
--- a/pkg/sql/oidext/oidext.go
+++ b/pkg/sql/oidext/oidext.go
@@ -34,6 +34,8 @@ const (
T__geography = oid.Oid(90003)
T_box2d = oid.Oid(90004)
T__box2d = oid.Oid(90005)
+ T_pgvector = oid.Oid(90006)
+ T__pgvector = oid.Oid(90007)
)
// ExtensionTypeName returns a mapping from extension oids
@@ -45,6 +47,8 @@ var ExtensionTypeName = map[oid.Oid]string{
T__geography: "_GEOGRAPHY",
T_box2d: "BOX2D",
T__box2d: "_BOX2D",
+ T_pgvector: "VECTOR",
+ T__pgvector: "_VECTOR",
}
// TypeName checks the name for a given type by first looking up oid.TypeName
diff --git a/pkg/sql/opt/operator.go b/pkg/sql/opt/operator.go
index 25cab29dc048..234fd30245e8 100644
--- a/pkg/sql/opt/operator.go
+++ b/pkg/sql/opt/operator.go
@@ -151,23 +151,26 @@ var ComparisonOpReverseMap = map[Operator]treecmp.ComparisonOperatorSymbol{
// BinaryOpReverseMap maps from an optimizer operator type to a semantic tree
// binary operator type.
var BinaryOpReverseMap = map[Operator]treebin.BinaryOperatorSymbol{
- BitandOp: treebin.Bitand,
- BitorOp: treebin.Bitor,
- BitxorOp: treebin.Bitxor,
- PlusOp: treebin.Plus,
- MinusOp: treebin.Minus,
- MultOp: treebin.Mult,
- DivOp: treebin.Div,
- FloorDivOp: treebin.FloorDiv,
- ModOp: treebin.Mod,
- PowOp: treebin.Pow,
- ConcatOp: treebin.Concat,
- LShiftOp: treebin.LShift,
- RShiftOp: treebin.RShift,
- FetchValOp: treebin.JSONFetchVal,
- FetchTextOp: treebin.JSONFetchText,
- FetchValPathOp: treebin.JSONFetchValPath,
- FetchTextPathOp: treebin.JSONFetchTextPath,
+ BitandOp: treebin.Bitand,
+ BitorOp: treebin.Bitor,
+ BitxorOp: treebin.Bitxor,
+ PlusOp: treebin.Plus,
+ MinusOp: treebin.Minus,
+ MultOp: treebin.Mult,
+ DivOp: treebin.Div,
+ FloorDivOp: treebin.FloorDiv,
+ ModOp: treebin.Mod,
+ PowOp: treebin.Pow,
+ ConcatOp: treebin.Concat,
+ LShiftOp: treebin.LShift,
+ RShiftOp: treebin.RShift,
+ FetchValOp: treebin.JSONFetchVal,
+ FetchTextOp: treebin.JSONFetchText,
+ FetchValPathOp: treebin.JSONFetchValPath,
+ FetchTextPathOp: treebin.JSONFetchTextPath,
+ VectorDistanceOp: treebin.Distance,
+ VectorCosDistanceOp: treebin.CosDistance,
+ VectorNegInnerProductOp: treebin.NegInnerProduct,
}
// UnaryOpReverseMap maps from an optimizer operator type to a semantic tree
diff --git a/pkg/sql/opt/ops/scalar.opt b/pkg/sql/opt/ops/scalar.opt
index 7572e006e9dd..dcc991d28d96 100644
--- a/pkg/sql/opt/ops/scalar.opt
+++ b/pkg/sql/opt/ops/scalar.opt
@@ -514,6 +514,30 @@ define TSMatches {
Right ScalarExpr
}
+# VectorDistance is the <-> operator when used with vector operands.
+# It maps to tree.Distance.
+[Scalar, Binary]
+define VectorDistance {
+ Left ScalarExpr
+ Right ScalarExpr
+}
+
+# VectorCosDistance is the <=> operator when used with vector operands.
+# It maps to tree.CosDistance.
+[Scalar, Binary]
+define VectorCosDistance {
+ Left ScalarExpr
+ Right ScalarExpr
+}
+
+# VectorNegInnerProduct is the <#> operator when used with vector operands.
+# It maps to tree.NegInnerProduct.
+[Scalar, Binary]
+define VectorNegInnerProduct {
+ Left ScalarExpr
+ Right ScalarExpr
+}
+
# AnyScalar is the form of ANY which refers to an ANY operation on a
# tuple or array, as opposed to Any which operates on a subquery.
[Scalar, Bool]
diff --git a/pkg/sql/opt/optbuilder/scalar.go b/pkg/sql/opt/optbuilder/scalar.go
index 85441b076aa7..5755c1035106 100644
--- a/pkg/sql/opt/optbuilder/scalar.go
+++ b/pkg/sql/opt/optbuilder/scalar.go
@@ -853,6 +853,12 @@ func (b *Builder) constructBinary(
return b.factory.ConstructFetchValPath(left, right)
case treebin.JSONFetchTextPath:
return b.factory.ConstructFetchTextPath(left, right)
+ case treebin.Distance:
+ return b.factory.ConstructVectorDistance(left, right)
+ case treebin.CosDistance:
+ return b.factory.ConstructVectorCosDistance(left, right)
+ case treebin.NegInnerProduct:
+ return b.factory.ConstructVectorNegInnerProduct(left, right)
}
panic(errors.AssertionFailedf("unhandled binary operator: %s", redact.Safe(bin)))
}
diff --git a/pkg/sql/parser/BUILD.bazel b/pkg/sql/parser/BUILD.bazel
index b9bb467341db..a00beb6bc2b2 100644
--- a/pkg/sql/parser/BUILD.bazel
+++ b/pkg/sql/parser/BUILD.bazel
@@ -37,6 +37,7 @@ go_library(
"//pkg/sql/sem/tree/treewindow", # keep
"//pkg/sql/types",
"//pkg/util/errorutil/unimplemented",
+ "//pkg/util/vector", # keep
"@com_github_cockroachdb_errors//:errors",
"@com_github_lib_pq//oid", # keep
"@org_golang_x_text//cases",
diff --git a/pkg/sql/parser/sql.y b/pkg/sql/parser/sql.y
index 83b87c3d700b..b08f8e97b9ad 100644
--- a/pkg/sql/parser/sql.y
+++ b/pkg/sql/parser/sql.y
@@ -41,6 +41,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree/treecmp"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree/treewindow"
"github.com/cockroachdb/cockroach/pkg/sql/types"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
"github.com/lib/pq/oid"
)
@@ -927,14 +928,14 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption
%token CLUSTER CLUSTERS COALESCE COLLATE COLLATION COLUMN COLUMNS COMMENT COMMENTS COMMIT
%token COMMITTED COMPACT COMPLETE COMPLETIONS CONCAT CONCURRENTLY CONFIGURATION CONFIGURATIONS CONFIGURE
%token CONFLICT CONNECTION CONNECTIONS CONSTRAINT CONSTRAINTS CONTAINS CONTROLCHANGEFEED CONTROLJOB
-%token CONVERSION CONVERT COPY COST COVERING CREATE CREATEDB CREATELOGIN CREATEROLE
+%token CONVERSION CONVERT COPY COS_DISTANCE COST COVERING CREATE CREATEDB CREATELOGIN CREATEROLE
%token CROSS CSV CUBE CURRENT CURRENT_CATALOG CURRENT_DATE CURRENT_SCHEMA
%token CURRENT_ROLE CURRENT_TIME CURRENT_TIMESTAMP
%token CURRENT_USER CURSOR CYCLE
%token DATA DATABASE DATABASES DATE DAY DEBUG_IDS DEC DEBUG_DUMP_METADATA_SST DECIMAL DEFAULT DEFAULTS DEFINER
%token DEALLOCATE DECLARE DEFERRABLE DEFERRED DELETE DELIMITER DEPENDS DESC DESTINATION DETACHED DETAILS
-%token DISCARD DISTINCT DO DOMAIN DOUBLE DROP
+%token DISCARD DISTANCE DISTINCT DO DOMAIN DOUBLE DROP
%token ELSE ENCODING ENCRYPTED ENCRYPTION_INFO_DIR ENCRYPTION_PASSPHRASE END ENUM ENUMS ESCAPE EXCEPT EXCLUDE EXCLUDING
%token EXISTS EXECUTE EXECUTION EXPERIMENTAL
@@ -977,7 +978,7 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption
%token MULTIPOINT MULTIPOINTM MULTIPOINTZ MULTIPOINTZM
%token MULTIPOLYGON MULTIPOLYGONM MULTIPOLYGONZ MULTIPOLYGONZM
-%token NAN NAME NAMES NATURAL NEVER NEW_DB_NAME NEW_KMS NEXT NO NOCANCELQUERY NOCONTROLCHANGEFEED
+%token NAN NAME NAMES NATURAL NEG_INNER_PRODUCT NEVER NEW_DB_NAME NEW_KMS NEXT NO NOCANCELQUERY NOCONTROLCHANGEFEED
%token NOCONTROLJOB NOCREATEDB NOCREATELOGIN NOCREATEROLE NODE NOLOGIN NOMODIFYCLUSTERSETTING NOREPLICATION
%token NOSQLLOGIN NO_INDEX_JOIN NO_ZIGZAG_JOIN NO_FULL_SCAN NONE NONVOTERS NORMAL NOT
%token NOTHING NOTHING_AFTER_RETURNING
@@ -1018,7 +1019,7 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption
%token UNBOUNDED UNCOMMITTED UNION UNIQUE UNKNOWN UNLISTEN UNLOGGED UNSAFE_RESTORE_INCOMPATIBLE_VERSION UNSPLIT
%token UPDATE UPDATES_CLUSTER_MONITORING_METRICS UPSERT UNSET UNTIL USE USER USERS USING UUID
-%token VALID VALIDATE VALUE VALUES VARBIT VARCHAR VARIADIC VERIFY_BACKUP_TABLE_DATA VIEW VARIABLES VARYING VIEWACTIVITY VIEWACTIVITYREDACTED VIEWDEBUG
+%token VALID VALIDATE VALUE VALUES VARBIT VARCHAR VARIADIC VECTOR VERIFY_BACKUP_TABLE_DATA VIEW VARIABLES VARYING VIEWACTIVITY VIEWACTIVITYREDACTED VIEWDEBUG
%token VIEWCLUSTERMETADATA VIEWCLUSTERSETTING VIRTUAL VISIBLE INVISIBLE VISIBILITY VOLATILE VOTERS
%token VIRTUAL_CLUSTER_NAME VIRTUAL_CLUSTER
@@ -1603,6 +1604,7 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption
%type <*types.T> character_base
%type <*types.T> geo_shape_type
%type <*types.T> const_geo
+%type <*types.T> const_vector
%type extract_arg
%type opt_varying
@@ -1760,7 +1762,7 @@ func (u *sqlSymUnion) logicalReplicationOptions() *tree.LogicalReplicationOption
// funny behavior of UNBOUNDED on the SQL standard, though.
%nonassoc UNBOUNDED // ideally should have same precedence as IDENT
%nonassoc IDENT NULL PARTITION RANGE ROWS GROUPS PRECEDING FOLLOWING CUBE ROLLUP
-%left CONCAT FETCHVAL FETCHTEXT FETCHVAL_PATH FETCHTEXT_PATH REMOVE_PATH AT_AT // multi-character ops
+%left CONCAT FETCHVAL FETCHTEXT FETCHVAL_PATH FETCHTEXT_PATH REMOVE_PATH AT_AT DISTANCE COS_DISTANCE NEG_INNER_PRODUCT // multi-character ops
%left '|'
%left '#'
%left '&'
@@ -14414,6 +14416,21 @@ const_geo:
$$.val = types.MakeGeography($3.geoShapeType(), geopb.SRID(val))
}
+const_vector:
+ VECTOR { $$.val = types.PGVector }
+| VECTOR '(' iconst32 ')'
+ {
+ dims := $3.int32()
+ if dims <= 0 {
+ sqllex.Error("dimensions for type vector must be at least 1")
+ return 1
+ } else if dims > vector.MaxDim {
+ sqllex.Error(fmt.Sprintf("dimensions for type vector cannot exceed %d", vector.MaxDim))
+ return 1
+ }
+ $$.val = types.MakePGVector(dims)
+ }
+
// We have a separate const_typename to allow defaulting fixed-length types such
// as CHAR() and BIT() to an unspecified length. SQL9x requires that these
// default to a length of one, but this makes no sense for constructs like CHAR
@@ -14431,6 +14448,7 @@ const_typename:
| character_with_length
| const_datetime
| const_geo
+| const_vector
opt_numeric_modifiers:
'(' iconst32 ')'
@@ -15016,6 +15034,18 @@ a_expr:
{
$$.val = &tree.ComparisonExpr{Operator: treecmp.MakeComparisonOperator(treecmp.TSMatches), Left: $1.expr(), Right: $3.expr()}
}
+| a_expr DISTANCE a_expr
+ {
+ $$.val = &tree.BinaryExpr{Operator: treebin.MakeBinaryOperator(treebin.Distance), Left: $1.expr(), Right: $3.expr()}
+ }
+| a_expr COS_DISTANCE a_expr
+ {
+ $$.val = &tree.BinaryExpr{Operator: treebin.MakeBinaryOperator(treebin.CosDistance), Left: $1.expr(), Right: $3.expr()}
+ }
+| a_expr NEG_INNER_PRODUCT a_expr
+ {
+ $$.val = &tree.BinaryExpr{Operator: treebin.MakeBinaryOperator(treebin.NegInnerProduct), Left: $1.expr(), Right: $3.expr()}
+ }
| a_expr INET_CONTAINS_OR_EQUALS a_expr
{
$$.val = &tree.FuncExpr{Func: tree.WrapFunction("inet_contains_or_equals"), Exprs: tree.Exprs{$1.expr(), $3.expr()}}
@@ -16228,6 +16258,9 @@ all_op:
| NOT_REGIMATCH { $$.val = treecmp.MakeComparisonOperator(treecmp.NotRegIMatch) }
| AND_AND { $$.val = treecmp.MakeComparisonOperator(treecmp.Overlaps) }
| AT_AT { $$.val = treecmp.MakeComparisonOperator(treecmp.TSMatches) }
+| DISTANCE { $$.val = treebin.MakeBinaryOperator(treebin.Distance) }
+| COS_DISTANCE { $$.val = treebin.MakeBinaryOperator(treebin.CosDistance) }
+| NEG_INNER_PRODUCT { $$.val = treebin.MakeBinaryOperator(treebin.NegInnerProduct) }
| '~' { $$.val = tree.MakeUnaryOperator(tree.UnaryComplement) }
| SQRT { $$.val = tree.MakeUnaryOperator(tree.UnarySqrt) }
| CBRT { $$.val = tree.MakeUnaryOperator(tree.UnaryCbrt) }
@@ -18203,6 +18236,7 @@ bare_label_keywords:
| VARCHAR
| VARIABLES
| VARIADIC
+| VECTOR
| VERIFY_BACKUP_TABLE_DATA
| VIEW
| VIEWACTIVITY
@@ -18288,6 +18322,7 @@ col_name_keyword:
| VALUES
| VARBIT
| VARCHAR
+| VECTOR
| VIRTUAL
| WORK
diff --git a/pkg/sql/parser/testdata/create_table b/pkg/sql/parser/testdata/create_table
index 4d7834e1d667..0238fe007724 100644
--- a/pkg/sql/parser/testdata/create_table
+++ b/pkg/sql/parser/testdata/create_table
@@ -2542,3 +2542,19 @@ ALTER TABLE a PARTITION ALL BY LIST ("a b", "c.d") (PARTITION "e.f" VALUES IN (1
ALTER TABLE a PARTITION ALL BY LIST ("a b", "c.d") (PARTITION "e.f" VALUES IN ((1))) -- fully parenthesized
ALTER TABLE a PARTITION ALL BY LIST ("a b", "c.d") (PARTITION "e.f" VALUES IN (_)) -- literals removed
ALTER TABLE _ PARTITION ALL BY LIST (_, _) (PARTITION _ VALUES IN (1)) -- identifiers removed
+
+parse
+CREATE TABLE a (a VECTOR(3))
+----
+CREATE TABLE a (a VECTOR(3))
+CREATE TABLE a (a VECTOR(3)) -- fully parenthesized
+CREATE TABLE a (a VECTOR(3)) -- literals removed
+CREATE TABLE _ (_ VECTOR(3)) -- identifiers removed
+
+parse
+CREATE TABLE a (a VECTOR)
+----
+CREATE TABLE a (a VECTOR)
+CREATE TABLE a (a VECTOR) -- fully parenthesized
+CREATE TABLE a (a VECTOR) -- literals removed
+CREATE TABLE _ (_ VECTOR) -- identifiers removed
diff --git a/pkg/sql/parser/testdata/select_exprs b/pkg/sql/parser/testdata/select_exprs
index 2ce2e9902d0e..3525f0a2298f 100644
--- a/pkg/sql/parser/testdata/select_exprs
+++ b/pkg/sql/parser/testdata/select_exprs
@@ -2058,3 +2058,43 @@ SELECT my_func('a', 1, true)
SELECT (my_func(('a'), (1), (true))) -- fully parenthesized
SELECT my_func('_', _, _) -- literals removed
SELECT _('a', 1, true) -- identifiers removed
+
+parse
+SELECT "[1,2]" <-> "[3,4]"
+----
+SELECT "[1,2]" <-> "[3,4]"
+SELECT (("[1,2]") <-> ("[3,4]")) -- fully parenthesized
+SELECT "[1,2]" <-> "[3,4]" -- literals removed
+SELECT _ <-> _ -- identifiers removed
+
+parse
+SELECT "[2,2]" <=> "[3,4]"
+----
+SELECT "[2,2]" <=> "[3,4]"
+SELECT (("[2,2]") <=> ("[3,4]")) -- fully parenthesized
+SELECT "[2,2]" <=> "[3,4]" -- literals removed
+SELECT _ <=> _ -- identifiers removed
+
+parse
+SELECT "[2,2]" <#> "[3,4]"
+----
+SELECT "[2,2]" <#> "[3,4]"
+SELECT (("[2,2]") <#> ("[3,4]")) -- fully parenthesized
+SELECT "[2,2]" <#> "[3,4]" -- literals removed
+SELECT _ <#> _ -- identifiers removed
+
+error
+SELECT "[2,2]" <# "[3,4]"
+----
+at or near "#": syntax error
+DETAIL: source SQL:
+SELECT "[2,2]" <# "[3,4]"
+ ^
+
+parse
+SELECT "[2,2]" <- "[3,4]"
+----
+SELECT "[2,2]" < (-"[3,4]") -- normalized!
+SELECT (("[2,2]") < ((-("[3,4]")))) -- fully parenthesized
+SELECT "[2,2]" < (-"[3,4]") -- literals removed
+SELECT _ < (-_) -- identifiers removed
diff --git a/pkg/sql/pg_catalog.go b/pkg/sql/pg_catalog.go
index e305e2fb89b7..51d6678c8a16 100644
--- a/pkg/sql/pg_catalog.go
+++ b/pkg/sql/pg_catalog.go
@@ -4775,6 +4775,7 @@ var datumToTypeCategory = map[types.Family]*tree.DString{
types.TupleFamily: typCategoryPseudo,
types.OidFamily: typCategoryNumeric,
types.PGLSNFamily: typCategoryUserDefined,
+ types.PGVectorFamily: typCategoryUserDefined,
types.RefCursorFamily: typCategoryUserDefined,
types.UuidFamily: typCategoryUserDefined,
types.INetFamily: typCategoryNetworkAddr,
diff --git a/pkg/sql/pgwire/pgwirebase/BUILD.bazel b/pkg/sql/pgwire/pgwirebase/BUILD.bazel
index bb2b1d401148..66f7bd826927 100644
--- a/pkg/sql/pgwire/pgwirebase/BUILD.bazel
+++ b/pkg/sql/pgwire/pgwirebase/BUILD.bazel
@@ -42,6 +42,7 @@ go_library(
"//pkg/util/tsearch",
"//pkg/util/uint128",
"//pkg/util/uuid",
+ "//pkg/util/vector",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_redact//:redact",
"@com_github_dustin_go_humanize//:go-humanize",
diff --git a/pkg/sql/pgwire/pgwirebase/encoding.go b/pkg/sql/pgwire/pgwirebase/encoding.go
index eca9700403fd..9138d56e998f 100644
--- a/pkg/sql/pgwire/pgwirebase/encoding.go
+++ b/pkg/sql/pgwire/pgwirebase/encoding.go
@@ -44,6 +44,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/tsearch"
"github.com/cockroachdb/cockroach/pkg/util/uint128"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/redact"
"github.com/dustin/go-humanize"
@@ -477,6 +478,12 @@ func DecodeDatum(
return nil, err
}
return &tree.DTSVector{TSVector: ret}, nil
+ case oidext.T_pgvector:
+ ret, err := vector.ParseVector(bs)
+ if err != nil {
+ return nil, err
+ }
+ return &tree.DPGVector{T: ret}, nil
}
switch typ.Family() {
case types.ArrayFamily, types.TupleFamily:
diff --git a/pkg/sql/pgwire/types.go b/pkg/sql/pgwire/types.go
index dbe3b6347540..76763862787a 100644
--- a/pkg/sql/pgwire/types.go
+++ b/pkg/sql/pgwire/types.go
@@ -258,6 +258,10 @@ func writeTextDatumNotNull(
b.textFormatter.FormatNode(v)
b.writeFromFmtCtx(b.textFormatter)
+ case *tree.DPGVector:
+ b.textFormatter.FormatNode(v)
+ b.writeFromFmtCtx(b.textFormatter)
+
case *tree.DArray:
// Arrays have custom formatting depending on their OID.
b.textFormatter.FormatNode(d)
diff --git a/pkg/sql/randgen/BUILD.bazel b/pkg/sql/randgen/BUILD.bazel
index 32594b4c5e4f..fbf3afb9198e 100644
--- a/pkg/sql/randgen/BUILD.bazel
+++ b/pkg/sql/randgen/BUILD.bazel
@@ -51,6 +51,7 @@ go_library(
"//pkg/util/tsearch",
"//pkg/util/uint128",
"//pkg/util/uuid",
+ "//pkg/util/vector",
"@com_github_cockroachdb_apd_v3//:apd",
"@com_github_cockroachdb_errors//:errors",
"@com_github_lib_pq//oid",
diff --git a/pkg/sql/randgen/datum.go b/pkg/sql/randgen/datum.go
index 298f3d4d0608..54148ce0a96e 100644
--- a/pkg/sql/randgen/datum.go
+++ b/pkg/sql/randgen/datum.go
@@ -35,6 +35,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/tsearch"
"github.com/cockroachdb/cockroach/pkg/util/uint128"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
"github.com/lib/pq/oid"
"github.com/twpayne/go-geom"
@@ -328,6 +329,8 @@ func RandDatumWithNullChance(
return tree.NewDTSVector(tsearch.RandomTSVector(rng))
case types.TSQueryFamily:
return tree.NewDTSQuery(tsearch.RandomTSQuery(rng))
+ case types.PGVectorFamily:
+ return tree.NewDPGVector(vector.Random(rng))
default:
panic(errors.AssertionFailedf("invalid type %v", typ.DebugString()))
}
diff --git a/pkg/sql/randgen/type.go b/pkg/sql/randgen/type.go
index 22c93c7b2df5..30baaad79cfa 100644
--- a/pkg/sql/randgen/type.go
+++ b/pkg/sql/randgen/type.go
@@ -173,18 +173,18 @@ func IsLegalColumnType(typ *types.T) bool {
return false
}
ctx := context.Background()
- version := clustersettings.MakeTestingClusterSettings().Version
- return colinfo.ValidateColumnDefType(ctx, version, typ) == nil
+ st := clustersettings.MakeTestingClusterSettings()
+ return colinfo.ValidateColumnDefType(ctx, st, typ) == nil
}
// RandArrayType generates a random array type.
func RandArrayType(rng *rand.Rand) *types.T {
ctx := context.Background()
- version := clustersettings.MakeTestingClusterSettings().Version
+ st := clustersettings.MakeTestingClusterSettings()
for {
typ := RandColumnType(rng)
resTyp := types.MakeArray(typ)
- if err := colinfo.ValidateColumnDefType(ctx, version, resTyp); err == nil {
+ if err := colinfo.ValidateColumnDefType(ctx, st, resTyp); err == nil {
return resTyp
}
}
diff --git a/pkg/sql/rowenc/encoded_datum.go b/pkg/sql/rowenc/encoded_datum.go
index de64669dc563..9ff1ad3fab7d 100644
--- a/pkg/sql/rowenc/encoded_datum.go
+++ b/pkg/sql/rowenc/encoded_datum.go
@@ -325,12 +325,12 @@ func mustUseValueEncodingForFingerprinting(t *types.T) bool {
// available, but for historical reasons we will keep on using the
// value-encoding (Fingerprint is used by hash routers, so changing its
// behavior can result in incorrect results in mixed version clusters).
- case types.JsonFamily, types.TSQueryFamily, types.TSVectorFamily:
+ case types.JsonFamily, types.TSQueryFamily, types.TSVectorFamily, types.PGVectorFamily:
return true
case types.ArrayFamily:
// Note that at time of this writing we don't support arrays of JSON
- // (tracked via #23468) nor of TSQuery / TSVector types (tracked by
- // #90886), so technically we don't need to do a recursive call here,
+ // (tracked via #23468) nor of TSQuery / TSVector / PGVector types (tracked by
+ // #90886, #121432), so technically we don't need to do a recursive call here,
// but we choose to be on the safe side, so we do it anyway.
return mustUseValueEncodingForFingerprinting(t.ArrayContents())
case types.TupleFamily:
diff --git a/pkg/sql/rowenc/encoded_datum_test.go b/pkg/sql/rowenc/encoded_datum_test.go
index 10d9c499a800..dadb8b0278a5 100644
--- a/pkg/sql/rowenc/encoded_datum_test.go
+++ b/pkg/sql/rowenc/encoded_datum_test.go
@@ -221,7 +221,7 @@ func TestEncDatumCompare(t *testing.T) {
for _, typ := range types.OidToType {
switch typ.Family() {
case types.AnyFamily, types.UnknownFamily, types.ArrayFamily, types.JsonFamily, types.TupleFamily, types.VoidFamily,
- types.TSQueryFamily, types.TSVectorFamily:
+ types.TSQueryFamily, types.TSVectorFamily, types.PGVectorFamily:
continue
case types.CollatedStringFamily:
typ = types.MakeCollatedString(types.String, *randgen.RandCollationLocale(rng))
diff --git a/pkg/sql/rowenc/keyside/BUILD.bazel b/pkg/sql/rowenc/keyside/BUILD.bazel
index 9990b599ed03..4ce546884cf9 100644
--- a/pkg/sql/rowenc/keyside/BUILD.bazel
+++ b/pkg/sql/rowenc/keyside/BUILD.bazel
@@ -38,6 +38,7 @@ go_test(
deps = [
":keyside",
"//pkg/settings/cluster",
+ "//pkg/sql/catalog/colinfo",
"//pkg/sql/randgen",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/tree",
diff --git a/pkg/sql/rowenc/keyside/keyside_test.go b/pkg/sql/rowenc/keyside/keyside_test.go
index 4c7895437df5..1b8e3837a9d4 100644
--- a/pkg/sql/rowenc/keyside/keyside_test.go
+++ b/pkg/sql/rowenc/keyside/keyside_test.go
@@ -19,6 +19,7 @@ import (
"time"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
+ "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo"
"github.com/cockroachdb/cockroach/pkg/sql/randgen"
"github.com/cockroachdb/cockroach/pkg/sql/rowenc/keyside"
"github.com/cockroachdb/cockroach/pkg/sql/sem/eval"
@@ -243,13 +244,13 @@ func genEncodingDirection() gopter.Gen {
}
func hasKeyEncoding(typ *types.T) bool {
- // Only some types are round-trip key encodable.
switch typ.Family() {
- case types.CollatedStringFamily, types.TupleFamily, types.DecimalFamily,
- types.GeographyFamily, types.GeometryFamily, types.TSVectorFamily, types.TSQueryFamily:
+ // Special case needed for CollatedStringFamily and DecimalFamily which do have
+ // a key encoding but do not roundtrip.
+ case types.CollatedStringFamily, types.DecimalFamily:
return false
case types.ArrayFamily:
return hasKeyEncoding(typ.ArrayContents())
}
- return true
+ return !colinfo.MustBeValueEncoded(typ)
}
diff --git a/pkg/sql/rowenc/valueside/BUILD.bazel b/pkg/sql/rowenc/valueside/BUILD.bazel
index de2a8e28bbd2..22c783c72a5a 100644
--- a/pkg/sql/rowenc/valueside/BUILD.bazel
+++ b/pkg/sql/rowenc/valueside/BUILD.bazel
@@ -29,6 +29,7 @@ go_library(
"//pkg/util/timeutil/pgdate",
"//pkg/util/tsearch",
"//pkg/util/uuid",
+ "//pkg/util/vector",
"@com_github_cockroachdb_errors//:errors",
"@com_github_lib_pq//oid",
],
diff --git a/pkg/sql/rowenc/valueside/array.go b/pkg/sql/rowenc/valueside/array.go
index a96ddd62799b..b9c1b9ac2578 100644
--- a/pkg/sql/rowenc/valueside/array.go
+++ b/pkg/sql/rowenc/valueside/array.go
@@ -17,6 +17,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented"
"github.com/cockroachdb/cockroach/pkg/util/json"
"github.com/cockroachdb/cockroach/pkg/util/tsearch"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
)
@@ -329,6 +330,12 @@ func encodeArrayElement(b []byte, d tree.Datum) ([]byte, error) {
return nil, err
}
return encoding.EncodeUntaggedBytesValue(b, encoded), nil
+ case *tree.DPGVector:
+ encoded, err := vector.Encode(nil, t.T)
+ if err != nil {
+ return nil, err
+ }
+ return encoding.EncodeUntaggedBytesValue(b, encoded), nil
default:
return nil, errors.Errorf("don't know how to encode %s (%T)", d, d)
}
diff --git a/pkg/sql/rowenc/valueside/decode.go b/pkg/sql/rowenc/valueside/decode.go
index a02745612002..ca6fa44a7485 100644
--- a/pkg/sql/rowenc/valueside/decode.go
+++ b/pkg/sql/rowenc/valueside/decode.go
@@ -22,6 +22,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/json"
"github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate"
"github.com/cockroachdb/cockroach/pkg/util/tsearch"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
"github.com/lib/pq/oid"
)
@@ -223,6 +224,16 @@ func DecodeUntaggedDatum(
return nil, b, err
}
return tree.NewDTSVector(v), b, nil
+ case types.PGVectorFamily:
+ b, data, err := encoding.DecodeUntaggedBytesValue(buf)
+ if err != nil {
+ return nil, b, err
+ }
+ vec, err := vector.Decode(data)
+ if err != nil {
+ return nil, b, err
+ }
+ return tree.NewDPGVector(vec), b, nil
case types.OidFamily:
// TODO: This possibly should decode to uint32 (with corresponding changes
// to encoding) to ensure that the value fits in a DOid without any loss of
diff --git a/pkg/sql/rowenc/valueside/encode.go b/pkg/sql/rowenc/valueside/encode.go
index b999b957b476..afae3a13cc89 100644
--- a/pkg/sql/rowenc/valueside/encode.go
+++ b/pkg/sql/rowenc/valueside/encode.go
@@ -17,6 +17,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/encoding"
"github.com/cockroachdb/cockroach/pkg/util/json"
"github.com/cockroachdb/cockroach/pkg/util/tsearch"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
)
@@ -96,6 +97,12 @@ func Encode(appendTo []byte, colID ColumnIDDelta, val tree.Datum, scratch []byte
return nil, err
}
return encoding.EncodeTSVectorValue(appendTo, uint32(colID), encoded), nil
+ case *tree.DPGVector:
+ encoded, err := vector.Encode(scratch, t.T)
+ if err != nil {
+ return nil, err
+ }
+ return encoding.EncodePGVectorValue(appendTo, uint32(colID), encoded), nil
case *tree.DArray:
a, err := encodeArray(t, scratch)
if err != nil {
diff --git a/pkg/sql/rowenc/valueside/legacy.go b/pkg/sql/rowenc/valueside/legacy.go
index 6162a72144ce..db37921d4846 100644
--- a/pkg/sql/rowenc/valueside/legacy.go
+++ b/pkg/sql/rowenc/valueside/legacy.go
@@ -22,6 +22,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate"
"github.com/cockroachdb/cockroach/pkg/util/tsearch"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
"github.com/lib/pq/oid"
)
@@ -168,6 +169,15 @@ func MarshalLegacy(colType *types.T, val tree.Datum) (roachpb.Value, error) {
r.SetBytes(data)
return r, nil
}
+ case types.PGVectorFamily:
+ if v, ok := val.(*tree.DPGVector); ok {
+ data, err := vector.Encode(nil, v.T)
+ if err != nil {
+ return r, err
+ }
+ r.SetBytes(data)
+ return r, nil
+ }
case types.ArrayFamily:
if v, ok := val.(*tree.DArray); ok {
if err := checkElementType(v.ParamTyp, colType.ArrayContents()); err != nil {
@@ -422,6 +432,16 @@ func UnmarshalLegacy(a *tree.DatumAlloc, typ *types.T, value roachpb.Value) (tre
return nil, err
}
return tree.NewDTSVector(vec), nil
+ case types.PGVectorFamily:
+ v, err := value.GetBytes()
+ if err != nil {
+ return nil, err
+ }
+ vec, err := vector.Decode(v)
+ if err != nil {
+ return nil, err
+ }
+ return tree.NewDPGVector(vec), nil
case types.EnumFamily:
v, err := value.GetBytes()
if err != nil {
diff --git a/pkg/sql/scanner/scan.go b/pkg/sql/scanner/scan.go
index 890e21c2919b..a4cb630e8b88 100644
--- a/pkg/sql/scanner/scan.go
+++ b/pkg/sql/scanner/scan.go
@@ -315,12 +315,32 @@ func (s *SQLScanner) Scan(lval ScanSymType) {
return
case '=': // <=
s.pos++
+ switch s.peek() {
+ case '>': // <=>
+ s.pos++
+ lval.SetID(lexbase.COS_DISTANCE)
+ return
+ }
lval.SetID(lexbase.LESS_EQUALS)
return
case '@': // <@
s.pos++
lval.SetID(lexbase.CONTAINED_BY)
return
+ case '-': // <-
+ switch s.peekN(1) {
+ case '>': // <->
+ s.pos += 2
+ lval.SetID(lexbase.DISTANCE)
+ return
+ }
+ case '#': // <#
+ switch s.peekN(1) {
+ case '>': // <#>
+ s.pos += 2
+ lval.SetID(lexbase.NEG_INNER_PRODUCT)
+ return
+ }
}
return
diff --git a/pkg/sql/schemachanger/comparator_generated_test.go b/pkg/sql/schemachanger/comparator_generated_test.go
index bcff16a947d5..e060cf7aedcb 100644
--- a/pkg/sql/schemachanger/comparator_generated_test.go
+++ b/pkg/sql/schemachanger/comparator_generated_test.go
@@ -2088,6 +2088,11 @@ func TestSchemaChangeComparator_vectorize_window(t *testing.T) {
var logicTestFile = "pkg/sql/logictest/testdata/logic_test/vectorize_window"
runSchemaChangeComparatorTest(t, logicTestFile)
}
+func TestSchemaChangeComparator_vectoross(t *testing.T) {
+ defer leaktest.AfterTest(t)()
+ var logicTestFile = "pkg/sql/logictest/testdata/logic_test/vectoross"
+ runSchemaChangeComparatorTest(t, logicTestFile)
+}
func TestSchemaChangeComparator_views(t *testing.T) {
defer leaktest.AfterTest(t)()
var logicTestFile = "pkg/sql/logictest/testdata/logic_test/views"
diff --git a/pkg/sql/sem/builtins/BUILD.bazel b/pkg/sql/sem/builtins/BUILD.bazel
index aaed197b09f1..8bfab120f80a 100644
--- a/pkg/sql/sem/builtins/BUILD.bazel
+++ b/pkg/sql/sem/builtins/BUILD.bazel
@@ -19,6 +19,7 @@ go_library(
"parse_ident_builtin.go",
"pg_builtins.go",
"pgcrypto_builtins.go",
+ "pgvector_builtins.go",
"replication_builtins.go",
"show_create_all_schemas_builtin.go",
"show_create_all_tables_builtin.go",
@@ -134,6 +135,7 @@ go_library(
"//pkg/util/ulid",
"//pkg/util/unaccent",
"//pkg/util/uuid",
+ "//pkg/util/vector",
"@com_github_cockroachdb_apd_v3//:apd",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_redact//:redact",
diff --git a/pkg/sql/sem/builtins/builtinconstants/constants.go b/pkg/sql/sem/builtins/builtinconstants/constants.go
index 699797403da1..c71a01d0b8b8 100644
--- a/pkg/sql/sem/builtins/builtinconstants/constants.go
+++ b/pkg/sql/sem/builtins/builtinconstants/constants.go
@@ -51,6 +51,7 @@ const (
CategoryJSON = "JSONB"
CategoryMultiRegion = "Multi-region"
CategoryMultiTenancy = "Multi-tenancy"
+ CategoryPGVector = "PGVector"
CategorySequences = "Sequence"
CategorySpatial = "Spatial"
CategoryString = "String and byte"
diff --git a/pkg/sql/sem/builtins/fixed_oids.go b/pkg/sql/sem/builtins/fixed_oids.go
index f7f9c142baa6..d882c601337a 100644
--- a/pkg/sql/sem/builtins/fixed_oids.go
+++ b/pkg/sql/sem/builtins/fixed_oids.go
@@ -2584,6 +2584,23 @@ var builtinOidsArray = []string{
2616: `crdb_internal.plan_logical_replication(req: bytes) -> bytes`,
2617: `crdb_internal.start_replication_stream_for_tables(req: bytes) -> bytes`,
2618: `crdb_internal.logical_replication_inject_failures(stream: int, proc: int, percent: int) -> void`,
+ 2619: `vectorsend(vector: vector) -> bytes`,
+ 2620: `vectorrecv(input: anyelement) -> vector`,
+ 2621: `vectorout(vector: vector) -> bytes`,
+ 2622: `vectorin(input: anyelement) -> vector`,
+ 2623: `char(vector: vector) -> "char"`,
+ 2624: `name(vector: vector) -> name`,
+ 2625: `text(vector: vector) -> string`,
+ 2626: `varchar(vector: vector) -> varchar`,
+ 2627: `bpchar(vector: vector) -> char`,
+ 2628: `vector(string: string) -> vector`,
+ 2629: `vector(vector: vector) -> vector`,
+ 2630: `cosine_distance(v1: vector, v2: vector) -> float`,
+ 2631: `l1_distance(v1: vector, v2: vector) -> float`,
+ 2632: `l2_distance(v1: vector, v2: vector) -> float`,
+ 2633: `inner_product(v1: vector, v2: vector) -> float`,
+ 2634: `vector_dims(vector: vector) -> int`,
+ 2635: `vector_norm(vector: vector) -> float`,
}
var builtinOidsBySignature map[string]oid.Oid
diff --git a/pkg/sql/sem/builtins/pgvector_builtins.go b/pkg/sql/sem/builtins/pgvector_builtins.go
new file mode 100644
index 000000000000..1bcbc03bdccb
--- /dev/null
+++ b/pkg/sql/sem/builtins/pgvector_builtins.go
@@ -0,0 +1,142 @@
+// Copyright 2024 The Cockroach Authors.
+//
+// Use of this software is governed by the Business Source License
+// included in the file licenses/BSL.txt.
+//
+// As of the Change Date specified in that file, in accordance with
+// the Business Source License, use of this software will be governed
+// by the Apache License, Version 2.0, included in the file
+// licenses/APL.txt.
+
+package builtins
+
+import (
+ "context"
+
+ "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinconstants"
+ "github.com/cockroachdb/cockroach/pkg/sql/sem/eval"
+ "github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
+ "github.com/cockroachdb/cockroach/pkg/sql/sem/volatility"
+ "github.com/cockroachdb/cockroach/pkg/sql/types"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
+)
+
+func init() {
+ for k, v := range pgvectorBuiltins {
+ v.props.Category = builtinconstants.CategoryPGVector
+ v.props.AvailableOnPublicSchema = true
+ const enforceClass = true
+ registerBuiltin(k, v, tree.NormalClass, enforceClass)
+ }
+}
+
+var pgvectorBuiltins = map[string]builtinDefinition{
+ "cosine_distance": makeBuiltin(defProps(),
+ tree.Overload{
+ Types: tree.ParamTypes{
+ {Name: "v1", Typ: types.PGVector},
+ {Name: "v2", Typ: types.PGVector},
+ },
+ ReturnType: tree.FixedReturnType(types.Float),
+ Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) {
+ v1 := tree.MustBeDPGVector(args[0])
+ v2 := tree.MustBeDPGVector(args[1])
+ distance, err := vector.CosDistance(v1.T, v2.T)
+ if err != nil {
+ return nil, err
+ }
+ return tree.NewDFloat(tree.DFloat(distance)), nil
+ },
+ Info: "Returns the cosine distance between the two vectors.",
+ Volatility: volatility.Immutable,
+ },
+ ),
+ "inner_product": makeBuiltin(defProps(),
+ tree.Overload{
+ Types: tree.ParamTypes{
+ {Name: "v1", Typ: types.PGVector},
+ {Name: "v2", Typ: types.PGVector},
+ },
+ ReturnType: tree.FixedReturnType(types.Float),
+ Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) {
+ v1 := tree.MustBeDPGVector(args[0])
+ v2 := tree.MustBeDPGVector(args[1])
+ distance, err := vector.InnerProduct(v1.T, v2.T)
+ if err != nil {
+ return nil, err
+ }
+ return tree.NewDFloat(tree.DFloat(distance)), nil
+ },
+ Info: "Returns the inner product between the two vectors.",
+ Volatility: volatility.Immutable,
+ },
+ ),
+ "l1_distance": makeBuiltin(defProps(),
+ tree.Overload{
+ Types: tree.ParamTypes{
+ {Name: "v1", Typ: types.PGVector},
+ {Name: "v2", Typ: types.PGVector},
+ },
+ ReturnType: tree.FixedReturnType(types.Float),
+ Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) {
+ v1 := tree.MustBeDPGVector(args[0])
+ v2 := tree.MustBeDPGVector(args[1])
+ distance, err := vector.L1Distance(v1.T, v2.T)
+ if err != nil {
+ return nil, err
+ }
+ return tree.NewDFloat(tree.DFloat(distance)), nil
+ },
+ Info: "Returns the Manhattan distance between the two vectors.",
+ Volatility: volatility.Immutable,
+ },
+ ),
+ "l2_distance": makeBuiltin(defProps(),
+ tree.Overload{
+ Types: tree.ParamTypes{
+ {Name: "v1", Typ: types.PGVector},
+ {Name: "v2", Typ: types.PGVector},
+ },
+ ReturnType: tree.FixedReturnType(types.Float),
+ Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) {
+ v1 := tree.MustBeDPGVector(args[0])
+ v2 := tree.MustBeDPGVector(args[1])
+ distance, err := vector.L2Distance(v1.T, v2.T)
+ if err != nil {
+ return nil, err
+ }
+ return tree.NewDFloat(tree.DFloat(distance)), nil
+ },
+ Info: "Returns the Euclidean distance between the two vectors.",
+ Volatility: volatility.Immutable,
+ },
+ ),
+ "vector_dims": makeBuiltin(defProps(),
+ tree.Overload{
+ Types: tree.ParamTypes{
+ {Name: "vector", Typ: types.PGVector},
+ },
+ ReturnType: tree.FixedReturnType(types.Int),
+ Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) {
+ v1 := tree.MustBeDPGVector(args[0])
+ return tree.NewDInt(tree.DInt(len(v1.T))), nil
+ },
+ Info: "Returns the number of the dimensions in the vector.",
+ Volatility: volatility.Immutable,
+ },
+ ),
+ "vector_norm": makeBuiltin(defProps(),
+ tree.Overload{
+ Types: tree.ParamTypes{
+ {Name: "vector", Typ: types.PGVector},
+ },
+ ReturnType: tree.FixedReturnType(types.Float),
+ Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) {
+ v1 := tree.MustBeDPGVector(args[0])
+ return tree.NewDFloat(tree.DFloat(vector.Norm(v1.T))), nil
+ },
+ Info: "Returns the Euclidean norm of the vector.",
+ Volatility: volatility.Immutable,
+ },
+ ),
+}
diff --git a/pkg/sql/sem/cast/cast.go b/pkg/sql/sem/cast/cast.go
index ad0b1cd6c220..cefdca085d36 100644
--- a/pkg/sql/sem/cast/cast.go
+++ b/pkg/sql/sem/cast/cast.go
@@ -251,6 +251,20 @@ func LookupCast(src, tgt *types.T) (Cast, bool) {
}, true
}
+ if srcFamily == types.ArrayFamily && tgtFamily == types.PGVectorFamily {
+ return Cast{
+ MaxContext: ContextAssignment,
+ Volatility: volatility.Stable,
+ }, true
+ }
+
+ if srcFamily == types.PGVectorFamily && tgtFamily == types.ArrayFamily {
+ return Cast{
+ MaxContext: ContextAssignment,
+ Volatility: volatility.Stable,
+ }, true
+ }
+
// Casts from array and tuple types to string types are immutable and
// allowed in assignment contexts.
// TODO(mgartner): Tuple to string casts should be stable. They are
diff --git a/pkg/sql/sem/cast/cast_map.go b/pkg/sql/sem/cast/cast_map.go
index 24be3ecd3693..039a01d77b39 100644
--- a/pkg/sql/sem/cast/cast_map.go
+++ b/pkg/sql/sem/cast/cast_map.go
@@ -72,6 +72,13 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{
oid.T_varchar: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
oid.T_text: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
},
+ oidext.T_pgvector: {
+ oid.T_bpchar: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_char: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_name: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_varchar: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_text: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ },
oid.T_bpchar: {
oid.T_bpchar: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable},
oid.T_char: {MaxContext: ContextAssignment, origin: ContextOriginPgCast, Volatility: volatility.Immutable},
@@ -79,11 +86,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{
oid.T_text: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable},
oid.T_varchar: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable},
// Automatic I/O conversions from bpchar to other types.
- oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
oid.T_date: {
MaxContext: ContextExplicit,
origin: ContextOriginAutomaticIOConversion,
@@ -166,11 +174,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{
// Automatic I/O conversions to string types.
oid.T_name: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
// Automatic I/O conversions from "char" to other types.
- oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
oid.T_date: {
MaxContext: ContextExplicit,
origin: ContextOriginAutomaticIOConversion,
@@ -486,11 +495,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{
// Automatic I/O conversions to string types.
oid.T_char: {MaxContext: ContextAssignment, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
// Automatic I/O conversions from NAME to other types.
- oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
oid.T_date: {
MaxContext: ContextExplicit,
origin: ContextOriginAutomaticIOConversion,
@@ -733,11 +743,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{
oid.T_text: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable},
oid.T_varchar: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable},
// Automatic I/O conversions from TEXT to other types.
- oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
oid.T_date: {
MaxContext: ContextExplicit,
origin: ContextOriginAutomaticIOConversion,
@@ -966,11 +977,12 @@ var castMap = map[oid.Oid]map[oid.Oid]Cast{
oid.T_text: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable},
oid.T_varchar: {MaxContext: ContextImplicit, origin: ContextOriginPgCast, Volatility: volatility.Immutable},
// Automatic I/O conversions from VARCHAR to other types.
- oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
- oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bit: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bool: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_box2d: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_pg_lsn: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oidext.T_pgvector: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
+ oid.T_bytea: {MaxContext: ContextExplicit, origin: ContextOriginAutomaticIOConversion, Volatility: volatility.Immutable},
oid.T_date: {
MaxContext: ContextExplicit,
origin: ContextOriginAutomaticIOConversion,
diff --git a/pkg/sql/sem/eval/BUILD.bazel b/pkg/sql/sem/eval/BUILD.bazel
index a1fe5927185d..05c92e17f7db 100644
--- a/pkg/sql/sem/eval/BUILD.bazel
+++ b/pkg/sql/sem/eval/BUILD.bazel
@@ -97,6 +97,7 @@ go_library(
"//pkg/util/tsearch",
"//pkg/util/ulid",
"//pkg/util/uuid",
+ "//pkg/util/vector",
"@com_github_cockroachdb_apd_v3//:apd",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_redact//:redact",
diff --git a/pkg/sql/sem/eval/binary_op.go b/pkg/sql/sem/eval/binary_op.go
index 9ca3d9bf12f7..fd48722a72de 100644
--- a/pkg/sql/sem/eval/binary_op.go
+++ b/pkg/sql/sem/eval/binary_op.go
@@ -32,6 +32,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/timeofday"
"github.com/cockroachdb/cockroach/pkg/util/trigram"
"github.com/cockroachdb/cockroach/pkg/util/tsearch"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
)
@@ -1229,6 +1230,33 @@ func (e *evaluator) EvalTSMatchesVectorQueryOp(
return tree.MakeDBool(tree.DBool(ret)), err
}
+func (e *evaluator) EvalDistanceVectorOp(
+ ctx context.Context, _ *tree.DistanceVectorOp, left, right tree.Datum,
+) (tree.Datum, error) {
+ v := tree.MustBeDPGVector(left)
+ q := tree.MustBeDPGVector(right)
+ ret, err := vector.L2Distance(q.T, v.T)
+ return tree.NewDFloat(tree.DFloat(ret)), err
+}
+
+func (e *evaluator) EvalCosDistanceVectorOp(
+ ctx context.Context, _ *tree.CosDistanceVectorOp, left, right tree.Datum,
+) (tree.Datum, error) {
+ v := tree.MustBeDPGVector(left)
+ q := tree.MustBeDPGVector(right)
+ ret, err := vector.CosDistance(q.T, v.T)
+ return tree.NewDFloat(tree.DFloat(ret)), err
+}
+
+func (e *evaluator) EvalNegInnerProductVectorOp(
+ ctx context.Context, _ *tree.NegInnerProductVectorOp, left, right tree.Datum,
+) (tree.Datum, error) {
+ v := tree.MustBeDPGVector(left)
+ q := tree.MustBeDPGVector(right)
+ ret, err := vector.NegInnerProduct(q.T, v.T)
+ return tree.NewDFloat(tree.DFloat(ret)), err
+}
+
func (e *evaluator) EvalPlusDateIntOp(
ctx context.Context, _ *tree.PlusDateIntOp, left, right tree.Datum,
) (tree.Datum, error) {
@@ -1648,3 +1676,39 @@ func decimalPGLSNEval(
}
return tree.NewDPGLSN(resultLSN), nil
}
+
+func (e *evaluator) EvalPlusPGVectorOp(
+ ctx context.Context, _ *tree.PlusPGVectorOp, left, right tree.Datum,
+) (tree.Datum, error) {
+ t1 := tree.MustBeDPGVector(left)
+ t2 := tree.MustBeDPGVector(right)
+ ret, err := vector.Add(t1.T, t2.T)
+ if err != nil {
+ return nil, err
+ }
+ return tree.NewDPGVector(ret), nil
+}
+
+func (e *evaluator) EvalMinusPGVectorOp(
+ ctx context.Context, _ *tree.MinusPGVectorOp, left, right tree.Datum,
+) (tree.Datum, error) {
+ t1 := tree.MustBeDPGVector(left)
+ t2 := tree.MustBeDPGVector(right)
+ ret, err := vector.Minus(t1.T, t2.T)
+ if err != nil {
+ return nil, err
+ }
+ return tree.NewDPGVector(ret), nil
+}
+
+func (e *evaluator) EvalMultPGVectorOp(
+ ctx context.Context, _ *tree.MultPGVectorOp, left, right tree.Datum,
+) (tree.Datum, error) {
+ t1 := tree.MustBeDPGVector(left)
+ t2 := tree.MustBeDPGVector(right)
+ ret, err := vector.Mult(t1.T, t2.T)
+ if err != nil {
+ return nil, err
+ }
+ return tree.NewDPGVector(ret), nil
+}
diff --git a/pkg/sql/sem/eval/cast.go b/pkg/sql/sem/eval/cast.go
index b5cc8aa39d5a..41d64cbdfb9c 100644
--- a/pkg/sql/sem/eval/cast.go
+++ b/pkg/sql/sem/eval/cast.go
@@ -18,6 +18,7 @@ import (
"time"
"github.com/cockroachdb/apd/v3"
+ "github.com/cockroachdb/cockroach/pkg/clusterversion"
"github.com/cockroachdb/cockroach/pkg/geo"
"github.com/cockroachdb/cockroach/pkg/geo/geopb"
"github.com/cockroachdb/cockroach/pkg/sql/lex"
@@ -34,6 +35,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate"
"github.com/cockroachdb/cockroach/pkg/util/tsearch"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
"github.com/lib/pq/oid"
)
@@ -502,6 +504,8 @@ func performCastWithoutPrecisionTruncation(
s = t.TSQuery.String()
case *tree.DTSVector:
s = t.TSVector.String()
+ case *tree.DPGVector:
+ s = t.T.String()
case *tree.DEnum:
s = t.LogicalRep
case *tree.DVoid:
@@ -601,6 +605,38 @@ func performCastWithoutPrecisionTruncation(
return d, nil
}
+ case types.PGVectorFamily:
+ if !evalCtx.Settings.Version.IsActive(ctx, clusterversion.V24_2) {
+ return nil, pgerror.Newf(pgcode.FeatureNotSupported,
+ "version %v must be finalized to use vector",
+ clusterversion.V24_2.Version())
+ }
+ switch d := d.(type) {
+ case *tree.DString:
+ return tree.ParseDPGVector(string(*d))
+ case *tree.DCollatedString:
+ return tree.ParseDPGVector(d.Contents)
+ case *tree.DArray:
+ switch d.ParamTyp.Family() {
+ case types.FloatFamily, types.IntFamily, types.DecimalFamily:
+ v := make(vector.T, len(d.Array))
+ for i, elem := range d.Array {
+ if elem == tree.DNull {
+ return nil, pgerror.Newf(pgcode.NullValueNotAllowed,
+ "array must not contain nulls")
+ }
+ datum, err := performCast(ctx, evalCtx, elem, types.Float4, false)
+ if err != nil {
+ return nil, err
+ }
+ v[i] = float32(*datum.(*tree.DFloat))
+ }
+ return tree.NewDPGVector(v), nil
+ }
+ case *tree.DPGVector:
+ return d, nil
+ }
+
case types.RefCursorFamily:
switch d := d.(type) {
case *tree.DString:
@@ -929,6 +965,14 @@ func performCastWithoutPrecisionTruncation(
}
}
return dcast, nil
+ case *tree.DPGVector:
+ dcast := tree.NewDArray(t.ArrayContents())
+ for i := range v.T {
+ if err := dcast.Append(tree.NewDFloat(tree.DFloat(v.T[i]))); err != nil {
+ return nil, err
+ }
+ }
+ return dcast, nil
}
case types.OidFamily:
switch v := d.(type) {
diff --git a/pkg/sql/sem/eval/testdata/eval/vector b/pkg/sql/sem/eval/testdata/eval/vector
new file mode 100644
index 000000000000..53755ffff368
--- /dev/null
+++ b/pkg/sql/sem/eval/testdata/eval/vector
@@ -0,0 +1,143 @@
+# Basic smoke tests for pgvector evaluation.
+
+eval
+'[1,2]'::vector <-> '[1,2]'::vector
+----
+0.0
+
+eval
+'[1,2]'::vector <#> '[1,2]'::vector
+----
+-5.0
+
+eval
+'[1,2]'::vector <=> '[1,2]'::vector
+----
+0.0
+
+eval
+'[1,2,3]'::vector <-> '[3,1,2]'::vector
+----
+2.449489742783178
+
+eval
+'[1,2,3]'::vector <#> '[3,1,2]'::vector
+----
+-11.0
+
+eval
+'[1,2,3]'::vector <=> '[3,1,2]'::vector
+----
+0.2142857142857143
+
+eval
+'[1,2,3]'::vector - '[3,1,2]'::vector
+----
+[-2,1,1]
+
+eval
+'[1,2,3]'::vector + '[3,1,2]'::vector
+----
+[4,3,5]
+
+eval
+'[1,2,3]'::vector * '[3,1,2]'::vector
+----
+[3,2,6]
+
+eval
+cosine_distance('[1,2,3]'::vector, '[3,1,2]'::vector)
+----
+0.2142857142857143
+
+eval
+inner_product('[1,2,3]'::vector, '[3,1,2]'::vector)
+----
+11.0
+
+eval
+l1_distance('[1,2,3]'::vector, '[3,1,2]'::vector)
+----
+4.0
+
+eval
+l2_distance('[1,2,3]'::vector, '[3,1,2]'::vector)
+----
+2.449489742783178
+
+eval
+vector_dims('[1,2,3]'::vector)
+----
+3
+
+eval
+vector_norm('[1,2,3]'::vector)
+----
+3.7416573867739413
+
+eval
+'[1,2]'::vector < '[1,2]'::vector
+----
+false
+
+eval
+'[1,2]'::vector <= '[1,2]'::vector
+----
+true
+
+eval
+'[1,2]'::vector > '[1,2]'::vector
+----
+false
+
+eval
+'[1,2]'::vector >= '[1,2]'::vector
+----
+true
+
+eval
+'[1,2]'::vector < '[1,3]'::vector
+----
+true
+
+eval
+'[2,2]'::vector < '[1,3]'::vector
+----
+false
+
+eval
+'[2,2]'::vector > '[1,3]'::vector
+----
+true
+
+# Mixed dimension comparisons
+
+eval
+'[1]'::vector < '[1,1]'::vector
+----
+true
+
+eval
+'[1,1]'::vector < '[1]'::vector
+----
+false
+
+eval
+'[1]'::vector > '[1,1]'::vector
+----
+false
+
+eval
+'[1,1]'::vector > '[1]'::vector
+----
+true
+
+eval
+'[100000000000000000000000]'::vector * '[100000000000000000000]'::vector
+----
+value out of range: overflow
+
+eval
+'[.000000000000000000000001]'::vector * '[.0000000000000000000001]'::vector
+----
+value out of range: underflow
diff --git a/pkg/sql/sem/eval/unsupported_types.go b/pkg/sql/sem/eval/unsupported_types.go
index 89fce60c6cb2..6ac35c46303f 100644
--- a/pkg/sql/sem/eval/unsupported_types.go
+++ b/pkg/sql/sem/eval/unsupported_types.go
@@ -14,22 +14,23 @@ import (
"context"
"github.com/cockroachdb/cockroach/pkg/clusterversion"
+ "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
+ "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
)
type unsupportedTypeChecker struct {
- // Uncomment this when a new type is introduced.
- // version clusterversion.Handle
+ // Uncomment this when a new type is introduced, or comment it out if there
+ // are no types in the checker.
+ version clusterversion.Handle
}
// NewUnsupportedTypeChecker returns a new tree.UnsupportedTypeChecker that can
// be used to check whether a type is allowed by the current cluster version.
-func NewUnsupportedTypeChecker(clusterversion.Handle) tree.UnsupportedTypeChecker {
- // Right now we don't have any unsupported types, so there is no benefit in
- // returning a type checker that never errors, so we just return nil. (The
- // infrastructure is already set up to handle such case gracefully.)
- return nil
+func NewUnsupportedTypeChecker(handle clusterversion.Handle) tree.UnsupportedTypeChecker {
+ // If there are no types in the checker, change this code to return nil.
+ return &unsupportedTypeChecker{version: handle}
}
var _ tree.UnsupportedTypeChecker = &unsupportedTypeChecker{}
@@ -38,5 +39,16 @@ var _ tree.UnsupportedTypeChecker = &unsupportedTypeChecker{}
func (tc *unsupportedTypeChecker) CheckType(ctx context.Context, typ *types.T) error {
// NB: when adding an unsupported type here, change the constructor to not
// return nil.
+ var errorTypeString string
+ switch typ.Family() {
+ case types.PGVectorFamily:
+ errorTypeString = "vector"
+ }
+ if errorTypeString != "" && !tc.version.IsActive(ctx, clusterversion.V24_2) {
+ return pgerror.Newf(pgcode.FeatureNotSupported,
+ "%s not supported until version 24.2", errorTypeString,
+ )
+ }
+
return nil
}
diff --git a/pkg/sql/sem/tree/BUILD.bazel b/pkg/sql/sem/tree/BUILD.bazel
index 530f89114d94..9ff37ef0738f 100644
--- a/pkg/sql/sem/tree/BUILD.bazel
+++ b/pkg/sql/sem/tree/BUILD.bazel
@@ -168,6 +168,7 @@ go_library(
"//pkg/util/tsearch",
"//pkg/util/uint128",
"//pkg/util/uuid",
+ "//pkg/util/vector",
"@com_github_cockroachdb_apd_v3//:apd",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_redact//:redact",
diff --git a/pkg/sql/sem/tree/constant.go b/pkg/sql/sem/tree/constant.go
index 85f727ed416f..a0226851c446 100644
--- a/pkg/sql/sem/tree/constant.go
+++ b/pkg/sql/sem/tree/constant.go
@@ -544,6 +544,8 @@ var (
types.Jsonb,
types.PGLSN,
types.PGLSNArray,
+ types.PGVector,
+ types.PGVectorArray,
types.RefCursor,
types.RefCursorArray,
types.TSQuery,
diff --git a/pkg/sql/sem/tree/datum.go b/pkg/sql/sem/tree/datum.go
index 1b264982debd..2ad79ddaf623 100644
--- a/pkg/sql/sem/tree/datum.go
+++ b/pkg/sql/sem/tree/datum.go
@@ -49,6 +49,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/tsearch"
"github.com/cockroachdb/cockroach/pkg/util/uint128"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
+ "github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/redact"
"github.com/lib/pq/oid"
@@ -3632,6 +3633,112 @@ func (d *DPGLSN) Size() uintptr {
return unsafe.Sizeof(*d)
}
+// DPGVector is the Datum representation of the PGVector type.
+type DPGVector struct {
+ vector.T
+}
+
+// NewDPGVector returns a new PGVector Datum.
+func NewDPGVector(vector vector.T) *DPGVector { return &DPGVector{vector} }
+
+// AsDPGVector attempts to retrieve a DPGVector from an Expr, returning a
+// DPGVector and a flag signifying whether the assertion was successful. The
+// function should be used instead of direct type assertions wherever a
+// *DPGVector wrapped by a *DOidWrapper is possible.
+func AsDPGVector(e Expr) (*DPGVector, bool) {
+ switch t := e.(type) {
+ case *DPGVector:
+ return t, true
+ case *DOidWrapper:
+ return AsDPGVector(t.Wrapped)
+ }
+ return nil, false
+}
+
+// MustBeDPGVector attempts to retrieve a DPGVector from an Expr, panicking if the
+// assertion fails.
+func MustBeDPGVector(e Expr) *DPGVector {
+ v, ok := AsDPGVector(e)
+ if !ok {
+ panic(errors.AssertionFailedf("expected *DPGVector, found %T", e))
+ }
+ return v
+}
+
+// ParseDPGVector takes a string of PGVector and returns a DPGVector value.
+func ParseDPGVector(s string) (Datum, error) {
+ v, err := vector.ParseVector(s)
+ if err != nil {
+ return nil, pgerror.Wrapf(err, pgcode.Syntax, "could not parse vector")
+ }
+ return NewDPGVector(v), nil
+}
+
+// Format implements the NodeFormatter interface.
+func (d *DPGVector) Format(ctx *FmtCtx) {
+ bareStrings := ctx.HasFlags(FmtFlags(lexbase.EncBareStrings))
+ if !bareStrings {
+ ctx.WriteByte('\'')
+ }
+ ctx.WriteString(d.String())
+ if !bareStrings {
+ ctx.WriteByte('\'')
+ }
+}
+
+// ResolvedType implements the TypedExpr interface.
+func (d *DPGVector) ResolvedType() *types.T { return types.PGVector }
+
+// AmbiguousFormat implements the Datum interface.
+func (d *DPGVector) AmbiguousFormat() bool {
+ return true
+}
+
+func (d *DPGVector) Compare(ctx context.Context, cmpCtx CompareContext, other Datum) (int, error) {
+ if other == DNull {
+ // NULL is less than any non-NULL value.
+ return 1, nil
+ }
+ v, ok := cmpCtx.UnwrapDatum(ctx, other).(*DPGVector)
+ if !ok {
+ return 0, makeUnsupportedComparisonMessage(d, other)
+ }
+ return d.T.Compare(v.T)
+}
+
+// Prev implements the Datum interface.
+func (d *DPGVector) Prev(ctx context.Context, cmpCtx CompareContext) (Datum, bool) {
+ return nil, false
+}
+
+// Next implements the Datum interface.
+func (d *DPGVector) Next(ctx context.Context, cmpCtx CompareContext) (Datum, bool) {
+ return nil, false
+}
+
+// IsMax implements the Datum interface.
+func (d *DPGVector) IsMax(ctx context.Context, cmpCtx CompareContext) bool {
+ return false
+}
+
+// IsMin implements the Datum interface.
+func (d *DPGVector) IsMin(ctx context.Context, cmpCtx CompareContext) bool {
+ return false
+}
+
+// Max implements the Datum interface.
+func (d *DPGVector) Max(ctx context.Context, cmpCtx CompareContext) (Datum, bool) {
+ return nil, false
+}
+
+// Min implements the Datum interface.
+func (d *DPGVector) Min(ctx context.Context, cmpCtx CompareContext) (Datum, bool) { return nil, false }
+
+// Size implements the Datum interface.
+func (d *DPGVector) Size() uintptr {
+ return unsafe.Sizeof(*d) + d.T.Size()
+}
+
// DBox2D is the Datum representation of the Box2D type.
type DBox2D struct {
geo.CartesianBoundingBox
@@ -3906,7 +4013,7 @@ func AsJSON(
// This is RFC3339Nano, but without the TZ fields.
return json.FromString(formatTime(t.UTC(), "2006-01-02T15:04:05.999999999")), nil
case *DDate, *DUuid, *DOid, *DInterval, *DBytes, *DIPAddr, *DTime, *DTimeTZ, *DBitArray, *DBox2D,
- *DTSVector, *DTSQuery, *DPGLSN:
+ *DTSVector, *DTSQuery, *DPGLSN, *DPGVector:
return json.FromString(
AsStringWithFlags(t, FmtBareStrings, FmtDataConversionConfig(dcc), FmtLocation(loc)),
), nil
@@ -5963,6 +6070,7 @@ var baseDatumTypeSizes = map[types.Family]struct {
types.GeographyFamily: {unsafe.Sizeof(DGeography{}), variableSize},
types.GeometryFamily: {unsafe.Sizeof(DGeometry{}), variableSize},
types.PGLSNFamily: {unsafe.Sizeof(DPGLSN{}), fixedSize},
+ types.PGVectorFamily: {unsafe.Sizeof(DPGVector{}), variableSize},
types.RefCursorFamily: {unsafe.Sizeof(DString("")), variableSize},
types.TimeFamily: {unsafe.Sizeof(DTime(0)), fixedSize},
types.TimeTZFamily: {unsafe.Sizeof(DTimeTZ{}), fixedSize},
@@ -6311,6 +6419,14 @@ func AdjustValueToType(typ *types.T, inVal Datum) (outVal Datum, err error) {
return nil, err
}
}
+ case types.PGVectorFamily:
+ if in, ok := inVal.(*DPGVector); ok {
+ width := int(typ.Width())
+ if width > 0 && len(in.T) != width {
+ return nil, pgerror.Newf(pgcode.DataException,
+ "expected %d dimensions, not %d", typ.Width(), len(in.T))
+ }
+ }
}
return inVal, nil
}
diff --git a/pkg/sql/sem/tree/eval.go b/pkg/sql/sem/tree/eval.go
index c597129de5a0..d72c47a0a66e 100644
--- a/pkg/sql/sem/tree/eval.go
+++ b/pkg/sql/sem/tree/eval.go
@@ -801,6 +801,13 @@ var BinOps = map[treebin.BinaryOperatorSymbol]*BinOpOverloads{
EvalOp: &PlusPGLSNDecimalOp{},
Volatility: volatility.Immutable,
},
+ {
+ LeftType: types.PGVector,
+ RightType: types.PGVector,
+ ReturnType: types.PGVector,
+ EvalOp: &PlusPGVectorOp{},
+ Volatility: volatility.Immutable,
+ },
}},
treebin.Minus: {overloads: []*BinOp{
@@ -987,6 +994,13 @@ var BinOps = map[treebin.BinaryOperatorSymbol]*BinOpOverloads{
EvalOp: &MinusPGLSNOp{},
Volatility: volatility.Immutable,
},
+ {
+ LeftType: types.PGVector,
+ RightType: types.PGVector,
+ ReturnType: types.PGVector,
+ EvalOp: &MinusPGVectorOp{},
+ Volatility: volatility.Immutable,
+ },
}},
treebin.Mult: {overloads: []*BinOp{
@@ -1070,6 +1084,13 @@ var BinOps = map[treebin.BinaryOperatorSymbol]*BinOpOverloads{
EvalOp: &MultIntervalDecimalOp{},
Volatility: volatility.Immutable,
},
+ {
+ LeftType: types.PGVector,
+ RightType: types.PGVector,
+ ReturnType: types.PGVector,
+ EvalOp: &MultPGVectorOp{},
+ Volatility: volatility.Immutable,
+ },
}},
treebin.Div: {overloads: []*BinOp{
@@ -1382,6 +1403,33 @@ var BinOps = map[treebin.BinaryOperatorSymbol]*BinOpOverloads{
Volatility: volatility.Immutable,
},
}},
+ treebin.Distance: {overloads: []*BinOp{
+ {
+ LeftType: types.PGVector,
+ RightType: types.PGVector,
+ ReturnType: types.Float,
+ EvalOp: &DistanceVectorOp{},
+ Volatility: volatility.Immutable,
+ },
+ }},
+ treebin.CosDistance: {overloads: []*BinOp{
+ {
+ LeftType: types.PGVector,
+ RightType: types.PGVector,
+ ReturnType: types.Float,
+ EvalOp: &CosDistanceVectorOp{},
+ Volatility: volatility.Immutable,
+ },
+ }},
+ treebin.NegInnerProduct: {overloads: []*BinOp{
+ {
+ LeftType: types.PGVector,
+ RightType: types.PGVector,
+ ReturnType: types.Float,
+ EvalOp: &NegInnerProductVectorOp{},
+ Volatility: volatility.Immutable,
+ },
+ }},
}
// CmpOp is a comparison operator.
@@ -1575,6 +1623,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{
makeEqFn(types.Jsonb, types.Jsonb, volatility.Immutable),
makeEqFn(types.Oid, types.Oid, volatility.Leakproof),
makeEqFn(types.PGLSN, types.PGLSN, volatility.Leakproof),
+ makeEqFn(types.PGVector, types.PGVector, volatility.Leakproof),
makeEqFn(types.RefCursor, types.RefCursor, volatility.Leakproof),
makeEqFn(types.String, types.String, volatility.Leakproof),
makeEqFn(types.Time, types.Time, volatility.Leakproof),
@@ -1635,6 +1684,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{
makeLtFn(types.Interval, types.Interval, volatility.Leakproof),
makeLtFn(types.Oid, types.Oid, volatility.Leakproof),
makeLtFn(types.PGLSN, types.PGLSN, volatility.Leakproof),
+ makeLtFn(types.PGVector, types.PGVector, volatility.Leakproof),
makeLtFn(types.RefCursor, types.RefCursor, volatility.Leakproof),
makeLtFn(types.String, types.String, volatility.Leakproof),
makeLtFn(types.Time, types.Time, volatility.Leakproof),
@@ -1694,6 +1744,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{
makeLeFn(types.Interval, types.Interval, volatility.Leakproof),
makeLeFn(types.Oid, types.Oid, volatility.Leakproof),
makeLeFn(types.PGLSN, types.PGLSN, volatility.Leakproof),
+ makeLeFn(types.PGVector, types.PGVector, volatility.Leakproof),
makeLeFn(types.RefCursor, types.RefCursor, volatility.Leakproof),
makeLeFn(types.String, types.String, volatility.Leakproof),
makeLeFn(types.Time, types.Time, volatility.Leakproof),
@@ -1774,6 +1825,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{
makeIsFn(types.Jsonb, types.Jsonb, volatility.Immutable),
makeIsFn(types.Oid, types.Oid, volatility.Leakproof),
makeIsFn(types.PGLSN, types.PGLSN, volatility.Leakproof),
+ makeIsFn(types.PGVector, types.PGVector, volatility.Leakproof),
makeIsFn(types.RefCursor, types.RefCursor, volatility.Leakproof),
makeIsFn(types.String, types.String, volatility.Leakproof),
makeIsFn(types.Time, types.Time, volatility.Leakproof),
@@ -1842,6 +1894,7 @@ var CmpOps = cmpOpFixups(map[treecmp.ComparisonOperatorSymbol]*CmpOpOverloads{
makeEvalTupleIn(types.Jsonb, volatility.Leakproof),
makeEvalTupleIn(types.Oid, volatility.Leakproof),
makeEvalTupleIn(types.PGLSN, volatility.Leakproof),
+ makeEvalTupleIn(types.PGVector, volatility.Leakproof),
makeEvalTupleIn(types.RefCursor, volatility.Leakproof),
makeEvalTupleIn(types.String, volatility.Leakproof),
makeEvalTupleIn(types.Time, volatility.Leakproof),
diff --git a/pkg/sql/sem/tree/eval_binary_ops.go b/pkg/sql/sem/tree/eval_binary_ops.go
index 0c96d326d24d..e68a034328a0 100644
--- a/pkg/sql/sem/tree/eval_binary_ops.go
+++ b/pkg/sql/sem/tree/eval_binary_ops.go
@@ -81,6 +81,15 @@ type TSMatchesVectorQueryOp struct{}
// TSMatchesQueryVectorOp is a BinaryEvalOp.
type TSMatchesQueryVectorOp struct{}
+type (
+ // DistanceVectorOp is a BinaryEvalOp.
+ DistanceVectorOp struct{}
+ // CosDistanceVectorOp is a BinaryEvalOp.
+ CosDistanceVectorOp struct{}
+ // NegInnerProductVectorOp is a BinaryEvalOp.
+ NegInnerProductVectorOp struct{}
+)
+
// AppendToMaybeNullArrayOp is a BinaryEvalOp.
type AppendToMaybeNullArrayOp struct {
Typ *types.T
@@ -180,6 +189,8 @@ type (
PlusDecimalPGLSNOp struct{}
// PlusPGLSNDecimalOp is a BinaryEvalOp.
PlusPGLSNDecimalOp struct{}
+ // PlusPGVectorOp is a BinaryEvalOp.
+ PlusPGVectorOp struct{}
)
type (
@@ -235,6 +246,8 @@ type (
MinusPGLSNDecimalOp struct{}
// MinusPGLSNOp is a BinaryEvalOp.
MinusPGLSNOp struct{}
+ // MinusPGVectorOp is a BinaryEvalOp.
+ MinusPGVectorOp struct{}
)
type (
// MultDecimalIntOp is a BinaryEvalOp.
@@ -259,6 +272,8 @@ type (
MultIntervalFloatOp struct{}
// MultIntervalIntOp is a BinaryEvalOp.
MultIntervalIntOp struct{}
+ // MultPGVectorOp is a BinaryEvalOp.
+ MultPGVectorOp struct{}
)
type (
diff --git a/pkg/sql/sem/tree/eval_expr_generated.go b/pkg/sql/sem/tree/eval_expr_generated.go
index 3726d3c19802..e3d8f59439fa 100644
--- a/pkg/sql/sem/tree/eval_expr_generated.go
+++ b/pkg/sql/sem/tree/eval_expr_generated.go
@@ -222,6 +222,11 @@ func (node *DPGLSN) Eval(ctx context.Context, v ExprEvaluator) (Datum, error) {
return node, nil
}
+// Eval is part of the TypedExpr interface.
+func (node *DPGVector) Eval(ctx context.Context, v ExprEvaluator) (Datum, error) {
+ return node, nil
+}
+
// Eval is part of the TypedExpr interface.
func (node *DString) Eval(ctx context.Context, v ExprEvaluator) (Datum, error) {
return node, nil
diff --git a/pkg/sql/sem/tree/eval_op_generated.go b/pkg/sql/sem/tree/eval_op_generated.go
index a99cd26ef4f1..f53bf8f4e845 100644
--- a/pkg/sql/sem/tree/eval_op_generated.go
+++ b/pkg/sql/sem/tree/eval_op_generated.go
@@ -77,6 +77,8 @@ type BinaryOpEvaluator interface {
EvalContainedByJsonbOp(context.Context, *ContainedByJsonbOp, Datum, Datum) (Datum, error)
EvalContainsArrayOp(context.Context, *ContainsArrayOp, Datum, Datum) (Datum, error)
EvalContainsJsonbOp(context.Context, *ContainsJsonbOp, Datum, Datum) (Datum, error)
+ EvalCosDistanceVectorOp(context.Context, *CosDistanceVectorOp, Datum, Datum) (Datum, error)
+ EvalDistanceVectorOp(context.Context, *DistanceVectorOp, Datum, Datum) (Datum, error)
EvalDivDecimalIntOp(context.Context, *DivDecimalIntOp, Datum, Datum) (Datum, error)
EvalDivDecimalOp(context.Context, *DivDecimalOp, Datum, Datum) (Datum, error)
EvalDivFloatOp(context.Context, *DivFloatOp, Datum, Datum) (Datum, error)
@@ -121,6 +123,7 @@ type BinaryOpEvaluator interface {
EvalMinusJsonbStringOp(context.Context, *MinusJsonbStringOp, Datum, Datum) (Datum, error)
EvalMinusPGLSNDecimalOp(context.Context, *MinusPGLSNDecimalOp, Datum, Datum) (Datum, error)
EvalMinusPGLSNOp(context.Context, *MinusPGLSNOp, Datum, Datum) (Datum, error)
+ EvalMinusPGVectorOp(context.Context, *MinusPGVectorOp, Datum, Datum) (Datum, error)
EvalMinusTimeIntervalOp(context.Context, *MinusTimeIntervalOp, Datum, Datum) (Datum, error)
EvalMinusTimeOp(context.Context, *MinusTimeOp, Datum, Datum) (Datum, error)
EvalMinusTimeTZIntervalOp(context.Context, *MinusTimeTZIntervalOp, Datum, Datum) (Datum, error)
@@ -147,6 +150,8 @@ type BinaryOpEvaluator interface {
EvalMultIntervalDecimalOp(context.Context, *MultIntervalDecimalOp, Datum, Datum) (Datum, error)
EvalMultIntervalFloatOp(context.Context, *MultIntervalFloatOp, Datum, Datum) (Datum, error)
EvalMultIntervalIntOp(context.Context, *MultIntervalIntOp, Datum, Datum) (Datum, error)
+ EvalMultPGVectorOp(context.Context, *MultPGVectorOp, Datum, Datum) (Datum, error)
+ EvalNegInnerProductVectorOp(context.Context, *NegInnerProductVectorOp, Datum, Datum) (Datum, error)
EvalOverlapsArrayOp(context.Context, *OverlapsArrayOp, Datum, Datum) (Datum, error)
EvalOverlapsINetOp(context.Context, *OverlapsINetOp, Datum, Datum) (Datum, error)
EvalPlusDateIntOp(context.Context, *PlusDateIntOp, Datum, Datum) (Datum, error)
@@ -169,6 +174,7 @@ type BinaryOpEvaluator interface {
EvalPlusIntervalTimestampOp(context.Context, *PlusIntervalTimestampOp, Datum, Datum) (Datum, error)
EvalPlusIntervalTimestampTZOp(context.Context, *PlusIntervalTimestampTZOp, Datum, Datum) (Datum, error)
EvalPlusPGLSNDecimalOp(context.Context, *PlusPGLSNDecimalOp, Datum, Datum) (Datum, error)
+ EvalPlusPGVectorOp(context.Context, *PlusPGVectorOp, Datum, Datum) (Datum, error)
EvalPlusTimeDateOp(context.Context, *PlusTimeDateOp, Datum, Datum) (Datum, error)
EvalPlusTimeIntervalOp(context.Context, *PlusTimeIntervalOp, Datum, Datum) (Datum, error)
EvalPlusTimeTZDateOp(context.Context, *PlusTimeTZDateOp, Datum, Datum) (Datum, error)
@@ -355,6 +361,16 @@ func (op *ContainsJsonbOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum)
return e.EvalContainsJsonbOp(ctx, op, a, b)
}
+// Eval is part of the BinaryEvalOp interface.
+func (op *CosDistanceVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
+ return e.EvalCosDistanceVectorOp(ctx, op, a, b)
+}
+
+// Eval is part of the BinaryEvalOp interface.
+func (op *DistanceVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
+ return e.EvalDistanceVectorOp(ctx, op, a, b)
+}
+
// Eval is part of the BinaryEvalOp interface.
func (op *DivDecimalIntOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
return e.EvalDivDecimalIntOp(ctx, op, a, b)
@@ -575,6 +591,11 @@ func (op *MinusPGLSNOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Da
return e.EvalMinusPGLSNOp(ctx, op, a, b)
}
+// Eval is part of the BinaryEvalOp interface.
+func (op *MinusPGVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
+ return e.EvalMinusPGVectorOp(ctx, op, a, b)
+}
+
// Eval is part of the BinaryEvalOp interface.
func (op *MinusTimeIntervalOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
return e.EvalMinusTimeIntervalOp(ctx, op, a, b)
@@ -705,6 +726,16 @@ func (op *MultIntervalIntOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum
return e.EvalMultIntervalIntOp(ctx, op, a, b)
}
+// Eval is part of the BinaryEvalOp interface.
+func (op *MultPGVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
+ return e.EvalMultPGVectorOp(ctx, op, a, b)
+}
+
+// Eval is part of the BinaryEvalOp interface.
+func (op *NegInnerProductVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
+ return e.EvalNegInnerProductVectorOp(ctx, op, a, b)
+}
+
// Eval is part of the BinaryEvalOp interface.
func (op *OverlapsArrayOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
return e.EvalOverlapsArrayOp(ctx, op, a, b)
@@ -815,6 +846,11 @@ func (op *PlusPGLSNDecimalOp) Eval(ctx context.Context, e OpEvaluator, a, b Datu
return e.EvalPlusPGLSNDecimalOp(ctx, op, a, b)
}
+// Eval is part of the BinaryEvalOp interface.
+func (op *PlusPGVectorOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
+ return e.EvalPlusPGVectorOp(ctx, op, a, b)
+}
+
// Eval is part of the BinaryEvalOp interface.
func (op *PlusTimeDateOp) Eval(ctx context.Context, e OpEvaluator, a, b Datum) (Datum, error) {
return e.EvalPlusTimeDateOp(ctx, op, a, b)
diff --git a/pkg/sql/sem/tree/expr.go b/pkg/sql/sem/tree/expr.go
index 4540cbea894b..11c8d1cac6b1 100644
--- a/pkg/sql/sem/tree/expr.go
+++ b/pkg/sql/sem/tree/expr.go
@@ -1050,6 +1050,7 @@ var binaryOpPrio = [...]int{
treebin.Bitxor: 6,
treebin.Bitor: 7,
treebin.Concat: 8, treebin.JSONFetchVal: 8, treebin.JSONFetchText: 8, treebin.JSONFetchValPath: 8, treebin.JSONFetchTextPath: 8,
+ treebin.Distance: 8, treebin.CosDistance: 8, treebin.NegInnerProduct: 8,
}
// binaryOpFullyAssoc indicates whether an operator is fully associative.
@@ -1063,6 +1064,7 @@ var binaryOpFullyAssoc = [...]bool{
treebin.Bitxor: true,
treebin.Bitor: true,
treebin.Concat: true, treebin.JSONFetchVal: false, treebin.JSONFetchText: false, treebin.JSONFetchValPath: false, treebin.JSONFetchTextPath: false,
+ treebin.Distance: false, treebin.CosDistance: false, treebin.NegInnerProduct: false,
}
// BinaryExpr represents a binary value expression.
diff --git a/pkg/sql/sem/tree/parse_string.go b/pkg/sql/sem/tree/parse_string.go
index d745336f3e9e..68e9c5b31d56 100644
--- a/pkg/sql/sem/tree/parse_string.go
+++ b/pkg/sql/sem/tree/parse_string.go
@@ -69,6 +69,8 @@ func ParseAndRequireString(
d, err = ParseDIntervalWithTypeMetadata(intervalStyle(ctx), s, itm)
case types.PGLSNFamily:
d, err = ParseDPGLSN(s)
+ case types.PGVectorFamily:
+ d, err = ParseDPGVector(s)
case types.RefCursorFamily:
d = NewDRefCursor(s)
case types.Box2DFamily:
diff --git a/pkg/sql/sem/tree/treebin/binary_operator.go b/pkg/sql/sem/tree/treebin/binary_operator.go
index 332d4954d238..f9203e0a29fb 100644
--- a/pkg/sql/sem/tree/treebin/binary_operator.go
+++ b/pkg/sql/sem/tree/treebin/binary_operator.go
@@ -61,6 +61,9 @@ const (
JSONFetchValPath
JSONFetchTextPath
TSMatch
+ Distance
+ CosDistance
+ NegInnerProduct
NumBinaryOperatorSymbols
)
@@ -86,6 +89,9 @@ var binaryOpName = [...]string{
JSONFetchValPath: "#>",
JSONFetchTextPath: "#>>",
TSMatch: "@@",
+ Distance: "<->",
+ CosDistance: "<=>",
+ NegInnerProduct: "<#>",
}
// IsPadded returns whether the binary operator needs to be padded.
diff --git a/pkg/sql/sem/tree/type_check.go b/pkg/sql/sem/tree/type_check.go
index e5987341aece..cfe94fc640cb 100644
--- a/pkg/sql/sem/tree/type_check.go
+++ b/pkg/sql/sem/tree/type_check.go
@@ -2116,6 +2116,12 @@ func (d *DPGLSN) TypeCheck(_ context.Context, _ *SemaContext, _ *types.T) (Typed
return d, nil
}
+// TypeCheck implements the Expr interface. It is implemented as an idempotent
+// identity function for Datum.
+func (d *DPGVector) TypeCheck(_ context.Context, _ *SemaContext, _ *types.T) (TypedExpr, error) {
+ return d, nil
+}
+
// TypeCheck implements the Expr interface. It is implemented as an idempotent
// identity function for Datum.
func (d *DGeography) TypeCheck(_ context.Context, _ *SemaContext, _ *types.T) (TypedExpr, error) {
diff --git a/pkg/sql/sem/tree/walk.go b/pkg/sql/sem/tree/walk.go
index 4775d4926498..c7eed182d673 100644
--- a/pkg/sql/sem/tree/walk.go
+++ b/pkg/sql/sem/tree/walk.go
@@ -762,6 +762,9 @@ func (expr *DBox2D) Walk(_ Visitor) Expr { return expr }
// Walk implements the Expr interface.
func (expr *DPGLSN) Walk(_ Visitor) Expr { return expr }
+// Walk implements the Expr interface.
+func (expr *DPGVector) Walk(_ Visitor) Expr { return expr }
+
// Walk implements the Expr interface.
func (expr *DGeography) Walk(_ Visitor) Expr { return expr }
diff --git a/pkg/sql/stats/stats_test.go b/pkg/sql/stats/stats_test.go
index cd290fc56f4e..cb383dd048d2 100644
--- a/pkg/sql/stats/stats_test.go
+++ b/pkg/sql/stats/stats_test.go
@@ -54,7 +54,7 @@ func TestStatsAnyType(t *testing.T) {
// Casting random integers to REGTYPE might fail.
continue loop
}
- if err := colinfo.ValidateColumnDefType(ctx, st.Version, typ); err == nil {
+ if err := colinfo.ValidateColumnDefType(ctx, st, typ); err == nil {
break
}
}
diff --git a/pkg/sql/types/oid.go b/pkg/sql/types/oid.go
index d6320a967b2b..12a83a87872c 100644
--- a/pkg/sql/types/oid.go
+++ b/pkg/sql/types/oid.go
@@ -108,6 +108,7 @@ var OidToType = map[oid.Oid]*T{
oidext.T_geometry: Geometry,
oidext.T_geography: Geography,
oidext.T_box2d: Box2D,
+ oidext.T_pgvector: PGVector,
}
// oidToArrayOid maps scalar type Oids to their corresponding array type Oid.
@@ -155,6 +156,7 @@ var oidToArrayOid = map[oid.Oid]oid.Oid{
oidext.T_geometry: oidext.T__geometry,
oidext.T_geography: oidext.T__geography,
oidext.T_box2d: oidext.T__box2d,
+ oidext.T_pgvector: oidext.T__pgvector,
}
// familyToOid maps each type family to a default OID value that is used when
@@ -191,6 +193,7 @@ var familyToOid = map[Family]oid.Oid{
GeometryFamily: oidext.T_geometry,
GeographyFamily: oidext.T_geography,
Box2DFamily: oidext.T_box2d,
+ PGVectorFamily: oidext.T_pgvector,
}
// ArrayOids is a set of all oids which correspond to an array type.
diff --git a/pkg/sql/types/types.go b/pkg/sql/types/types.go
index da055d3416d6..4f5d21a4b68e 100644
--- a/pkg/sql/types/types.go
+++ b/pkg/sql/types/types.go
@@ -487,6 +487,15 @@ var (
},
}
+ // PGVector is the type representing a PGVector object.
+ PGVector = &T{
+ InternalType: InternalType{
+ Family: PGVectorFamily,
+ Oid: oidext.T_pgvector,
+ Locale: &emptyLocale,
+ },
+ }
+
// Void is the type representing void.
Void = &T{
InternalType: InternalType{
@@ -647,6 +656,10 @@ var (
PGLSNArray = &T{InternalType: InternalType{
Family: ArrayFamily, ArrayContents: PGLSN, Oid: oid.T__pg_lsn, Locale: &emptyLocale}}
+ // PGVectorArray is the type of an array value having PGVector-typed elements.
+ PGVectorArray = &T{InternalType: InternalType{
+ Family: ArrayFamily, ArrayContents: PGVector, Oid: oidext.T__pgvector, Locale: &emptyLocale}}
+
// RefCursorArray is the type of an array value having REFCURSOR-typed elements.
RefCursorArray = &T{InternalType: InternalType{
Family: ArrayFamily, ArrayContents: RefCursor, Oid: oid.T__refcursor, Locale: &emptyLocale}}
@@ -1218,6 +1231,17 @@ func MakeLabeledTuple(contents []*T, labels []string) *T {
}}
}
+// MakePGVector constructs a new instance of a VECTOR type (pg_vector) that has
+// the given number of dimensions.
+func MakePGVector(dims int32) *T {
+ return &T{InternalType: InternalType{
+ Family: PGVectorFamily,
+ Oid: oidext.T_pgvector,
+ Width: dims,
+ Locale: &emptyLocale,
+ }}
+}
+
// NewCompositeType constructs a new instance of a TupleFamily type with the
// given field types and labels, and the given user-defined type OIDs.
func NewCompositeType(typeOID, arrayTypeOID oid.Oid, contents []*T, labels []string) *T {
@@ -1294,6 +1318,7 @@ func (t *T) Locale() string {
// STRING : max # of characters
// COLLATEDSTRING: max # of characters
// BIT : max # of bits
+// VECTOR : # of dimensions
//
// Width is always 0 for other types.
func (t *T) Width() int32 {
@@ -1344,7 +1369,7 @@ func (t *T) TypeModifier() int32 {
// var header size.
return width + 4
}
- case BitFamily:
+ case BitFamily, PGVectorFamily:
if width := t.Width(); width != 0 {
return width
}
@@ -1507,6 +1532,7 @@ var familyNames = map[Family]redact.SafeString{
JsonFamily: "jsonb",
OidFamily: "oid",
PGLSNFamily: "pg_lsn",
+ PGVectorFamily: "vector",
RefCursorFamily: "refcursor",
StringFamily: "string",
TimeFamily: "time",
@@ -1788,6 +1814,8 @@ func (t *T) SQLStandardNameWithTypmod(haveTypmod bool, typmod int) string {
}
case PGLSNFamily:
return "pg_lsn"
+ case PGVectorFamily:
+ return "vector"
case RefCursorFamily:
return "refcursor"
case StringFamily, CollatedStringFamily:
@@ -1999,6 +2027,11 @@ func (t *T) SQLString() string {
// databases when this function is called to produce DDL like in SHOW
// CREATE.
return t.TypeMeta.Name.FQName(false /* explicitCatalog */)
+ case PGVectorFamily:
+ if t.Width() == 0 {
+ return "VECTOR"
+ }
+ return fmt.Sprintf("VECTOR(%d)", t.Width())
}
return strings.ToUpper(t.Name())
}
@@ -2046,7 +2079,7 @@ func (t *T) SQLStringForError() redact.RedactableString {
IntervalFamily, StringFamily, BytesFamily, TimestampTZFamily, CollatedStringFamily, OidFamily,
UnknownFamily, UuidFamily, INetFamily, TimeFamily, JsonFamily, TimeTZFamily, BitFamily,
GeometryFamily, GeographyFamily, Box2DFamily, VoidFamily, EncodedKeyFamily, TSQueryFamily,
- TSVectorFamily, AnyFamily, PGLSNFamily, RefCursorFamily:
+ TSVectorFamily, AnyFamily, PGLSNFamily, PGVectorFamily, RefCursorFamily:
// These types do not contain other types, and do not require redaction.
return redact.Sprint(redact.SafeString(t.SQLString()))
}
@@ -2816,6 +2849,8 @@ func IsValidArrayElementType(t *T) (valid bool, issueNum int) {
return false, 90886
case TSVectorFamily:
return false, 90886
+ case PGVectorFamily:
+ return false, 121432
default:
return true, 0
}
diff --git a/pkg/sql/types/types.proto b/pkg/sql/types/types.proto
index 84ab3a03fa9d..539dbc55f8cc 100644
--- a/pkg/sql/types/types.proto
+++ b/pkg/sql/types/types.proto
@@ -402,6 +402,10 @@ enum Family {
// Oid : T_refcursor
RefCursorFamily = 31;
+ // PGVectorFamily is a type family for the vector type, which is the
+ // type representing pgvector vectors.
+ PGVectorFamily = 32;
+
// AnyFamily is a special type family used during static analysis as a
// wildcard type that matches any other type, including scalar, array, and
// tuple types. Execution-time values should never have this type. As an
diff --git a/pkg/util/encoding/encoding.go b/pkg/util/encoding/encoding.go
index f0bae698e4d1..c3f62460a452 100644
--- a/pkg/util/encoding/encoding.go
+++ b/pkg/util/encoding/encoding.go
@@ -1780,6 +1780,7 @@ const (
// Special case
JsonEmptyArray Type = 42
JsonEmptyArrayDesc Type = 43
+ PGVector Type = 44
)
// typMap maps an encoded type byte to a decoded Type. It's got 256 slots, one
@@ -2622,6 +2623,12 @@ func EncodeUntaggedFloatValue(appendTo []byte, f float64) []byte {
return EncodeUint64Ascending(appendTo, math.Float64bits(f))
}
+// EncodeUntaggedFloat32Value encodes a float32 value, appends it to the supplied buffer,
+// and returns the final buffer.
+func EncodeUntaggedFloat32Value(appendTo []byte, f float32) []byte {
+ return EncodeUint32Ascending(appendTo, math.Float32bits(f))
+}
+
// EncodeBytesValue encodes a byte array value with its value tag, appends it to
// the supplied buffer, and returns the final buffer.
func EncodeBytesValue(appendTo []byte, colID uint32, data []byte) []byte {
@@ -2820,6 +2827,14 @@ func EncodeTSVectorValue(appendTo []byte, colID uint32, data []byte) []byte {
return EncodeUntaggedBytesValue(appendTo, data)
}
+// EncodePGVectorValue encodes an already-byte-encoded PGVector value with no
+// value tag but with a length prefix, appends it to the supplied buffer, and
+// returns the final buffer.
+func EncodePGVectorValue(appendTo []byte, colID uint32, data []byte) []byte {
+ appendTo = EncodeValueTag(appendTo, colID, PGVector)
+ return EncodeUntaggedBytesValue(appendTo, data)
+}
+
// DecodeValueTag decodes a value encoded by EncodeValueTag, used as a prefix in
// each of the other EncodeFooValue methods.
//
@@ -2922,6 +2937,16 @@ func DecodeUntaggedFloatValue(b []byte) (remaining []byte, f float64, err error)
return b, math.Float64frombits(i), err
}
+// DecodeUntaggedFloat32Value decodes a value encoded by EncodeUntaggedFloat32Value.
+func DecodeUntaggedFloat32Value(b []byte) (remaining []byte, f float32, err error) {
+ if len(b) < 4 {
+ return b, 0, fmt.Errorf("float32 value should be exactly 4 bytes: %d", len(b))
+ }
+ var i uint32
+ b, i, err = DecodeUint32Ascending(b)
+ return b, math.Float32frombits(i), err
+}
+
// DecodeBytesValue decodes a value encoded by EncodeBytesValue.
func DecodeBytesValue(b []byte) (remaining []byte, data []byte, err error) {
b, err = decodeValueTypeAssert(b, Bytes)
@@ -3208,7 +3233,7 @@ func PeekValueLengthWithOffsetsAndType(b []byte, dataOffset int, typ Type) (leng
return dataOffset + n, err
case Float:
return dataOffset + floatValueEncodedLength, nil
- case Bytes, Array, JSON, Geo, TSVector, TSQuery:
+ case Bytes, Array, JSON, Geo, TSVector, TSQuery, PGVector:
_, n, i, err := DecodeNonsortingUvarint(b)
return dataOffset + n + int(i), err
case Box2D:
diff --git a/pkg/util/encoding/type_string.go b/pkg/util/encoding/type_string.go
index a6eaee2c3d5e..19174c451844 100644
--- a/pkg/util/encoding/type_string.go
+++ b/pkg/util/encoding/type_string.go
@@ -53,6 +53,7 @@ func _() {
_ = x[JSONObjectDesc-41]
_ = x[JsonEmptyArray-42]
_ = x[JsonEmptyArrayDesc-43]
+ _ = x[PGVector-44]
}
func (i Type) String() string {
@@ -145,6 +146,8 @@ func (i Type) String() string {
return "JsonEmptyArray"
case JsonEmptyArrayDesc:
return "JsonEmptyArrayDesc"
+ case PGVector:
+ return "PGVector"
default:
return "Type(" + strconv.FormatInt(int64(i), 10) + ")"
}
diff --git a/pkg/util/parquet/writer_bench_test.go b/pkg/util/parquet/writer_bench_test.go
index 3e550681a341..6525da667b14 100644
--- a/pkg/util/parquet/writer_bench_test.go
+++ b/pkg/util/parquet/writer_bench_test.go
@@ -69,7 +69,7 @@ func getBenchmarkTypes() []*types.T {
for _, typ := range randgen.SeedTypes {
switch typ.Family() {
case types.AnyFamily, types.TSQueryFamily, types.TSVectorFamily,
- types.VoidFamily:
+ types.VoidFamily, types.PGVectorFamily:
case types.TupleFamily:
// Replace Any Tuple with Tuple of Ints with size 5.
typs = append(typs, types.MakeTuple([]*types.T{
diff --git a/pkg/util/parquet/writer_test.go b/pkg/util/parquet/writer_test.go
index 094667490e89..934a5e59e3bc 100644
--- a/pkg/util/parquet/writer_test.go
+++ b/pkg/util/parquet/writer_test.go
@@ -51,7 +51,7 @@ func newColSchema(numCols int) *colSchema {
// that are not supported by the writer.
func typSupported(typ *types.T) bool {
switch typ.Family() {
- case types.AnyFamily, types.TSQueryFamily, types.TSVectorFamily, types.VoidFamily:
+ case types.AnyFamily, types.TSQueryFamily, types.TSVectorFamily, types.PGVectorFamily, types.VoidFamily:
return false
case types.ArrayFamily:
if typ.ArrayContents().Family() == types.ArrayFamily || typ.ArrayContents().Family() == types.TupleFamily {
diff --git a/pkg/util/vector/BUILD.bazel b/pkg/util/vector/BUILD.bazel
new file mode 100644
index 000000000000..f4aeca51ac43
--- /dev/null
+++ b/pkg/util/vector/BUILD.bazel
@@ -0,0 +1,23 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "vector",
+ srcs = ["vector.go"],
+ importpath = "github.com/cockroachdb/cockroach/pkg/util/vector",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sql/pgwire/pgcode",
+ "//pkg/sql/pgwire/pgerror",
+ "//pkg/util/encoding",
+ ],
+)
+
+go_test(
+ name = "vector_test",
+ srcs = ["vector_test.go"],
+ embed = [":vector"],
+ deps = [
+ "//pkg/util/randutil",
+ "@com_github_stretchr_testify//assert",
+ ],
+)
diff --git a/pkg/util/vector/vector.go b/pkg/util/vector/vector.go
new file mode 100644
index 000000000000..c5ed3085d57d
--- /dev/null
+++ b/pkg/util/vector/vector.go
@@ -0,0 +1,288 @@
+// Copyright 2024 The Cockroach Authors.
+//
+// Use of this software is governed by the Business Source License
+// included in the file licenses/BSL.txt.
+//
+// As of the Change Date specified in that file, in accordance with
+// the Business Source License, use of this software will be governed
+// by the Apache License, Version 2.0, included in the file
+// licenses/APL.txt.
+
+package vector
+
+import (
+ "math"
+ "math/rand"
+ "strconv"
+ "strings"
+
+ "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
+ "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
+ "github.com/cockroachdb/cockroach/pkg/util/encoding"
+)
+
+// MaxDim is the maximum number of dimensions a vector can have.
+const MaxDim = 16000
+
+// T is the type of a PGVector-like vector.
+type T []float32
+
+// ParseVector parses the Postgres string representation of a vector.
+func ParseVector(input string) (T, error) {
+ input = strings.TrimSpace(input)
+ if !strings.HasPrefix(input, "[") || !strings.HasSuffix(input, "]") {
+ return T{}, pgerror.Newf(pgcode.InvalidTextRepresentation,
+ "malformed vector literal: Vector contents must start with \"[\" and"+
+ " end with \"]\"")
+ }
+
+ input = strings.TrimPrefix(input, "[")
+ input = strings.TrimSuffix(input, "]")
+ parts := strings.Split(input, ",")
+
+ if len(parts) > MaxDim {
+ return T{}, pgerror.Newf(pgcode.ProgramLimitExceeded, "vector cannot have more than %d dimensions", MaxDim)
+ }
+
+ vector := make([]float32, len(parts))
+ for i, part := range parts {
+ part = strings.TrimSpace(part)
+ if part == "" {
+ return T{}, pgerror.New(pgcode.InvalidTextRepresentation, "invalid input syntax for type vector: empty string")
+ }
+
+ val, err := strconv.ParseFloat(part, 32)
+ if err != nil {
+ return T{}, pgerror.Newf(pgcode.InvalidTextRepresentation, "invalid input syntax for type vector: %s", part)
+ }
+
+ if math.IsInf(val, 0) {
+ return T{}, pgerror.New(pgcode.DataException, "infinite value not allowed in vector")
+ }
+ if math.IsNaN(val) {
+ return T{}, pgerror.New(pgcode.DataException, "NaN not allowed in vector")
+ }
+ vector[i] = float32(val)
+ }
+
+ return vector, nil
+}
+
+// String implements the fmt.Stringer interface.
+func (v T) String() string {
+ var sb strings.Builder
+ sb.WriteString("[")
+ // Pre-grow by a reasonable amount to avoid multiple allocations.
+ sb.Grow(len(v) * 8)
+ for i, v := range v {
+ if i > 0 {
+ sb.WriteString(",")
+ }
+ sb.WriteString(strconv.FormatFloat(float64(v), 'g', -1, 32))
+ }
+ sb.WriteString("]")
+ return sb.String()
+}
+
+// Size returns the size of the vector in bytes.
+func (v T) Size() uintptr {
+ return 24 + uintptr(cap(v))*4
+}
+
+// Compare returns -1 if v < v2, 1 if v > v2, and 0 if v == v2.
+func (v T) Compare(v2 T) (int, error) {
+ n := min(len(v), len(v2))
+ for i := 0; i < n; i++ {
+ if v[i] < v2[i] {
+ return -1, nil
+ } else if v[i] > v2[i] {
+ return 1, nil
+ }
+ }
+ if len(v) < len(v2) {
+ return -1, nil
+ } else if len(v) > len(v2) {
+ return 1, nil
+ }
+ return 0, nil
+}
+
+// Encode encodes the vector as a byte array suitable for storing in KV.
+func Encode(appendTo []byte, t T) ([]byte, error) {
+ appendTo = encoding.EncodeUint32Ascending(appendTo, uint32(len(t)))
+ for i := range t {
+ appendTo = encoding.EncodeUntaggedFloat32Value(appendTo, t[i])
+ }
+ return appendTo, nil
+}
+
+// Decode decodes the byte array into a vector.
+func Decode(b []byte) (ret T, err error) {
+ var n uint32
+ b, n, err = encoding.DecodeUint32Ascending(b)
+ if err != nil {
+ return nil, err
+ }
+ ret = make(T, n)
+ for i := range ret {
+ b, ret[i], err = encoding.DecodeUntaggedFloat32Value(b)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return ret, nil
+}
+
+func checkDims(t T, t2 T) error {
+ if len(t) != len(t2) {
+ return pgerror.Newf(pgcode.DataException, "different vector dimensions %d and %d", len(t), len(t2))
+ }
+ return nil
+}
+
+// L1Distance returns the L1 (Manhattan) distance between t and t2.
+func L1Distance(t T, t2 T) (float64, error) {
+ if err := checkDims(t, t2); err != nil {
+ return 0, err
+ }
+ var distance float32
+ for i := range len(t) {
+ diff := t[i] - t2[i]
+ distance += float32(math.Abs(float64(diff)))
+ }
+ return float64(distance), nil
+}
+
+// L2Distance returns the Euclidean distance between t and t2.
+func L2Distance(t T, t2 T) (float64, error) {
+ if err := checkDims(t, t2); err != nil {
+ return 0, err
+ }
+ var distance float32
+ for i := range len(t) {
+ diff := t[i] - t2[i]
+ distance += diff * diff
+ }
+ // TODO(queries): check for overflow and validate intermediate result if needed.
+ return math.Sqrt(float64(distance)), nil
+}
+
+// CosDistance returns the cosine distance between t and t2.
+func CosDistance(t T, t2 T) (float64, error) {
+ if err := checkDims(t, t2); err != nil {
+ return 0, err
+ }
+ var distance, normA, normB float32
+ for i := range len(t) {
+ distance += t[i] * t2[i]
+ normA += t[i] * t[i]
+ normB += t2[i] * t2[i]
+ }
+ // Use sqrt(a * b) over sqrt(a) * sqrt(b)
+ similarity := float64(distance) / math.Sqrt(float64(normA)*float64(normB))
+ /* Keep in range */
+ if similarity > 1 {
+ similarity = 1
+ } else if similarity < -1 {
+ similarity = -1
+ }
+ return 1 - similarity, nil
+}
+
+// InnerProduct returns the negative inner product of t1 and t2.
+func InnerProduct(t T, t2 T) (float64, error) {
+ if err := checkDims(t, t2); err != nil {
+ return 0, err
+ }
+ var distance float32
+ for i := range len(t) {
+ distance += t[i] * t2[i]
+ }
+ return float64(distance), nil
+}
+
+// NegInnerProduct returns the negative inner product of t1 and t2.
+func NegInnerProduct(t T, t2 T) (float64, error) {
+ p, err := InnerProduct(t, t2)
+ return p * -1, err
+}
+
+// Norm returns the L2 norm of t.
+func Norm(t T) float64 {
+ var norm float64
+ for i := range t {
+ norm += float64(t[i]) * float64(t[i])
+ }
+ // TODO(queries): check for overflow and validate intermediate result if needed.
+ return math.Sqrt(norm)
+}
+
+// Add returns t+t2, pointwise.
+func Add(t T, t2 T) (T, error) {
+ if err := checkDims(t, t2); err != nil {
+ return nil, err
+ }
+ ret := make(T, len(t))
+ for i := range t {
+ ret[i] = t[i] + t2[i]
+ }
+ for i := range ret {
+ if math.IsInf(float64(ret[i]), 0) {
+ return nil, pgerror.New(pgcode.NumericValueOutOfRange, "value out of range: overflow")
+ }
+ }
+ return ret, nil
+}
+
+// Minus returns t-t2, pointwise.
+func Minus(t T, t2 T) (T, error) {
+ if err := checkDims(t, t2); err != nil {
+ return nil, err
+ }
+ ret := make(T, len(t))
+ for i := range t {
+ ret[i] = t[i] - t2[i]
+ }
+ for i := range ret {
+ if math.IsInf(float64(ret[i]), 0) {
+ return nil, pgerror.New(pgcode.NumericValueOutOfRange, "value out of range: overflow")
+ }
+ }
+ return ret, nil
+}
+
+// Mult returns t*t2, pointwise.
+func Mult(t T, t2 T) (T, error) {
+ if err := checkDims(t, t2); err != nil {
+ return nil, err
+ }
+ ret := make(T, len(t))
+ for i := range t {
+ ret[i] = t[i] * t2[i]
+ }
+ for i := range ret {
+ if math.IsInf(float64(ret[i]), 0) {
+ return nil, pgerror.New(pgcode.NumericValueOutOfRange, "value out of range: overflow")
+ }
+ if ret[i] == 0 && !(t[i] == 0 || t2[i] == 0) {
+ return nil, pgerror.New(pgcode.NumericValueOutOfRange, "value out of range: underflow")
+ }
+ }
+ return ret, nil
+}
+
+// Random returns a random vector.
+func Random(rng *rand.Rand) T {
+ n := 1 + rng.Intn(1000)
+ v := make(T, n)
+ for i := range v {
+ for {
+ v[i] = math.Float32frombits(rng.Uint32())
+ if math.IsNaN(float64(v[i])) || math.IsInf(float64(v[i]), 0) {
+ continue
+ }
+ break
+ }
+ }
+ return v
+}
diff --git a/pkg/util/vector/vector_test.go b/pkg/util/vector/vector_test.go
new file mode 100644
index 000000000000..28fbf998b36b
--- /dev/null
+++ b/pkg/util/vector/vector_test.go
@@ -0,0 +1,208 @@
+// Copyright 2024 The Cockroach Authors.
+//
+// Use of this software is governed by the Business Source License
+// included in the file licenses/BSL.txt.
+//
+// As of the Change Date specified in that file, in accordance with
+// the Business Source License, use of this software will be governed
+// by the Apache License, Version 2.0, included in the file
+// licenses/APL.txt.
+
+package vector
+
+import (
+ "math"
+ "strings"
+ "testing"
+
+ "github.com/cockroachdb/cockroach/pkg/util/randutil"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestParseVector(t *testing.T) {
+ testCases := []struct {
+ input string
+ expected T
+ hasError bool
+ }{
+ {input: "[1,2,3]", expected: T{1, 2, 3}, hasError: false},
+ {input: "[1.0, 2.0, 3.0]", expected: T{1.0, 2.0, 3.0}, hasError: false},
+ {input: "[1.0, 2.0, 3.0", expected: T{}, hasError: true},
+ {input: "1.0, 2.0, 3.0]", expected: T{}, hasError: true},
+ {input: "[1.0, 2.0, [3.0]]", expected: T{}, hasError: true},
+ {input: "1.0, 2.0, 3.0]", expected: T{}, hasError: true},
+ {input: "1.0, , 3.0]", expected: T{}, hasError: true},
+ {input: "", expected: T{}, hasError: true},
+ {input: "[]", expected: T{}, hasError: true},
+ {input: "1.0, 2.0, 3.0", expected: T{}, hasError: true},
+ }
+
+ for _, tc := range testCases {
+ result, err := ParseVector(tc.input)
+
+ if tc.hasError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expected, result)
+ // Test roundtripping through String().
+ s := result.String()
+ result, err = ParseVector(s)
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expected, result)
+ }
+ }
+
+ // Test the maxdims error case.
+ var sb strings.Builder
+ sb.WriteString("[")
+ for i := 0; i < MaxDim; i++ {
+ sb.WriteString("1,")
+ }
+ sb.WriteString("1]")
+ _, err := ParseVector(sb.String())
+ assert.Errorf(t, err, "vector cannot have more than %d dimensions", MaxDim)
+}
+
+func TestRoundtripRandomPGVector(t *testing.T) {
+ rng, _ := randutil.NewTestRand()
+ for i := 0; i < 1000; i++ {
+ v := Random(rng)
+ encoded, err := Encode(nil, v)
+ assert.NoError(t, err)
+ roundtripped, err := Decode(encoded)
+ assert.NoError(t, err)
+ assert.Equal(t, v.String(), roundtripped.String())
+ reEncoded, err := Encode(nil, roundtripped)
+ assert.NoError(t, err)
+ assert.Equal(t, encoded, reEncoded)
+ }
+}
+
+func TestDistances(t *testing.T) {
+ // Test L1, L2, Cosine distance.
+ testCases := []struct {
+ v1 T
+ v2 T
+ l1 float64
+ l2 float64
+ cos float64
+ err bool
+ }{
+ {v1: T{1, 2, 3}, v2: T{4, 5, 6}, l1: 9, l2: 5.196152422, cos: 0.02536815, err: false},
+ {v1: T{-1, -2, -3}, v2: T{-4, -5, -6}, l1: 9, l2: 5.196152422, cos: 0.02536815, err: false},
+ {v1: T{0, 0, 0}, v2: T{0, 0, 0}, l1: 0, l2: 0, cos: math.NaN(), err: false},
+ {v1: T{1, 2, 3}, v2: T{1, 2, 3}, l1: 0, l2: 0, cos: 0, err: false},
+ {v1: T{1, 2, 3}, v2: T{1, 2, 4}, l1: 1, l2: 1, cos: 0.008539, err: false},
+ // Different vector sizes errors.
+ {v1: T{1, 2, 3}, v2: T{4, 5}, err: true},
+ }
+
+ for _, tc := range testCases {
+ l1, l1Err := L1Distance(tc.v1, tc.v2)
+ l2, l2Err := L2Distance(tc.v1, tc.v2)
+ cos, cosErr := CosDistance(tc.v1, tc.v2)
+
+ if tc.err {
+ assert.Error(t, l1Err)
+ assert.Error(t, l2Err)
+ assert.Error(t, cosErr)
+ } else {
+ assert.NoError(t, l1Err)
+ assert.NoError(t, l2Err)
+ assert.NoError(t, cosErr)
+ assert.InDelta(t, tc.l1, l1, 0.000001)
+ assert.InDelta(t, tc.l2, l2, 0.000001)
+ assert.InDelta(t, tc.cos, cos, 0.000001)
+ }
+ }
+}
+
+func TestProducts(t *testing.T) {
+ // Test inner product and negative inner product
+ testCases := []struct {
+ v1 T
+ v2 T
+ ip float64
+ negIp float64
+ err bool
+ }{
+ {v1: T{1, 2, 3}, v2: T{4, 5, 6}, ip: 32, negIp: -32, err: false},
+ {v1: T{-1, -2, -3}, v2: T{-4, -5, -6}, ip: 32, negIp: -32, err: false},
+ {v1: T{0, 0, 0}, v2: T{0, 0, 0}, ip: 0, negIp: 0, err: false},
+ {v1: T{1, 2, 3}, v2: T{1, 2, 3}, ip: 14, negIp: -14, err: false},
+ {v1: T{1, 2, 3}, v2: T{1, 2, 4}, ip: 17, negIp: -17, err: false},
+ // Different vector sizes errors.
+ {v1: T{1, 2, 3}, v2: T{4, 5}, err: true},
+ }
+
+ for _, tc := range testCases {
+ ip, ipErr := InnerProduct(tc.v1, tc.v2)
+ negIp, negIpErr := NegInnerProduct(tc.v1, tc.v2)
+
+ if tc.err {
+ assert.Error(t, ipErr)
+ assert.Error(t, negIpErr)
+ } else {
+ assert.NoError(t, ipErr)
+ assert.NoError(t, negIpErr)
+ assert.InDelta(t, tc.ip, ip, 0.000001)
+ assert.InDelta(t, tc.negIp, negIp, 0.000001)
+ }
+ }
+}
+
+func TestNorm(t *testing.T) {
+ testCases := []struct {
+ v T
+ norm float64
+ }{
+ {v: T{1, 2, 3}, norm: 3.7416573867739413},
+ {v: T{0, 0, 0}, norm: 0},
+ {v: T{-1, -2, -3}, norm: 3.7416573867739413},
+ }
+
+ for _, tc := range testCases {
+ norm := Norm(tc.v)
+ assert.InDelta(t, tc.norm, norm, 0.000001)
+ }
+}
+
+func TestPointwiseOps(t *testing.T) {
+ // Test L1, L2, Cosine distance.
+ testCases := []struct {
+ v1 T
+ v2 T
+ add T
+ minus T
+ mult T
+ err bool
+ }{
+ {v1: T{1, 2, 3}, v2: T{4, 5, 6}, add: T{5, 7, 9}, minus: T{-3, -3, -3}, mult: T{4, 10, 18}, err: false},
+ {v1: T{-1, -2, -3}, v2: T{-4, -5, -6}, add: T{-5, -7, -9}, minus: T{3, 3, 3}, mult: T{4, 10, 18}, err: false},
+ {v1: T{0, 0, 0}, v2: T{0, 0, 0}, add: T{0, 0, 0}, minus: T{0, 0, 0}, mult: T{0, 0, 0}, err: false},
+ {v1: T{1, 2, 3}, v2: T{1, 2, 3}, add: T{2, 4, 6}, minus: T{0, 0, 0}, mult: T{1, 4, 9}, err: false},
+ {v1: T{1, 2, 3}, v2: T{1, 2, 4}, add: T{2, 4, 7}, minus: T{0, 0, -1}, mult: T{1, 4, 12}, err: false},
+ // Different vector sizes errors.
+ {v1: T{1, 2, 3}, v2: T{4, 5}, err: true},
+ }
+
+ for _, tc := range testCases {
+ add, addErr := Add(tc.v1, tc.v2)
+ minus, minusErr := Minus(tc.v1, tc.v2)
+ mult, multErr := Mult(tc.v1, tc.v2)
+
+ if tc.err {
+ assert.Error(t, addErr)
+ assert.Error(t, minusErr)
+ assert.Error(t, multErr)
+ } else {
+ assert.NoError(t, addErr)
+ assert.NoError(t, minusErr)
+ assert.NoError(t, multErr)
+ assert.Equal(t, tc.add, add)
+ assert.Equal(t, tc.minus, minus)
+ assert.Equal(t, tc.mult, mult)
+ }
+ }
+}
diff --git a/pkg/workload/rand/rand.go b/pkg/workload/rand/rand.go
index f98b0391410c..f583865a82b6 100644
--- a/pkg/workload/rand/rand.go
+++ b/pkg/workload/rand/rand.go
@@ -436,6 +436,8 @@ func DatumToGoSQL(d tree.Datum) (interface{}, error) {
return d.String(), nil
case *tree.DTSVector:
return d.String(), nil
+ case *tree.DPGVector:
+ return d.String(), nil
}
return nil, errors.Errorf("unhandled datum type: %s", reflect.TypeOf(d))
}