Skip to content

Commit

Permalink
[Feature] Support for Default Model Id (#337)
Browse files Browse the repository at this point in the history
* Support for default Model Id

Signed-off-by: Varun Jain <[email protected]>

* Support for Default Model id

Signed-off-by: Varun Jain <[email protected]>

* Support for default model Id

Signed-off-by: Varun Jain <[email protected]>

* Removing wildcard Imports

Signed-off-by: Varun Jain <[email protected]>

* Typo fix

Signed-off-by: Varun Jain <[email protected]>

* Integ test cases

Signed-off-by: Varun Jain <[email protected]>

* Fixing Integ Test case

Signed-off-by: Varun Jain <[email protected]>

* Addressing Comments

Signed-off-by: Varun Jain <[email protected]>

* Added Visitor test cases and addressed comments

Signed-off-by: Varun Jain <[email protected]>

* Comments Addressed of Jack

Signed-off-by: Varun Jain <[email protected]>

* Addressed changes requested by Martin

Signed-off-by: Varun Jain <[email protected]>

* Addressed changes requested by Martin

Signed-off-by: Varun Jain <[email protected]>

* Fixing test cases

Signed-off-by: Varun Jain <[email protected]>

* Increasing test coverage

Signed-off-by: Varun Jain <[email protected]>

* Renaming and addressing comments of Martin

Signed-off-by: Varun Jain <[email protected]>

* Addressing Comments of Navneet

Signed-off-by: Varun Jain <[email protected]>

* Updating tests

Signed-off-by: Varun Jain <[email protected]>

---------

Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun authored Sep 29, 2023
1 parent 2c5d150 commit 9c37b0e
Show file tree
Hide file tree
Showing 17 changed files with 636 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
- Enabled support for applying default modelId in neural search query ([#337](https://github.com/opensearch-project/neural-search/pull/337)
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
Expand All @@ -43,6 +44,7 @@
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
Expand All @@ -52,6 +54,7 @@
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand Down Expand Up @@ -80,6 +83,7 @@ public Collection<Object> createComponents(
final IndexNameExpressionResolver indexNameExpressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralSearchClusterUtil.instance().initialize(clusterService);
NeuralQueryBuilder.initialize(clientAccessor);
SparseEncodingQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
Expand Down Expand Up @@ -136,4 +140,11 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED);
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRequestProcessor>> getRequestProcessors(
Parameters parameters
) {
return Map.of(NeuralQueryEnricherProcessor.TYPE, new NeuralQueryEnricherProcessor.Factory());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import static org.opensearch.ingest.ConfigurationUtils.*;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;

import java.util.Map;

import lombok.Getter;
import lombok.Setter;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.Nullable;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

/**
* Neural Search Query Request Processor, It modifies the search request with neural query clause
* and adds model Id if not present in the search query.
*/
@Setter
@Getter
public class NeuralQueryEnricherProcessor extends AbstractProcessor implements SearchRequestProcessor {

/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "neural_query_enricher";

private final String modelId;

private final Map<String, Object> neuralFieldDefaultIdMap;

/**
* Returns the type of the processor.
*
* @return The processor type.
*/
@Override
public String getType() {
return TYPE;
}

private NeuralQueryEnricherProcessor(
String tag,
String description,
boolean ignoreFailure,
@Nullable String modelId,
@Nullable Map<String, Object> neuralFieldDefaultIdMap
) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
this.neuralFieldDefaultIdMap = neuralFieldDefaultIdMap;
}

/**
* Processes the Search Request.
*
* @return The Search Request.
*/
@Override
public SearchRequest processRequest(SearchRequest searchRequest) {
QueryBuilder queryBuilder = searchRequest.source().query();
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap));
return searchRequest;
}

public static class Factory implements Processor.Factory<SearchRequestProcessor> {
private static final String DEFAULT_MODEL_ID = "default_model_id";
private static final String NEURAL_FIELD_DEFAULT_ID = "neural_field_default_id";

/**
* Create the processor object.
*
* @return NeuralQueryEnricherProcessor
*/
@Override
public NeuralQueryEnricherProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws IllegalArgumentException {
String modelId = readOptionalStringProperty(TYPE, tag, config, DEFAULT_MODEL_ID);
Map<String, Object> neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID);

if (modelId == null && neuralInfoMap == null) {
throw new IllegalArgumentException("model Id or neural info map either of them should be provided");
}

return new NeuralQueryEnricherProcessor(tag, description, ignoreFailure, modelId, neuralInfoMap);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
Expand All @@ -37,6 +38,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import com.google.common.annotations.VisibleForTesting;

Expand Down Expand Up @@ -82,6 +84,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;

/**
* Constructor from stream input
Expand All @@ -93,7 +96,12 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
this.modelId = in.readOptionalString();
} else {
this.modelId = in.readString();
}
this.k = in.readVInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
Expand All @@ -102,7 +110,12 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeString(this.queryText);
out.writeString(this.modelId);
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
out.writeOptionalString(this.modelId);
} else {
out.writeString(this.modelId);
}
out.writeVInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
}
Expand All @@ -112,7 +125,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
if (modelId != null) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
xContentBuilder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter);
Expand Down Expand Up @@ -164,8 +179,9 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx
}
requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query");
requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query");
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");

if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");
}
return neuralQueryBuilder;
}

Expand Down Expand Up @@ -258,4 +274,8 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.query.visitor;

import java.util.Map;

import lombok.AllArgsConstructor;

import org.apache.lucene.search.BooleanClause;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilderVisitor;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

/**
* Neural Search Query Visitor. It visits each and every component of query buikder tree.
*/
@AllArgsConstructor
public class NeuralSearchQueryVisitor implements QueryBuilderVisitor {

private final String modelId;
private final Map<String, Object> neuralFieldMap;

/**
* Accept method accepts every query builder from the search request,
* and processes it if the required conditions in accept method are satisfied.
*/
@Override
public void accept(QueryBuilder queryBuilder) {
if (queryBuilder instanceof NeuralQueryBuilder) {
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder;
if (neuralQueryBuilder.modelId() == null) {
if (neuralFieldMap != null
&& neuralQueryBuilder.fieldName() != null
&& neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) {
String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName());
neuralQueryBuilder.modelId(fieldDefaultModelId);
} else if (modelId != null) {
neuralQueryBuilder.modelId(modelId);
} else {
throw new IllegalArgumentException(
"model id must be provided in neural query or a default model id must be set in search request processor"
);
}
}
}
}

/**
* Retrieves the child visitor from the Visitor object.
*
* @return The sub Query Visitor
*/
@Override
public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) {
return this;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.util;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.Version;
import org.opensearch.cluster.service.ClusterService;

/**
* Class abstracts information related to underlying OpenSearch cluster
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Log4j2
public class NeuralSearchClusterUtil {
private ClusterService clusterService;

private static NeuralSearchClusterUtil instance;

/**
* Return instance of the cluster context, must be initialized first for proper usage
* @return instance of cluster context
*/
public static synchronized NeuralSearchClusterUtil instance() {
if (instance == null) {
instance = new NeuralSearchClusterUtil();
}
return instance;
}

/**
* Initializes instance of cluster context by injecting dependencies
* @param clusterService
*/
public void initialize(final ClusterService clusterService) {
this.clusterService = clusterService;
}

/**
* Return minimal OpenSearch version based on all nodes currently discoverable in the cluster
* @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version
*/
public Version getClusterMinVersion() {
return this.clusterService.state().getNodes().getMinNodeVersion();
}

}
Loading

0 comments on commit 9c37b0e

Please sign in to comment.