-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support for Default Model Id (#337)
* Support for default Model Id Signed-off-by: Varun Jain <[email protected]> * Support for Default Model id Signed-off-by: Varun Jain <[email protected]> * Support for default model Id Signed-off-by: Varun Jain <[email protected]> * Removing wildcard Imports Signed-off-by: Varun Jain <[email protected]> * Typo fix Signed-off-by: Varun Jain <[email protected]> * Integ test cases Signed-off-by: Varun Jain <[email protected]> * Fixing Integ Test case Signed-off-by: Varun Jain <[email protected]> * Addressing Comments Signed-off-by: Varun Jain <[email protected]> * Added Visitor test cases and addressed comments Signed-off-by: Varun Jain <[email protected]> * Comments Addressed of Jack Signed-off-by: Varun Jain <[email protected]> * Addressed changes requested by Martin Signed-off-by: Varun Jain <[email protected]> * Addressed changes requested by Martin Signed-off-by: Varun Jain <[email protected]> * Fixing test cases Signed-off-by: Varun Jain <[email protected]> * Increasing test coverage Signed-off-by: Varun Jain <[email protected]> * Renaming and addressing comments of Martin Signed-off-by: Varun Jain <[email protected]> * Addressing Comments of Navneet Signed-off-by: Varun Jain <[email protected]> * Updating tests Signed-off-by: Varun Jain <[email protected]> --------- Signed-off-by: Varun Jain <[email protected]>
- Loading branch information
1 parent
2c5d150
commit 9c37b0e
Showing
17 changed files
with
636 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.neuralsearch.processor; | ||
|
||
import static org.opensearch.ingest.ConfigurationUtils.*; | ||
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; | ||
|
||
import java.util.Map; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
|
||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.common.Nullable; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.ingest.ConfigurationUtils; | ||
import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor; | ||
import org.opensearch.search.pipeline.AbstractProcessor; | ||
import org.opensearch.search.pipeline.Processor; | ||
import org.opensearch.search.pipeline.SearchRequestProcessor; | ||
|
||
/** | ||
* Neural Search Query Request Processor, It modifies the search request with neural query clause | ||
* and adds model Id if not present in the search query. | ||
*/ | ||
@Setter | ||
@Getter | ||
public class NeuralQueryEnricherProcessor extends AbstractProcessor implements SearchRequestProcessor { | ||
|
||
/** | ||
* Key to reference this processor type from a search pipeline. | ||
*/ | ||
public static final String TYPE = "neural_query_enricher"; | ||
|
||
private final String modelId; | ||
|
||
private final Map<String, Object> neuralFieldDefaultIdMap; | ||
|
||
/** | ||
* Returns the type of the processor. | ||
* | ||
* @return The processor type. | ||
*/ | ||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
private NeuralQueryEnricherProcessor( | ||
String tag, | ||
String description, | ||
boolean ignoreFailure, | ||
@Nullable String modelId, | ||
@Nullable Map<String, Object> neuralFieldDefaultIdMap | ||
) { | ||
super(tag, description, ignoreFailure); | ||
this.modelId = modelId; | ||
this.neuralFieldDefaultIdMap = neuralFieldDefaultIdMap; | ||
} | ||
|
||
/** | ||
* Processes the Search Request. | ||
* | ||
* @return The Search Request. | ||
*/ | ||
@Override | ||
public SearchRequest processRequest(SearchRequest searchRequest) { | ||
QueryBuilder queryBuilder = searchRequest.source().query(); | ||
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap)); | ||
return searchRequest; | ||
} | ||
|
||
public static class Factory implements Processor.Factory<SearchRequestProcessor> { | ||
private static final String DEFAULT_MODEL_ID = "default_model_id"; | ||
private static final String NEURAL_FIELD_DEFAULT_ID = "neural_field_default_id"; | ||
|
||
/** | ||
* Create the processor object. | ||
* | ||
* @return NeuralQueryEnricherProcessor | ||
*/ | ||
@Override | ||
public NeuralQueryEnricherProcessor create( | ||
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories, | ||
String tag, | ||
String description, | ||
boolean ignoreFailure, | ||
Map<String, Object> config, | ||
PipelineContext pipelineContext | ||
) throws IllegalArgumentException { | ||
String modelId = readOptionalStringProperty(TYPE, tag, config, DEFAULT_MODEL_ID); | ||
Map<String, Object> neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID); | ||
|
||
if (modelId == null && neuralInfoMap == null) { | ||
throw new IllegalArgumentException("model Id or neural info map either of them should be provided"); | ||
} | ||
|
||
return new NeuralQueryEnricherProcessor(tag, description, ignoreFailure, modelId, neuralInfoMap); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.neuralsearch.query.visitor; | ||
|
||
import java.util.Map; | ||
|
||
import lombok.AllArgsConstructor; | ||
|
||
import org.apache.lucene.search.BooleanClause; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.index.query.QueryBuilderVisitor; | ||
import org.opensearch.neuralsearch.query.NeuralQueryBuilder; | ||
|
||
/** | ||
* Neural Search Query Visitor. It visits each and every component of query buikder tree. | ||
*/ | ||
@AllArgsConstructor | ||
public class NeuralSearchQueryVisitor implements QueryBuilderVisitor { | ||
|
||
private final String modelId; | ||
private final Map<String, Object> neuralFieldMap; | ||
|
||
/** | ||
* Accept method accepts every query builder from the search request, | ||
* and processes it if the required conditions in accept method are satisfied. | ||
*/ | ||
@Override | ||
public void accept(QueryBuilder queryBuilder) { | ||
if (queryBuilder instanceof NeuralQueryBuilder) { | ||
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder; | ||
if (neuralQueryBuilder.modelId() == null) { | ||
if (neuralFieldMap != null | ||
&& neuralQueryBuilder.fieldName() != null | ||
&& neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) { | ||
String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName()); | ||
neuralQueryBuilder.modelId(fieldDefaultModelId); | ||
} else if (modelId != null) { | ||
neuralQueryBuilder.modelId(modelId); | ||
} else { | ||
throw new IllegalArgumentException( | ||
"model id must be provided in neural query or a default model id must be set in search request processor" | ||
); | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Retrieves the child visitor from the Visitor object. | ||
* | ||
* @return The sub Query Visitor | ||
*/ | ||
@Override | ||
public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) { | ||
return this; | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.neuralsearch.util; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.NoArgsConstructor; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
import org.opensearch.Version; | ||
import org.opensearch.cluster.service.ClusterService; | ||
|
||
/** | ||
* Class abstracts information related to underlying OpenSearch cluster | ||
*/ | ||
@NoArgsConstructor(access = AccessLevel.PRIVATE) | ||
@Log4j2 | ||
public class NeuralSearchClusterUtil { | ||
private ClusterService clusterService; | ||
|
||
private static NeuralSearchClusterUtil instance; | ||
|
||
/** | ||
* Return instance of the cluster context, must be initialized first for proper usage | ||
* @return instance of cluster context | ||
*/ | ||
public static synchronized NeuralSearchClusterUtil instance() { | ||
if (instance == null) { | ||
instance = new NeuralSearchClusterUtil(); | ||
} | ||
return instance; | ||
} | ||
|
||
/** | ||
* Initializes instance of cluster context by injecting dependencies | ||
* @param clusterService | ||
*/ | ||
public void initialize(final ClusterService clusterService) { | ||
this.clusterService = clusterService; | ||
} | ||
|
||
/** | ||
* Return minimal OpenSearch version based on all nodes currently discoverable in the cluster | ||
* @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version | ||
*/ | ||
public Version getClusterMinVersion() { | ||
return this.clusterService.state().getNodes().getMinNodeVersion(); | ||
} | ||
|
||
} |
Oops, something went wrong.