Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support for Default Model Id #337

Merged
merged 19 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
- Enabled support for applying default modelId in neural search query ([#337](https://github.com/opensearch-project/neural-search/pull/337)
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
Expand All @@ -43,6 +44,7 @@
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
Expand All @@ -52,6 +54,7 @@
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand Down Expand Up @@ -80,6 +83,7 @@
final IndexNameExpressionResolver indexNameExpressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralSearchClusterUtil.instance().initialize(clusterService);

Check warning on line 86 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L86

Added line #L86 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should start getting rid of these kind of inits and move towards singleton pattern without this kind of inits.

May be an AI for maintainers

NeuralQueryBuilder.initialize(clientAccessor);
SparseEncodingQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
Expand Down Expand Up @@ -136,4 +140,11 @@
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED);
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRequestProcessor>> getRequestProcessors(
Parameters parameters
) {
return Map.of(NeuralQueryEnricherProcessor.TYPE, new NeuralQueryEnricherProcessor.Factory());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import static org.opensearch.ingest.ConfigurationUtils.*;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;

import java.util.Map;

import lombok.Getter;
import lombok.Setter;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.Nullable;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

/**
* Neural Search Query Request Processor, It modifies the search request with neural query clause
* and adds model Id if not present in the search query.
*/
@Setter
@Getter
public class NeuralQueryEnricherProcessor extends AbstractProcessor implements SearchRequestProcessor {

/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "neural_query_enricher";

private final String modelId;

private final Map<String, Object> neuralFieldDefaultIdMap;

/**
* Returns the type of the processor.
*
* @return The processor type.
*/
@Override
public String getType() {
return TYPE;
}

private NeuralQueryEnricherProcessor(
String tag,
String description,
boolean ignoreFailure,
@Nullable String modelId,
@Nullable Map<String, Object> neuralFieldDefaultIdMap
) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
this.neuralFieldDefaultIdMap = neuralFieldDefaultIdMap;
}

/**
* Processes the Search Request.
*
* @return The Search Request.
*/
@Override
public SearchRequest processRequest(SearchRequest searchRequest) {
QueryBuilder queryBuilder = searchRequest.source().query();
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap));
return searchRequest;
}

public static class Factory implements Processor.Factory<SearchRequestProcessor> {
private static final String DEFAULT_MODEL_ID = "default_model_id";
private static final String NEURAL_FIELD_DEFAULT_ID = "neural_field_default_id";

/**
* Create the processor object.
*
* @return NeuralQueryEnricherProcessor
*/
@Override
public NeuralQueryEnricherProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws IllegalArgumentException {
String modelId = readOptionalStringProperty(TYPE, tag, config, DEFAULT_MODEL_ID);
Map<String, Object> neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID);

if (modelId == null && neuralInfoMap == null) {
throw new IllegalArgumentException("model Id or neural info map either of them should be provided");
}

return new NeuralQueryEnricherProcessor(tag, description, ignoreFailure, modelId, neuralInfoMap);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
Expand All @@ -37,6 +38,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;

import com.google.common.annotations.VisibleForTesting;

Expand Down Expand Up @@ -82,6 +84,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;

/**
* Constructor from stream input
Expand All @@ -93,7 +96,12 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
this.modelId = in.readOptionalString();
} else {
this.modelId = in.readString();
}
this.k = in.readVInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
Expand All @@ -102,7 +110,12 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeString(this.queryText);
out.writeString(this.modelId);
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
out.writeOptionalString(this.modelId);
} else {
out.writeString(this.modelId);
}
out.writeVInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
}
Expand All @@ -112,7 +125,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
if (modelId != null) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
xContentBuilder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter);
Expand Down Expand Up @@ -164,8 +179,9 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx
}
requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query");
requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query");
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");

if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query");
}
return neuralQueryBuilder;
}

Expand Down Expand Up @@ -258,4 +274,8 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.query.visitor;

import java.util.Map;

import lombok.AllArgsConstructor;

import org.apache.lucene.search.BooleanClause;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilderVisitor;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

/**
* Neural Search Query Visitor. It visits each and every component of query buikder tree.
*/
@AllArgsConstructor
public class NeuralSearchQueryVisitor implements QueryBuilderVisitor {

private final String modelId;
private final Map<String, Object> neuralFieldMap;

/**
* Accept method accepts every query builder from the search request,
* and processes it if the required conditions in accept method are satisfied.
*/
@Override
public void accept(QueryBuilder queryBuilder) {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
if (queryBuilder instanceof NeuralQueryBuilder) {
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder;
if (neuralQueryBuilder.modelId() == null) {
if (neuralFieldMap != null
&& neuralQueryBuilder.fieldName() != null
&& neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) {
String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName());
neuralQueryBuilder.modelId(fieldDefaultModelId);
} else if (modelId != null) {
neuralQueryBuilder.modelId(modelId);
} else {
throw new IllegalArgumentException(
"model id must be provided in neural query or a default model id must be set in search request processor"
);
}
}
}
}

/**
* Retrieves the child visitor from the Visitor object.
*
* @return The sub Query Visitor
*/
@Override
public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) {
return this;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.util;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.Version;
import org.opensearch.cluster.service.ClusterService;

/**
* Class abstracts information related to underlying OpenSearch cluster
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Log4j2
public class NeuralSearchClusterUtil {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
private ClusterService clusterService;

private static NeuralSearchClusterUtil instance;

/**
* Return instance of the cluster context, must be initialized first for proper usage
* @return instance of cluster context
*/
public static synchronized NeuralSearchClusterUtil instance() {
if (instance == null) {
instance = new NeuralSearchClusterUtil();
}
return instance;
}

/**
* Initializes instance of cluster context by injecting dependencies
* @param clusterService
*/
public void initialize(final ClusterService clusterService) {
this.clusterService = clusterService;
}

/**
* Return minimal OpenSearch version based on all nodes currently discoverable in the cluster
* @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version
*/
public Version getClusterMinVersion() {
return this.clusterService.state().getNodes().getMinNodeVersion();
}

}
Loading
Loading