Skip to content

Commit

Permalink
SOLR-17164: Add 2 arg variant of vectorSimilarity() function
Browse files Browse the repository at this point in the history
  • Loading branch information
hossman committed Mar 18, 2024
1 parent f67b718 commit 768c5af
Show file tree
Hide file tree
Showing 7 changed files with 622 additions and 53 deletions.
2 changes: 2 additions & 0 deletions solr/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ Improvements
* SOLR-17172: Add QueryLimits termination to the existing heavy SearchComponent-s. This allows query limits (e.g. timeAllowed,
cpuAllowed) to terminate expensive operations within components if limits are exceeded. (Andrzej Bialecki)

* SOLR-17164: Add 2 arg variant of vectorSimilarity() function (Sanjay Dutt, hossman)

Optimizations
---------------------
* SOLR-17144: Close searcherExecutor thread per core after 1 minute (Pierre Salagnac, Christine Poerschke)
Expand Down
40 changes: 1 addition & 39 deletions solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,19 @@
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.FunctionScoreQuery;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.docvalues.BoolDocValues;
import org.apache.lucene.queries.function.docvalues.DoubleDocValues;
import org.apache.lucene.queries.function.docvalues.LongDocValues;
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.ConstNumberSource;
import org.apache.lucene.queries.function.valuesource.ConstValueSource;
import org.apache.lucene.queries.function.valuesource.DefFunction;
import org.apache.lucene.queries.function.valuesource.DivFloatFunction;
import org.apache.lucene.queries.function.valuesource.DocFreqValueSource;
import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource;
import org.apache.lucene.queries.function.valuesource.DualFloatFunction;
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.IDFValueSource;
import org.apache.lucene.queries.function.valuesource.IfFunction;
import org.apache.lucene.queries.function.valuesource.JoinDocFreqValueSource;
Expand Down Expand Up @@ -344,41 +340,7 @@ public ValueSource parse(FunctionQParser fp) throws SyntaxError {
}
});
alias("sum", "add");
addParser(
"vectorSimilarity",
new ValueSourceParser() {
@Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError {

VectorEncoding vectorEncoding = VectorEncoding.valueOf(fp.parseArg());
VectorSimilarityFunction functionName = VectorSimilarityFunction.valueOf(fp.parseArg());

int vectorEncodingFlag =
vectorEncoding.equals(VectorEncoding.BYTE)
? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING
: 0;
ValueSource v1 =
fp.parseValueSource(
FunctionQParser.FLAG_DEFAULT
| FunctionQParser.FLAG_CONSUME_DELIMITER
| vectorEncodingFlag);
ValueSource v2 =
fp.parseValueSource(
FunctionQParser.FLAG_DEFAULT
| FunctionQParser.FLAG_CONSUME_DELIMITER
| vectorEncodingFlag);

switch (vectorEncoding) {
case FLOAT32:
return new FloatVectorSimilarityFunction(functionName, v1, v2);
case BYTE:
return new ByteVectorSimilarityFunction(functionName, v1, v2);
default:
throw new SyntaxError("Invalid vector encoding: " + vectorEncoding);
}
}
});

addParser("vectorSimilarity", new VectorSimilaritySourceParser());
addParser(
"product",
new ValueSourceParser() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;

import static org.apache.solr.common.SolrException.ErrorCode;
import static org.apache.solr.common.SolrException.ErrorCode.BAD_REQUEST;

import java.util.Arrays;
import java.util.Locale;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
import org.apache.solr.common.SolrException;
import org.apache.solr.schema.DenseVectorField;
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.SchemaField;

/**
* This class provides implementation for two variants for parsing function query vectorSimilarity
* which is used to calculate the similarity between two vectors.
*/
public class VectorSimilaritySourceParser extends ValueSourceParser {
@Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError {

final String arg1Str = fp.parseArg();
if (arg1Str == null || !fp.hasMoreArguments())
throw new SolrException(
BAD_REQUEST, "Invalid number of arguments. Please provide either two or four arguments.");

final String arg2Str = peekIsConstVector(fp) ? null : fp.parseArg();
if (fp.hasMoreArguments() && arg2Str != null) {
return handle4ArgsVariant(fp, arg1Str, arg2Str);
}
return handle2ArgsVariant(fp, arg1Str, arg2Str);
}

/**
* returns true if and only if the next argument is a constant vector, taking into consideration
* that the next (literal) argument may be a param reference
*/
private boolean peekIsConstVector(final FunctionQParser fp) throws SyntaxError {
final char rawPeek = fp.sp.peek();
if ('[' == rawPeek) {
return true;
}
if ('$' == rawPeek) {
final int savedPos = fp.sp.pos;
try {
final String rawParam = fp.parseArg();
return ((null != rawParam) && ('[' == (new StrParser(rawParam)).peek()));
} finally {
fp.sp.pos = savedPos;
}
}
return false;
}

private static int buildVectorEncodingFlag(final VectorEncoding vectorEncoding) {
return FunctionQParser.FLAG_DEFAULT
| FunctionQParser.FLAG_CONSUME_DELIMITER
| (vectorEncoding.equals(VectorEncoding.BYTE)
? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING
: 0);
}

/** Expects to find args #3 and #4 (two vector ValueSources) still in the function parser */
private ValueSource handle4ArgsVariant(FunctionQParser fp, String vecEncStr, String vecSimFuncStr)
throws SyntaxError {
final var vectorEncoding = enumValueOrBadRequest(VectorEncoding.class, vecEncStr);
final var vectorSimilarityFunction =
enumValueOrBadRequest(VectorSimilarityFunction.class, vecSimFuncStr);
final int vectorEncodingFlag = buildVectorEncodingFlag(vectorEncoding);
final ValueSource v1 = fp.parseValueSource(vectorEncodingFlag);
final ValueSource v2 = fp.parseValueSource(vectorEncodingFlag);
return createSimilarityFunction(vectorSimilarityFunction, vectorEncoding, v1, v2);
}

/**
* If <code>field2Name</code> is null, then expects to find a constant vector as the only
* remaining arg in the function parser.
*/
private ValueSource handle2ArgsVariant(FunctionQParser fp, String field1Name, String field2Name)
throws SyntaxError {

final SchemaField field1 = fp.req.getSchema().getField(field1Name);
final DenseVectorField field1Type = requireVectorType(field1);

final var vectorEncoding = field1Type.getVectorEncoding();
final var vectorSimilarityFunction = field1Type.getSimilarityFunction();

final ValueSource v1 = field1Type.getValueSource(field1, fp);
final ValueSource v2;

if (null == field2Name) {
final int vectorEncodingFlag = buildVectorEncodingFlag(vectorEncoding);
v2 = fp.parseValueSource(vectorEncodingFlag);

} else {
final SchemaField field2 = fp.req.getSchema().getField(field2Name);
final DenseVectorField field2Type = requireVectorType(field2);
if (vectorEncoding != field2Type.getVectorEncoding()
|| vectorSimilarityFunction != field2Type.getSimilarityFunction()) {
throw new SolrException(
BAD_REQUEST,
String.format(
Locale.ROOT,
"Invalid arguments: vector field %s and vector field %s must have the same vectorEncoding and similarityFunction",
field1.getName(),
field2.getName()));
}
v2 = field2Type.getValueSource(field2, fp);
}
return createSimilarityFunction(vectorSimilarityFunction, vectorEncoding, v1, v2);
}

private ValueSource createSimilarityFunction(
VectorSimilarityFunction functionName,
VectorEncoding vectorEncoding,
ValueSource v1,
ValueSource v2)
throws SyntaxError {
switch (vectorEncoding) {
case FLOAT32:
return new FloatVectorSimilarityFunction(functionName, v1, v2);
case BYTE:
return new ByteVectorSimilarityFunction(functionName, v1, v2);
default:
throw new SyntaxError("Invalid vector encoding: " + vectorEncoding);
}
}

private DenseVectorField requireVectorType(final SchemaField field) throws SyntaxError {
final FieldType fieldType = field.getType();
if (fieldType instanceof DenseVectorField) {
return (DenseVectorField) field.getType();
}
throw new SolrException(
BAD_REQUEST,
String.format(
Locale.ROOT,
"Type mismatch: Expected [%s], but found a different field type for field: [%s]",
DenseVectorField.class.getSimpleName(),
field.getName()));
}

/**
* Helper method that returns the correct Enum instance for the <code>arg</code> String, or throws
* a {@link ErrorCode#BAD_REQUEST} with specifics on the "Invalid argument"
*/
private static <T extends Enum<T>> T enumValueOrBadRequest(
final Class<T> enumClass, final String arg) throws SolrException {
assert null != enumClass;
try {
return Enum.valueOf(enumClass, arg);
} catch (IllegalArgumentException | NullPointerException e) {
throw new SolrException(
BAD_REQUEST,
String.format(
Locale.ROOT,
"Invalid argument: %s is not a valid %s. Expected one of %s",
arg,
enumClass.getSimpleName(),
Arrays.toString(enumClass.getEnumConstants())));
}
}
}
66 changes: 60 additions & 6 deletions solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -911,13 +911,67 @@ public void testFuncVector() throws Exception {
}

public void testFuncKnnVector() throws Exception {
assertFuncEquals(
"vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])",
"vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])");
try (SolrQueryRequest req =
req(
"v1", "[1,2,3]",
"v2", " [1,2,3] ",
"v3", " [1, 2, 3] ")) {
assertFuncEquals(
req,
"vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])",
"vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])",
"vectorSimilarity(FLOAT32, COSINE,$v1, [4, 5, 6])",
"vectorSimilarity(FLOAT32, COSINE, $v2 , [4, 5, 6])",
"vectorSimilarity(FLOAT32, COSINE, $v3 , [4, 5, 6])");
}

assertFuncEquals(
"vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])",
"vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])");
try (SolrQueryRequest req =
req(
"f1", "bar_i",
"f2", " bar_i ",
"f3", " field(bar_i) ")) {
assertFuncEquals(
req,
"vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])",
"vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])",
"vectorSimilarity(BYTE, EUCLIDEAN,$f1, [4, 5, 6])",
"vectorSimilarity(BYTE, EUCLIDEAN, $f1, [4, 5, 6])",
"vectorSimilarity(BYTE, EUCLIDEAN, $f2, [4, 5, 6])",
"vectorSimilarity(BYTE, EUCLIDEAN, $f3, [4, 5, 6])");
}

try (SolrQueryRequest req =
req(
"f", "vector",
"v1", "[1,2,3,4]",
"v2", " [1, 2, 3, 4]")) {
assertFuncEquals(
req,
"vectorSimilarity(FLOAT32,COSINE,vector,[1,2,3,4])",
"vectorSimilarity(FLOAT32,COSINE,vector,$v1)",
"vectorSimilarity(FLOAT32,COSINE,vector, $v1)",
"vectorSimilarity(FLOAT32,COSINE,vector,$v2)",
"vectorSimilarity(FLOAT32,COSINE,vector, $v2)",
"vectorSimilarity(vector,[1,2,3,4])",
"vectorSimilarity( vector,[1,2,3,4])",
"vectorSimilarity( $f,[1,2,3,4])",
"vectorSimilarity(vector,$v1)",
"vectorSimilarity(vector, $v1)",
"vectorSimilarity( $f, $v1)",
"vectorSimilarity(vector,$v2)",
"vectorSimilarity(vector, $v2)");
}

// contrived, but helps us test the param resolution
// for both field names in the 2arg usecase
try (SolrQueryRequest req = req("f", "vector")) {
assertFuncEquals(
req,
"vectorSimilarity($f, $f)",
"vectorSimilarity($f, vector)",
"vectorSimilarity(vector, $f)",
"vectorSimilarity(vector, vector)");
}
}

public void testFuncQuery() throws Exception {
Expand Down
Loading

0 comments on commit 768c5af

Please sign in to comment.