Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder #1931

Open
wants to merge 13 commits into
base: 4.x
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public String getClassName() {
@NonNull
@Override
public String asCql(boolean includeFrozen, boolean pretty) {
return String.format("'%s(%d)'", getClassName(), getDimensions());
return String.format(
"vector<%s, %d>", this.subtype.asCql(includeFrozen, pretty), getDimensions());
}

/* ============== General class implementation ============== */
Expand Down
10 changes: 10 additions & 0 deletions query-builder/revapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -2772,6 +2772,16 @@
"code": "java.method.addedToInterface",
"new": "method com.datastax.oss.driver.api.querybuilder.update.UpdateStart com.datastax.oss.driver.api.querybuilder.update.UpdateStart::usingTtl(int)",
"justification": "JAVA-2210: Add ability to set TTL for modification queries"
},
{
"code": "java.method.addedToInterface",
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(java.lang.String, com.datastax.oss.driver.api.core.data.CqlVector<? extends java.lang.Number>)",
"justification": "JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder"
},
{
"code": "java.method.addedToInterface",
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(com.datastax.oss.driver.api.core.CqlIdentifier, com.datastax.oss.driver.api.core.data.CqlVector<? extends java.lang.Number>)",
"justification": "JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder"
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.datastax.oss.driver.api.querybuilder.select;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
import com.datastax.oss.driver.api.querybuilder.BindMarker;
import com.datastax.oss.driver.api.querybuilder.BuildableQuery;
Expand Down Expand Up @@ -146,6 +147,16 @@ default Select orderBy(@NonNull String columnName, @NonNull ClusteringOrder orde
return orderBy(CqlIdentifier.fromCql(columnName), order);
}

/**
* Shortcut for {@link #orderByAnnOf(CqlIdentifier, CqlVector)}, adding an ORDER BY ... ANN OF ...
* clause
*/
@NonNull
Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<? extends Number> ann);

/** Adds the ORDER BY ... ANN OF ... clause */
@NonNull
Select orderByAnnOf(@NonNull CqlIdentifier columnId, @NonNull CqlVector<? extends Number> ann);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm changing my answer on this: we should remove the type bound here and just make this a CqlVector<?>. My rationale goes as follows.

The fact that ANN comparisons only support float vectors actually isn't a constraint on the underlying Java type used here. In theory any Java type whose codec will generate a serialized value that can be understood as a float when the server receives the message would work here. That means we could pass in a CqlVector of floats, decimals or doubles and still have the server handle that without issue. I'll also note that there's no common supertype for Float, BigDecimal or Double (the corresponding Java types based on the docs) other than Number... and we can't use that here because it includes types other than these three. So there's no meaningful supertype we can use for a type bound at this point.

A second point: the query builder is built around the notion of generating CQL based on whatever the user passes in; there's almost no checking whether types correlate to things the user specified. So, for example, OngoingValues doesn't have any kind of type bounds around it's various value() methods. I accept that this isn't exactly the same as the float constraint we're discussing here... but it is an indication that the query builder largely doesn't concern itself with type checking with the expectation that the server will handle that when it receives the query.

Finally, I'll note that avoiding the type bound here at least exposes the idea (at an API level) of using external/custom types which serialize to CQL float values by way of their own codecs. Scala users, for example, might want to support using Spire numerics for their CQL queries. This actually won't work for now since we constrain the codec registry used to the (immutable) default which doesn't include any support for Spire types... but that's an implementation detail. By avoiding type bounds at an API level we've left that door open for such a change without having to make a (potentially breaking) API change and without sacrificing our design principles elsewhere.

/**
* Adds a LIMIT clause to this query with a literal value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.cql.SimpleStatementBuilder;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
import com.datastax.oss.driver.api.querybuilder.BindMarker;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import com.datastax.oss.driver.api.querybuilder.relation.Relation;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import com.datastax.oss.driver.api.querybuilder.select.SelectFrom;
Expand Down Expand Up @@ -49,6 +51,7 @@ public class DefaultSelect implements SelectFrom, Select {
private final ImmutableList<Relation> relations;
private final ImmutableList<Selector> groupByClauses;
private final ImmutableMap<CqlIdentifier, ClusteringOrder> orderings;
private final Ann ann;
private final Object limit;
private final Object perPartitionLimit;
private final boolean allowsFiltering;
Expand All @@ -65,6 +68,7 @@ public DefaultSelect(@Nullable CqlIdentifier keyspace, @NonNull CqlIdentifier ta
ImmutableMap.of(),
null,
null,
null,
false);
}

Expand All @@ -74,6 +78,8 @@ public DefaultSelect(@Nullable CqlIdentifier keyspace, @NonNull CqlIdentifier ta
* @param selectors if it contains {@link AllSelector#INSTANCE}, that must be the only element.
* This isn't re-checked because methods that call this constructor internally already do it,
* make sure you do it yourself.
* @param ann Approximate nearest neighbor. ANN ordering does not support secondary ordering or
* ASC order.
*/
public DefaultSelect(
@Nullable CqlIdentifier keyspace,
Expand All @@ -84,6 +90,7 @@ public DefaultSelect(
@NonNull ImmutableList<Relation> relations,
@NonNull ImmutableList<Selector> groupByClauses,
@NonNull ImmutableMap<CqlIdentifier, ClusteringOrder> orderings,
@Nullable Ann ann,
@Nullable Object limit,
@Nullable Object perPartitionLimit,
boolean allowsFiltering) {
Expand All @@ -94,6 +101,9 @@ public DefaultSelect(
|| (limit instanceof Integer && (Integer) limit > 0)
|| limit instanceof BindMarker,
"limit must be a strictly positive integer or a bind marker");
Preconditions.checkArgument(
orderings.isEmpty() || ann == null, "ANN ordering does not support secondary ordering");
this.ann = ann;
this.keyspace = keyspace;
this.table = table;
this.isJson = isJson;
Expand All @@ -117,6 +127,7 @@ public SelectFrom json() {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -134,6 +145,7 @@ public SelectFrom distinct() {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand Down Expand Up @@ -193,6 +205,7 @@ public Select withSelectors(@NonNull ImmutableList<Selector> newSelectors) {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand Down Expand Up @@ -221,6 +234,7 @@ public Select withRelations(@NonNull ImmutableList<Relation> newRelations) {
newRelations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand Down Expand Up @@ -249,6 +263,7 @@ public Select withGroupByClauses(@NonNull ImmutableList<Selector> newGroupByClau
relations,
newGroupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -260,6 +275,19 @@ public Select orderBy(@NonNull CqlIdentifier columnId, @NonNull ClusteringOrder
return withOrderings(ImmutableCollections.append(orderings, columnId, order));
}

@NonNull
@Override
public Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<? extends Number> ann) {
return withAnn(new Ann(CqlIdentifier.fromCql(columnName), ann));
}

@NonNull
@Override
public Select orderByAnnOf(
@NonNull CqlIdentifier columnId, @NonNull CqlVector<? extends Number> ann) {
return withAnn(new Ann(columnId, ann));
}

@NonNull
@Override
public Select orderByIds(@NonNull Map<CqlIdentifier, ClusteringOrder> newOrderings) {
Expand All @@ -277,6 +305,24 @@ public Select withOrderings(@NonNull ImmutableMap<CqlIdentifier, ClusteringOrder
relations,
groupByClauses,
newOrderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
}

@NonNull
Select withAnn(@NonNull Ann ann) {
return new DefaultSelect(
keyspace,
table,
isJson,
isDistinct,
selectors,
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -295,6 +341,7 @@ public Select limit(int limit) {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -312,6 +359,7 @@ public Select limit(@Nullable BindMarker bindMarker) {
relations,
groupByClauses,
orderings,
ann,
bindMarker,
perPartitionLimit,
allowsFiltering);
Expand All @@ -331,6 +379,7 @@ public Select perPartitionLimit(int perPartitionLimit) {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -348,6 +397,7 @@ public Select perPartitionLimit(@Nullable BindMarker bindMarker) {
relations,
groupByClauses,
orderings,
ann,
limit,
bindMarker,
allowsFiltering);
Expand All @@ -365,6 +415,7 @@ public Select allowFiltering() {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
true);
Expand All @@ -391,15 +442,20 @@ public String asCql() {
CqlHelper.append(relations, builder, " WHERE ", " AND ", null);
CqlHelper.append(groupByClauses, builder, " GROUP BY ", ",", null);

boolean first = true;
for (Map.Entry<CqlIdentifier, ClusteringOrder> entry : orderings.entrySet()) {
if (first) {
builder.append(" ORDER BY ");
first = false;
} else {
builder.append(",");
if (ann != null) {
builder.append(" ORDER BY ").append(this.ann.columnId.asCql(true)).append(" ANN OF ");
QueryBuilder.literal(ann.vector).appendTo(builder);
} else {
boolean first = true;
for (Map.Entry<CqlIdentifier, ClusteringOrder> entry : orderings.entrySet()) {
if (first) {
builder.append(" ORDER BY ");
first = false;
} else {
builder.append(",");
}
builder.append(entry.getKey().asCql(true)).append(" ").append(entry.getValue().name());
}
builder.append(entry.getKey().asCql(true)).append(" ").append(entry.getValue().name());
}

if (limit != null) {
Expand Down Expand Up @@ -512,4 +568,14 @@ public boolean allowsFiltering() {
public String toString() {
return asCql();
}

public static class Ann {
private final CqlVector<? extends Number> vector;
private final CqlIdentifier columnId;

private Ann(CqlIdentifier columnId, CqlVector<? extends Number> vector) {
this.vector = vector;
this.columnId = columnId;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.deleteFrom;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;

import com.datastax.oss.driver.api.core.data.CqlVector;
import org.junit.Test;

public class DeleteSelectorTest {
Expand All @@ -34,6 +35,16 @@ public void should_generate_column_deletion() {
.hasCql("DELETE v FROM ks.foo WHERE k=?");
}

@Test
public void should_generate_vector_deletion() {
assertThat(
deleteFrom("foo")
.column("v")
.whereColumn("k")
.isEqualTo(literal(CqlVector.newInstance(0.1, 0.2))))
.hasCql("DELETE v FROM foo WHERE k=[0.1, 0.2]");
}

@Test
public void should_generate_field_deletion() {
assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
import static org.assertj.core.api.Assertions.catchThrowable;

import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.querybuilder.term.Term;
import com.datastax.oss.driver.internal.querybuilder.insert.DefaultInsert;
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap;
Expand All @@ -41,6 +42,12 @@ public void should_generate_column_assignments() {
.hasCql("INSERT INTO foo (a,b) VALUES (?,?)");
}

@Test
public void should_generate_vector_literals() {
assertThat(insertInto("foo").value("a", literal(CqlVector.newInstance(0.1, 0.2, 0.3))))
.hasCql("INSERT INTO foo (a) VALUES ([0.1, 0.2, 0.3])");
}

@Test
public void should_keep_last_assignment_if_column_listed_twice() {
assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,10 @@ public void should_generate_alter_table_with_no_compression() {
assertThat(alterTable("bar").withNoCompression())
.hasCql("ALTER TABLE bar WITH compression={'sstable_compression':''}");
}

@Test
public void should_generate_alter_table_with_vector() {
assertThat(alterTable("bar").alterColumn("v", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
.hasCql("ALTER TABLE bar ALTER v TYPE vector<float, 3>");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,10 @@ public void should_generate_alter_table_with_rename_three_columns() {
assertThat(alterType("bar").renameField("x", "y").renameField("u", "v").renameField("b", "a"))
.hasCql("ALTER TYPE bar RENAME x TO y AND u TO v AND b TO a");
}

@Test
public void should_generate_alter_type_with_vector() {
assertThat(alterType("foo", "bar").alterField("vec", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
.hasCql("ALTER TYPE foo.bar ALTER vec TYPE vector<float, 3>");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,13 @@ public void should_generate_create_table_time_window_compaction() {
.hasCql(
"CREATE TABLE bar (k int PRIMARY KEY,v text) WITH compaction={'class':'TimeWindowCompactionStrategy','compaction_window_size':10,'compaction_window_unit':'DAYS','timestamp_resolution':'MICROSECONDS','unsafe_aggressive_sstable_expiration':false}");
}

@Test
public void should_generate_vector_column() {
assertThat(
createTable("foo")
.withPartitionKey("k", DataTypes.INT)
.withColumn("v", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
.hasCql("CREATE TABLE foo (k int PRIMARY KEY,v vector<float, 3>)");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,13 @@ public void should_create_type_with_collections() {
.withField("map", DataTypes.mapOf(DataTypes.INT, DataTypes.TEXT)))
.hasCql("CREATE TYPE ks1.type (map map<int, text>)");
}

@Test
public void should_create_type_with_vector() {
assertThat(
createType("ks1", "type")
.withField("c1", DataTypes.INT)
.withField("vec", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
.hasCql("CREATE TYPE ks1.type (c1 int,vec vector<float, 3>)");
}
}
Loading