Skip to content

Commit

Permalink
enhancements: support neural_sparse query by tokens (opensearch-proje…
Browse files Browse the repository at this point in the history
…ct#693)

* enhancements: support neural sparse query by tokens

Signed-off-by: zhichao-aws <[email protected]>

* add empty string logics for streams

Signed-off-by: zhichao-aws <[email protected]>

* fix fromXContent error case logic

Signed-off-by: zhichao-aws <[email protected]>

* add ut

Signed-off-by: zhichao-aws <[email protected]>

* ut it

Signed-off-by: zhichao-aws <[email protected]>

* updates for comments

Signed-off-by: zhichao-aws <[email protected]>

* nit

Signed-off-by: zhichao-aws <[email protected]>

* fix comments

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws authored Apr 22, 2024
1 parent 7b0229d commit f5a47a6
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
- BWC tests for text chunking processor ([#661](https://github.com/opensearch-project/neural-search/pull/661))
- Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670))
- Allowing query by raw tokens in neural_sparse query ([#693](https://github.com/opensearch-project/neural-search/pull/693))
### Bug Fixes
- Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -62,26 +63,26 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
@VisibleForTesting
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
@VisibleForTesting
static final ParseField QUERY_TOKENS_FIELD = new ParseField("query_tokens");
@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
// We use max_token_score field to help WAND scorer prune query clause in lucene 9.7. But in lucene 9.8 the inner
// logics change, this field is not needed any more.
@VisibleForTesting
@Deprecated
static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score").withAllDeprecated();

private static MLCommonsClientAccessor ML_CLIENT;

public static void initialize(MLCommonsClientAccessor mlClient) {
NeuralSparseQueryBuilder.ML_CLIENT = mlClient;
}

private String fieldName;
private String queryText;
private String modelId;
private Float maxTokenScore;
private Supplier<Map<String, Float>> queryTokensSupplier;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;

public static void initialize(MLCommonsClientAccessor mlClient) {
NeuralSparseQueryBuilder.ML_CLIENT = mlClient;
}

/**
* Constructor from stream input
*
Expand All @@ -102,21 +103,31 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
this.queryTokensSupplier = () -> queryTokens;
}
// to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
// after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
if (StringUtils.EMPTY.equals(this.queryText)) {
this.queryText = null;
}
if (StringUtils.EMPTY.equals(this.modelId)) {
this.modelId = null;
}
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeString(queryText);
out.writeString(this.fieldName);
// to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
// after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
out.writeString(StringUtils.defaultString(this.queryText, StringUtils.EMPTY));
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
out.writeOptionalString(this.modelId);
} else {
out.writeString(this.modelId);
out.writeString(StringUtils.defaultString(this.modelId, StringUtils.EMPTY));
}
out.writeOptionalFloat(maxTokenScore);
if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) {
if (!Objects.isNull(this.queryTokensSupplier) && !Objects.isNull(this.queryTokensSupplier.get())) {
out.writeBoolean(true);
out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
out.writeMap(this.queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
} else {
out.writeBoolean(false);
}
Expand All @@ -126,11 +137,16 @@ protected void doWriteTo(StreamOutput out) throws IOException {
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
if (Objects.nonNull(queryText)) {
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
}
if (Objects.nonNull(modelId)) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
if (maxTokenScore != null) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore);
if (Objects.nonNull(maxTokenScore)) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore);
if (Objects.nonNull(queryTokensSupplier) && Objects.nonNull(queryTokensSupplier.get())) {
xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), queryTokensSupplier.get());
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand All @@ -144,6 +160,16 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* "max_token_score": float (optional)
* }
*
* or
* "SAMPLE_FIELD": {
* "query_tokens": {
* "token_a": float,
* "token_b": float,
* ...
* }
* }
*
*
* @param parser XContentParser
* @return NeuralQueryBuilder
* @throws IOException can be thrown by parser
Expand Down Expand Up @@ -171,16 +197,40 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw
}

requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query");
requireValue(
sparseEncodingQueryBuilder.queryText(),
String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME)
);
if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
if (Objects.isNull(sparseEncodingQueryBuilder.queryTokensSupplier())) {
requireValue(
sparseEncodingQueryBuilder.modelId(),
String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME)
sparseEncodingQueryBuilder.queryText(),
String.format(
Locale.ROOT,
"either %s field or %s field must be provided for [%s] query",
QUERY_TEXT_FIELD.getPreferredName(),
QUERY_TOKENS_FIELD.getPreferredName(),
NAME
)
);
if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
requireValue(
sparseEncodingQueryBuilder.modelId(),
String.format(
Locale.ROOT,
"using %s, %s field must be provided for [%s] query",
QUERY_TEXT_FIELD.getPreferredName(),
MODEL_ID_FIELD.getPreferredName(),
NAME
)
);
}
}

if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.queryText())) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "%s field can not be empty", QUERY_TEXT_FIELD.getPreferredName())
);
}
if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.modelId())) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", MODEL_ID_FIELD.getPreferredName()));
}

return sparseEncodingQueryBuilder;
}

Expand All @@ -207,6 +257,9 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui
String.format(Locale.ROOT, "[%s] query does not support [%s] field", NAME, currentFieldName)
);
}
} else if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
Map<String, Float> queryTokens = parser.map(HashMap::new, XContentParser::floatValue);
sparseEncodingQueryBuilder.queryTokensSupplier(() -> queryTokens);
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -293,14 +346,14 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
@Override
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false;
if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false;
if (Objects.isNull(obj) || getClass() != obj.getClass()) return false;
if (Objects.isNull(queryTokensSupplier) && Objects.nonNull(obj.queryTokensSupplier)) return false;
if (Objects.nonNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
.append(modelId, obj.modelId)
.append(maxTokenScore, obj.maxTokenScore);
if (queryTokensSupplier != null) {
if (Objects.nonNull(queryTokensSupplier)) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
return equalsBuilder.isEquals();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TOKENS_FIELD;

import java.io.IOException;
import java.util.List;
Expand All @@ -23,6 +24,7 @@
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import org.apache.commons.lang.StringUtils;
import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
Expand Down Expand Up @@ -95,6 +97,32 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() {
assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId());
}

@SneakyThrows
public void testFromXContent_whenBuiltWithQueryTokens_thenBuildSuccessfully() {
/*
{
"VECTOR_FIELD": {
"query_tokens": {
"token_a": float_score_a,
"token_b": float_score_b
}
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(QUERY_TOKENS_FIELD.getPreferredName(), QUERY_TOKENS_SUPPLIER.get())
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);

assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName());
assertEquals(QUERY_TOKENS_SUPPLIER.get(), sparseEncodingQueryBuilder.queryTokensSupplier().get());
}

@SneakyThrows
public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
/*
Expand Down Expand Up @@ -276,13 +304,56 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {
expectThrows(IOException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
public void testFromXContent_whenBuildWithEmptyQuery_thenFail() {
/*
{
"VECTOR_FIELD": {
"query_text": ""
}
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), StringUtils.EMPTY)
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
public void testFromXContent_whenBuildWithEmptyModelId_thenFail() {
/*
{
"VECTOR_FIELD": {
"model_id": ""
}
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(MODEL_ID_FIELD.getPreferredName(), StringUtils.EMPTY)
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testToXContent() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.modelId(MODEL_ID)
.queryText(QUERY_TEXT)
.maxTokenScore(MAX_TOKEN_SCORE);
.maxTokenScore(MAX_TOKEN_SCORE)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand All @@ -308,15 +379,27 @@ public void testToXContent() {
assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName()));
assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName()));
assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0);
Map<String, Double> parsedQueryTokens = (Map<String, Double>) secondInnerMap.get(QUERY_TOKENS_FIELD.getPreferredName());
assertEquals(QUERY_TOKENS_SUPPLIER.get().keySet(), parsedQueryTokens.keySet());
for (Map.Entry<String, Float> entry : QUERY_TOKENS_SUPPLIER.get().entrySet()) {
assertEquals(entry.getValue(), parsedQueryTokens.get(entry.getKey()).floatValue(), 0);
}
}

public void testStreams_whenCurrentVersion_thenSuccess() {
setUpClusterService(Version.CURRENT);
testStreams();
testStreamsWithQueryTokensOnly();
}

public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() {
setUpClusterService(Version.V_2_12_0);
testStreams();
testStreamsWithQueryTokensOnly();
}

@SneakyThrows
public void testStreams() {
private void testStreams() {
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryText(QUERY_TEXT);
Expand Down Expand Up @@ -356,6 +439,26 @@ public void testStreams() {
assertEquals(original, copy);
}

@SneakyThrows
private void testStreamsWithQueryTokensOnly() {
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);

BytesStreamOutput streamOutput = new BytesStreamOutput();
original.writeTo(streamOutput);

FilterStreamInput filterStreamInput = new NamedWriteableAwareStreamInput(
streamOutput.bytes().streamInput(),
new NamedWriteableRegistry(
List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new))
)
);

NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);
}

public void testHashAndEquals() {
String fieldName1 = "field 1";
String fieldName2 = "field 2";
Expand Down Expand Up @@ -459,6 +562,18 @@ public void testHashAndEquals() {
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens2);

// Identical to sparseEncodingQueryBuilder_baseline except null query text
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nullQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except null model id
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nullModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.boost(boost1)
.queryName(queryName1);

assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline);
assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode());

Expand Down Expand Up @@ -491,6 +606,12 @@ public void testHashAndEquals() {

assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nullQueryText);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nullQueryText.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nullModelId);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nullModelId.hashCode());
}

@SneakyThrows
Expand Down
Loading

0 comments on commit f5a47a6

Please sign in to comment.