From 768c5af847a77a4a48b302cb929d1e0e24486e9a Mon Sep 17 00:00:00 2001 From: Chris Hostetter Date: Mon, 18 Mar 2024 13:27:24 -0700 Subject: [PATCH] SOLR-17164: Add 2 arg variant of vectorSimilarity() function --- solr/CHANGES.txt | 2 + .../apache/solr/search/ValueSourceParser.java | 40 +--- .../search/VectorSimilaritySourceParser.java | 183 +++++++++++++++++ .../apache/solr/search/QueryEqualityTest.java | 66 ++++++- .../VectorSimilaritySourceParserTest.java | 187 ++++++++++++++++++ .../TestDenseVectorFunctionQuery.java | 165 ++++++++++++++++ .../query-guide/pages/function-queries.adoc | 32 ++- 7 files changed, 622 insertions(+), 53 deletions(-) create mode 100644 solr/core/src/java/org/apache/solr/search/VectorSimilaritySourceParser.java create mode 100644 solr/core/src/test/org/apache/solr/search/VectorSimilaritySourceParserTest.java diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index 70571a9c2ed..9dc9fb0aa47 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -120,6 +120,8 @@ Improvements * SOLR-17172: Add QueryLimits termination to the existing heavy SearchComponent-s. This allows query limits (e.g. timeAllowed, cpuAllowed) to terminate expensive operations within components if limits are exceeded. (Andrzej Bialecki) +* SOLR-17164: Add 2 arg variant of vectorSimilarity() function (Sanjay Dutt, hossman) + Optimizations --------------------- * SOLR-17144: Close searcherExecutor thread per core after 1 minute (Pierre Salagnac, Christine Poerschke) diff --git a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java index f3054efa1ea..a6bd40f8f49 100644 --- a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java +++ b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java @@ -26,15 +26,12 @@ import java.util.Map; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; -import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.queries.function.FunctionValues; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.queries.function.docvalues.BoolDocValues; import org.apache.lucene.queries.function.docvalues.DoubleDocValues; import org.apache.lucene.queries.function.docvalues.LongDocValues; -import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; import org.apache.lucene.queries.function.valuesource.ConstNumberSource; import org.apache.lucene.queries.function.valuesource.ConstValueSource; import org.apache.lucene.queries.function.valuesource.DefFunction; @@ -42,7 +39,6 @@ import org.apache.lucene.queries.function.valuesource.DocFreqValueSource; import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource; import org.apache.lucene.queries.function.valuesource.DualFloatFunction; -import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; import org.apache.lucene.queries.function.valuesource.IDFValueSource; import org.apache.lucene.queries.function.valuesource.IfFunction; import org.apache.lucene.queries.function.valuesource.JoinDocFreqValueSource; @@ -344,41 +340,7 @@ public ValueSource parse(FunctionQParser fp) throws SyntaxError { } }); alias("sum", "add"); - addParser( - "vectorSimilarity", - new ValueSourceParser() { - @Override - public ValueSource parse(FunctionQParser fp) throws SyntaxError { - - VectorEncoding vectorEncoding = VectorEncoding.valueOf(fp.parseArg()); - VectorSimilarityFunction functionName = VectorSimilarityFunction.valueOf(fp.parseArg()); - - int vectorEncodingFlag = - vectorEncoding.equals(VectorEncoding.BYTE) - ? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING - : 0; - ValueSource v1 = - fp.parseValueSource( - FunctionQParser.FLAG_DEFAULT - | FunctionQParser.FLAG_CONSUME_DELIMITER - | vectorEncodingFlag); - ValueSource v2 = - fp.parseValueSource( - FunctionQParser.FLAG_DEFAULT - | FunctionQParser.FLAG_CONSUME_DELIMITER - | vectorEncodingFlag); - - switch (vectorEncoding) { - case FLOAT32: - return new FloatVectorSimilarityFunction(functionName, v1, v2); - case BYTE: - return new ByteVectorSimilarityFunction(functionName, v1, v2); - default: - throw new SyntaxError("Invalid vector encoding: " + vectorEncoding); - } - } - }); - + addParser("vectorSimilarity", new VectorSimilaritySourceParser()); addParser( "product", new ValueSourceParser() { diff --git a/solr/core/src/java/org/apache/solr/search/VectorSimilaritySourceParser.java b/solr/core/src/java/org/apache/solr/search/VectorSimilaritySourceParser.java new file mode 100644 index 00000000000..aed934cdc8e --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/VectorSimilaritySourceParser.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search; + +import static org.apache.solr.common.SolrException.ErrorCode; +import static org.apache.solr.common.SolrException.ErrorCode.BAD_REQUEST; + +import java.util.Arrays; +import java.util.Locale; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; +import org.apache.solr.common.SolrException; +import org.apache.solr.schema.DenseVectorField; +import org.apache.solr.schema.FieldType; +import org.apache.solr.schema.SchemaField; + +/** + * This class provides implementation for two variants for parsing function query vectorSimilarity + * which is used to calculate the similarity between two vectors. + */ +public class VectorSimilaritySourceParser extends ValueSourceParser { + @Override + public ValueSource parse(FunctionQParser fp) throws SyntaxError { + + final String arg1Str = fp.parseArg(); + if (arg1Str == null || !fp.hasMoreArguments()) + throw new SolrException( + BAD_REQUEST, "Invalid number of arguments. Please provide either two or four arguments."); + + final String arg2Str = peekIsConstVector(fp) ? null : fp.parseArg(); + if (fp.hasMoreArguments() && arg2Str != null) { + return handle4ArgsVariant(fp, arg1Str, arg2Str); + } + return handle2ArgsVariant(fp, arg1Str, arg2Str); + } + + /** + * returns true if and only if the next argument is a constant vector, taking into consideration + * that the next (literal) argument may be a param reference + */ + private boolean peekIsConstVector(final FunctionQParser fp) throws SyntaxError { + final char rawPeek = fp.sp.peek(); + if ('[' == rawPeek) { + return true; + } + if ('$' == rawPeek) { + final int savedPos = fp.sp.pos; + try { + final String rawParam = fp.parseArg(); + return ((null != rawParam) && ('[' == (new StrParser(rawParam)).peek())); + } finally { + fp.sp.pos = savedPos; + } + } + return false; + } + + private static int buildVectorEncodingFlag(final VectorEncoding vectorEncoding) { + return FunctionQParser.FLAG_DEFAULT + | FunctionQParser.FLAG_CONSUME_DELIMITER + | (vectorEncoding.equals(VectorEncoding.BYTE) + ? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING + : 0); + } + + /** Expects to find args #3 and #4 (two vector ValueSources) still in the function parser */ + private ValueSource handle4ArgsVariant(FunctionQParser fp, String vecEncStr, String vecSimFuncStr) + throws SyntaxError { + final var vectorEncoding = enumValueOrBadRequest(VectorEncoding.class, vecEncStr); + final var vectorSimilarityFunction = + enumValueOrBadRequest(VectorSimilarityFunction.class, vecSimFuncStr); + final int vectorEncodingFlag = buildVectorEncodingFlag(vectorEncoding); + final ValueSource v1 = fp.parseValueSource(vectorEncodingFlag); + final ValueSource v2 = fp.parseValueSource(vectorEncodingFlag); + return createSimilarityFunction(vectorSimilarityFunction, vectorEncoding, v1, v2); + } + + /** + * If field2Name is null, then expects to find a constant vector as the only + * remaining arg in the function parser. + */ + private ValueSource handle2ArgsVariant(FunctionQParser fp, String field1Name, String field2Name) + throws SyntaxError { + + final SchemaField field1 = fp.req.getSchema().getField(field1Name); + final DenseVectorField field1Type = requireVectorType(field1); + + final var vectorEncoding = field1Type.getVectorEncoding(); + final var vectorSimilarityFunction = field1Type.getSimilarityFunction(); + + final ValueSource v1 = field1Type.getValueSource(field1, fp); + final ValueSource v2; + + if (null == field2Name) { + final int vectorEncodingFlag = buildVectorEncodingFlag(vectorEncoding); + v2 = fp.parseValueSource(vectorEncodingFlag); + + } else { + final SchemaField field2 = fp.req.getSchema().getField(field2Name); + final DenseVectorField field2Type = requireVectorType(field2); + if (vectorEncoding != field2Type.getVectorEncoding() + || vectorSimilarityFunction != field2Type.getSimilarityFunction()) { + throw new SolrException( + BAD_REQUEST, + String.format( + Locale.ROOT, + "Invalid arguments: vector field %s and vector field %s must have the same vectorEncoding and similarityFunction", + field1.getName(), + field2.getName())); + } + v2 = field2Type.getValueSource(field2, fp); + } + return createSimilarityFunction(vectorSimilarityFunction, vectorEncoding, v1, v2); + } + + private ValueSource createSimilarityFunction( + VectorSimilarityFunction functionName, + VectorEncoding vectorEncoding, + ValueSource v1, + ValueSource v2) + throws SyntaxError { + switch (vectorEncoding) { + case FLOAT32: + return new FloatVectorSimilarityFunction(functionName, v1, v2); + case BYTE: + return new ByteVectorSimilarityFunction(functionName, v1, v2); + default: + throw new SyntaxError("Invalid vector encoding: " + vectorEncoding); + } + } + + private DenseVectorField requireVectorType(final SchemaField field) throws SyntaxError { + final FieldType fieldType = field.getType(); + if (fieldType instanceof DenseVectorField) { + return (DenseVectorField) field.getType(); + } + throw new SolrException( + BAD_REQUEST, + String.format( + Locale.ROOT, + "Type mismatch: Expected [%s], but found a different field type for field: [%s]", + DenseVectorField.class.getSimpleName(), + field.getName())); + } + + /** + * Helper method that returns the correct Enum instance for the arg String, or throws + * a {@link ErrorCode#BAD_REQUEST} with specifics on the "Invalid argument" + */ + private static > T enumValueOrBadRequest( + final Class enumClass, final String arg) throws SolrException { + assert null != enumClass; + try { + return Enum.valueOf(enumClass, arg); + } catch (IllegalArgumentException | NullPointerException e) { + throw new SolrException( + BAD_REQUEST, + String.format( + Locale.ROOT, + "Invalid argument: %s is not a valid %s. Expected one of %s", + arg, + enumClass.getSimpleName(), + Arrays.toString(enumClass.getEnumConstants()))); + } + } +} diff --git a/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java b/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java index 8eb1c3da71f..653c0935879 100644 --- a/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java +++ b/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java @@ -911,13 +911,67 @@ public void testFuncVector() throws Exception { } public void testFuncKnnVector() throws Exception { - assertFuncEquals( - "vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])", - "vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])"); + try (SolrQueryRequest req = + req( + "v1", "[1,2,3]", + "v2", " [1,2,3] ", + "v3", " [1, 2, 3] ")) { + assertFuncEquals( + req, + "vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])", + "vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])", + "vectorSimilarity(FLOAT32, COSINE,$v1, [4, 5, 6])", + "vectorSimilarity(FLOAT32, COSINE, $v2 , [4, 5, 6])", + "vectorSimilarity(FLOAT32, COSINE, $v3 , [4, 5, 6])"); + } - assertFuncEquals( - "vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])", - "vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])"); + try (SolrQueryRequest req = + req( + "f1", "bar_i", + "f2", " bar_i ", + "f3", " field(bar_i) ")) { + assertFuncEquals( + req, + "vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])", + "vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])", + "vectorSimilarity(BYTE, EUCLIDEAN,$f1, [4, 5, 6])", + "vectorSimilarity(BYTE, EUCLIDEAN, $f1, [4, 5, 6])", + "vectorSimilarity(BYTE, EUCLIDEAN, $f2, [4, 5, 6])", + "vectorSimilarity(BYTE, EUCLIDEAN, $f3, [4, 5, 6])"); + } + + try (SolrQueryRequest req = + req( + "f", "vector", + "v1", "[1,2,3,4]", + "v2", " [1, 2, 3, 4]")) { + assertFuncEquals( + req, + "vectorSimilarity(FLOAT32,COSINE,vector,[1,2,3,4])", + "vectorSimilarity(FLOAT32,COSINE,vector,$v1)", + "vectorSimilarity(FLOAT32,COSINE,vector, $v1)", + "vectorSimilarity(FLOAT32,COSINE,vector,$v2)", + "vectorSimilarity(FLOAT32,COSINE,vector, $v2)", + "vectorSimilarity(vector,[1,2,3,4])", + "vectorSimilarity( vector,[1,2,3,4])", + "vectorSimilarity( $f,[1,2,3,4])", + "vectorSimilarity(vector,$v1)", + "vectorSimilarity(vector, $v1)", + "vectorSimilarity( $f, $v1)", + "vectorSimilarity(vector,$v2)", + "vectorSimilarity(vector, $v2)"); + } + + // contrived, but helps us test the param resolution + // for both field names in the 2arg usecase + try (SolrQueryRequest req = req("f", "vector")) { + assertFuncEquals( + req, + "vectorSimilarity($f, $f)", + "vectorSimilarity($f, vector)", + "vectorSimilarity(vector, $f)", + "vectorSimilarity(vector, vector)"); + } } public void testFuncQuery() throws Exception { diff --git a/solr/core/src/test/org/apache/solr/search/VectorSimilaritySourceParserTest.java b/solr/core/src/test/org/apache/solr/search/VectorSimilaritySourceParserTest.java new file mode 100644 index 00000000000..943bc350056 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/VectorSimilaritySourceParserTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search; + +import static org.apache.solr.SolrTestCaseJ4.assumeWorkingMockito; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; +import org.apache.solr.SolrTestCase; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.schema.BinaryField; +import org.apache.solr.schema.DenseVectorField; +import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.IntPointField; +import org.apache.solr.schema.SchemaField; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +/** Test for {@link VectorSimilaritySourceParser} */ +public class VectorSimilaritySourceParserTest extends SolrTestCase { + private static final VectorSimilaritySourceParser vecSimilarity = + new VectorSimilaritySourceParser(); + private SolrQueryRequest request; + private SolrParams localParams; + private SolrParams params; + private IndexSchema indexSchema; + + @BeforeClass + public static void beforeClass() { + assumeWorkingMockito(); + } + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + resetMocks(); + } + + @Test + public void testReportErrorPassingZeroArg() throws SyntaxError { + SolrException error = + assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity()")); + assertEquals( + "Invalid number of arguments. Please provide either two or four arguments.", + error.getMessage()); + } + + @Test + public void testReportErrorPassingOneArg() throws SyntaxError { + SolrException error = + assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1)")); + assertEquals( + "Invalid number of arguments. Please provide either two or four arguments.", + error.getMessage()); + } + + @Test + public void testReportErrorIfSecArgsEmpty() throws Exception { + SchemaField field1 = new SchemaField("field1", new DenseVectorField(5)); + when(indexSchema.getField("field1")).thenReturn(field1); + + SolrException error = + assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1,)")); + assertEquals( + "Invalid number of arguments. Please provide either two or four arguments.", + error.getMessage()); + } + + @Test + public void testReportErrorIfFirstArgNotVector() throws SyntaxError { + SchemaField field1 = new SchemaField("field1", new IntPointField()); + when(indexSchema.getField("field1")).thenReturn(field1); + + SolrException error = + assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1, field2)")); + assertEquals( + "Type mismatch: Expected [DenseVectorField], but found a different field type for field: [field1]", + error.getMessage()); + } + + @Test + public void testReportErrorIfSecArgNotVector() throws SyntaxError { + DenseVectorField fieldType = new DenseVectorField(5); + SchemaField field1 = new SchemaField("field1", fieldType); + SchemaField field2 = new SchemaField("field2", new BinaryField()); + when(indexSchema.getField("field1")).thenReturn(field1); + when(indexSchema.getField("field2")).thenReturn(field2); + + SolrException error = + assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1, field2)")); + assertEquals( + "Type mismatch: Expected [DenseVectorField], but found a different field type for field: [field2]", + error.getMessage()); + } + + @Test + public void testReportErrorIfFieldMissmatch() throws SyntaxError { + DenseVectorField vectorField1 = + new DenseVectorField(5, VectorSimilarityFunction.COSINE, VectorEncoding.BYTE); + SchemaField field1 = new SchemaField("field1", vectorField1); + DenseVectorField vectorField2 = + new DenseVectorField(5, VectorSimilarityFunction.COSINE, VectorEncoding.FLOAT32); + SchemaField field2 = new SchemaField("field2", vectorField2); + DenseVectorField vectorField3 = + new DenseVectorField(5, VectorSimilarityFunction.DOT_PRODUCT, VectorEncoding.FLOAT32); + SchemaField field3 = new SchemaField("field3", vectorField3); + + when(indexSchema.getField("field1")).thenReturn(field1); + when(indexSchema.getField("field2")).thenReturn(field2); + when(indexSchema.getField("field3")).thenReturn(field3); + + SolrException error = + assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field1, field2)")); + assertEquals( + "Invalid arguments: vector field field1 and vector field field2 must have the same vectorEncoding and similarityFunction", + error.getMessage()); + + error = + assertThrows(SolrException.class, () -> parseWithMocks("vectorSimilarity(field2, field3)")); + assertEquals( + "Invalid arguments: vector field field2 and vector field field3 must have the same vectorEncoding and similarityFunction", + error.getMessage()); + } + + @Test + public void test2ArgsByteVectorField() throws SyntaxError { + DenseVectorField vectorField = + new DenseVectorField(5, VectorSimilarityFunction.COSINE, VectorEncoding.BYTE); + SchemaField field1 = new SchemaField("field1", vectorField); + SchemaField field2 = new SchemaField("field2", vectorField); + when(indexSchema.getField("field1")).thenReturn(field1); + when(indexSchema.getField("field2")).thenReturn(field2); + + ValueSource valueSource = parseWithMocks("vectorSimilarity(field1, field2)"); + assertTrue(valueSource instanceof ByteVectorSimilarityFunction); + } + + @Test + public void test2ArgsFloatVectorAndConst() throws Exception { + DenseVectorField vectorField = + new DenseVectorField(5, VectorSimilarityFunction.COSINE, VectorEncoding.FLOAT32); + SchemaField field1 = new SchemaField("field1", vectorField); + when(indexSchema.getField("field1")).thenReturn(field1); + + ValueSource valueSource = parseWithMocks("vectorSimilarity(field1, [1, 2, 3, 4, 5])"); + assertTrue(valueSource instanceof FloatVectorSimilarityFunction); + } + + private void resetMocks() { + request = mock(SolrQueryRequest.class); + localParams = mock(SolrParams.class); + params = mock(SolrParams.class); + indexSchema = mock(IndexSchema.class); + when(request.getSchema()).thenReturn(indexSchema); + } + + protected ValueSource parseWithMocks(final String input) throws SyntaxError { + final String funcPrefix = "vectorSimilarity("; + assert input.startsWith(funcPrefix); + final FunctionQParser fqp = + new FunctionQParser(input.substring(funcPrefix.length()), localParams, params, request); + return vecSimilarity.parse(fqp); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java b/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java index c6573ff693c..503edd739db 100644 --- a/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java +++ b/solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.List; import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.params.CommonParams; import org.junit.After; @@ -200,4 +201,168 @@ public void vectorQueryInRerankQParser_ShouldRescoreOnlyFirstKResults() { "//result/doc[3]/float[@name='score'][.='0.7002023']", "//result/doc[4]/float[@name='score'][.='0.7002023']"); } + + @Test + public void testReportsErrorInvalidNumberOfArgs() { + assertQEx( + "vectorSimilarity test number of arguments failed!", + "Invalid number of arguments. Please provide either two or four arguments.", + req(CommonParams.Q, "{!func} vectorSimilarity()", "fq", "id:(1 2 3)", "fl", "id, score"), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "vectorSimilarity test number of arguments failed!", + "Invalid number of arguments. Please provide either two or four arguments.", + req( + CommonParams.Q, + "{!func} vectorSimilarity(vector)", + "fq", + "id:(1 2 3)", + "fl", + "id, score"), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "vectorSimilarity test number of arguments failed!", + "Invalid number of arguments. Please provide either two or four arguments.", + req( + CommonParams.Q, + "{!func} vectorSimilarity(vector,)", + "fq", + "id:(1 2 3)", + "fl", + "id, score"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void testReportsErrorInvalidArgs() { + assertQEx( + "vectorSimilarity 2arg: first arg non-vector field", + "undefined field: \"bogus\"", + req(CommonParams.Q, "{!func} vectorSimilarity(bogus, vector_byte_encoding)"), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "vectorSimilarity 2arg: second arg non-vector field", + "undefined field: \"bogus\"", + req(CommonParams.Q, "{!func} vectorSimilarity(vector_byte_encoding, bogus)"), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "vectorSimilarity 3+ args: 1st arg not valid encoding", + "Invalid argument: BOGUS is not a valid VectorEncoding. Expected one of [", + req( + CommonParams.Q, + "{!func} vectorSimilarity(BOGUS, DOT_PRODUCT, vector_byte_encoding, vector_byte_encoding)"), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "vectorSimilarity 3+ args: 2nd arg not valid encoding", + "Invalid argument: BOGUS is not a valid VectorSimilarityFunction. Expected one of [", + req( + CommonParams.Q, + "{!func} vectorSimilarity(BYTE, BOGUS, vector_byte_encoding, vector_byte_encoding)"), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "vectorSimilarity 3 args: first two are valid for 2 arg syntax", + "SyntaxError: Expected ')'", + req(CommonParams.Q, "{!func} vectorSimilarity(vector_byte_encoding,[1,2,3,3],BOGUS)"), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "vectorSimilarity 3 args: first two are valid for 4 arg syntax, w/valid 3rd arg field", + "SyntaxError: Expected identifier", + req(CommonParams.Q, "{!func} vectorSimilarity(BYTE, DOT_PRODUCT, vector_byte_encoding)"), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "vectorSimilarity 3 args: first two are valid for 4 arg syntax, w/valid 3rd arg const vector", + "SyntaxError: Expected identifier", + req(CommonParams.Q, "{!func} vectorSimilarity(BYTE, DOT_PRODUCT, [1,2,3,3])"), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "vectorSimilarity 5 args: valid 4 arg syntax with extra cruft", + "SyntaxError: Expected ')'", + req( + CommonParams.Q, + "{!func} vectorSimilarity(BYTE, DOT_PRODUCT, vector_byte_encoding, vector_byte_encoding, BOGUS)"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void test2ArgsByteFieldAndConstVector() throws Exception { + assertQ( + req( + CommonParams.Q, + "{!func} vectorSimilarity(vector_byte_encoding, [1,2,3,3])", + "fq", + "id:(1 2)", + "fl", + "id, score", + "rows", + "1"), + "//result[@numFound='" + 2 + "']", + "//result/doc[1]/str[@name='id'][.=1]"); + assertQ( + req( + CommonParams.Q, + "{!func} vectorSimilarity(vector_byte_encoding, [3,3,2,1])", + "fq", + "id:(1 2)", + "fl", + "id, score", + "rows", + "1"), + "//result[@numFound='" + 2 + "']", + "//result/doc[1]/str[@name='id'][.=2]"); + } + + @Test + public void test2ArgsFloatFieldAndConstVector() throws Exception { + assertQ( + req( + CommonParams.Q, + "{!func} vectorSimilarity(vector, [1,2,3,3])", + "fq", + "id:(1 2 3)", + "fl", + "id, score"), + "//result[@numFound='" + 3 + "']", + "//result/doc[1]/str[@name='id'][.=2]", + "//result/doc[2]/str[@name='id'][.=3]", + "//result/doc[3]/str[@name='id'][.=1]"); + } + + @Test + public void test2ArgsFloatVectorField() throws Exception { + assertQ( + req( + CommonParams.Q, + "{!func} vectorSimilarity(vector, vector2)", + "fq", + "id:(1 2 3 4)", + "fl", + "id, score"), + "//result[@numFound='" + 4 + "']", + "//result/doc[1]/str[@name='id'][.=2]", + "//result/doc[2]/str[@name='id'][.=1]"); + } + + @Test + public void test2ArgsIfEitherFieldMissingValueDocScoreZero() { + assertQ( + req( + CommonParams.Q, + "{!func} vectorSimilarity(vector, vector2)", + "fq", + "id:(3)", + "fl", + "id, score"), + "//result[@numFound='" + 1 + "']", + "//result/doc[1]/float[@name='score'][.=0.0]"); + assertQ( + req( + CommonParams.Q, + "{!func} vectorSimilarity(vector, vector2)", + "fq", + "id:(4)", + "fl", + "id, score"), + "//result[@numFound='" + 1 + "']", + "//result/doc[1]/float[@name='score'][.=0.0]"); + } } diff --git a/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc b/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc index db839dc873e..48f9345f1cd 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/function-queries.adoc @@ -153,19 +153,35 @@ There must be an even number of ValueSource instances passed in and the method a * `dist(1,x,y,z,e,f,g)`: Manhattan distance between (x,y,z) and (e,f,g) where each letter is a field name. === vectorSimilarity Function -Returns the similarity between two Knn vectors in an n-dimensional space. -Takes in input the vector element encoding, the similarity measure plus two ValueSource instances and calculates the similarity between the two vectors. +Returns the similarity between two Knn vectors in an n-dimensional space. There are two variants of this function. -* The encodings supported are: `BYTE`, `FLOAT32`. -* The similarities supported are: `EUCLIDEAN`, `COSINE`, `DOT_PRODUCT` +==== vectorSimilarity(vector1, vector2) + +This function accepts two vectors as input: The first argument must be the name of a `DenseVectorField`. The second argument can be either the name of a second `DenseVectorField` or a constant vector. + +If two field names are specified, they must be configured with the same `vectorDimensions`, `vectorEncoding`, and `similarityFunction`. If a constant vector is specified, then it will be parsed using the `vectorEncoding` configured on the field specified by the first argument and must have the same dimensions. + +*Syntax Examples* -Each ValueSource must be a knn vector (field or constant). +* `vectorSimilarity(vectorField1, vectorField2)`: calculates the configured similarity between vector fields `vectorField1` and `vectorField2` for each document. +* `vectorSimilarity(vectorField1, [1,2,3,4])`: calculates the configured similarity between vector field `vectorField1` and `[1, 2, 3, 4]` for each document. + +[NOTE] +Only field names that follow xref:indexing-guide:fields.adoc#field-properties[recommended field naming conventions] are guaranteed to work with this syntax. Atypical field names requiring `field("...")` syntax when used in Function Queries must use the more complex 4 argument variant syntax of the `vectorSimilarity(...)` function described below. + +==== vectorSimilarity(ENCODING, SIMILARITY_FUNCTION, vector1, vector2) + +Takes in input the vector element encoding, the similarity measure plus two ValueSource instances (either a `DenseVectorField` or a constant vector) and calculates the similarity between the two vectors. + +* The encodings supported are: `BYTE`, `FLOAT32` +** This is used to parse any constant vector arguments +* The similarities supported are: `EUCLIDEAN`, `COSINE`, `DOT_PRODUCT` *Syntax Examples* -* `vectorSimilarity(FLOAT32, COSINE, [1,2,3], [4,5,6])`: calculates the cosine similarity between [1, 2, 3] and [4, 5, 6] for each document. -* `vectorSimilarity(FLOAT32, DOT_PRODUCT, vectorField1, vectorField2)`: calculates the dot product similarity between the vector in 'vectorField1' and in 'vectorField2' for each document. -* `vectorSimilarity(BYTE, EUCLIDEAN, [1,5,4,3], vectorField)`: calculates the euclidean similarity between the vector in 'vectorField' and the constant vector [1, 5, 4, 3] for each document. +* `vectorSimilarity(FLOAT32, COSINE, [1,2,3], [4,5,6])`: calculates the cosine similarity between `[1, 2, 3]` and `[4, 5, 6]` for each document. +* `vectorSimilarity(FLOAT32, DOT_PRODUCT, vectorField1, vectorField2)`: calculates the dot product similarity between the vector in `vectorField1` and in `vectorField2` for each document. +* `vectorSimilarity(BYTE, EUCLIDEAN, [1,5,4,3], vectorField)`: calculates the euclidean similarity between the vector in `vectorField` and the constant vector `[1, 5, 4, 3]` for each document. === docfreq(field,val) Function Returns the number of documents that contain the term in the field.