Skip to content

Commit

Permalink
Change query clause name to neural_sparse
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Sep 28, 2023
1 parent 67ced0d commit d66746e
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
Expand Down Expand Up @@ -81,7 +81,7 @@ public Collection<Object> createComponents(
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralQueryBuilder.initialize(clientAccessor);
SparseEncodingQueryBuilder.initialize(clientAccessor);
NeuralSparseQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}
Expand All @@ -91,7 +91,7 @@ public List<QuerySpec<?>> getQueries() {
return Arrays.asList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent),
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent),
new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent)
new QuerySpec<>(NeuralSparseQueryBuilder.NAME, NeuralSparseQueryBuilder::new, NeuralSparseQueryBuilder::fromXContent)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import com.google.common.annotations.VisibleForTesting;

/**
* SparseEncodingQueryBuilder is responsible for handling "sparse_encoding" query types. It uses an ML SPARSE_ENCODING model
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML SPARSE_ENCODING model
* or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed
* to Lucene FeatureQuery wrapped by Lucene BooleanQuery.
*/
Expand All @@ -55,8 +55,8 @@
@Accessors(chain = true, fluent = true)
@NoArgsConstructor
@AllArgsConstructor
public class SparseEncodingQueryBuilder extends AbstractQueryBuilder<SparseEncodingQueryBuilder> {
public static final String NAME = "sparse_encoding";
public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQueryBuilder> {
public static final String NAME = "neural_sparse";
@VisibleForTesting
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
@VisibleForTesting
Expand All @@ -65,7 +65,7 @@ public class SparseEncodingQueryBuilder extends AbstractQueryBuilder<SparseEncod
private static MLCommonsClientAccessor ML_CLIENT;

public static void initialize(MLCommonsClientAccessor mlClient) {
SparseEncodingQueryBuilder.ML_CLIENT = mlClient;
NeuralSparseQueryBuilder.ML_CLIENT = mlClient;
}

private String fieldName;
Expand All @@ -79,7 +79,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
* @param in StreamInput to initialize object from
* @throws IOException thrown if unable to read from input stream
*/
public SparseEncodingQueryBuilder(StreamInput in) throws IOException {
public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
Expand Down Expand Up @@ -115,8 +115,8 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* @return NeuralQueryBuilder
* @throws IOException can be thrown by parser
*/
public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) throws IOException {
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder();
public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throws IOException {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder();
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "First token of " + NAME + "query must be START_OBJECT");
}
Expand Down Expand Up @@ -150,7 +150,7 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr
return sparseEncodingQueryBuilder;
}

private static void parseQueryParams(XContentParser parser, SparseEncodingQueryBuilder sparseEncodingQueryBuilder) throws IOException {
private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBuilder sparseEncodingQueryBuilder) throws IOException {
XContentParser.Token token;
String currentFieldName = "";
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -200,7 +200,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
}, actionListener::onFailure)
))
);
return new SparseEncodingQueryBuilder().fieldName(fieldName)
return new NeuralSparseQueryBuilder().fieldName(fieldName)
.queryText(queryText)
.modelId(modelId)
.queryTokensSupplier(queryTokensSetOnce::get);
Expand Down Expand Up @@ -254,7 +254,7 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
}

@Override
protected boolean doEquals(SparseEncodingQueryBuilder obj) {
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD;
import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap;
import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.NAME;
import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -42,7 +42,7 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.test.OpenSearchTestCase;

public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase {
public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase {

private static final String FIELD_NAME = "testField";
private static final String QUERY_TEXT = "Hello world!";
Expand Down Expand Up @@ -71,7 +71,7 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);

assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText());
Expand Down Expand Up @@ -102,7 +102,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);

assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText());
Expand Down Expand Up @@ -137,7 +137,7 @@ public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(ParsingException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
expectThrows(ParsingException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
Expand All @@ -158,7 +158,7 @@ public void testFromXContent_whenBuildWithMissingQuery_thenFail() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
Expand All @@ -179,7 +179,7 @@ public void testFromXContent_whenBuildWithMissingModelId_thenFail() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
Expand All @@ -206,13 +206,13 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IOException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
expectThrows(IOException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testToXContent() {
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.modelId(MODEL_ID)
.queryText(QUERY_TEXT);

Expand Down Expand Up @@ -243,7 +243,7 @@ public void testToXContent() {

@SneakyThrows
public void testStreams() {
SparseEncodingQueryBuilder original = new SparseEncodingQueryBuilder();
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryText(QUERY_TEXT);
original.modelId(MODEL_ID);
Expand All @@ -260,7 +260,7 @@ public void testStreams() {
)
);

SparseEncodingQueryBuilder copy = new SparseEncodingQueryBuilder(filterStreamInput);
NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);
}

Expand All @@ -276,54 +276,54 @@ public void testHashAndEquals() {
String queryName1 = "query-1";
String queryName2 = "query-2";

SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except default boost and query name
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder().fieldName(
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new NeuralSparseQueryBuilder().fieldName(
fieldName1
).queryText(queryText1).modelId(modelId1);

// Identical to sparseEncodingQueryBuilder_baseline except diff field name
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder().fieldName(fieldName2)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new NeuralSparseQueryBuilder().fieldName(fieldName2)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff query text
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText2)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff model ID
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId2)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff boost
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffBoost = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost2)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff query name
SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder().fieldName(fieldName1)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
Expand Down Expand Up @@ -356,7 +356,7 @@ public void testHashAndEquals() {

@SneakyThrows
public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() {
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID);
Map<String, Float> expectedMap = Map.of("1", 1f, "2", 2f);
Expand All @@ -366,7 +366,7 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier()
listener.onResponse(List.of(Map.of("response", List.of(expectedMap))));
return null;
}).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any());
SparseEncodingQueryBuilder.initialize(mlCommonsClientAccessor);
NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor);

final CountDownLatch inProgressLatch = new CountDownLatch(1);
QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class);
Expand All @@ -382,15 +382,15 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier()
return null;
}).when(queryRewriteContext).registerAsyncAction(any());

SparseEncodingQueryBuilder queryBuilder = (SparseEncodingQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext);
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext);
assertNotNull(queryBuilder.queryTokensSupplier());
assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS));
assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get());
}

@SneakyThrows
public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() {
SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME)
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
Expand Down
Loading

0 comments on commit d66746e

Please sign in to comment.