forked from opensearch-project/neural-search
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Reranker feature (opensearch-project#591)
* Adding support for generic re-ranker interface and opensearch ml re-ranker for improving search relavancy. (opensearch-project#494) Signed-off-by: HenryL27 <[email protected]> Co-authored-by: Heemin Kim <[email protected]> Signed-off-by: Martin Gaievski <[email protected]> --------- Signed-off-by: HenryL27 <[email protected]> Signed-off-by: Martin Gaievski <[email protected]> Co-authored-by: HenryL27 <[email protected]> Co-authored-by: Heemin Kim <[email protected]> (cherry picked from commit 1bb48e2)
- Loading branch information
1 parent
ac04063
commit 0c5f387
Showing
21 changed files
with
1,999 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor.factory; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.Set; | ||
import java.util.StringJoiner; | ||
|
||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.ingest.ConfigurationUtils; | ||
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; | ||
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; | ||
import org.opensearch.neuralsearch.processor.rerank.RerankType; | ||
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; | ||
import org.opensearch.search.pipeline.Processor; | ||
import org.opensearch.search.pipeline.SearchResponseProcessor; | ||
|
||
import com.google.common.collect.Sets; | ||
|
||
import lombok.AllArgsConstructor; | ||
|
||
/** | ||
* Factory for rerank processors. Must: | ||
* - Instantiate the right kind of rerank processor | ||
* - Instantiate the appropriate context source fetchers | ||
*/ | ||
@AllArgsConstructor | ||
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> { | ||
|
||
public static final String RERANK_PROCESSOR_TYPE = "rerank"; | ||
public static final String CONTEXT_CONFIG_FIELD = "context"; | ||
|
||
private final MLCommonsClientAccessor clientAccessor; | ||
private final ClusterService clusterService; | ||
|
||
@Override | ||
public SearchResponseProcessor create( | ||
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, | ||
final String tag, | ||
final String description, | ||
final boolean ignoreFailure, | ||
final Map<String, Object> config, | ||
final Processor.PipelineContext pipelineContext | ||
) { | ||
RerankType type = findRerankType(config); | ||
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type); | ||
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers( | ||
config, | ||
includeQueryContextFetcher, | ||
tag, | ||
clusterService | ||
); | ||
switch (type) { | ||
case ML_OPENSEARCH: | ||
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel()); | ||
String modelId = ConfigurationUtils.readStringProperty( | ||
RERANK_PROCESSOR_TYPE, | ||
tag, | ||
rerankerConfig, | ||
MLOpenSearchRerankProcessor.MODEL_ID_FIELD | ||
); | ||
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); | ||
default: | ||
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel())); | ||
} | ||
} | ||
|
||
private RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException { | ||
// Set of rerank type labels in the config | ||
Set<String> rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet()); | ||
// A rerank type must be provided | ||
if (rerankTypes.size() == 0) { | ||
StringJoiner msgBuilder = new StringJoiner(", ", "No rerank type found. Possible rerank types are: [", "]"); | ||
for (RerankType t : RerankType.values()) { | ||
msgBuilder.add(t.getLabel()); | ||
} | ||
throw new IllegalArgumentException(msgBuilder.toString()); | ||
} | ||
// Only one rerank type may be provided | ||
if (rerankTypes.size() > 1) { | ||
StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted."); | ||
rerankTypes.forEach(rt -> msgBuilder.add(rt)); | ||
throw new IllegalArgumentException(msgBuilder.toString()); | ||
} | ||
return RerankType.from(rerankTypes.iterator().next()); | ||
} | ||
|
||
/** | ||
* Factory class for context fetchers. Constructs a list of context fetchers | ||
* specified in the pipeline config (and maybe the query context fetcher) | ||
*/ | ||
private static class ContextFetcherFactory { | ||
|
||
/** | ||
* Map rerank types to whether they should include the query context source fetcher | ||
* @param type the constructing RerankType | ||
* @return does this RerankType depend on the QueryContextSourceFetcher? | ||
*/ | ||
public static boolean shouldIncludeQueryContextFetcher(RerankType type) { | ||
return type == RerankType.ML_OPENSEARCH; | ||
} | ||
|
||
/** | ||
* Create necessary queryContextFetchers for this processor | ||
* @param config processor config object. Look for "context" field to find fetchers | ||
* @param includeQueryContextFetcher should I include the queryContextFetcher? | ||
* @return list of contextFetchers for the processor to use | ||
*/ | ||
public static List<ContextSourceFetcher> createFetchers( | ||
Map<String, Object> config, | ||
boolean includeQueryContextFetcher, | ||
String tag, | ||
final ClusterService clusterService | ||
) { | ||
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD); | ||
List<ContextSourceFetcher> fetchers = new ArrayList<>(); | ||
for (String key : contextConfig.keySet()) { | ||
Object cfg = contextConfig.get(key); | ||
switch (key) { | ||
case DocumentContextSourceFetcher.NAME: | ||
fetchers.add(DocumentContextSourceFetcher.create(cfg, clusterService)); | ||
break; | ||
default: | ||
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key)); | ||
} | ||
} | ||
if (includeQueryContextFetcher) { | ||
fetchers.add(new QueryContextSourceFetcher(clusterService)); | ||
} | ||
return fetchers; | ||
} | ||
} | ||
} |
83 changes: 83 additions & 0 deletions
83
src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor.rerank; | ||
|
||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; | ||
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; | ||
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; | ||
|
||
/** | ||
* Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore | ||
*/ | ||
public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor { | ||
|
||
public static final String MODEL_ID_FIELD = "model_id"; | ||
|
||
protected final String modelId; | ||
|
||
protected final MLCommonsClientAccessor mlCommonsClientAccessor; | ||
|
||
/** | ||
* Constructor | ||
* @param description | ||
* @param tag | ||
* @param ignoreFailure | ||
* @param modelId id of TEXT_SIMILARITY model | ||
* @param contextSourceFetchers | ||
* @param mlCommonsClientAccessor | ||
*/ | ||
public MLOpenSearchRerankProcessor( | ||
final String description, | ||
final String tag, | ||
final boolean ignoreFailure, | ||
final String modelId, | ||
final List<ContextSourceFetcher> contextSourceFetchers, | ||
final MLCommonsClientAccessor mlCommonsClientAccessor | ||
) { | ||
super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers); | ||
this.modelId = modelId; | ||
this.mlCommonsClientAccessor = mlCommonsClientAccessor; | ||
} | ||
|
||
@Override | ||
public void rescoreSearchResponse( | ||
final SearchResponse response, | ||
final Map<String, Object> rerankingContext, | ||
final ActionListener<List<Float>> listener | ||
) { | ||
Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD); | ||
if (!(ctxObj instanceof List<?>)) { | ||
listener.onFailure( | ||
new IllegalStateException( | ||
String.format( | ||
Locale.ROOT, | ||
"No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?", | ||
RerankProcessorFactory.CONTEXT_CONFIG_FIELD, | ||
DocumentContextSourceFetcher.NAME | ||
) | ||
) | ||
); | ||
return; | ||
} | ||
List<?> ctxList = (List<?>) ctxObj; | ||
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList()); | ||
mlCommonsClientAccessor.inferenceSimilarity( | ||
modelId, | ||
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD), | ||
contexts, | ||
listener | ||
); | ||
} | ||
|
||
} |
Oops, something went wrong.