diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 88772169e260c..7d385c189479b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -28,15 +28,6 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.queries.function.FunctionQuery; -import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource; -import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource; -import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource; -import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource; -import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; @@ -67,6 +58,7 @@ import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.aggregations.support.CoreValuesSourceType; +import org.elasticsearch.search.vectors.DenseVectorQuery; import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery; import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; @@ -1484,19 +1476,7 @@ private Query createExactKnnByteQuery(byte[] queryVector) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } - VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType); - return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER) - .add( - new FunctionQuery( - new ByteVectorSimilarityFunction( - vectorSimilarityFunction, - new ByteKnnVectorFieldSource(name()), - new ConstKnnByteVectorValueSource(queryVector) - ) - ), - BooleanClause.Occur.SHOULD - ) - .build(); + return new DenseVectorQuery.Bytes(queryVector, name()); } private Query createExactKnnFloatQuery(float[] queryVector) { @@ -1519,19 +1499,7 @@ && isNotUnitVector(squaredMagnitude)) { } } } - VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType); - return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER) - .add( - new FunctionQuery( - new FloatVectorSimilarityFunction( - vectorSimilarityFunction, - new FloatKnnVectorFieldSource(name()), - new ConstKnnFloatValueSource(queryVector) - ) - ), - BooleanClause.Occur.SHOULD - ) - .build(); + return new DenseVectorQuery.Floats(queryVector, name()); } Query createKnnQuery(float[] queryVector, int numCands, Query filter, Float similarityThreshold, BitSetProducer parentFilter) { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java new file mode 100644 index 0000000000000..8fd59a0e6f224 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java @@ -0,0 +1,209 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * Exact knn query. Will iterate and score all documents that have the provided dense vector field in the index. + */ +public abstract class DenseVectorQuery extends Query { + + protected final String field; + + public DenseVectorQuery(String field) { + this.field = field; + } + + @Override + public void visit(QueryVisitor queryVisitor) { + queryVisitor.visitLeaf(this); + } + + abstract static class DenseVectorWeight extends Weight { + private final String field; + private final float boost; + + protected DenseVectorWeight(DenseVectorQuery query, float boost) { + super(query); + this.field = query.field; + this.boost = boost; + } + + abstract VectorScorer vectorScorer(LeafReaderContext leafReaderContext) throws IOException; + + @Override + public Explanation explain(LeafReaderContext leafReaderContext, int i) throws IOException { + VectorScorer vectorScorer = vectorScorer(leafReaderContext); + if (vectorScorer == null) { + return Explanation.noMatch("No vector values found for field: " + field); + } + DocIdSetIterator iterator = vectorScorer.iterator(); + iterator.advance(i); + if (iterator.docID() == i) { + float score = vectorScorer.score(); + return Explanation.match(vectorScorer.score() * boost, "found vector with calculated similarity: " + score); + } + return Explanation.noMatch("Document not found in vector values for field: " + field); + } + + @Override + public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException { + VectorScorer vectorScorer = vectorScorer(leafReaderContext); + if (vectorScorer == null) { + return null; + } + return new DenseVectorScorer(this, vectorScorer); + } + + @Override + public boolean isCacheable(LeafReaderContext leafReaderContext) { + return true; + } + } + + public static class Floats extends DenseVectorQuery { + + private final float[] query; + + public Floats(float[] query, String field) { + super(field); + this.query = query; + } + + public float[] getQuery() { + return query; + } + + @Override + public String toString(String field) { + return "DenseVectorQuery.Floats"; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new DenseVectorWeight(Floats.this, boost) { + @Override + VectorScorer vectorScorer(LeafReaderContext leafReaderContext) throws IOException { + FloatVectorValues vectorValues = leafReaderContext.reader().getFloatVectorValues(field); + if (vectorValues == null) { + return null; + } + return vectorValues.scorer(query); + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Floats floats = (Floats) o; + return Objects.equals(field, floats.field) && Objects.deepEquals(query, floats.query); + } + + @Override + public int hashCode() { + return Objects.hash(field, Arrays.hashCode(query)); + } + } + + public static class Bytes extends DenseVectorQuery { + + private final byte[] query; + + public Bytes(byte[] query, String field) { + super(field); + this.query = query; + } + + @Override + public String toString(String field) { + return "DenseVectorQuery.Bytes"; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new DenseVectorWeight(Bytes.this, boost) { + @Override + VectorScorer vectorScorer(LeafReaderContext leafReaderContext) throws IOException { + ByteVectorValues vectorValues = leafReaderContext.reader().getByteVectorValues(field); + if (vectorValues == null) { + return null; + } + return vectorValues.scorer(query); + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Bytes bytes = (Bytes) o; + return Objects.equals(field, bytes.field) && Objects.deepEquals(query, bytes.query); + } + + @Override + public int hashCode() { + return Objects.hash(field, Arrays.hashCode(query)); + } + } + + static class DenseVectorScorer extends Scorer { + + private final VectorScorer vectorScorer; + private final DocIdSetIterator iterator; + private final float boost; + + DenseVectorScorer(DenseVectorWeight weight, VectorScorer vectorScorer) { + super(weight); + this.vectorScorer = vectorScorer; + this.iterator = vectorScorer.iterator(); + this.boost = weight.boost; + } + + @Override + public DocIdSetIterator iterator() { + return vectorScorer.iterator(); + } + + @Override + public float getMaxScore(int i) throws IOException { + // TODO: can we optimize this at all? + return Float.POSITIVE_INFINITY; + } + + @Override + public float score() throws IOException { + assert iterator.docID() != -1; + return vectorScorer.score() * boost; + } + + @Override + public int docID() { + return iterator.docID(); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index fa4c8bb089855..f178e66955fdc 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -8,11 +8,6 @@ package org.elasticsearch.index.mapper.vectors; -import org.apache.lucene.queries.function.FunctionQuery; -import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; -import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; @@ -25,6 +20,7 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; +import org.elasticsearch.search.vectors.DenseVectorQuery; import org.elasticsearch.search.vectors.VectorData; import java.io.IOException; @@ -218,16 +214,7 @@ public void testExactKnnQuery() { queryVector[i] = randomFloat(); } Query query = field.createExactKnnQuery(VectorData.fromFloats(queryVector)); - assertTrue(query instanceof BooleanQuery); - BooleanQuery booleanQuery = (BooleanQuery) query; - boolean foundFunction = false; - for (BooleanClause clause : booleanQuery) { - if (clause.getQuery() instanceof FunctionQuery functionQuery) { - foundFunction = true; - assertTrue(functionQuery.getValueSource() instanceof FloatVectorSimilarityFunction); - } - } - assertTrue("Unable to find FloatVectorSimilarityFunction in created BooleanQuery", foundFunction); + assertTrue(query instanceof DenseVectorQuery.Floats); } { DenseVectorFieldType field = new DenseVectorFieldType( @@ -245,16 +232,7 @@ public void testExactKnnQuery() { queryVector[i] = randomByte(); } Query query = field.createExactKnnQuery(VectorData.fromBytes(queryVector)); - assertTrue(query instanceof BooleanQuery); - BooleanQuery booleanQuery = (BooleanQuery) query; - boolean foundFunction = false; - for (BooleanClause clause : booleanQuery) { - if (clause.getQuery() instanceof FunctionQuery functionQuery) { - foundFunction = true; - assertTrue(functionQuery.getValueSource() instanceof ByteVectorSimilarityFunction); - } - } - assertTrue("Unable to find FloatVectorSimilarityFunction in created BooleanQuery", foundFunction); + assertTrue(query instanceof DenseVectorQuery.Bytes); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractDenseVectorQueryTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractDenseVectorQueryTestCase.java new file mode 100644 index 0000000000000..6d2d600a18a81 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractDenseVectorQueryTestCase.java @@ -0,0 +1,307 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.store.BaseDirectoryWrapper; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +abstract class AbstractDenseVectorQueryTestCase extends ESTestCase { + + abstract DenseVectorQuery getDenseVectorQuery(String field, float[] query); + + abstract float[] randomVector(int dim); + + abstract Field getKnnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction); + + public void testEquals() { + DenseVectorQuery q1 = getDenseVectorQuery("f1", new float[] { 0, 1 }); + DenseVectorQuery q2 = getDenseVectorQuery("f1", new float[] { 0, 1 }); + + assertEquals(q2, q1); + + assertNotEquals(null, q1); + assertNotEquals(q1, new TermQuery(new Term("f1", "x"))); + + assertNotEquals(q1, getDenseVectorQuery("f2", new float[] { 0, 1 })); + assertNotEquals(q1, getDenseVectorQuery("f1", new float[] { 1, 1 })); + } + + public void testEmptyIndex() throws IOException { + try (Directory indexStore = getIndexStore("field"); IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + DenseVectorQuery kvq = getDenseVectorQuery("field", new float[] { 1, 2 }); + assertMatches(searcher, kvq, 0); + } + } + + /** testDimensionMismatch */ + public void testDimensionMismatch() throws IOException { + try ( + Directory indexStore = getIndexStore("field", new float[] { 0, 1 }, new float[] { 1, 2 }, new float[] { 0, 0 }); + IndexReader reader = DirectoryReader.open(indexStore) + ) { + IndexSearcher searcher = newSearcher(reader); + DenseVectorQuery kvq = getDenseVectorQuery("field", new float[] { 0 }); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10)); + assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage()); + } + } + + /** testNonVectorField */ + public void testNonVectorField() throws IOException { + try ( + Directory indexStore = getIndexStore("field", new float[] { 0, 1 }, new float[] { 1, 2 }, new float[] { 0, 0 }); + IndexReader reader = DirectoryReader.open(indexStore) + ) { + IndexSearcher searcher = newSearcher(reader); + assertMatches(searcher, getDenseVectorQuery("xyzzy", new float[] { 0 }), 0); + assertMatches(searcher, getDenseVectorQuery("id", new float[] { 0 }), 0); + } + } + + public void testScoreEuclidean() throws IOException { + float[][] vectors = new float[5][]; + for (int j = 0; j < 5; j++) { + vectors[j] = new float[] { j, j }; + } + try ( + Directory d = getStableIndexStore("field", VectorSimilarityFunction.EUCLIDEAN, vectors); + IndexReader reader = DirectoryReader.open(d) + ) { + IndexSearcher searcher = new IndexSearcher(reader); + float[] queryVector = new float[] { 2, 3 }; + DenseVectorQuery query = getDenseVectorQuery("field", queryVector); + Query rewritten = query.rewrite(searcher); + Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1); + Scorer scorer = weight.scorer(reader.leaves().get(0)); + + // prior to advancing, score is 0 + assertEquals(-1, scorer.docID()); + + DocIdSetIterator it = scorer.iterator(); + assertEquals(5, it.cost()); + it.nextDoc(); + int curDoc = 0; + // iterate the docs and assert the scores are what we expect + while (it.docID() != NO_MORE_DOCS) { + assertEquals(VectorSimilarityFunction.EUCLIDEAN.compare(vectors[curDoc], queryVector), scorer.score(), 0.0001); + curDoc++; + it.nextDoc(); + } + } + } + + public void testScoreCosine() throws IOException { + float[][] vectors = new float[5][]; + for (int j = 1; j <= 5; j++) { + vectors[j - 1] = new float[] { j, j * j }; + } + try (Directory d = getStableIndexStore("field", COSINE, vectors)) { + try (IndexReader reader = DirectoryReader.open(d)) { + assertEquals(1, reader.leaves().size()); + IndexSearcher searcher = new IndexSearcher(reader); + float[] queryVector = new float[] { 2, 3 }; + DenseVectorQuery query = getDenseVectorQuery("field", queryVector); + Query rewritten = query.rewrite(searcher); + Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1); + Scorer scorer = weight.scorer(reader.leaves().get(0)); + + // prior to advancing, score is undefined + assertEquals(-1, scorer.docID()); + DocIdSetIterator it = scorer.iterator(); + assertEquals(5, it.cost()); + it.nextDoc(); + int curDoc = 0; + // iterate the docs and assert the scores are what we expect + while (it.docID() != NO_MORE_DOCS) { + assertEquals(COSINE.compare(vectors[curDoc], queryVector), scorer.score(), 0.0001); + curDoc++; + it.nextDoc(); + } + } + } + } + + public void testScoreMIP() throws IOException { + try ( + Directory indexStore = getIndexStore( + "field", + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, + new float[] { 0, 1 }, + new float[] { 1, 2 }, + new float[] { 0, 0 } + ); + IndexReader reader = DirectoryReader.open(indexStore) + ) { + IndexSearcher searcher = newSearcher(reader); + DenseVectorQuery kvq = getDenseVectorQuery("field", new float[] { 0, -1 }); + assertMatches(searcher, kvq, 3); + ScoreDoc[] scoreDocs = searcher.search(kvq, 3).scoreDocs; + assertIdMatches(reader, "id2", scoreDocs[0]); + assertIdMatches(reader, "id0", scoreDocs[1]); + assertIdMatches(reader, "id1", scoreDocs[2]); + + assertEquals(1.0, scoreDocs[0].score, 1e-7); + assertEquals(1 / 2f, scoreDocs[1].score, 1e-7); + assertEquals(1 / 3f, scoreDocs[2].score, 1e-7); + } + } + + public void testExplain() throws IOException { + float[][] vectors = new float[5][]; + for (int j = 0; j < 5; j++) { + vectors[j] = new float[] { j, j }; + } + try (Directory d = getStableIndexStore("field", VectorSimilarityFunction.EUCLIDEAN, vectors)) { + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = new IndexSearcher(reader); + DenseVectorQuery query = getDenseVectorQuery("field", new float[] { 2, 3 }); + Explanation matched = searcher.explain(query, 2); + assertTrue(matched.isMatch()); + assertEquals(1 / 2f, matched.getValue()); + assertEquals(0, matched.getDetails().length); + + Explanation nomatch = searcher.explain(query, 6); + assertFalse(nomatch.isMatch()); + + nomatch = searcher.explain(getDenseVectorQuery("someMissingField", new float[] { 2, 3 }), 6); + assertFalse(nomatch.isMatch()); + } + } + } + + public void testRandom() throws IOException { + int numDocs = atLeast(100); + int dimension = atLeast(5); + int numIters = atLeast(10); + boolean everyDocHasAVector = random().nextBoolean(); + try (Directory d = newDirectoryForTest()) { + RandomIndexWriter w = new RandomIndexWriter(random(), d); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (everyDocHasAVector || random().nextInt(10) != 2) { + doc.add(getKnnVectorField("field", randomVector(dimension), VectorSimilarityFunction.EUCLIDEAN)); + } + w.addDocument(doc); + } + w.close(); + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + DenseVectorQuery query = getDenseVectorQuery("field", randomVector(dimension)); + int n = random().nextInt(100) + 1; + TopDocs results = searcher.search(query, n); + assert reader.hasDeletions() == false; + assertTrue(results.totalHits.value >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + void assertIdMatches(IndexReader reader, String expectedId, ScoreDoc scoreDoc) throws IOException { + String actualId = reader.storedFields().document(scoreDoc.doc).get("id"); + assertEquals(expectedId, actualId); + } + + private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches) throws IOException { + ScoreDoc[] result = searcher.search(q, 1000).scoreDocs; + assertEquals(expectedMatches, result.length); + } + + Directory getIndexStore(String field, float[]... contents) throws IOException { + return getIndexStore(field, VectorSimilarityFunction.EUCLIDEAN, contents); + } + + private Directory getStableIndexStore(String field, VectorSimilarityFunction vectorSimilarityFunction, float[]... contents) + throws IOException { + Directory indexStore = newDirectoryForTest(); + try (IndexWriter writer = new IndexWriter(indexStore, new IndexWriterConfig())) { + for (int i = 0; i < contents.length; ++i) { + Document doc = new Document(); + doc.add(getKnnVectorField(field, contents[i], vectorSimilarityFunction)); + doc.add(new StringField("id", "id" + i, Field.Store.YES)); + writer.addDocument(doc); + } + // Add some documents without a vector + for (int i = 0; i < 5; i++) { + Document doc = new Document(); + doc.add(new StringField("other", "value", Field.Store.NO)); + writer.addDocument(doc); + } + } + return indexStore; + } + + Directory getIndexStore(String field, VectorSimilarityFunction vectorSimilarityFunction, float[]... contents) throws IOException { + Directory indexStore = newDirectoryForTest(); + RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore); + for (int i = 0; i < contents.length; ++i) { + Document doc = new Document(); + doc.add(getKnnVectorField(field, contents[i], vectorSimilarityFunction)); + doc.add(new StringField("id", "id" + i, Field.Store.YES)); + writer.addDocument(doc); + if (randomBoolean()) { + // Add some documents without a vector + for (int j = 0; j < randomIntBetween(1, 5); j++) { + doc = new Document(); + doc.add(new StringField("other", "value", Field.Store.NO)); + // Add fields that will be matched by our test filters but won't have vectors + doc.add(new StringField("id", "id" + j, Field.Store.YES)); + writer.addDocument(doc); + } + } + } + // Add some documents without a vector + for (int i = 0; i < 5; i++) { + Document doc = new Document(); + doc.add(new StringField("other", "value", Field.Store.NO)); + writer.addDocument(doc); + } + writer.close(); + return indexStore; + } + + protected BaseDirectoryWrapper newDirectoryForTest() { + return LuceneTestCase.newDirectory(random()); + } + +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/DenseVectorQueryBytesTests.java b/server/src/test/java/org/elasticsearch/search/vectors/DenseVectorQueryBytesTests.java new file mode 100644 index 0000000000000..8007f5048adca --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/DenseVectorQueryBytesTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.index.VectorSimilarityFunction; + +public class DenseVectorQueryBytesTests extends AbstractDenseVectorQueryTestCase { + @Override + DenseVectorQuery getDenseVectorQuery(String field, float[] query) { + byte[] bytes = new byte[query.length]; + for (int i = 0; i < query.length; i++) { + bytes[i] = (byte) query[i]; + } + return new DenseVectorQuery.Bytes(bytes, field); + } + + @Override + float[] randomVector(int dim) { + byte[] bytes = new byte[dim]; + random().nextBytes(bytes); + float[] floats = new float[dim]; + for (int i = 0; i < dim; i++) { + floats[i] = bytes[i]; + } + return floats; + } + + @Override + Field getKnnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) { + byte[] bytes = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + bytes[i] = (byte) vector[i]; + } + return new KnnByteVectorField(name, bytes, similarityFunction); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/DenseVectorQueryFloatsTests.java b/server/src/test/java/org/elasticsearch/search/vectors/DenseVectorQueryFloatsTests.java new file mode 100644 index 0000000000000..04355ee53d3c9 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/DenseVectorQueryFloatsTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.VectorSimilarityFunction; + +public class DenseVectorQueryFloatsTests extends AbstractDenseVectorQueryTestCase { + @Override + DenseVectorQuery getDenseVectorQuery(String field, float[] query) { + return new DenseVectorQuery.Floats(query, field); + } + + @Override + float[] randomVector(int dim) { + float[] vector = new float[dim]; + for (int i = 0; i < vector.length; i++) { + vector[i] = randomFloat(); + } + return vector; + } + + @Override + Field getKnnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnFloatVectorField(name, vector, similarityFunction); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java index 02093d9fa0e44..1e77e35b60a4c 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java @@ -8,11 +8,8 @@ package org.elasticsearch.search.vectors; -import org.apache.lucene.queries.function.FunctionQuery; -import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.Strings; import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.index.IndexVersions; @@ -25,9 +22,9 @@ import org.elasticsearch.xcontent.XContentFactory; import java.io.IOException; +import java.util.Arrays; import java.util.Collection; import java.util.List; -import java.util.Locale; public class ExactKnnQueryBuilderTests extends AbstractQueryTestCase { @@ -86,22 +83,16 @@ public void testValidOutput() { @Override protected void doAssertLuceneQuery(ExactKnnQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { - assertTrue(query instanceof BooleanQuery); - BooleanQuery booleanQuery = (BooleanQuery) query; - boolean foundFunction = false; - for (BooleanClause clause : booleanQuery) { - if (clause.getQuery() instanceof FunctionQuery functionQuery) { - foundFunction = true; - assertTrue(functionQuery.getValueSource() instanceof FloatVectorSimilarityFunction); - String description = functionQuery.getValueSource().description().toLowerCase(Locale.ROOT); - if (context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.NORMALIZED_VECTOR_COSINE)) { - assertTrue(description, description.contains("dot_product")); - } else { - assertTrue(description, description.contains("cosine")); - } - } + assertTrue(query instanceof DenseVectorQuery.Floats); + DenseVectorQuery.Floats denseVectorQuery = (DenseVectorQuery.Floats) query; + assertEquals(VECTOR_FIELD, denseVectorQuery.field); + float[] expected = Arrays.copyOf(queryBuilder.getQuery().asFloatVector(), queryBuilder.getQuery().asFloatVector().length); + if (context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.NORMALIZED_VECTOR_COSINE)) { + VectorUtil.l2normalize(expected); + assertArrayEquals(expected, denseVectorQuery.getQuery(), 0.0f); + } else { + assertArrayEquals(expected, denseVectorQuery.getQuery(), 0.0f); } - assertTrue("Unable to find FloatVectorSimilarityFunction in created BooleanQuery", foundFunction); } @Override