From fc546e4876ae1f27862880ed33f34b7756b98de8 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Thu, 7 Dec 2023 15:04:39 +0100 Subject: [PATCH] Add support for pgvector. [closes #612] Signed-off-by: Mark Paluch --- README.md | 98 +++--- .../codec/BuiltinDynamicCodecs.java | 17 +- .../io/r2dbc/postgresql/codec/Vector.java | 128 +++++++ .../r2dbc/postgresql/codec/VectorCodec.java | 319 ++++++++++++++++++ .../postgresql/codec/VectorFloatCodec.java | 105 ++++++ .../AbstractCodecIntegrationTests.java | 13 + .../postgresql/VectorIntegrationTests.java | 202 +++++++++++ .../codec/VectorCodecUnitTests.java | 84 +++++ .../postgresql/codec/VectorUnitTests.java | 37 ++ 9 files changed, 949 insertions(+), 54 deletions(-) create mode 100644 src/main/java/io/r2dbc/postgresql/codec/Vector.java create mode 100644 src/main/java/io/r2dbc/postgresql/codec/VectorCodec.java create mode 100644 src/main/java/io/r2dbc/postgresql/codec/VectorFloatCodec.java create mode 100644 src/test/java/io/r2dbc/postgresql/VectorIntegrationTests.java create mode 100644 src/test/java/io/r2dbc/postgresql/codec/VectorCodecUnitTests.java create mode 100644 src/test/java/io/r2dbc/postgresql/codec/VectorUnitTests.java diff --git a/README.md b/README.md index c127542c..6d4d9e82 100644 --- a/README.md +++ b/README.md @@ -425,54 +425,55 @@ When available, the driver registers also an array variant of the codec. This reference table shows the type mapping between [PostgreSQL][p] and Java data types: -| PostgreSQL Type | Supported Data Type | -|:------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------| -| [`bigint`][psql-bigint-ref] | [**`Long`**][java-long-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref] | -| [`bit`][psql-bit-ref] | Not yet supported.| -| [`bit varying`][psql-bit-ref] | Not yet supported.| -| [`boolean or bool`][psql-boolean-ref] | [`Boolean`][java-boolean-ref]| -| [`box`][psql-box-ref] | **`Box`**| -| [`bytea`][psql-bytea-ref] | [**`ByteBuffer`**][java-ByteBuffer-ref], [`byte[]`][java-byte-ref], [`Blob`][r2dbc-blob-ref]| -| [`character`][psql-character-ref] | [`String`][java-string-ref]| -| [`character varying`][psql-character-ref] | [`String`][java-string-ref]| -| [`cidr`][psql-cidr-ref] | Not yet supported.| -| [`circle`][psql-circle-ref] | **`Circle`**| -| [`date`][psql-date-ref] | [`LocalDate`][java-ld-ref]| -| [`double precision`][psql-floating-point-ref] | [**`Double`**][java-double-ref], [`Float`][java-float-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| -| [enumerated types][psql-enum-ref] | Client code `Enum` types through `EnumCodec`| -| [`geometry`][postgis-ref] | **`org.locationtech.jts.geom.Geometry`**| -| [`hstore`][psql-hstore-ref] | [**`Map`**][java-map-ref]| -| [`inet`][psql-inet-ref] | [**`InetAddress`**][java-inet-ref]| -| [`integer`][psql-integer-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| -| [`interval`][psql-interval-ref] | **`Interval`**| -| [`json`][psql-json-ref] | **`Json`**, [`String`][java-string-ref]. Reading: `ByteBuf`[`byte[]`][java-primitive-ref], [`ByteBuffer`][java-ByteBuffer-ref], [`String`][java-string-ref], [`InputStream`][java-inputstream-ref]| -| [`jsonb`][psql-json-ref] | **`Json`**, [`String`][java-string-ref]. Reading: `ByteBuf`[`byte[]`][java-primitive-ref], [`ByteBuffer`][java-ByteBuffer-ref], [`String`][java-string-ref], [`InputStream`][java-inputstream-ref]| -| [`line`][psql-line-ref] | **`Line`**| -| [`lseg`][psql-lseq-ref] | **`Lseg`**| -| [`macaddr`][psql-macaddr-ref] | Not yet supported.| -| [`macaddr8`][psql-macaddr8-ref] | Not yet supported.| -| [`money`][psql-money-ref] | Not yet supported. Please don't use this type. It is a very poor implementation. | -| [`name`][psql-name-ref] | [**`String`**][java-string-ref] -| [`numeric`][psql-bignumeric-ref] | [`BigDecimal`][java-bigdecimal-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigInteger`][java-biginteger-ref]| -| [`oid`][psql-oid-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| -| [`path`][psql-path-ref] | **`Path`**| -| [`pg_lsn`][psql-pg_lsn-ref] | Not yet supported.| -| [`point`][psql-point-ref] | **`Point`**| -| [`polygon`][psql-polygon-ref] | **`Polygon`**| -| [`real`][psql-real-ref] | [**`Float`**][java-float-ref], [`Double`][java-double-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| -| [`smallint`][psql-smallint-ref] | [**`Short`**][java-short-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| -| [`smallserial`][psql-smallserial-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| -| [`serial`][psql-serial-ref] | [**`Long`**][java-long-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| -| [`text`][psql-text-ref] | [**`String`**][java-string-ref], [`Clob`][r2dbc-clob-ref]| -| [`time [without time zone]`][psql-time-ref] | [`LocalTime`][java-lt-ref]| -| [`time [with time zone]`][psql-time-ref] | [`OffsetTime`][java-ot-ref]| -| [`timestamp [without time zone]`][psql-time-ref]|[**`LocalDateTime`**][java-ldt-ref], [`LocalTime`][java-lt-ref], [`LocalDate`][java-ld-ref], [`java.util.Date`][java-legacy-date-ref]| -| [`timestamp [with time zone]`][psql-time-ref] | [**`OffsetDatetime`**][java-odt-ref], [`ZonedDateTime`][java-zdt-ref], [`Instant`][java-instant-ref]| -| [`tsquery`][psql-tsquery-ref] | Not yet supported.| -| [`tsvector`][psql-tsvector-ref] | Not yet supported.| -| [`txid_snapshot`][psql-txid_snapshot-ref] | Not yet supported.| -| [`uuid`][psql-uuid-ref] | [**`UUID`**][java-uuid-ref], [`String`][java-string-ref]|| -| [`xml`][psql-xml-ref] | Not yet supported. | +| PostgreSQL Type | Supported Data Type | +|:-------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------| +| [`bigint`][psql-bigint-ref] | [**`Long`**][java-long-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref] | +| [`bit`][psql-bit-ref] | Not yet supported.| +| [`bit varying`][psql-bit-ref] | Not yet supported.| +| [`boolean or bool`][psql-boolean-ref] | [`Boolean`][java-boolean-ref]| +| [`box`][psql-box-ref] | **`Box`**| +| [`bytea`][psql-bytea-ref] | [**`ByteBuffer`**][java-ByteBuffer-ref], [`byte[]`][java-byte-ref], [`Blob`][r2dbc-blob-ref]| +| [`character`][psql-character-ref] | [`String`][java-string-ref]| +| [`character varying`][psql-character-ref] | [`String`][java-string-ref]| +| [`cidr`][psql-cidr-ref] | Not yet supported.| +| [`circle`][psql-circle-ref] | **`Circle`**| +| [`date`][psql-date-ref] | [`LocalDate`][java-ld-ref]| +| [`double precision`][psql-floating-point-ref] | [**`Double`**][java-double-ref], [`Float`][java-float-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| +| [enumerated types][psql-enum-ref] | Client code `Enum` types through `EnumCodec`| +| [`geometry`][postgis-ref] | **`org.locationtech.jts.geom.Geometry`**| +| [`hstore`][psql-hstore-ref] | [**`Map`**][java-map-ref]| +| [`inet`][psql-inet-ref] | [**`InetAddress`**][java-inet-ref]| +| [`integer`][psql-integer-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| +| [`interval`][psql-interval-ref] | **`Interval`**| +| [`json`][psql-json-ref] | **`Json`**, [`String`][java-string-ref]. Reading: `ByteBuf`[`byte[]`][java-primitive-ref], [`ByteBuffer`][java-ByteBuffer-ref], [`String`][java-string-ref], [`InputStream`][java-inputstream-ref]| +| [`jsonb`][psql-json-ref] | **`Json`**, [`String`][java-string-ref]. Reading: `ByteBuf`[`byte[]`][java-primitive-ref], [`ByteBuffer`][java-ByteBuffer-ref], [`String`][java-string-ref], [`InputStream`][java-inputstream-ref]| +| [`line`][psql-line-ref] | **`Line`**| +| [`lseg`][psql-lseq-ref] | **`Lseg`**| +| [`macaddr`][psql-macaddr-ref] | Not yet supported.| +| [`macaddr8`][psql-macaddr8-ref] | Not yet supported.| +| [`money`][psql-money-ref] | Not yet supported. Please don't use this type. It is a very poor implementation. | +| [`name`][psql-name-ref] | [**`String`**][java-string-ref] +| [`numeric`][psql-bignumeric-ref] | [`BigDecimal`][java-bigdecimal-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigInteger`][java-biginteger-ref]| +| [`oid`][psql-oid-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| +| [`path`][psql-path-ref] | **`Path`**| +| [`pg_lsn`][psql-pg_lsn-ref] | Not yet supported.| +| [`point`][psql-point-ref] | **`Point`**| +| [`polygon`][psql-polygon-ref] | **`Polygon`**| +| [`real`][psql-real-ref] | [**`Float`**][java-float-ref], [`Double`][java-double-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| +| [`smallint`][psql-smallint-ref] | [**`Short`**][java-short-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Integer`][java-integer-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| +| [`smallserial`][psql-smallserial-ref] | [**`Integer`**][java-integer-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Long`][java-long-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| +| [`serial`][psql-serial-ref] | [**`Long`**][java-long-ref], [`Boolean`][java-boolean-ref], [`Byte`][java-byte-ref], [`Short`][java-short-ref], [`Integer`][java-integer-ref], [`BigDecimal`][java-bigdecimal-ref], [`BigInteger`][java-biginteger-ref]| +| [`text`][psql-text-ref] | [**`String`**][java-string-ref], [`Clob`][r2dbc-clob-ref]| +| [`time [without time zone]`][psql-time-ref] | [`LocalTime`][java-lt-ref]| +| [`time [with time zone]`][psql-time-ref] | [`OffsetTime`][java-ot-ref]| +| [`timestamp [without time zone]`][psql-time-ref] |[**`LocalDateTime`**][java-ldt-ref], [`LocalTime`][java-lt-ref], [`LocalDate`][java-ld-ref], [`java.util.Date`][java-legacy-date-ref]| +| [`timestamp [with time zone]`][psql-time-ref] | [**`OffsetDatetime`**][java-odt-ref], [`ZonedDateTime`][java-zdt-ref], [`Instant`][java-instant-ref]| +| [`tsquery`][psql-tsquery-ref] | Not yet supported.| +| [`tsvector`][psql-tsvector-ref] | Not yet supported.| +| [`txid_snapshot`][psql-txid_snapshot-ref] | Not yet supported.| +| [`uuid`][psql-uuid-ref] | [**`UUID`**][java-uuid-ref], [`String`][java-string-ref]|| +| [`xml`][psql-xml-ref] | Not yet supported. | +| [`vector`][psql-vector-ref] | **`Vector`**, [`float[]`][java-float-ref] | Types in **bold** indicate the native (default) Java type. @@ -550,6 +551,7 @@ Support for the following single-dimensional arrays (read and write): [psql-xml-ref]: https://www.postgresql.org/docs/current/datatype-xml.html [psql-runtime-config]: https://www.postgresql.org/docs/current/runtime-config-client.html [postgis-ref]: http://postgis.net/workshops/postgis-intro/geometries.html +[psql-vector-ref]: https://github.com/pgvector/pgvector [r2dbc-blob-ref]: https://r2dbc.io/spec/0.9.0.RELEASE/api/io/r2dbc/spi/Blob.html [r2dbc-clob-ref]: https://r2dbc.io/spec/0.9.0.RELEASE/api/io/r2dbc/spi/Clob.html diff --git a/src/main/java/io/r2dbc/postgresql/codec/BuiltinDynamicCodecs.java b/src/main/java/io/r2dbc/postgresql/codec/BuiltinDynamicCodecs.java index 53bcadb1..917fec50 100644 --- a/src/main/java/io/r2dbc/postgresql/codec/BuiltinDynamicCodecs.java +++ b/src/main/java/io/r2dbc/postgresql/codec/BuiltinDynamicCodecs.java @@ -24,6 +24,7 @@ import reactor.util.annotation.Nullable; import java.util.Arrays; +import java.util.Collections; import java.util.stream.Collectors; /** @@ -44,7 +45,7 @@ enum BuiltinCodec { public boolean isSupported() { return this.jtsPresent; } - }; + }, VECTOR("vector"); private final String name; @@ -52,13 +53,16 @@ public boolean isSupported() { this.name = name; } - public Codec createCodec(ByteBufAllocator byteBufAllocator, int oid) { + public Iterable> createCodec(ByteBufAllocator byteBufAllocator, int oid, int typarray) { switch (this) { case HSTORE: - return new HStoreCodec(byteBufAllocator, oid); + return Collections.singletonList(new HStoreCodec(byteBufAllocator, oid)); case POSTGIS_GEOMETRY: - return new PostgisGeometryCodec(oid); + return Collections.singletonList(new PostgisGeometryCodec(oid)); + case VECTOR: + VectorCodec vectorCodec = new VectorCodec(byteBufAllocator, oid, typarray); + return Arrays.asList(vectorCodec, new VectorCodec.VectorArrayCodec(byteBufAllocator, vectorCodec), new VectorFloatCodec(byteBufAllocator, oid)); default: throw new UnsupportedOperationException(String.format("Codec %s for OID %d not supported", name(), oid)); } @@ -93,11 +97,12 @@ public Publisher register(PostgresqlConnection connection, ByteBufAllocato .flatMap(it -> it.map((row, rowMetadata) -> { int oid = PostgresqlObjectId.toInt(row.get("oid", Long.class)); + int typarray = PostgresqlObjectId.toInt(row.get("typarray", Long.class)); String typname = row.get("typname", String.class); BuiltinCodec lookup = BuiltinCodec.lookup(typname); if (lookup.isSupported()) { - registry.addLast(lookup.createCodec(byteBufAllocator, oid)); + lookup.createCodec(byteBufAllocator, oid, typarray).forEach(registry::addLast); } return EMPTY; @@ -106,7 +111,7 @@ public Publisher register(PostgresqlConnection connection, ByteBufAllocato } private PostgresqlStatement createQuery(PostgresqlConnection connection) { - return connection.createStatement(String.format("SELECT oid, typname FROM pg_catalog.pg_type WHERE typname IN (%s)", getPlaceholders())); + return connection.createStatement(String.format("SELECT oid, typname, typarray FROM pg_catalog.pg_type WHERE typname IN (%s)", getPlaceholders())); } private static String getPlaceholders() { diff --git a/src/main/java/io/r2dbc/postgresql/codec/Vector.java b/src/main/java/io/r2dbc/postgresql/codec/Vector.java new file mode 100644 index 00000000..19811799 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/codec/Vector.java @@ -0,0 +1,128 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql.codec; + +import io.r2dbc.postgresql.util.Assert; + +import java.util.Arrays; +import java.util.Collection; + +/** + * Value object that maps to the {@code vector} datatype provided by Postgres pgvector. + * + * @since 1.0.3 + */ +public class Vector { + + private static final Vector EMPTY = new Vector(new float[0]); + + private final float[] vec; + + private Vector(float[] vec) { + this.vec = Assert.requireNonNull(vec, "Vector must not be null"); + } + + /** + * Create a new empty {@link Vector}. + * + * @return the empty {@link Vector} object + */ + public static Vector empty() { + return EMPTY; + } + + /** + * Create a new {@link Vector} given {@code vector} points. + * + * @param vec the vector values + * @return the new {@link Vector} object + */ + public static Vector of(float... vec) { + Assert.requireNonNull(vec, "Vector must not be null"); + return vec.length == 0 ? empty() : new Vector(vec); + } + + /** + * Create a new {@link Vector} given {@code vector} points. + * + * @param vec the vector values + * @return the new {@link Vector} object + */ + public static Vector of(Collection vec) { + Assert.requireNonNull(vec, "Vector must not be null"); + + if (vec.isEmpty()) { + return empty(); + } + + float[] floats = new float[vec.size()]; + int index = 0; + for (Number number : vec) { + Number next = Assert.requireNonNull(number, "Vector must not contain null elements"); + floats[index++] = next.floatValue(); + } + + return new Vector(floats); + } + + /** + * Return the vector values. + * + * @return the vector values. + */ + public float[] getVector() { + if (this.vec.length == 0) { + return this.vec; + } + float[] copy = new float[this.vec.length]; + System.arraycopy(this.vec, 0, copy, 0, this.vec.length); + return copy; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Vector other = (Vector) o; + return Arrays.equals(this.vec, other.vec); + } + + @Override + public int hashCode() { + return Arrays.hashCode(this.vec); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append('['); + + for (int i = 0; i < this.vec.length; i++) { + if (i != 0) { + builder.append(','); + } + builder.append(this.vec[i]); + } + builder.append(']'); + + return builder.toString(); + } +} diff --git a/src/main/java/io/r2dbc/postgresql/codec/VectorCodec.java b/src/main/java/io/r2dbc/postgresql/codec/VectorCodec.java new file mode 100644 index 00000000..a93507e5 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/codec/VectorCodec.java @@ -0,0 +1,319 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.r2dbc.postgresql.client.EncodedParameter; +import io.r2dbc.postgresql.message.Format; +import io.r2dbc.postgresql.util.Assert; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static io.r2dbc.postgresql.client.EncodedParameter.NULL_VALUE; +import static io.r2dbc.postgresql.message.Format.FORMAT_BINARY; +import static io.r2dbc.postgresql.message.Format.FORMAT_TEXT; + +/** + * Codec for pgvector. + * + * @since 1.0.3 + */ +public class VectorCodec implements Codec, CodecMetadata, ArrayCodecDelegate { + + private final ByteBufAllocator byteBufAllocator; + + private final int oid; + + private final PostgresTypeIdentifier arrayOid; + + VectorCodec(ByteBufAllocator byteBufAllocator, int oid, int arrayOid) { + this.byteBufAllocator = Assert.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); + this.oid = oid; + this.arrayOid = () -> arrayOid; + } + + @Override + public EncodedParameter encodeNull() { + return new EncodedParameter(Format.FORMAT_BINARY, this.oid, NULL_VALUE); + } + + @Override + public Class type() { + return Vector.class; + } + + @Override + public PostgresTypeIdentifier getArrayDataType() { + return this.arrayOid; + } + + @Override + public Iterable getDataTypes() { + return Collections.singleton(AbstractCodec.getDataType(this.oid)); + } + + @Override + public boolean canDecode(int dataType, Format format, Class type) { + Assert.requireNonNull(format, "format must not be null"); + Assert.requireNonNull(type, "type must not be null"); + + return dataType == this.oid && (type == Object.class || Vector.class.isAssignableFrom(type) || type.isAssignableFrom(float[].class) || type.isAssignableFrom(Float[].class)); + } + + @Override + public boolean canEncode(Object value) { + return value instanceof Vector; + } + + @Override + public boolean canEncodeNull(Class type) { + return (type == Object.class || Vector.class.isAssignableFrom(type) || type.isAssignableFrom(float[].class) || type.isAssignableFrom(Float[].class)); + } + + @Override + public Vector decode(ByteBuf buffer, PostgresTypeIdentifier dataType, Format format, Class type) { + return decode(buffer, dataType.getObjectId(), format, type); + } + + @Override + public Vector decode(@Nullable ByteBuf buffer, int dataType, Format format, Class type) { + return buffer == null ? null : Vector.of(decode(buffer, format)); + } + + static float[] decode(ByteBuf buffer, Format format) { + + if (format == Format.FORMAT_TEXT) { + List objects = decodeText(buffer, (byte) ','); + float[] values = new float[objects.size()]; + for (int i = 0; i < objects.size(); i++) { + Number v = objects.get(i); + values[i] = v.floatValue(); + } + + return values; + } + + int dim = buffer.readShort(); + buffer.readShort(); // unused + float[] values = new float[dim]; + for (int i = 0; i < dim; i++) { + values[i] = buffer.readFloat(); + } + + return values; + } + + @Override + public EncodedParameter encode(Object value) { + return encode(value, this.oid); + } + + @Override + public EncodedParameter encode(Object value, int dataType) { + + Assert.requireNonNull(value, "value must not be null"); + + return new EncodedParameter(FORMAT_BINARY, dataType, Mono.fromSupplier(() -> encodeBinary(((Vector) value).getVector()))); + } + + @Override + public String encodeToText(Vector value) { + return value.toString(); + } + + private ByteBuf allocateBuffer(int dim) { + return this.byteBufAllocator.buffer(estimateBufferSize(dim)); + } + + private ByteBuf encodeBinary(float[] values) { + ByteBuf buffer = allocateBuffer(values.length); + encodeBinary(buffer, values); + return buffer; + } + + static void encodeBinary(ByteBuf buffer, float[] values) { + prepareBuffer(buffer, values.length); + + for (float v : values) { + buffer.writeFloat(v); + } + } + + static int estimateBufferSize(int dim) { + return 2 + 2 + (4 * dim); + } + + private static void prepareBuffer(ByteBuf buffer, int dim) { + buffer.writeShort(dim); + buffer.writeShort(0); // unused + } + + // duplicates ArrayCodec parsing. + private static List decodeText(ByteBuf buf, byte delimiter) { + + boolean insideString = false; + boolean wasInsideString = false; // needed for checking if NULL + // value occurred + List decoded = new ArrayList<>(); // array dimension arrays + + int indentEscape = 0; + int readFrom = 0; + boolean requiresEscapeCharFiltering = false; + while (buf.isReadable()) { + + byte currentChar = buf.readByte(); + // escape character that we need to skip + + if (currentChar == '\\') { + indentEscape++; + buf.skipBytes(1); + requiresEscapeCharFiltering = true; + } else if (!insideString && currentChar == '[') { + // subarray start + + for (int t = indentEscape + 1; t < buf.writerIndex(); t++) { + if (!Character.isWhitespace(buf.getByte(t)) && buf.getByte(t) != '[') { + break; + } + } + + readFrom = buf.readerIndex(); + } else if (currentChar == '"') { + // quoted element + insideString = !insideString; + wasInsideString = true; + } else if (!insideString && Character.isWhitespace(currentChar)) { + // white space + continue; + } else if ((!insideString && (currentChar == delimiter || currentChar == ']')) + || indentEscape == buf.writerIndex() - 1) { + // array end or element end + // when character that is a part of array element + int skipTrailingBytes = 0; + if (currentChar != ']' && currentChar != delimiter && readFrom > 0) { + skipTrailingBytes++; + } + + if (wasInsideString) { + readFrom++; + skipTrailingBytes++; + } + + ByteBuf slice = buf.slice(readFrom, (buf.readerIndex() - readFrom) - (skipTrailingBytes + /* skip current char as we've over-read */ 1)); + try { + if (requiresEscapeCharFiltering) { + ByteBuf filtered = slice.alloc().buffer(slice.readableBytes()); + while (slice.isReadable()) { + byte ch = slice.readByte(); + if (ch == '\\') { + ch = slice.readByte(); + } + filtered.writeByte(ch); + } + slice = filtered; + } + + // add element to current array + if (slice.isReadable() || wasInsideString) { + if (!wasInsideString && slice.readableBytes() == 4 && slice.getByte(0) == 'N' && "NULL".equals(slice.toString(StandardCharsets.US_ASCII))) { + decoded.add(null); + } else { + decoded.add(NumericDecodeUtils.decodeNumber(slice, PostgresqlObjectId.FLOAT4, FORMAT_TEXT)); + } + } + } finally { + + if (requiresEscapeCharFiltering) { + slice.release(); + } + } + + wasInsideString = false; + requiresEscapeCharFiltering = false; + readFrom = buf.readerIndex(); + } + } + + return decoded; + } + + /** + * Array support for single-dimensional Vector arrays. + */ + static class VectorArrayCodec extends ArrayCodec { + + private final ByteBufAllocator byteBufAllocator; + + public VectorArrayCodec(ByteBufAllocator byteBufAllocator, VectorCodec delegate) { + super(byteBufAllocator, delegate, Vector.class); + this.byteBufAllocator = byteBufAllocator; + } + + @Override + public VectorCodec getDelegate() { + return (VectorCodec) super.getDelegate(); + } + + @Override + EncodedParameter doEncode(Object[] value, PostgresTypeIdentifier dataType) { + boolean hasNulls = hasNulls(value); + + return new EncodedParameter(FORMAT_BINARY, dataType.getObjectId(), Mono.fromSupplier(() -> { + + ByteBuf buffer = this.byteBufAllocator.buffer(); + buffer.writeInt(1); // 1-dimensional vectors supported for now. + + buffer.writeInt(hasNulls ? 1 : 0); // flags: 0=no-nulls, 1=has-nulls + buffer.writeInt(getDelegate().getArrayDataType().getObjectId()); + + buffer.writeInt(value.length); // dimension size + buffer.writeInt(0); // lower bound ignored + + for (Object o : value) { + + if (o == null) { + buffer.writeInt(-1); + } else { + ByteBuf nested = this.byteBufAllocator.buffer(); + VectorCodec.encodeBinary(nested, ((Vector) o).getVector()); + buffer.writeInt(nested.readableBytes()); + buffer.writeBytes(nested); + nested.release(); + } + } + + return buffer; + })); + } + + private static boolean hasNulls(Object[] value) { + for (Object o : value) { + if (o == null) { + return true; + } + } + return false; + } + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/codec/VectorFloatCodec.java b/src/main/java/io/r2dbc/postgresql/codec/VectorFloatCodec.java new file mode 100644 index 00000000..af985d06 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/codec/VectorFloatCodec.java @@ -0,0 +1,105 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.r2dbc.postgresql.client.EncodedParameter; +import io.r2dbc.postgresql.message.Format; +import io.r2dbc.postgresql.util.Assert; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +import java.util.Collections; + +import static io.r2dbc.postgresql.client.EncodedParameter.NULL_VALUE; +import static io.r2dbc.postgresql.message.Format.FORMAT_BINARY; + +public class VectorFloatCodec implements Codec, CodecMetadata { + + private final ByteBufAllocator byteBufAllocator; + + private final int oid; + + VectorFloatCodec(ByteBufAllocator byteBufAllocator, int oid) { + this.byteBufAllocator = Assert.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); + this.oid = oid; + } + + @Override + public EncodedParameter encodeNull() { + return new EncodedParameter(Format.FORMAT_BINARY, this.oid, NULL_VALUE); + } + + @Override + public Class type() { + return float[].class; + } + + @Override + public Iterable getDataTypes() { + return Collections.singleton(AbstractCodec.getDataType(this.oid)); + } + + @Override + public boolean canDecode(int dataType, Format format, Class type) { + Assert.requireNonNull(format, "format must not be null"); + Assert.requireNonNull(type, "type must not be null"); + + return dataType == this.oid && type == float[].class; + } + + @Override + public boolean canEncode(Object value) { + return value instanceof float[]; + } + + @Override + public boolean canEncodeNull(Class type) { + return false; + } + + @Override + public float[] decode(@Nullable ByteBuf buffer, int dataType, Format format, Class type) { + + if (buffer == null) { + return null; + } + + return VectorCodec.decode(buffer, format); + } + + @Override + public EncodedParameter encode(Object value) { + return encode(value, this.oid); + } + + @Override + public EncodedParameter encode(Object value, int dataType) { + + Assert.requireNonNull(value, "value must not be null"); + + float[] vec = (float[]) value; + return new EncodedParameter(FORMAT_BINARY, dataType, Mono.fromSupplier(() -> { + + ByteBuf buffer = this.byteBufAllocator.buffer(VectorCodec.estimateBufferSize(vec.length)); + VectorCodec.encodeBinary(buffer, vec); + return buffer; + })); + } + +} diff --git a/src/test/java/io/r2dbc/postgresql/AbstractCodecIntegrationTests.java b/src/test/java/io/r2dbc/postgresql/AbstractCodecIntegrationTests.java index 011813c1..be231a8b 100644 --- a/src/test/java/io/r2dbc/postgresql/AbstractCodecIntegrationTests.java +++ b/src/test/java/io/r2dbc/postgresql/AbstractCodecIntegrationTests.java @@ -31,6 +31,7 @@ import io.r2dbc.postgresql.codec.Point; import io.r2dbc.postgresql.codec.Polygon; import io.r2dbc.postgresql.codec.PostgresqlObjectId; +import io.r2dbc.postgresql.codec.Vector; import io.r2dbc.spi.Blob; import io.r2dbc.spi.Clob; import io.r2dbc.spi.Connection; @@ -70,6 +71,7 @@ import java.time.temporal.ChronoUnit; import java.util.Date; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.function.BiConsumer; @@ -78,6 +80,7 @@ import static io.r2dbc.postgresql.util.TestByteBufAllocator.TEST; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; /** * Integrations tests for our built-in codecs. @@ -690,6 +693,16 @@ void polygonTwoDimensionalArray() { Point.of(.42, 5.3), Point.of(-3.5, 0.)), null}}, "POLYGON[][]"); } + @Test + void vector() { + + List> extensions = SERVER.getJdbcOperations().queryForList("select * from pg_available_extensions() where name = 'vector'"); + assumeThat(extensions).isNotEmpty(); + + SERVER.getJdbcOperations().execute("CREATE EXTENSION IF NOT EXISTS vector"); + testCodec(Vector.class, Vector.of(1, 2.2f, 3), "VECTOR"); + } + private static Mono close(Connection connection) { return Mono.from(connection.close()).then(Mono.empty()); } diff --git a/src/test/java/io/r2dbc/postgresql/VectorIntegrationTests.java b/src/test/java/io/r2dbc/postgresql/VectorIntegrationTests.java new file mode 100644 index 00000000..7e1edb59 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/VectorIntegrationTests.java @@ -0,0 +1,202 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.api.PostgresqlConnection; +import io.r2dbc.postgresql.api.PostgresqlResult; +import io.r2dbc.postgresql.codec.Vector; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.jdbc.core.JdbcOperations; +import reactor.test.StepVerifier; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assumptions.assumeThat; + +/** + * Integration tests for {@link Vector}. + */ +final class VectorIntegrationTests extends AbstractIntegrationTests { + + @Override + @BeforeEach + void setUp() { + + JdbcOperations jdbcOperations = SERVER.getJdbcOperations(); + List> extensions = jdbcOperations.queryForList("select * from pg_available_extensions() where name = 'vector'"); + assumeThat(extensions).isNotEmpty(); + + jdbcOperations.execute("CREATE EXTENSION IF NOT EXISTS VECTOR;"); + jdbcOperations.execute("CREATE TABLE IF NOT EXISTS vector_arrays (id bigserial PRIMARY KEY, embedding vector[3]);"); + jdbcOperations.execute("CREATE TABLE IF NOT EXISTS vector_items (id bigserial PRIMARY KEY, embedding vector(3));"); + + jdbcOperations.execute("DELETE FROM vector_arrays;"); + jdbcOperations.execute("DELETE FROM vector_items;"); + + super.setUp(); + } + + @Override + protected void customize(PostgresqlConnectionConfiguration.Builder builder) { + builder.forceBinary(true); + super.customize(builder); + } + + @Test + void shouldReadVector() { + + JdbcOperations jdbcOperations = SERVER.getJdbcOperations(); + + jdbcOperations.execute("INSERT INTO vector_items (embedding) VALUES ('[1,2,3]'), ('[4,5,6]');"); + + PostgresqlConnection connection = this.connectionFactory.create().block(); + + connection.createStatement("SELECT * FROM vector_items WHERE id != $1 ORDER BY embedding <-> '[3,1,2]' ") + .bind("$1", 1) + .execute() + .flatMap(result -> result.map(readable -> readable.get("embedding"))) + .as(StepVerifier::create) + .assertNext(o -> { + assertThat(o).isInstanceOf(Vector.class).isEqualTo(Vector.of(1, 2, 3)); + }) + .assertNext(o -> { + assertThat(o).isInstanceOf(Vector.class).isEqualTo(Vector.of(4, 5, 6)); + }) + .verifyComplete(); + + connection.createStatement("SELECT * FROM vector_items WHERE id != $1 ORDER BY embedding <-> '[3,1,2]' ") + .bind("$1", 1) + .execute() + .flatMap(result -> result.map(readable -> readable.get("embedding", float[].class))) + .as(StepVerifier::create) + .assertNext(o -> { + assertThat(o).contains(1f, 2f, 3f); + }) + .assertNext(o -> { + assertThat(o).contains(4, 5, 6); + }) + .verifyComplete(); + + connection.close().block(); + } + + @Test + void shouldReadVectorArray() { + + JdbcOperations jdbcOperations = SERVER.getJdbcOperations(); + jdbcOperations.execute("INSERT INTO vector_arrays (embedding) VALUES (ARRAY['[1,2,3]'::vector,'[4,5,6]'::vector]);"); + + PostgresqlConnection connection = this.connectionFactory.create().block(); + + connection.createStatement("SELECT * FROM vector_arrays WHERE id != $1") + .bind("$1", 1) + .execute() + .flatMap(result -> result.map(readable -> readable.get("embedding"))) + .as(StepVerifier::create) + .assertNext(o -> { + assertThat(o).isInstanceOf(Vector[].class); + assertThat((Vector[]) o).contains(Vector.of(1, 2, 3), Vector.of(4, 5, 6)); + }) + .verifyComplete(); + + connection.close().block(); + } + + @Test + void shouldWriteVector() { + + PostgresqlConnection connection = this.connectionFactory.create().block(); + + connection.createStatement("INSERT INTO vector_items (embedding) VALUES ($1)") + .bind("$1", Vector.of(1, 2, 3)) + .execute() + .flatMap(PostgresqlResult::getRowsUpdated) + .as(StepVerifier::create) + .expectNext(1L) + .verifyComplete(); + + connection.createStatement("SELECT * FROM vector_items") + .execute() + .flatMap(result -> result.map(readable -> readable.get("embedding"))) + .as(StepVerifier::create) + .assertNext(o -> { + assertThat(o).isInstanceOf(Vector.class).isEqualTo(Vector.of(1, 2, 3)); + }).verifyComplete(); + + connection.close().block(); + } + + @Test + void shouldWriteVectorArray() { + + PostgresqlConnection connection = this.connectionFactory.create().block(); + + connection.createStatement("INSERT INTO vector_arrays (embedding) VALUES ($1);") + .bind("$1", new Vector[]{Vector.of(1, 2, 3), Vector.of(4, 5, 6)}) + .execute() + .flatMap(PostgresqlResult::getRowsUpdated) + .as(StepVerifier::create) + .expectNext(1L) + .verifyComplete(); + + connection.createStatement("SELECT * FROM vector_arrays WHERE id != $1") + .bind("$1", 1) + .execute() + .flatMap(result -> result.map(readable -> readable.get("embedding"))) + .as(StepVerifier::create) + .assertNext(o -> { + assertThat(o).isInstanceOf(Vector[].class); + assertThat((Vector[]) o).contains(Vector.of(1, 2, 3), Vector.of(4, 5, 6)); + }) + .verifyComplete(); + + connection.close().block(); + } + + @Test + void shouldWriteVectorArrayWithNulls() { + + JdbcOperations jdbcOperations = SERVER.getJdbcOperations(); + + PostgresqlConnection connection = this.connectionFactory.create().block(); + + connection.createStatement("INSERT INTO vector_arrays (embedding) VALUES ($1);") + .bind("$1", new Vector[]{Vector.of(1, 2, 3), null, Vector.of(4, 5, 6)}) + .execute() + .flatMap(PostgresqlResult::getRowsUpdated) + .as(StepVerifier::create) + .expectNext(1L) + .verifyComplete(); + + connection.createStatement("SELECT * FROM vector_arrays WHERE id != $1") + .bind("$1", 1) + .execute() + .flatMap(result -> result.map(readable -> readable.get("embedding"))) + .as(StepVerifier::create) + .assertNext(o -> { + assertThat(o).isInstanceOf(Vector[].class); + assertThat((Vector[]) o).contains(Vector.of(1, 2, 3), Vector.of(4, 5, 6)); + }) + .verifyComplete(); + + connection.close().block(); + } + +} diff --git a/src/test/java/io/r2dbc/postgresql/codec/VectorCodecUnitTests.java b/src/test/java/io/r2dbc/postgresql/codec/VectorCodecUnitTests.java new file mode 100644 index 00000000..04987995 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/codec/VectorCodecUnitTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql.codec; + +import io.netty.buffer.ByteBuf; +import io.r2dbc.postgresql.client.EncodedParameter; +import io.r2dbc.postgresql.client.ParameterAssert; +import org.junit.jupiter.api.Test; + +import static io.r2dbc.postgresql.client.EncodedParameter.NULL_VALUE; +import static io.r2dbc.postgresql.message.Format.FORMAT_BINARY; +import static io.r2dbc.postgresql.message.Format.FORMAT_TEXT; +import static io.r2dbc.postgresql.util.ByteBufUtils.encode; +import static io.r2dbc.postgresql.util.TestByteBufAllocator.TEST; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Unit tests for {@link VectorCodec}. + */ +class VectorCodecUnitTests { + + @Test + void constructorNoByteBufAllocator() { + assertThatIllegalArgumentException().isThrownBy(() -> new VectorCodec(null, 0,0)) + .withMessage("byteBufAllocator must not be null"); + } + + @Test + void doCanDecode() { + VectorCodec codec = new VectorCodec(TEST, 16384,0); + + assertThat(codec.canDecode(16384, FORMAT_BINARY, Object.class)).isTrue(); + assertThat(codec.canDecode(16384, FORMAT_TEXT, Object.class)).isTrue(); + assertThat(codec.canDecode(16384, FORMAT_BINARY, Vector.class)).isTrue(); + assertThat(codec.canDecode(16384, FORMAT_TEXT, Vector.class)).isTrue(); + assertThat(codec.canDecode(16384, FORMAT_BINARY, Float[].class)).isTrue(); + assertThat(codec.canDecode(16384, FORMAT_BINARY, float[].class)).isTrue(); + } + + @Test + void decodeText() { + VectorCodec codec = new VectorCodec(TEST, 16384,0); + Vector vector = Vector.of(1.1f, 2.2f, 3f); + ByteBuf vectorAsText = encode(TEST, "[1.1,2.2,3]"); + assertThat(codec.decode(vectorAsText, 16384, FORMAT_TEXT, Vector.class)).isEqualTo(vector); + } + + @Test + void decodeBinary() { + VectorCodec codec = new VectorCodec(TEST, 16384,0); + Vector vector = Vector.of(1.1f, 2.2f, 3f); + ByteBuf vectorAsBinary = TEST.buffer(16).writeShort(3).writeShort(0).writeFloat(1.1f).writeFloat(2.2f).writeFloat(3f); + assertThat(codec.decode(vectorAsBinary, 16384, FORMAT_BINARY, Vector.class)).isEqualTo(vector); + } + + @Test + void encodeNull() { + ParameterAssert.assertThat(new VectorCodec(TEST, 1234,0).encodeNull()) + .isEqualTo(new EncodedParameter(FORMAT_BINARY, 1234, NULL_VALUE)); + } + + @Test + void encodeBinary() { + ParameterAssert.assertThat(new VectorCodec(TEST, 1234,0).encode(Vector.of(1.1f, 2.2f, 3f))) + .hasFormat(FORMAT_BINARY) + .hasType(1234) + .hasValue(TEST.buffer(16).writeShort(3).writeShort(0).writeFloat(1.1f).writeFloat(2.2f).writeFloat(3f)); + } +} diff --git a/src/test/java/io/r2dbc/postgresql/codec/VectorUnitTests.java b/src/test/java/io/r2dbc/postgresql/codec/VectorUnitTests.java new file mode 100644 index 00000000..65f92f93 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/codec/VectorUnitTests.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql.codec; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link Vector}. + */ +class VectorUnitTests { + + @Test + void createsVectorCorrectly() { + + assertThat(Vector.of(1, 2, 3)).isEqualTo(Vector.of(1, 2, 3)).hasSameHashCodeAs(Vector.of(1, 2, 3)).hasToString("[1.0,2.0,3.0]"); + assertThat(Vector.of(1, 2, 3)).isNotEqualTo(Vector.of(2, 1, 3)); + assertThat(Vector.of(Arrays.asList(1, 2, 3))).isEqualTo(Vector.of(1, 2, 3)); + } +}