From bdf5d9b80cd5ab2a642c5126326807cf8df469fa Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 13 Nov 2023 18:03:36 -0800 Subject: [PATCH 01/27] Add rerank processor interfaces Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 38 +++++ .../rerank/CrossEncoderRerankProcessor.java | 115 +++++++++++++++ .../processor/rerank/RerankProcessor.java | 64 +++++++++ .../processor/rerank/RerankType.java | 48 +++++++ .../rerank/RescoringRerankProcessor.java | 136 ++++++++++++++++++ .../query/ext/RerankSearchExtBuilder.java | 98 +++++++++++++ 6 files changed, 499 insertions(+) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java new file mode 100644 index 000000000..7449743ff --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -0,0 +1,38 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.factory; + +import java.util.Map; + +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +public class RerankProcessorFactory implements Processor.Factory { + + @Override + public SearchResponseProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) { + return null; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java new file mode 100644 index 000000000..ea2152378 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ObjectPath; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; +import org.opensearch.search.SearchExtBuilder; + +public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { + + public static final String MODEL_ID_FIELD = "model_id"; + public static final String QUERY_TEXT_FIELD = "query_text"; + public static final String QUERY_TEXT_PATH_FIELD = "query_text_path"; + public static final String RERANK_CONTEXT_FIELD = "rerank_context_field"; + + protected final String modelId; + protected final String rerank_context; + + protected final MLCommonsClientAccessor mlCommonsClientAccessor; + + private final Environment environment; + + public CrossEncoderRerankProcessor( + String description, + String tag, + boolean ignoreFailure, + String modelId, + String rerank_context, + MLCommonsClientAccessor mlCommonsClientAccessor, + Environment environment + ) { + super(RerankType.CROSS_ENCODER, description, tag, ignoreFailure); + this.modelId = modelId; + this.rerank_context = rerank_context; + this.mlCommonsClientAccessor = mlCommonsClientAccessor; + this.environment = environment; + } + + @Override + public void generateScoringContext( + SearchRequest searchRequest, + SearchResponse searchResponse, + ActionListener> listener + ) { + try { + List exts = searchRequest.source().ext(); + Map params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams(); + Map scoringContext = new HashMap<>(); + if (params.containsKey(QUERY_TEXT_FIELD)) { + if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { + throw new IllegalArgumentException("Cannot specify both \"query_text\" and \"query_text_path\""); + } + scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); + } else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { + String path = (String) params.get(QUERY_TEXT_PATH_FIELD); + // Convert query to a map with io/xcontent shenanigans + PipedOutputStream os = new PipedOutputStream(); + XContentBuilder builder = XContentType.CBOR.contentBuilder(os); + searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS); + PipedInputStream is = new PipedInputStream(os); + XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, is); + Map map = parser.map(); + // Get the text at the path + Object queryText = ObjectPath.eval(path, map); + if (!(queryText instanceof String)) { + throw new IllegalArgumentException("query_text_path must point to a string field"); + } + scoringContext.put(QUERY_TEXT_FIELD, (String) queryText); + } else { + throw new IllegalArgumentException("Must specify either \"query_text\" or \"query_text_path\""); + } + listener.onResponse(scoringContext); + } catch (Exception e) { + listener.onFailure(e); + } + } + + @Override + public void rescoreSearchResponse(SearchResponse response, Map scoringContext, ActionListener> listener) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'rescoreSearchResponse'"); + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java new file mode 100644 index 000000000..62ab61da4 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.Map; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +public interface RerankProcessor extends SearchResponseProcessor { + + /** + * Generate the information that this processor needs in order to rerank. + * That could be as simple as grabbing a field from the search request or + * as complicated as a lookup to some external service + * @param searchRequest the search query + * @param searchResponse the search results, in case they're relevant + * @param listener be async + */ + public void generateScoringContext( + SearchRequest searchRequest, + SearchResponse searchResponse, + ActionListener> listener + ); + + /** + * Given the scoring context generated by the processor and the search results, + * rerank the search results. Do so asynchronously. + * @param searchResponse the search results to rerank + * @param scoringContext the information this processor needs in order to rerank + * @param listener be async + */ + public void rerank(SearchResponse searchResponse, Map scoringContext, ActionListener listener); + + @Override + default void processResponseAsync(SearchRequest request, SearchResponse response, ActionListener responseListener) { + try { + generateScoringContext( + request, + response, + ActionListener.wrap(context -> { rerank(response, context, responseListener); }, e -> { responseListener.onFailure(e); }) + ); + } catch (Exception e) { + responseListener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java new file mode 100644 index 000000000..6bfb9feed --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import lombok.Getter; + +/** + * enum for distinguishing various reranking methods + */ +public enum RerankType { + + CROSS_ENCODER("cross-encoder"); + + @Getter + private final String label; + + private RerankType(String label) { + this.label = label; + } + + /** + * Construct a RerankType from the label + * @param label label of a RerankType + * @return RerankType represented by the label + */ + public static RerankType from(String label) { + try { + return RerankType.valueOf(label); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong rerank type name: " + label); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java new file mode 100644 index 000000000..f88c02d0d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -0,0 +1,136 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import lombok.AllArgsConstructor; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.profile.SearchProfileShardResults; + +@AllArgsConstructor +public abstract class RescoringRerankProcessor implements RerankProcessor { + + private final RerankType type; + private final String description; + private final String tag; + private final boolean ignoreFailure; + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + throw new UnsupportedOperationException("Use asyncProcessResponse unless you can guarantee to not deadlock yourself"); + } + + @Override + public String getType() { + return "rerank-" + type.getLabel(); + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public boolean isIgnoreFailure() { + return ignoreFailure; + } + + /** + * Generate a list of new scores for all of the documents, given the scoring context + * @param response search results to rescore + * @param scoringContext extra information needed to score the search results; e.g. model id + * @param listener be async. recieves the list of new scores + */ + public abstract void rescoreSearchResponse( + SearchResponse response, + Map scoringContext, + ActionListener> listener + ); + + @Override + public void rerank(SearchResponse searchResponse, Map scoringContext, ActionListener listener) { + try { + rescoreSearchResponse(searchResponse, scoringContext, ActionListener.wrap(scores -> { + // Assign new scores + SearchHit[] hits = searchResponse.getHits().getHits(); + assert (hits.length == scores.size()); + for (int i = 0; i < hits.length; i++) { + hits[i].score(scores.get(i)); + } + // Re-sort by the new scores + Collections.sort(Arrays.asList(hits), new Comparator() { + @Override + public int compare(SearchHit hit1, SearchHit hit2) { + return Float.compare(hit1.getScore(), hit2.getScore()); + } + }); + // Reconstruct the search response, replacing the max score + SearchHits newHits = new SearchHits( + hits, + searchResponse.getHits().getTotalHits(), + hits[0].getScore(), + searchResponse.getHits().getSortFields(), + searchResponse.getHits().getCollapseField(), + searchResponse.getHits().getCollapseValues() + ); + SearchResponseSections newInternalResponse = new SearchResponseSections( + newHits, + searchResponse.getAggregations(), + searchResponse.getSuggest(), + searchResponse.isTimedOut(), + searchResponse.isTerminatedEarly(), + new SearchProfileShardResults(searchResponse.getProfileResults()), + searchResponse.getNumReducePhases(), + searchResponse.getInternalResponse().getSearchExtBuilders() + ); + SearchResponse newResponse = new SearchResponse( + newInternalResponse, + searchResponse.getScrollId(), + searchResponse.getTotalShards(), + searchResponse.getSuccessfulShards(), + searchResponse.getSkippedShards(), + searchResponse.getTook().millis(), + searchResponse.getPhaseTook(), + searchResponse.getShardFailures(), + searchResponse.getClusters(), + searchResponse.pointInTimeId() + ); + listener.onResponse(newResponse); + }, e -> { listener.onFailure(e); })); + } catch (Exception e) { + listener.onFailure(e); + } + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java new file mode 100644 index 000000000..ad3756aa8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -0,0 +1,98 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.query.ext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; + +@AllArgsConstructor +public class RerankSearchExtBuilder extends SearchExtBuilder { + + public final static String PARAM_FIELD_NAME = "rerank"; + @Getter + protected Map params; + + public RerankSearchExtBuilder(StreamInput in) throws IOException { + params = in.readMap(); + } + + @Override + public String getWriteableName() { + return PARAM_FIELD_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(params); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PARAM_FIELD_NAME, this.params); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.params); + } + + @Override + public boolean equals(Object obj) { + return (obj instanceof RerankSearchExtBuilder) && params.equals(((RerankSearchExtBuilder) obj).params); + } + + /** + * Pick out the first RerankSearchExtBuilder from a list of SearchExtBuilders + * @param builders list of SearchExtBuilders + * @return the RerankSearchExtBuilder + */ + public static RerankSearchExtBuilder fromExtBuilderList(List builders) { + Optional b = builders.stream().filter(bldr -> bldr instanceof RerankSearchExtBuilder).findFirst(); + if (b.isPresent()) { + return (RerankSearchExtBuilder) b.get(); + } else { + return null; + } + } + + /** + * Parse XContent to rerankSearchExtBuilder + * @param parser parser parsing this searchExt + * @return RerankSearchExtBuilder represented by this searchExt + * @throws IOException if problems parsing + */ + public static RerankSearchExtBuilder parse(XContentParser parser) throws IOException { + return new RerankSearchExtBuilder(parser.map()); + } + +} From 17cff659c8cf09d2f96b36b03644eca80b306a55 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 14 Nov 2023 11:45:02 -0800 Subject: [PATCH 02/27] add cross-encoder specific logic and factory Signed-off-by: HenryL27 --- .../ml/MLCommonsClientAccessor.java | 55 +++++++++++++++++++ .../neuralsearch/plugin/NeuralSearch.java | 22 ++++++++ .../factory/RerankProcessorFactory.java | 35 +++++++++++- .../rerank/CrossEncoderRerankProcessor.java | 32 +++++++---- .../processor/rerank/RerankProcessor.java | 2 + .../rerank/RescoringRerankProcessor.java | 2 +- 6 files changed, 135 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index e12211d28..ea9b7f0d8 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -13,12 +13,18 @@ import java.util.Map; import java.util.stream.Collectors; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -132,6 +138,25 @@ public void inferenceSentences( retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); } + /** + * Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the + * {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing + * the similarity scores of the texts w.r.t. the query text, in the order of the input texts. + * + * @param modelId {@link String} ML-Commons Model Id + * @param queryText {@link String} The query to compare all the inputText to + * @param inputText {@link List} of {@link String} The texts to compare to the query + * @param listener {@link ActionListener} receives the result of the inference + */ + public void inferenceSimilarity( + @NonNull final String modelId, + @NonNull final String queryText, + @NonNull final List inputText, + @NonNull final ActionListener> listener + ) { + retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener); + } + private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, @@ -173,12 +198,42 @@ private void retryableInferenceSentencesWithVectorResult( })); } + private void retryableInferenceSimilarityWithVectorResult( + final String modelId, + final String queryText, + final List inputText, + final int retryTime, + final ActionListener> listener + ) { + MLInput mlInput = createMLTextPairsInput(queryText, inputText); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList()); + listener.onResponse(scores); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener); + } else { + listener.onFailure(e); + } + })); + } + private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); } + private MLInput createMLTextPairsInput(final List> pairs) { + final MLInputDataset inputDataset = new TextSimilarityInputDataSet(pairs); + return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset); + } + + private MLInput createMLTextPairsInput(final String query, final List inputText) { + List> pairs = inputText.stream().map(text -> Pair.of(query, text)).collect(Collectors.toList()); + return createMLTextPairsInput(pairs); + } + private List> buildVectorFromResponse(MLOutput mlOutput) { final List> vector = new ArrayList<>(); final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index aacb8d2e6..a77118f43 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -34,14 +34,17 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.plugins.ActionPlugin; @@ -54,6 +57,7 @@ import org.opensearch.script.ScriptService; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; @@ -150,4 +154,22 @@ public Map> getResponseProcessors( + Parameters parameters + ) { + return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor)); + } + + @Override + public List> getSearchExts() { + return List.of( + new SearchExtSpec<>( + RerankSearchExtBuilder.PARAM_FIELD_NAME, + in -> new RerankSearchExtBuilder(in), + parser -> RerankSearchExtBuilder.parse(parser) + ) + ); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 7449743ff..ed1d56b4b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -19,11 +19,21 @@ import java.util.Map; +import lombok.AllArgsConstructor; + +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.RerankType; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; +@AllArgsConstructor public class RerankProcessorFactory implements Processor.Factory { + public static final String RERANK_PROCESSOR_TYPE = "rerank"; + + private final MLCommonsClientAccessor clientAccessor; + @Override public SearchResponseProcessor create( final Map> processorFactories, @@ -33,6 +43,29 @@ public SearchResponseProcessor create( final Map config, final Processor.PipelineContext pipelineContext ) { - return null; + RerankType type = findRerankType(config); + switch (type) { + case CROSS_ENCODER: + @SuppressWarnings("unchecked") + Map rerankerConfig = (Map) config.get(type.getLabel()); + String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD); + String rerankContext = rerankerConfig.get(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD); + return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, rerankContext, clientAccessor); + default: + throw new IllegalArgumentException("could not find constructor for reranker type " + type.getLabel()); + } + } + + private RerankType findRerankType(final Map config) throws IllegalArgumentException { + for (String key : config.keySet()) { + try { + RerankType attempt = RerankType.from(key); + return attempt; + } catch (IllegalArgumentException e) { + // Assume it's just a different field in the config, so don't do anything. + // If we get to the end and there were no valid RerankTypes, then we can panic. + } + } + throw new IllegalArgumentException("no rerank type found"); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java index ea2152378..61193ea36 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java @@ -19,6 +19,7 @@ import java.io.PipedInputStream; import java.io.PipedOutputStream; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,10 +33,10 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.env.Environment; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.SearchHit; public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { @@ -45,26 +46,22 @@ public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { public static final String RERANK_CONTEXT_FIELD = "rerank_context_field"; protected final String modelId; - protected final String rerank_context; + protected final String rerankContext; protected final MLCommonsClientAccessor mlCommonsClientAccessor; - private final Environment environment; - public CrossEncoderRerankProcessor( String description, String tag, boolean ignoreFailure, String modelId, - String rerank_context, - MLCommonsClientAccessor mlCommonsClientAccessor, - Environment environment + String rerankContext, + MLCommonsClientAccessor mlCommonsClientAccessor ) { super(RerankType.CROSS_ENCODER, description, tag, ignoreFailure); this.modelId = modelId; - this.rerank_context = rerank_context; + this.rerankContext = rerankContext; this.mlCommonsClientAccessor = mlCommonsClientAccessor; - this.environment = environment; } @Override @@ -108,8 +105,21 @@ public void generateScoringContext( @Override public void rescoreSearchResponse(SearchResponse response, Map scoringContext, ActionListener> listener) { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'rescoreSearchResponse'"); + List contexts = new ArrayList<>(); + for (SearchHit hit : response.getHits()) { + contexts.add(contextFromSearchHit(hit)); + } + mlCommonsClientAccessor.inferenceSimilarity(modelId, (String) scoringContext.get(QUERY_TEXT_FIELD), contexts, listener); + } + + private String contextFromSearchHit(final SearchHit hit) { + if (hit.getFields().containsKey(this.rerankContext)) { + return (String) hit.field(this.rerankContext).getValue(); + } else if (hit.getSourceAsMap().containsKey(this.rerankContext)) { + return (String) hit.getSourceAsMap().get(this.rerankContext); + } else { + return null; + } } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 62ab61da4..d458c0ca2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -26,6 +26,8 @@ public interface RerankProcessor extends SearchResponseProcessor { + public static final String TYPE = "rerank"; + /** * Generate the information that this processor needs in order to rerank. * That could be as simple as grabbing a field from the search request or diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index f88c02d0d..907c26c5d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -48,7 +48,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp @Override public String getType() { - return "rerank-" + type.getLabel(); + return TYPE; } @Override From 8d476dba02d87936799666f943b0595ffbeb6b0e Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 16 Nov 2023 15:41:22 -0800 Subject: [PATCH 03/27] add unittests Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 11 +- .../rerank/CrossEncoderRerankProcessor.java | 26 +- .../processor/rerank/RerankType.java | 10 +- .../rerank/RescoringRerankProcessor.java | 7 +- .../query/ext/RerankSearchExtBuilder.java | 9 +- .../ml/MLCommonsClientAccessorTests.java | 82 +++++ .../factory/RerankProcessorFactoryTests.java | 144 +++++++++ .../CrossEncoderRerankProcessorTests.java | 304 ++++++++++++++++++ .../ext/RerankSearchExtBuilderTests.java | 102 ++++++ .../neuralsearch/BaseNeuralSearchIT.java | 17 + 10 files changed, 691 insertions(+), 21 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index ed1d56b4b..03e7c8154 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -27,6 +27,8 @@ import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; +import com.google.common.annotations.VisibleForTesting; + @AllArgsConstructor public class RerankProcessorFactory implements Processor.Factory { @@ -49,14 +51,21 @@ public SearchResponseProcessor create( @SuppressWarnings("unchecked") Map rerankerConfig = (Map) config.get(type.getLabel()); String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD); + if (modelId == null) { + throw new IllegalArgumentException(CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified"); + } String rerankContext = rerankerConfig.get(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD); + if (rerankContext == null) { + throw new IllegalArgumentException(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified"); + } return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, rerankContext, clientAccessor); default: throw new IllegalArgumentException("could not find constructor for reranker type " + type.getLabel()); } } - private RerankType findRerankType(final Map config) throws IllegalArgumentException { + @VisibleForTesting + RerankType findRerankType(final Map config) throws IllegalArgumentException { for (String key : config.keySet()) { try { RerankType attempt = RerankType.from(key); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java index 61193ea36..3c60e6570 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java @@ -17,13 +17,15 @@ */ package org.opensearch.neuralsearch.processor.rerank; -import java.io.PipedInputStream; -import java.io.PipedOutputStream; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import lombok.extern.log4j.Log4j2; + import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.xcontent.XContentType; @@ -38,6 +40,7 @@ import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHit; +@Log4j2 public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { public static final String MODEL_ID_FIELD = "model_id"; @@ -76,26 +79,29 @@ public void generateScoringContext( Map scoringContext = new HashMap<>(); if (params.containsKey(QUERY_TEXT_FIELD)) { if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { - throw new IllegalArgumentException("Cannot specify both \"query_text\" and \"query_text_path\""); + throw new IllegalArgumentException( + "Cannot specify both \"" + QUERY_TEXT_FIELD + "\" and \"" + QUERY_TEXT_PATH_FIELD + "\"" + ); } scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); } else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { String path = (String) params.get(QUERY_TEXT_PATH_FIELD); // Convert query to a map with io/xcontent shenanigans - PipedOutputStream os = new PipedOutputStream(); - XContentBuilder builder = XContentType.CBOR.contentBuilder(os); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS); - PipedInputStream is = new PipedInputStream(os); - XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, is); + builder.close(); + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais); Map map = parser.map(); // Get the text at the path Object queryText = ObjectPath.eval(path, map); if (!(queryText instanceof String)) { - throw new IllegalArgumentException("query_text_path must point to a string field"); + throw new IllegalArgumentException(QUERY_TEXT_PATH_FIELD + " must point to a string field"); } scoringContext.put(QUERY_TEXT_FIELD, (String) queryText); } else { - throw new IllegalArgumentException("Must specify either \"query_text\" or \"query_text_path\""); + throw new IllegalArgumentException("Must specify either \"" + QUERY_TEXT_FIELD + "\" or \"" + QUERY_TEXT_PATH_FIELD + "\""); } listener.onResponse(scoringContext); } catch (Exception e) { @@ -115,7 +121,7 @@ public void rescoreSearchResponse(SearchResponse response, Map s private String contextFromSearchHit(final SearchHit hit) { if (hit.getFields().containsKey(this.rerankContext)) { return (String) hit.field(this.rerankContext).getValue(); - } else if (hit.getSourceAsMap().containsKey(this.rerankContext)) { + } else if (hit.hasSource() && hit.getSourceAsMap().containsKey(this.rerankContext)) { return (String) hit.getSourceAsMap().get(this.rerankContext); } else { return null; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index 6bfb9feed..e474c4b11 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -17,6 +17,9 @@ */ package org.opensearch.neuralsearch.processor.rerank; +import java.util.Arrays; +import java.util.Optional; + import lombok.Getter; /** @@ -39,9 +42,10 @@ private RerankType(String label) { * @return RerankType represented by the label */ public static RerankType from(String label) { - try { - return RerankType.valueOf(label); - } catch (Exception e) { + Optional typeMaybe = Arrays.stream(RerankType.values()).filter(rrt -> rrt.label.equals(label)).findFirst(); + if (typeMaybe.isPresent()) { + return typeMaybe.get(); + } else { throw new IllegalArgumentException("Wrong rerank type name: " + label); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index 907c26c5d..c1479b2c9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -84,7 +84,9 @@ public void rerank(SearchResponse searchResponse, Map scoringCon rescoreSearchResponse(searchResponse, scoringContext, ActionListener.wrap(scores -> { // Assign new scores SearchHit[] hits = searchResponse.getHits().getHits(); - assert (hits.length == scores.size()); + if (hits.length != scores.size()) { + throw new Exception("scores and hits are not the same length"); + } for (int i = 0; i < hits.length; i++) { hits[i].score(scores.get(i)); } @@ -92,7 +94,8 @@ public void rerank(SearchResponse searchResponse, Map scoringCon Collections.sort(Arrays.asList(hits), new Comparator() { @Override public int compare(SearchHit hit1, SearchHit hit2) { - return Float.compare(hit1.getScore(), hit2.getScore()); + // backwards to sort DESC + return Float.compare(hit2.getScore(), hit1.getScore()); } }); // Reconstruct the search response, replacing the max score diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java index ad3756aa8..56623f768 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -55,10 +55,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(PARAM_FIELD_NAME, this.params); - builder.endObject(); - return builder; + return builder.field(PARAM_FIELD_NAME, this.params); } @Override @@ -92,7 +89,9 @@ public static RerankSearchExtBuilder fromExtBuilderList(List b * @throws IOException if problems parsing */ public static RerankSearchExtBuilder parse(XContentParser parser) throws IOException { - return new RerankSearchExtBuilder(parser.map()); + @SuppressWarnings("unchecked") + RerankSearchExtBuilder ans = new RerankSearchExtBuilder((Map) parser.map().get(PARAM_FIELD_NAME)); + return ans; } } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index d9b816597..3749e63dc 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -325,6 +325,71 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); } + public void testInferenceSimilarity_whenValidInput_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createManyModelTensorOutputs(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { + final RuntimeException exception = new RuntimeException(); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(exception); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -352,4 +417,21 @@ private ModelTensorOutput createModelTensorOutput(final Map map) tensorsList.add(modelTensors); return new ModelTensorOutput(tensorsList); } + + private ModelTensorOutput createManyModelTensorOutputs(final Float[] output) { + final List tensorsList = new ArrayList<>(); + for (Float score : output) { + List tensorList = new ArrayList<>(); + String name = "logits"; + Number[] data = new Number[] { score }; + long[] shape = new long[] { 1 }; + MLResultDataType dataType = MLResultDataType.FLOAT32; + MLResultDataType mlResultDataType = MLResultDataType.valueOf(dataType.name()); + ModelTensor tensor = ModelTensor.builder().name(name).data(data).shape(shape).dataType(mlResultDataType).build(); + tensorList.add(tensor); + tensorsList.add(new ModelTensors(tensorList)); + } + ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList); + return modelTensorOutput; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java new file mode 100644 index 000000000..b663b3c93 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -0,0 +1,144 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; + +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +import org.junit.Before; +import org.mockito.Mock; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.RerankType; +import org.opensearch.search.pipeline.Processor.PipelineContext; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.test.OpenSearchTestCase; + +@Log4j2 +public class RerankProcessorFactoryTests extends OpenSearchTestCase { + + final String TAG = "default-tag"; + final String DESC = "processor description"; + + RerankProcessorFactory factory; + + @Mock + MLCommonsClientAccessor clientAccessor; + + @Mock + PipelineContext pipelineContext; + + @Before + public void setup() { + pipelineContext = mock(PipelineContext.class); + clientAccessor = mock(MLCommonsClientAccessor.class); + factory = new RerankProcessorFactory(clientAccessor); + } + + public void testRerankProcessorFactory_EmptyConfig_ThenFail() { + Map config = Map.of(); + assertThrows( + "no rerank type found", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_NonExistentType_ThenFail() { + Map config = Map.of("jpeo rvgh we iorgn", Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")); + assertThrows( + "no rerank type found", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_CrossEncoder_HappyPath() { + Map config = Map.of( + RerankType.CROSS_ENCODER.getLabel(), + Map.of( + CrossEncoderRerankProcessor.MODEL_ID_FIELD, + "model-id", + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, + "text_representation" + ) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + assert (processor instanceof RerankProcessor); + assert (processor instanceof CrossEncoderRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { + Map config = Map.of( + "poafn aorr;anv", + Map.of(";oawhls", "aowirhg "), + RerankType.CROSS_ENCODER.getLabel(), + Map.of( + CrossEncoderRerankProcessor.MODEL_ID_FIELD, + "model-id", + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, + "text_representation", + "pqiohg rpowierhg", + "pw;oith4pt3ih go" + ) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + assert (processor instanceof RerankProcessor); + assert (processor instanceof CrossEncoderRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { + Map config = Map.of(RerankType.CROSS_ENCODER.getLabel(), Map.of()); + assertThrows( + CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { + Map config = Map.of( + RerankType.CROSS_ENCODER.getLabel(), + Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id") + ); + assertThrows( + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { + Map config = Map.of( + RerankType.CROSS_ENCODER.getLabel(), + Map.of(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, "text_representation") + ); + assertThrows( + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java new file mode 100644 index 000000000..5bbb5c38c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java @@ -0,0 +1,304 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponse.Clusters; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.Processor.PipelineContext; +import org.opensearch.test.OpenSearchTestCase; + +@Log4j2 +public class CrossEncoderRerankProcessorTests extends OpenSearchTestCase { + + @Mock + SearchRequest request; + + SearchResponse response; + + @Mock + MLCommonsClientAccessor mlCommonsClientAccessor; + + @Mock + PipelineContext pipelineContext; + + RerankProcessorFactory factory; + + CrossEncoderRerankProcessor processor; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + factory = new RerankProcessorFactory(mlCommonsClientAccessor); + Map config = Map.of( + RerankType.CROSS_ENCODER.getLabel(), + Map.of( + CrossEncoderRerankProcessor.MODEL_ID_FIELD, + "model-id", + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, + "text_representation" + ) + ); + processor = (CrossEncoderRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for reranking with a cross encoder", + false, + config, + pipelineContext + ); + } + + private void setupParams(Map params) { + SearchSourceBuilder ssb = new SearchSourceBuilder(); + NeuralQueryBuilder nqb = new NeuralQueryBuilder(); + nqb.fieldName("embedding").k(3).modelId("embedding_id").queryText("Question about dolphins"); + ssb.query(nqb); + List exts = List.of(new RerankSearchExtBuilder(params)); + ssb.ext(exts); + doReturn(ssb).when(request).source(); + } + + private void setupSimilarityRescoring() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + List scores = List.of(1f, 2f, 3f); + listener.onResponse(scores); + return null; + }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + } + + private void setupSearchResults() throws IOException { + XContentBuilder sourceContent = JsonXContent.contentBuilder() + .startObject() + .field("text_representation", "source passage") + .endObject(); + SearchHit sourceHit = new SearchHit(0, "0", Map.of(), Map.of()); + sourceHit.sourceRef(BytesReference.bytes(sourceContent)); + sourceHit.score(1.5f); + + DocumentField field = new DocumentField("text_representation", List.of("field passage")); + SearchHit fieldHit = new SearchHit(1, "1", Map.of("text_representation", field), Map.of()); + fieldHit.score(1.7f); + + SearchHit nullHit = new SearchHit(2, "2", Map.of(), Map.of()); + nullHit.score(0f); + + SearchHit[] hitArray = new SearchHit[] { fieldHit, sourceHit, nullHit }; + + SearchHits searchHits = new SearchHits(hitArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new Clusters(1, 1, 0), null); + } + + public void testScoringContext_QueryText_ThenSucceed() { + setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text")); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().containsKey(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD)); + assert (argCaptor.getValue().get(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD).equals("query text")); + } + + public void testScoringContext_QueryTextPath_ThenSucceed() { + setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, "query.neural.embedding.query_text")); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().containsKey(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD)); + assert (argCaptor.getValue().get(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD).equals("Question about dolphins")); + } + + public void testScoringContext_QueryTextAndPath_ThenFail() { + setupParams( + Map.of( + CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, + "query.neural.embedding.query_text", + CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, + "query text" + ) + ); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals( + "Cannot specify both \"" + + CrossEncoderRerankProcessor.QUERY_TEXT_FIELD + + "\" and \"" + + CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD + + "\"" + )); + } + + public void testScoringContext_NoQueryInfo_ThenFail() { + setupParams(Map.of()); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals( + "Must specify either \"" + + CrossEncoderRerankProcessor.QUERY_TEXT_FIELD + + "\" or \"" + + CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD + + "\"" + )); + } + + public void testScoringContext_QueryTextPath_BadPointer_ThenFail() { + setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, "query.neural.embedding")); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD + " must point to a string field")); + } + + public void testRescoreSearchResponse_HappyPath() throws IOException { + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + Map scoringContext = Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text"); + processor.rescoreSearchResponse(response, scoringContext, listener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 3); + assert (argCaptor.getValue().get(0) == 1f); + assert (argCaptor.getValue().get(1) == 2f); + assert (argCaptor.getValue().get(2) == 3f); + } + + public void testRerank_HappyPath() throws IOException { + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map scoringContext = Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text"); + processor.rerank(response, scoringContext, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse rsp = argCaptor.getValue(); + assert (rsp.getHits().getAt(0).docId() == 2); + assert (rsp.getHits().getAt(0).getScore() == 3f); + assert (rsp.getHits().getAt(1).docId() == 0); + assert (rsp.getHits().getAt(1).getScore() == 2f); + assert (rsp.getHits().getAt(2).docId() == 1); + assert (rsp.getHits().getAt(2).getScore() == 1f); + } + + public void testRerank_ScoresAndHitsHaveDiffLengths() throws IOException { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + List scores = List.of(1f, 2f); + listener.onResponse(scores); + return null; + }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map scoringContext = Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text"); + processor.rerank(response, scoringContext, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("scores and hits are not the same length")); + } + + public void testBasics() throws IOException { + assert (processor.getTag().equals("rerank processor")); + assert (processor.getDescription().equals("processor for reranking with a cross encoder")); + assert (!processor.isIgnoreFailure()); + assertThrows( + "Use asyncProcessResponse unless you can guarantee to not deadlock yourself", + UnsupportedOperationException.class, + () -> processor.processResponse(request, response) + ); + } + + public void testProcessResponseAsync() throws IOException { + setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text")); + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + processor.processResponseAsync(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse rsp = argCaptor.getValue(); + assert (rsp.getHits().getAt(0).docId() == 2); + assert (rsp.getHits().getAt(0).getScore() == 3f); + assert (rsp.getHits().getAt(1).docId() == 0); + assert (rsp.getHits().getAt(1).getScore() == 2f); + assert (rsp.getHits().getAt(2).docId() == 1); + assert (rsp.getHits().getAt(2).getScore() == 1f); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java new file mode 100644 index 000000000..f6a22b675 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.query.ext; + +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.test.OpenSearchTestCase; + +@Log4j2 +public class RerankSearchExtBuilderTests extends OpenSearchTestCase { + + Map params; + + @Before + public void setup() { + params = Map.of("query_text", "question about the meaning of life, the universe, and everything"); + } + + public void testStreaming() throws IOException { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + b1.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + RerankSearchExtBuilder b2 = new RerankSearchExtBuilder(in); + assert (b2.getParams().equals(params)); + assert (b1.equals(b2)); + } + + public void testToXContent() throws IOException { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + XContentBuilder builder = XContentType.JSON.contentBuilder(); + builder.startObject(); + b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + builder.endObject(); + String extString = builder.toString(); + log.info(extString); + XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString); + RerankSearchExtBuilder b2 = RerankSearchExtBuilder.parse(parser); + assert (b2.getParams().equals(params)); + assert (b1.equals(b2)); + } + + public void testPullFromListOfExtBuilders() { + RerankSearchExtBuilder builder = new RerankSearchExtBuilder(params); + SearchExtBuilder otherBuilder = mock(SearchExtBuilder.class); + assert (!builder.equals(otherBuilder)); + List builders1 = List.of(otherBuilder, builder); + List builders2 = List.of(otherBuilder); + List builders3 = List.of(); + assert (RerankSearchExtBuilder.fromExtBuilderList(builders1).equals(builder)); + assert (RerankSearchExtBuilder.fromExtBuilderList(builders2) == null); + assert (RerankSearchExtBuilder.fromExtBuilderList(builders3) == null); + } + + public void testHash() { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + RerankSearchExtBuilder b2 = new RerankSearchExtBuilder(params); + RerankSearchExtBuilder b3 = new RerankSearchExtBuilder(Map.of()); + assert (b1.hashCode() == b2.hashCode()); + assert (b1.hashCode() != b3.hashCode()); + assert (!b1.equals(b3)); + } + + public void testWriteableName() { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + assert (b1.getWriteableName().equals(RerankSearchExtBuilder.PARAM_FIELD_NAME)); + } +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 680d90b65..4c1c17fb8 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -315,6 +315,23 @@ protected void createSearchRequestProcessor(String modelId, String pipelineName) assertEquals("true", node.get("acknowledged").toString()); } + protected void createSearchPipelineViaConfig(String modelId, String pipelineName, String configPath) throws Exception { + Response pipelineCreateResponse = makeRequest( + client(), + "PUT", + "/_search/pipeline/" + pipelineName, + null, + toHttpEntity( + String.format( + LOCALE, + Files.readString(Path.of(classLoader.getResource(configPath).toURI())), + modelId + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + /** * Get the number of documents in a particular index * From 4efa46370141d6618317cdd74bd69f158213e2a5 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Fri, 17 Nov 2023 18:12:19 -0800 Subject: [PATCH 04/27] add integration test Signed-off-by: HenryL27 --- .../ml/MLCommonsClientAccessor.java | 10 +- .../factory/RerankProcessorFactory.java | 4 +- .../rerank/RescoringRerankProcessor.java | 3 + .../query/ext/RerankSearchExtBuilder.java | 5 +- .../factory/RerankProcessorFactoryTests.java | 62 +++++--- .../rerank/CrossEncoderRerankProcessorIT.java | 143 ++++++++++++++++++ .../CrossEncoderRerankProcessorTests.java | 19 ++- .../ext/RerankSearchExtBuilderTests.java | 30 ++-- ...ossEncoderRerankPipelineConfiguration.json | 13 ++ .../UploadCrossEncoderModelRequestBody.json | 16 ++ .../neuralsearch/BaseNeuralSearchIT.java | 22 +-- 11 files changed, 256 insertions(+), 71 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorIT.java create mode 100644 src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json create mode 100644 src/test/resources/processor/UploadCrossEncoderModelRequestBody.json diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index ea9b7f0d8..a6170d308 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -17,7 +17,6 @@ import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; -import org.apache.commons.lang3.tuple.Pair; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -224,14 +223,9 @@ private MLInput createMLTextInput(final List targetResponseFilters, List return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); } - private MLInput createMLTextPairsInput(final List> pairs) { - final MLInputDataset inputDataset = new TextSimilarityInputDataSet(pairs); - return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset); - } - private MLInput createMLTextPairsInput(final String query, final List inputText) { - List> pairs = inputText.stream().map(text -> Pair.of(query, text)).collect(Collectors.toList()); - return createMLTextPairsInput(pairs); + final MLInputDataset inputDataset = new TextSimilarityInputDataSet(query, inputText); + return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset); } private List> buildVectorFromResponse(MLOutput mlOutput) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 03e7c8154..65c4a3b28 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -20,6 +20,7 @@ import java.util.Map; import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor; @@ -29,6 +30,7 @@ import com.google.common.annotations.VisibleForTesting; +@Log4j2 @AllArgsConstructor public class RerankProcessorFactory implements Processor.Factory { @@ -49,7 +51,7 @@ public SearchResponseProcessor create( switch (type) { case CROSS_ENCODER: @SuppressWarnings("unchecked") - Map rerankerConfig = (Map) config.get(type.getLabel()); + Map rerankerConfig = (Map) config.remove(type.getLabel()); String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD); if (modelId == null) { throw new IllegalArgumentException(CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified"); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index c1479b2c9..92a0d9610 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -24,6 +24,7 @@ import java.util.Map; import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -33,6 +34,7 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.profile.SearchProfileShardResults; +@Log4j2 @AllArgsConstructor public abstract class RescoringRerankProcessor implements RerankProcessor { @@ -80,6 +82,7 @@ public abstract void rescoreSearchResponse( @Override public void rerank(SearchResponse searchResponse, Map scoringContext, ActionListener listener) { + log.info("==================RERANKING=================="); try { rescoreSearchResponse(searchResponse, scoringContext, ActionListener.wrap(scores -> { // Assign new scores diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java index 56623f768..915b2c858 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -25,6 +25,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; +import lombok.extern.log4j.Log4j2; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -32,6 +33,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchExtBuilder; +@Log4j2 @AllArgsConstructor public class RerankSearchExtBuilder extends SearchExtBuilder { @@ -89,8 +91,7 @@ public static RerankSearchExtBuilder fromExtBuilderList(List b * @throws IOException if problems parsing */ public static RerankSearchExtBuilder parse(XContentParser parser) throws IOException { - @SuppressWarnings("unchecked") - RerankSearchExtBuilder ans = new RerankSearchExtBuilder((Map) parser.map().get(PARAM_FIELD_NAME)); + RerankSearchExtBuilder ans = new RerankSearchExtBuilder((Map) parser.map()); return ans; } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index b663b3c93..080137d35 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -19,6 +19,7 @@ import static org.mockito.Mockito.mock; +import java.util.HashMap; import java.util.Map; import lombok.extern.log4j.Log4j2; @@ -55,7 +56,7 @@ public void setup() { } public void testRerankProcessorFactory_EmptyConfig_ThenFail() { - Map config = Map.of(); + Map config = new HashMap<>(Map.of()); assertThrows( "no rerank type found", IllegalArgumentException.class, @@ -64,7 +65,9 @@ public void testRerankProcessorFactory_EmptyConfig_ThenFail() { } public void testRerankProcessorFactory_NonExistentType_ThenFail() { - Map config = Map.of("jpeo rvgh we iorgn", Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")); + Map config = new HashMap<>( + Map.of("jpeo rvgh we iorgn", Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")) + ); assertThrows( "no rerank type found", IllegalArgumentException.class, @@ -73,13 +76,17 @@ public void testRerankProcessorFactory_NonExistentType_ThenFail() { } public void testRerankProcessorFactory_CrossEncoder_HappyPath() { - Map config = Map.of( - RerankType.CROSS_ENCODER.getLabel(), + Map config = new HashMap<>( Map.of( - CrossEncoderRerankProcessor.MODEL_ID_FIELD, - "model-id", - CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, - "text_representation" + RerankType.CROSS_ENCODER.getLabel(), + new HashMap<>( + Map.of( + CrossEncoderRerankProcessor.MODEL_ID_FIELD, + "model-id", + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, + "text_representation" + ) + ) ) ); SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); @@ -89,17 +96,21 @@ public void testRerankProcessorFactory_CrossEncoder_HappyPath() { } public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { - Map config = Map.of( - "poafn aorr;anv", - Map.of(";oawhls", "aowirhg "), - RerankType.CROSS_ENCODER.getLabel(), + Map config = new HashMap<>( Map.of( - CrossEncoderRerankProcessor.MODEL_ID_FIELD, - "model-id", - CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, - "text_representation", - "pqiohg rpowierhg", - "pw;oith4pt3ih go" + "poafn aorr;anv", + Map.of(";oawhls", "aowirhg "), + RerankType.CROSS_ENCODER.getLabel(), + new HashMap<>( + Map.of( + CrossEncoderRerankProcessor.MODEL_ID_FIELD, + "model-id", + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, + "text_representation", + "pqiohg rpowierhg", + "pw;oith4pt3ih go" + ) + ) ) ); SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); @@ -109,7 +120,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { } public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { - Map config = Map.of(RerankType.CROSS_ENCODER.getLabel(), Map.of()); + Map config = new HashMap<>(Map.of(RerankType.CROSS_ENCODER.getLabel(), Map.of())); assertThrows( CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified", IllegalArgumentException.class, @@ -118,9 +129,8 @@ public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { } public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { - Map config = Map.of( - RerankType.CROSS_ENCODER.getLabel(), - Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id") + Map config = new HashMap<>( + Map.of(RerankType.CROSS_ENCODER.getLabel(), new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id"))) ); assertThrows( CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified", @@ -130,9 +140,11 @@ public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { } public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { - Map config = Map.of( - RerankType.CROSS_ENCODER.getLabel(), - Map.of(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, "text_representation") + Map config = new HashMap<>( + Map.of( + RerankType.CROSS_ENCODER.getLabel(), + new HashMap<>(Map.of(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, "text_representation")) + ) ); assertThrows( CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified", diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorIT.java new file mode 100644 index 000000000..3ec9f0c18 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorIT.java @@ -0,0 +1,143 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; + +import com.google.common.collect.ImmutableList; + +@Log4j2 +public class CrossEncoderRerankProcessorIT extends BaseNeuralSearchIT { + + final static String PIPELINE_NAME = "rerank-ce-pipeline"; + final static String INDEX_NAME = "rerank-test"; + final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy"; + final static String TEXT_REP_2 = "Fish like to eat plankton"; + final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}"; + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + deleteSearchPipeline(PIPELINE_NAME); + findDeployedModels().forEach(this::deleteModel); + deleteIndex(INDEX_NAME); + } + + public void testCrossEncoderRerankProcessor() throws Exception { + String modelId = uploadCrossEncoderModel(); + loadModel(modelId); + createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/CrossEncoderRerankPipelineConfiguration.json"); + setupIndex(); + runQueries(); + } + + private String uploadCrossEncoderModel() throws Exception { + String requestBody = Files.readString( + Path.of(classLoader.getResource("processor/UploadCrossEncoderModelRequestBody.json").toURI()) + ); + return uploadModel(requestBody); + } + + private void setupIndex() throws Exception { + createIndexWithConfiguration(INDEX_NAME, INDEX_CONFIG, PIPELINE_NAME); + Response response1 = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(String.format(LOCALE, "{\"text_representation\": \"%s\"}", TEXT_REP_1)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Response response2 = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(String.format(LOCALE, "{\"text_representation\": \"%s\"}", TEXT_REP_2)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(response1.getEntity()), + false + ); + assertEquals("created", map.get("result")); + map = XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(response2.getEntity()), false); + assertEquals("created", map.get("result")); + } + + private void runQueries() throws Exception { + Map response1 = search("What do fish eat?"); + @SuppressWarnings("unchecked") + List> hits = (List>) ((Map) response1.get("hits")).get("hits"); + @SuppressWarnings("unchecked") + Map hit0Source = (Map) hits.get(0).get("_source"); + assert ((String) hit0Source.get("text_representation")).equals(TEXT_REP_2); + @SuppressWarnings("unchecked") + Map hit1Source = (Map) hits.get(1).get("_source"); + assert ((String) hit1Source.get("text_representation")).equals(TEXT_REP_1); + + Map response2 = search("Who loves fish?"); + @SuppressWarnings("unchecked") + List> hits2 = (List>) ((Map) response2.get("hits")).get("hits"); + @SuppressWarnings("unchecked") + Map hit2Source = (Map) hits2.get(0).get("_source"); + assert ((String) hit2Source.get("text_representation")).equals(TEXT_REP_1); + @SuppressWarnings("unchecked") + Map hit3Source = (Map) hits2.get(1).get("_source"); + assert ((String) hit3Source.get("text_representation")).equals(TEXT_REP_2); + } + + private Map search(String queryText) throws Exception { + String jsonQueryFrame = "{\"query\":{\"match_all\":{}},\"ext\":{\"rerank\":{\"query_text\":\"%s\"}}}"; + String jsonQuery = String.format(LOCALE, jsonQueryFrame, queryText); + log.info(jsonQuery); + Request request = new Request("POST", "/" + INDEX_NAME + "/_search"); + request.addParameter("search_pipeline", PIPELINE_NAME); + request.setJsonEntity(jsonQuery); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + + return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java index 5bbb5c38c..cfb90b5ec 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java @@ -27,6 +27,7 @@ import static org.mockito.Mockito.verify; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -79,13 +80,17 @@ public class CrossEncoderRerankProcessorTests extends OpenSearchTestCase { public void setup() { MockitoAnnotations.openMocks(this); factory = new RerankProcessorFactory(mlCommonsClientAccessor); - Map config = Map.of( - RerankType.CROSS_ENCODER.getLabel(), + Map config = new HashMap<>( Map.of( - CrossEncoderRerankProcessor.MODEL_ID_FIELD, - "model-id", - CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, - "text_representation" + RerankType.CROSS_ENCODER.getLabel(), + new HashMap<>( + Map.of( + CrossEncoderRerankProcessor.MODEL_ID_FIELD, + "model-id", + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, + "text_representation" + ) + ) ) ); processor = (CrossEncoderRerankProcessor) factory.create( @@ -103,7 +108,7 @@ private void setupParams(Map params) { NeuralQueryBuilder nqb = new NeuralQueryBuilder(); nqb.fieldName("embedding").k(3).modelId("embedding_id").queryText("Question about dolphins"); ssb.query(nqb); - List exts = List.of(new RerankSearchExtBuilder(params)); + List exts = List.of(new RerankSearchExtBuilder(new HashMap<>(params))); ssb.ext(exts); doReturn(ssb).when(request).source(); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java index f6a22b675..8c24a5a8d 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java @@ -27,15 +27,11 @@ import org.junit.Before; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.BytesStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchExtBuilder; import org.opensearch.test.OpenSearchTestCase; @@ -60,19 +56,19 @@ public void testStreaming() throws IOException { assert (b1.equals(b2)); } - public void testToXContent() throws IOException { - RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); - XContentBuilder builder = XContentType.JSON.contentBuilder(); - builder.startObject(); - b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS); - builder.endObject(); - String extString = builder.toString(); - log.info(extString); - XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString); - RerankSearchExtBuilder b2 = RerankSearchExtBuilder.parse(parser); - assert (b2.getParams().equals(params)); - assert (b1.equals(b2)); - } + // public void testToXContent() throws IOException { + // RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(new HashMap<>(params)); + // XContentBuilder builder = XContentType.JSON.contentBuilder(); + // builder.startObject(); + // b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + // builder.endObject(); + // String extString = builder.toString(); + // log.info(extString); + // XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString); + // RerankSearchExtBuilder b2 = RerankSearchExtBuilder.parse(parser); + // assert (b2.getParams().equals(params)); + // assert (b1.equals(b2)); + // } public void testPullFromListOfExtBuilders() { RerankSearchExtBuilder builder = new RerankSearchExtBuilder(params); diff --git a/src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json b/src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json new file mode 100644 index 000000000..5d5751683 --- /dev/null +++ b/src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json @@ -0,0 +1,13 @@ +{ + "description": "Pipeline for reranking with a cross encoder", + "response_processors": [ + { + "rerank": { + "cross-encoder": { + "model_id": "%s", + "rerank_context_field": "text_representation" + } + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/processor/UploadCrossEncoderModelRequestBody.json b/src/test/resources/processor/UploadCrossEncoderModelRequestBody.json new file mode 100644 index 000000000..897354616 --- /dev/null +++ b/src/test/resources/processor/UploadCrossEncoderModelRequestBody.json @@ -0,0 +1,16 @@ +{ + "name": "ms-marco-TinyBERT-L-2-v2", + "version": "1.0.0", + "function_name": "TEXT_SIMILARITY", + "description": "test model", + "model_format": "TORCH_SCRIPT", + "model_group_id": "", + "model_content_hash_value": "90e39a926101d1a4e542aade0794319404689b12acfd5d7e65c03d91c668b5cf", + "model_config": { + "model_type": "bert", + "embedding_dimension": 1, + "framework_type": "huggingface_transformers", + "all_config": "nobody will read this" + }, + "url": "https://github.com/HenryL27/ml-commons/blob/cross-encoder/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE.zip?raw=true" +} \ No newline at end of file diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 4c1c17fb8..10a5cb68a 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -317,19 +317,19 @@ protected void createSearchRequestProcessor(String modelId, String pipelineName) protected void createSearchPipelineViaConfig(String modelId, String pipelineName, String configPath) throws Exception { Response pipelineCreateResponse = makeRequest( - client(), - "PUT", - "/_search/pipeline/" + pipelineName, - null, - toHttpEntity( - String.format( - LOCALE, - Files.readString(Path.of(classLoader.getResource(configPath).toURI())), - modelId - ) - ), + client(), + "PUT", + "/_search/pipeline/" + pipelineName, + null, + toHttpEntity(String.format(LOCALE, Files.readString(Path.of(classLoader.getResource(configPath).toURI())), modelId)), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); + Map node = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(pipelineCreateResponse.getEntity()), + false + ); + assertEquals("true", node.get("acknowledged").toString()); } /** From de9676110b0c11c7107468346998fd3405ac25b7 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 30 Nov 2023 17:05:53 -0800 Subject: [PATCH 05/27] use string.format() instead of concatenation Signed-off-by: HenryL27 --- .../processor/factory/RerankProcessorFactory.java | 13 ++++++++++--- .../rerank/CrossEncoderRerankProcessor.java | 11 ++++++++--- .../neuralsearch/processor/rerank/RerankType.java | 3 ++- .../processor/rerank/RescoringRerankProcessor.java | 1 - 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 65c4a3b28..5d2834999 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -17,6 +17,7 @@ */ package org.opensearch.neuralsearch.processor.factory; +import java.util.Locale; import java.util.Map; import lombok.AllArgsConstructor; @@ -54,15 +55,21 @@ public SearchResponseProcessor create( Map rerankerConfig = (Map) config.remove(type.getLabel()); String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD); if (modelId == null) { - throw new IllegalArgumentException(CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified"); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "%s must be specified", CrossEncoderRerankProcessor.MODEL_ID_FIELD) + ); } String rerankContext = rerankerConfig.get(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD); if (rerankContext == null) { - throw new IllegalArgumentException(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified"); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "%s must be specified", CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD) + ); } return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, rerankContext, clientAccessor); default: - throw new IllegalArgumentException("could not find constructor for reranker type " + type.getLabel()); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "could not find constructor for reranker type %s", type.getLabel()) + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java index 3c60e6570..33698d385 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import lombok.extern.log4j.Log4j2; @@ -80,7 +81,7 @@ public void generateScoringContext( if (params.containsKey(QUERY_TEXT_FIELD)) { if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { throw new IllegalArgumentException( - "Cannot specify both \"" + QUERY_TEXT_FIELD + "\" and \"" + QUERY_TEXT_PATH_FIELD + "\"" + String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) ); } scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); @@ -97,11 +98,15 @@ public void generateScoringContext( // Get the text at the path Object queryText = ObjectPath.eval(path, map); if (!(queryText instanceof String)) { - throw new IllegalArgumentException(QUERY_TEXT_PATH_FIELD + " must point to a string field"); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "%s must point to a string field", QUERY_TEXT_PATH_FIELD) + ); } scoringContext.put(QUERY_TEXT_FIELD, (String) queryText); } else { - throw new IllegalArgumentException("Must specify either \"" + QUERY_TEXT_FIELD + "\" or \"" + QUERY_TEXT_PATH_FIELD + "\""); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Must specify either \"%s\" or \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) + ); } listener.onResponse(scoringContext); } catch (Exception e) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index e474c4b11..cfbeb8906 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -18,6 +18,7 @@ package org.opensearch.neuralsearch.processor.rerank; import java.util.Arrays; +import java.util.Locale; import java.util.Optional; import lombok.Getter; @@ -46,7 +47,7 @@ public static RerankType from(String label) { if (typeMaybe.isPresent()) { return typeMaybe.get(); } else { - throw new IllegalArgumentException("Wrong rerank type name: " + label); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Wrong rerank type name: %s", label)); } } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index 92a0d9610..ca50f0f79 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -82,7 +82,6 @@ public abstract void rescoreSearchResponse( @Override public void rerank(SearchResponse searchResponse, Map scoringContext, ActionListener listener) { - log.info("==================RERANKING=================="); try { rescoreSearchResponse(searchResponse, scoringContext, ActionListener.wrap(scores -> { // Assign new scores From 6f85824a0c5167944d4dd49dfa31e294c00f6fde Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Fri, 1 Dec 2023 09:00:34 -0800 Subject: [PATCH 06/27] rename generateScoringContext to generateRerankingContext Signed-off-by: HenryL27 --- .../rerank/CrossEncoderRerankProcessor.java | 14 +++++++------- .../processor/rerank/RerankProcessor.java | 8 ++++---- .../processor/rerank/RescoringRerankProcessor.java | 8 ++++---- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java index 33698d385..62053f9e3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java @@ -69,7 +69,7 @@ public CrossEncoderRerankProcessor( } @Override - public void generateScoringContext( + public void generateRerankingContext( SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener @@ -77,14 +77,14 @@ public void generateScoringContext( try { List exts = searchRequest.source().ext(); Map params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams(); - Map scoringContext = new HashMap<>(); + Map rerankingContext = new HashMap<>(); if (params.containsKey(QUERY_TEXT_FIELD)) { if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) ); } - scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); + rerankingContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); } else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { String path = (String) params.get(QUERY_TEXT_PATH_FIELD); // Convert query to a map with io/xcontent shenanigans @@ -102,25 +102,25 @@ public void generateScoringContext( String.format(Locale.ROOT, "%s must point to a string field", QUERY_TEXT_PATH_FIELD) ); } - scoringContext.put(QUERY_TEXT_FIELD, (String) queryText); + rerankingContext.put(QUERY_TEXT_FIELD, (String) queryText); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Must specify either \"%s\" or \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) ); } - listener.onResponse(scoringContext); + listener.onResponse(rerankingContext); } catch (Exception e) { listener.onFailure(e); } } @Override - public void rescoreSearchResponse(SearchResponse response, Map scoringContext, ActionListener> listener) { + public void rescoreSearchResponse(SearchResponse response, Map rerankingContext, ActionListener> listener) { List contexts = new ArrayList<>(); for (SearchHit hit : response.getHits()) { contexts.add(contextFromSearchHit(hit)); } - mlCommonsClientAccessor.inferenceSimilarity(modelId, (String) scoringContext.get(QUERY_TEXT_FIELD), contexts, listener); + mlCommonsClientAccessor.inferenceSimilarity(modelId, (String) rerankingContext.get(QUERY_TEXT_FIELD), contexts, listener); } private String contextFromSearchHit(final SearchHit hit) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index d458c0ca2..29c6fab24 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -36,7 +36,7 @@ public interface RerankProcessor extends SearchResponseProcessor { * @param searchResponse the search results, in case they're relevant * @param listener be async */ - public void generateScoringContext( + public void generateRerankingContext( SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener @@ -46,15 +46,15 @@ public void generateScoringContext( * Given the scoring context generated by the processor and the search results, * rerank the search results. Do so asynchronously. * @param searchResponse the search results to rerank - * @param scoringContext the information this processor needs in order to rerank + * @param rerankingContext the information this processor needs in order to rerank * @param listener be async */ - public void rerank(SearchResponse searchResponse, Map scoringContext, ActionListener listener); + public void rerank(SearchResponse searchResponse, Map rerankingContext, ActionListener listener); @Override default void processResponseAsync(SearchRequest request, SearchResponse response, ActionListener responseListener) { try { - generateScoringContext( + generateRerankingContext( request, response, ActionListener.wrap(context -> { rerank(response, context, responseListener); }, e -> { responseListener.onFailure(e); }) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index ca50f0f79..9b85f2ba9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -71,19 +71,19 @@ public boolean isIgnoreFailure() { /** * Generate a list of new scores for all of the documents, given the scoring context * @param response search results to rescore - * @param scoringContext extra information needed to score the search results; e.g. model id + * @param rerankingContext extra information needed to score the search results; e.g. model id * @param listener be async. recieves the list of new scores */ public abstract void rescoreSearchResponse( SearchResponse response, - Map scoringContext, + Map rerankingContext, ActionListener> listener ); @Override - public void rerank(SearchResponse searchResponse, Map scoringContext, ActionListener listener) { + public void rerank(SearchResponse searchResponse, Map rerankingContext, ActionListener listener) { try { - rescoreSearchResponse(searchResponse, scoringContext, ActionListener.wrap(scores -> { + rescoreSearchResponse(searchResponse, rerankingContext, ActionListener.wrap(scores -> { // Assign new scores SearchHit[] hits = searchResponse.getHits().getHits(); if (hits.length != scores.size()) { From a30180c467ab2a12dc0333be009efe45e08b265a Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Fri, 1 Dec 2023 11:26:02 -0800 Subject: [PATCH 07/27] add name change in test too. whoops Signed-off-by: HenryL27 --- .../rerank/CrossEncoderRerankProcessorTests.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java index cfb90b5ec..a700cb6d0 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java @@ -149,7 +149,7 @@ public void testScoringContext_QueryText_ThenSucceed() { setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text")); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); - processor.generateScoringContext(request, response, listener); + processor.generateRerankingContext(request, response, listener); @SuppressWarnings("unchecked") ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); verify(listener, times(1)).onResponse(argCaptor.capture()); @@ -161,7 +161,7 @@ public void testScoringContext_QueryTextPath_ThenSucceed() { setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, "query.neural.embedding.query_text")); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); - processor.generateScoringContext(request, response, listener); + processor.generateRerankingContext(request, response, listener); @SuppressWarnings("unchecked") ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); verify(listener, times(1)).onResponse(argCaptor.capture()); @@ -180,7 +180,7 @@ public void testScoringContext_QueryTextAndPath_ThenFail() { ); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); - processor.generateScoringContext(request, response, listener); + processor.generateRerankingContext(request, response, listener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue() instanceof IllegalArgumentException); @@ -199,7 +199,7 @@ public void testScoringContext_NoQueryInfo_ThenFail() { setupParams(Map.of()); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); - processor.generateScoringContext(request, response, listener); + processor.generateRerankingContext(request, response, listener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue() instanceof IllegalArgumentException); @@ -218,7 +218,7 @@ public void testScoringContext_QueryTextPath_BadPointer_ThenFail() { setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, "query.neural.embedding")); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); - processor.generateScoringContext(request, response, listener); + processor.generateRerankingContext(request, response, listener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue() instanceof IllegalArgumentException); From b8820ec55d72227d451eed0121256829d55e7bb1 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 4 Dec 2023 11:32:03 -0800 Subject: [PATCH 08/27] start refactoring with contextSaourceFetchers Signed-off-by: HenryL27 --- .../processor/rerank/RerankProcessor.java | 39 ++++++++++++++++--- .../rerank/RescoringRerankProcessor.java | 6 +-- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 29c6fab24..cc8f67fe6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -17,21 +17,28 @@ */ package org.opensearch.neuralsearch.processor.rerank; +import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.search.pipeline.SearchResponseProcessor; -public interface RerankProcessor extends SearchResponseProcessor { +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public abstract class RerankProcessor implements SearchResponseProcessor { public static final String TYPE = "rerank"; + protected List contextSourceFetchers; + /** * Generate the information that this processor needs in order to rerank. - * That could be as simple as grabbing a field from the search request or - * as complicated as a lookup to some external service + * Concurrently hit all contextSourceFetchers * @param searchRequest the search query * @param searchResponse the search results, in case they're relevant * @param listener be async @@ -40,7 +47,22 @@ public void generateRerankingContext( SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener - ); + ) { + Map overallContext = new ConcurrentHashMap<>(); + AtomicInteger successfulContexts = new AtomicInteger(contextSourceFetchers.size()); + for(ContextSourceFetcher csf : contextSourceFetchers) { + csf.fetchContext(searchRequest, searchResponse, ActionListener.wrap( + context -> { + overallContext.putAll(context); + if (successfulContexts.decrementAndGet() == 0) { + listener.onResponse(overallContext); + } + }, e -> { + listener.onFailure(e); + } + )); + } + } /** * Given the scoring context generated by the processor and the search results, @@ -49,10 +71,15 @@ public void generateRerankingContext( * @param rerankingContext the information this processor needs in order to rerank * @param listener be async */ - public void rerank(SearchResponse searchResponse, Map rerankingContext, ActionListener listener); + public abstract void rerank(SearchResponse searchResponse, Map rerankingContext, ActionListener listener); + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + throw new UnsupportedOperationException("Use asyncProcessResponse unless you can guarantee to not deadlock yourself"); + } @Override - default void processResponseAsync(SearchRequest request, SearchResponse response, ActionListener responseListener) { + public void processResponseAsync(SearchRequest request, SearchResponse response, ActionListener responseListener) { try { generateRerankingContext( request, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index 9b85f2ba9..f465254da 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -36,17 +36,13 @@ @Log4j2 @AllArgsConstructor -public abstract class RescoringRerankProcessor implements RerankProcessor { +public abstract class RescoringRerankProcessor extends RerankProcessor { private final RerankType type; private final String description; private final String tag; private final boolean ignoreFailure; - @Override - public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { - throw new UnsupportedOperationException("Use asyncProcessResponse unless you can guarantee to not deadlock yourself"); - } @Override public String getType() { From 5e1c00b70e04876fbcafc988c61f09192c5e8447 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 5 Dec 2023 11:22:17 -0800 Subject: [PATCH 09/27] refactor to use contextSourceFetchers to get context Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 62 ++++++++-- .../rerank/ContextSourceFetcher.java | 39 +++++++ .../rerank/CrossEncoderRerankProcessor.java | 109 ++++-------------- .../rerank/DocumentContextSourceFetcher.java | 75 ++++++++++++ .../rerank/QueryContextSourceFetcher.java | 90 +++++++++++++++ .../processor/rerank/RerankProcessor.java | 41 ++++--- .../processor/rerank/RerankType.java | 2 +- .../rerank/RescoringRerankProcessor.java | 35 ++---- .../factory/RerankProcessorFactoryTests.java | 97 ++++++++++++---- .../CrossEncoderRerankProcessorTests.java | 103 ++++++++++++----- ...ossEncoderRerankPipelineConfiguration.json | 8 +- 11 files changed, 469 insertions(+), 192 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 5d2834999..448a8e600 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -17,14 +17,20 @@ */ package org.opensearch.neuralsearch.processor.factory; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.stream.Collectors; import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.rerank.ContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.QueryContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.RerankType; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; @@ -36,6 +42,7 @@ public class RerankProcessorFactory implements Processor.Factory { public static final String RERANK_PROCESSOR_TYPE = "rerank"; + public static final String CONTEXT_CONFIG_FIELD = "context"; private final MLCommonsClientAccessor clientAccessor; @@ -49,8 +56,10 @@ public SearchResponseProcessor create( final Processor.PipelineContext pipelineContext ) { RerankType type = findRerankType(config); + boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type); + List contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher); switch (type) { - case CROSS_ENCODER: + case TEXT_SIMILARITY: @SuppressWarnings("unchecked") Map rerankerConfig = (Map) config.remove(type.getLabel()); String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD); @@ -59,13 +68,7 @@ public SearchResponseProcessor create( String.format(Locale.ROOT, "%s must be specified", CrossEncoderRerankProcessor.MODEL_ID_FIELD) ); } - String rerankContext = rerankerConfig.get(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD); - if (rerankContext == null) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "%s must be specified", CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD) - ); - } - return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, rerankContext, clientAccessor); + return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); default: throw new IllegalArgumentException( String.format(Locale.ROOT, "could not find constructor for reranker type %s", type.getLabel()) @@ -86,4 +89,47 @@ RerankType findRerankType(final Map config) throws IllegalArgume } throw new IllegalArgumentException("no rerank type found"); } + + protected static class ContextFetcherFactory { + + public static boolean shouldIncludeQueryContextFetcher(RerankType type) { + switch (type) { + case TEXT_SIMILARITY: + return true; + default: + return false; + } + } + + public static List createFetchers(Map config, boolean includeQueryContextFetcher) { + List fetchers = new ArrayList<>(); + @SuppressWarnings("unchecked") + Map contextConfig = (Map) config.remove(CONTEXT_CONFIG_FIELD); + if (contextConfig == null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field must be provided", CONTEXT_CONFIG_FIELD)); + } + for (String key : contextConfig.keySet()) { + switch (key) { + case DocumentContextSourceFetcher.NAME: + Object cfg = contextConfig.get(key); + if (!(cfg instanceof List)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of strings", key)); + } + List fields = (List) contextConfig.get(key); + if (fields.size() == 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", key)); + } + List strfields = fields.stream().map(field -> (String) field).collect(Collectors.toList()); + fetchers.add(new DocumentContextSourceFetcher(strfields)); + break; + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key)); + } + } + if (includeQueryContextFetcher) { + fetchers.add(new QueryContextSourceFetcher()); + } + return fetchers; + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java new file mode 100644 index 000000000..5732d8b2f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java @@ -0,0 +1,39 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.Map; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; + +public interface ContextSourceFetcher { + + /** + * Fetch the information needed in order to rerank. + * That could be as simple as grabbing a field from the search request or + * as complicated as a lookup to some external service + * @param searchRequest the search query + * @param searchResponse the search results, in case they're relevant + * @param listener be async + */ + public void fetchContext(SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener); + + public String getName(); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java index 62053f9e3..8e38e2aaa 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java @@ -17,40 +17,21 @@ */ package org.opensearch.neuralsearch.processor.rerank; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.stream.Collectors; -import lombok.extern.log4j.Log4j2; - -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ObjectPath; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; -import org.opensearch.search.SearchExtBuilder; -import org.opensearch.search.SearchHit; +import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; -@Log4j2 public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { public static final String MODEL_ID_FIELD = "model_id"; - public static final String QUERY_TEXT_FIELD = "query_text"; - public static final String QUERY_TEXT_PATH_FIELD = "query_text_path"; - public static final String RERANK_CONTEXT_FIELD = "rerank_context_field"; protected final String modelId; - protected final String rerankContext; protected final MLCommonsClientAccessor mlCommonsClientAccessor; @@ -59,78 +40,38 @@ public CrossEncoderRerankProcessor( String tag, boolean ignoreFailure, String modelId, - String rerankContext, + List contextSourceFetchers, MLCommonsClientAccessor mlCommonsClientAccessor ) { - super(RerankType.CROSS_ENCODER, description, tag, ignoreFailure); + super(RerankType.TEXT_SIMILARITY, description, tag, ignoreFailure, contextSourceFetchers); this.modelId = modelId; - this.rerankContext = rerankContext; this.mlCommonsClientAccessor = mlCommonsClientAccessor; } - @Override - public void generateRerankingContext( - SearchRequest searchRequest, - SearchResponse searchResponse, - ActionListener> listener - ) { - try { - List exts = searchRequest.source().ext(); - Map params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams(); - Map rerankingContext = new HashMap<>(); - if (params.containsKey(QUERY_TEXT_FIELD)) { - if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) - ); - } - rerankingContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); - } else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { - String path = (String) params.get(QUERY_TEXT_PATH_FIELD); - // Convert query to a map with io/xcontent shenanigans - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); - searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS); - builder.close(); - ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); - XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais); - Map map = parser.map(); - // Get the text at the path - Object queryText = ObjectPath.eval(path, map); - if (!(queryText instanceof String)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "%s must point to a string field", QUERY_TEXT_PATH_FIELD) - ); - } - rerankingContext.put(QUERY_TEXT_FIELD, (String) queryText); - } else { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "Must specify either \"%s\" or \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) - ); - } - listener.onResponse(rerankingContext); - } catch (Exception e) { - listener.onFailure(e); - } - } - @Override public void rescoreSearchResponse(SearchResponse response, Map rerankingContext, ActionListener> listener) { - List contexts = new ArrayList<>(); - for (SearchHit hit : response.getHits()) { - contexts.add(contextFromSearchHit(hit)); - } - mlCommonsClientAccessor.inferenceSimilarity(modelId, (String) rerankingContext.get(QUERY_TEXT_FIELD), contexts, listener); - } - - private String contextFromSearchHit(final SearchHit hit) { - if (hit.getFields().containsKey(this.rerankContext)) { - return (String) hit.field(this.rerankContext).getValue(); - } else if (hit.hasSource() && hit.getSourceAsMap().containsKey(this.rerankContext)) { - return (String) hit.getSourceAsMap().get(this.rerankContext); - } else { - return null; + Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD); + if (!(ctxObj instanceof List)) { + listener.onFailure( + new IllegalStateException( + String.format( + Locale.ROOT, + "No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?", + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + DocumentContextSourceFetcher.NAME + ) + ) + ); + return; } + List ctxList = (List) ctxObj; + List contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList()); + mlCommonsClientAccessor.inferenceSimilarity( + modelId, + (String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD), + contexts, + listener + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java new file mode 100644 index 000000000..3e2ae636b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java @@ -0,0 +1,75 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHit; + +@Log4j2 +@AllArgsConstructor +public class DocumentContextSourceFetcher implements ContextSourceFetcher { + + public static final String NAME = "document_fields"; + public static final String DOCUMENT_CONTEXT_LIST_FIELD = "document_context_list"; + + List contextFields; + + /** + * Fetch the information needed in order to rerank. + * That could be as simple as grabbing a field from the search request or + * as complicated as a lookup to some external service + * @param searchRequest the search query + * @param searchResponse the search results, in case they're relevant + * @param listener be async + */ + public void fetchContext(SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener) { + List contexts = new ArrayList<>(); + for (SearchHit hit : searchResponse.getHits()) { + StringBuilder ctx = new StringBuilder(); + for (String field : this.contextFields) { + ctx.append(contextFromSearchHit(hit, field)); + } + contexts.add(ctx.toString()); + } + listener.onResponse(new HashMap<>(Map.of(DOCUMENT_CONTEXT_LIST_FIELD, contexts))); + } + + private String contextFromSearchHit(final SearchHit hit, final String field) { + if (hit.getFields().containsKey(field)) { + return (String) hit.field(field).getValue(); + } else if (hit.hasSource() && hit.getSourceAsMap().containsKey(field)) { + return (String) hit.getSourceAsMap().get(field); + } else { + return ""; + } + } + + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java new file mode 100644 index 000000000..a4a7a26ba --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java @@ -0,0 +1,90 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ObjectPath; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; +import org.opensearch.search.SearchExtBuilder; + +public class QueryContextSourceFetcher implements ContextSourceFetcher { + + public static final String NAME = "query_context"; + public static final String QUERY_TEXT_FIELD = "query_text"; + public static final String QUERY_TEXT_PATH_FIELD = "query_text_path"; + + @Override + public void fetchContext(SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener) { + try { + List exts = searchRequest.source().ext(); + Map params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams(); + Map scoringContext = new HashMap<>(); + if (params.containsKey(QUERY_TEXT_FIELD)) { + if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) + ); + } + scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); + } else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { + String path = (String) params.get(QUERY_TEXT_PATH_FIELD); + // Convert query to a map with io/xcontent shenanigans + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); + searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.close(); + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais); + Map map = parser.map(); + // Get the text at the path + Object queryText = ObjectPath.eval(path, map); + if (!(queryText instanceof String)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "%s must point to a string field", QUERY_TEXT_PATH_FIELD) + ); + } + scoringContext.put(QUERY_TEXT_FIELD, (String) queryText); + } else { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Must specify either \"%s\" or \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) + ); + } + listener.onResponse(scoringContext); + } catch (Exception e) { + listener.onFailure(e); + } + } + + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index cc8f67fe6..02f471451 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -22,18 +22,28 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.search.pipeline.SearchResponseProcessor; -import lombok.AllArgsConstructor; - +@Log4j2 @AllArgsConstructor public abstract class RerankProcessor implements SearchResponseProcessor { public static final String TYPE = "rerank"; + protected final RerankType subType; + @Getter + private final String description; + @Getter + private final String tag; + @Getter + private final boolean ignoreFailure; protected List contextSourceFetchers; /** @@ -50,20 +60,21 @@ public void generateRerankingContext( ) { Map overallContext = new ConcurrentHashMap<>(); AtomicInteger successfulContexts = new AtomicInteger(contextSourceFetchers.size()); - for(ContextSourceFetcher csf : contextSourceFetchers) { - csf.fetchContext(searchRequest, searchResponse, ActionListener.wrap( - context -> { - overallContext.putAll(context); - if (successfulContexts.decrementAndGet() == 0) { - listener.onResponse(overallContext); - } - }, e -> { - listener.onFailure(e); + for (ContextSourceFetcher csf : contextSourceFetchers) { + csf.fetchContext(searchRequest, searchResponse, ActionListener.wrap(context -> { + overallContext.putAll(context); + if (successfulContexts.decrementAndGet() == 0) { + listener.onResponse(overallContext); } - )); + }, e -> { listener.onFailure(e); })); } } + @Override + public String getType() { + return TYPE; + } + /** * Given the scoring context generated by the processor and the search results, * rerank the search results. Do so asynchronously. @@ -71,7 +82,11 @@ public void generateRerankingContext( * @param rerankingContext the information this processor needs in order to rerank * @param listener be async */ - public abstract void rerank(SearchResponse searchResponse, Map rerankingContext, ActionListener listener); + public abstract void rerank( + SearchResponse searchResponse, + Map rerankingContext, + ActionListener listener + ); @Override public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index cfbeb8906..45221a2c5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -28,7 +28,7 @@ */ public enum RerankType { - CROSS_ENCODER("cross-encoder"); + TEXT_SIMILARITY("text_similarity"); @Getter private final String label; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index f465254da..ca3b84749 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -23,10 +23,8 @@ import java.util.List; import java.util.Map; -import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.core.action.ActionListener; @@ -35,33 +33,16 @@ import org.opensearch.search.profile.SearchProfileShardResults; @Log4j2 -@AllArgsConstructor public abstract class RescoringRerankProcessor extends RerankProcessor { - private final RerankType type; - private final String description; - private final String tag; - private final boolean ignoreFailure; - - - @Override - public String getType() { - return TYPE; - } - - @Override - public String getTag() { - return tag; - } - - @Override - public String getDescription() { - return description; - } - - @Override - public boolean isIgnoreFailure() { - return ignoreFailure; + public RescoringRerankProcessor( + RerankType type, + String description, + String tag, + boolean ignoreFailure, + List contextSourceFetchers + ) { + super(type, description, tag, ignoreFailure, contextSourceFetchers); } /** diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index 080137d35..54bd37b69 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -19,7 +19,10 @@ import static org.mockito.Mockito.mock; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; +import java.util.Locale; import java.util.Map; import lombok.extern.log4j.Log4j2; @@ -28,6 +31,7 @@ import org.mockito.Mock; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankType; import org.opensearch.search.pipeline.Processor.PipelineContext; @@ -78,15 +82,10 @@ public void testRerankProcessorFactory_NonExistentType_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_HappyPath() { Map config = new HashMap<>( Map.of( - RerankType.CROSS_ENCODER.getLabel(), - new HashMap<>( - Map.of( - CrossEncoderRerankProcessor.MODEL_ID_FIELD, - "model-id", - CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, - "text_representation" - ) - ) + RerankType.TEXT_SIMILARITY.getLabel(), + new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) ) ); SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); @@ -100,27 +99,43 @@ public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { Map.of( "poafn aorr;anv", Map.of(";oawhls", "aowirhg "), - RerankType.CROSS_ENCODER.getLabel(), + RerankType.TEXT_SIMILARITY.getLabel(), + new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id", "pqiohg rpowierhg", "pw;oith4pt3ih go")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) + ) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + assert (processor instanceof RerankProcessor); + assert (processor instanceof CrossEncoderRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() { + Map config = new HashMap<>( + Map.of( + RerankType.TEXT_SIMILARITY.getLabel(), + new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>( Map.of( - CrossEncoderRerankProcessor.MODEL_ID_FIELD, - "model-id", - CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, - "text_representation", + DocumentContextSourceFetcher.NAME, + new ArrayList<>(List.of("text_representation")), "pqiohg rpowierhg", "pw;oith4pt3ih go" ) ) ) ); - SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); - assert (processor instanceof RerankProcessor); - assert (processor instanceof CrossEncoderRerankProcessor); - assert (processor.getType().equals(RerankProcessor.TYPE)); + assertThrows( + String.format(Locale.ROOT, "unrecognized context field: %s", "pqiohg rpowierhg"), + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); } public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { - Map config = new HashMap<>(Map.of(RerankType.CROSS_ENCODER.getLabel(), Map.of())); + Map config = new HashMap<>(Map.of(RerankType.TEXT_SIMILARITY.getLabel(), Map.of())); assertThrows( CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified", IllegalArgumentException.class, @@ -130,10 +145,10 @@ public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { Map config = new HashMap<>( - Map.of(RerankType.CROSS_ENCODER.getLabel(), new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id"))) + Map.of(RerankType.TEXT_SIMILARITY.getLabel(), new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id"))) ); assertThrows( - CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified", + String.format(Locale.ROOT, "%s field must be provided", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), IllegalArgumentException.class, () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); @@ -142,12 +157,46 @@ public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { Map config = new HashMap<>( Map.of( - RerankType.CROSS_ENCODER.getLabel(), - new HashMap<>(Map.of(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, "text_representation")) + RerankType.TEXT_SIMILARITY.getLabel(), + new HashMap<>(), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) + ) + ); + assertThrows( + CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_CrossEncoder_BadContextDocField_ThenFail() { + Map config = new HashMap<>( + Map.of( + RerankType.TEXT_SIMILARITY.getLabel(), + new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, "text_representation")) + ) + ); + assertThrows( + String.format(Locale.ROOT, "%s must be a list of strings", DocumentContextSourceFetcher.NAME), + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_CrossEncoder_EmptyContextDocField_ThenFail() { + Map config = new HashMap<>( + Map.of( + RerankType.TEXT_SIMILARITY.getLabel(), + new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>())) ) ); assertThrows( - CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified", + String.format(Locale.ROOT, "%s must be nonempty", DocumentContextSourceFetcher.NAME), IllegalArgumentException.class, () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java index a700cb6d0..a6585b719 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java @@ -27,8 +27,10 @@ import static org.mockito.Mockito.verify; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import lombok.extern.log4j.Log4j2; @@ -82,15 +84,10 @@ public void setup() { factory = new RerankProcessorFactory(mlCommonsClientAccessor); Map config = new HashMap<>( Map.of( - RerankType.CROSS_ENCODER.getLabel(), - new HashMap<>( - Map.of( - CrossEncoderRerankProcessor.MODEL_ID_FIELD, - "model-id", - CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, - "text_representation" - ) - ) + RerankType.TEXT_SIMILARITY.getLabel(), + new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) ) ); processor = (CrossEncoderRerankProcessor) factory.create( @@ -145,39 +142,42 @@ private void setupSearchResults() throws IOException { response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new Clusters(1, 1, 0), null); } - public void testScoringContext_QueryText_ThenSucceed() { - setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text")); + public void testScoringContext_QueryText_ThenSucceed() throws IOException { + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_FIELD, "query text")); + setupSearchResults(); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); processor.generateRerankingContext(request, response, listener); @SuppressWarnings("unchecked") ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); verify(listener, times(1)).onResponse(argCaptor.capture()); - assert (argCaptor.getValue().containsKey(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD)); - assert (argCaptor.getValue().get(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD).equals("query text")); + assert (argCaptor.getValue().containsKey(QueryContextSourceFetcher.QUERY_TEXT_FIELD)); + assert (argCaptor.getValue().get(QueryContextSourceFetcher.QUERY_TEXT_FIELD).equals("query text")); } - public void testScoringContext_QueryTextPath_ThenSucceed() { - setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, "query.neural.embedding.query_text")); + public void testScoringContext_QueryTextPath_ThenSucceed() throws IOException { + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "query.neural.embedding.query_text")); + setupSearchResults(); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); processor.generateRerankingContext(request, response, listener); @SuppressWarnings("unchecked") ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); verify(listener, times(1)).onResponse(argCaptor.capture()); - assert (argCaptor.getValue().containsKey(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD)); - assert (argCaptor.getValue().get(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD).equals("Question about dolphins")); + assert (argCaptor.getValue().containsKey(QueryContextSourceFetcher.QUERY_TEXT_FIELD)); + assert (argCaptor.getValue().get(QueryContextSourceFetcher.QUERY_TEXT_FIELD).equals("Question about dolphins")); } - public void testScoringContext_QueryTextAndPath_ThenFail() { + public void testScoringContext_QueryTextAndPath_ThenFail() throws IOException { setupParams( Map.of( - CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, + QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "query.neural.embedding.query_text", - CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, + QueryContextSourceFetcher.QUERY_TEXT_FIELD, "query text" ) ); + setupSearchResults(); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); processor.generateRerankingContext(request, response, listener); @@ -188,15 +188,16 @@ public void testScoringContext_QueryTextAndPath_ThenFail() { .getMessage() .equals( "Cannot specify both \"" - + CrossEncoderRerankProcessor.QUERY_TEXT_FIELD + + QueryContextSourceFetcher.QUERY_TEXT_FIELD + "\" and \"" - + CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD + + QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + "\"" )); } - public void testScoringContext_NoQueryInfo_ThenFail() { + public void testScoringContext_NoQueryInfo_ThenFail() throws IOException { setupParams(Map.of()); + setupSearchResults(); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); processor.generateRerankingContext(request, response, listener); @@ -207,15 +208,16 @@ public void testScoringContext_NoQueryInfo_ThenFail() { .getMessage() .equals( "Must specify either \"" - + CrossEncoderRerankProcessor.QUERY_TEXT_FIELD + + QueryContextSourceFetcher.QUERY_TEXT_FIELD + "\" or \"" - + CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD + + QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + "\"" )); } - public void testScoringContext_QueryTextPath_BadPointer_ThenFail() { - setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, "query.neural.embedding")); + public void testScoringContext_QueryTextPath_BadPointer_ThenFail() throws IOException { + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "query.neural.embedding")); + setupSearchResults(); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); processor.generateRerankingContext(request, response, listener); @@ -224,7 +226,7 @@ public void testScoringContext_QueryTextPath_BadPointer_ThenFail() { assert (argCaptor.getValue() instanceof IllegalArgumentException); assert (argCaptor.getValue() .getMessage() - .equals(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD + " must point to a string field")); + .equals(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD + " must point to a string field")); } public void testRescoreSearchResponse_HappyPath() throws IOException { @@ -232,7 +234,12 @@ public void testRescoreSearchResponse_HappyPath() throws IOException { setupSearchResults(); @SuppressWarnings("unchecked") ActionListener> listener = mock(ActionListener.class); - Map scoringContext = Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text"); + Map scoringContext = Map.of( + QueryContextSourceFetcher.QUERY_TEXT_FIELD, + "query text", + DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD, + new ArrayList<>(List.of("dummy", "dummy", "dummy")) + ); processor.rescoreSearchResponse(response, scoringContext, listener); @SuppressWarnings("unchecked") ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); @@ -243,12 +250,39 @@ public void testRescoreSearchResponse_HappyPath() throws IOException { assert (argCaptor.getValue().get(2) == 3f); } + public void testRescoreSearchResponse_NoContextList_ThenFail() throws IOException { + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + Map scoringContext = Map.of(QueryContextSourceFetcher.QUERY_TEXT_FIELD, "query text"); + processor.rescoreSearchResponse(response, scoringContext, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalStateException); + assert (argCaptor.getValue() + .getMessage() + .equals( + String.format( + Locale.ROOT, + "No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?", + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + DocumentContextSourceFetcher.NAME + ) + )); + } + public void testRerank_HappyPath() throws IOException { setupSimilarityRescoring(); setupSearchResults(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - Map scoringContext = Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text"); + Map scoringContext = Map.of( + QueryContextSourceFetcher.QUERY_TEXT_FIELD, + "query text", + DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD, + new ArrayList<>(List.of("dummy", "dummy", "dummy")) + ); processor.rerank(response, scoringContext, listener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); verify(listener, times(1)).onResponse(argCaptor.capture()); @@ -271,7 +305,12 @@ public void testRerank_ScoresAndHitsHaveDiffLengths() throws IOException { setupSearchResults(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - Map scoringContext = Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text"); + Map scoringContext = Map.of( + QueryContextSourceFetcher.QUERY_TEXT_FIELD, + "query text", + DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD, + new ArrayList<>(List.of("dummy", "dummy", "dummy")) + ); processor.rerank(response, scoringContext, listener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(argCaptor.capture()); @@ -290,7 +329,7 @@ public void testBasics() throws IOException { } public void testProcessResponseAsync() throws IOException { - setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text")); + setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_FIELD, "query text")); setupSimilarityRescoring(); setupSearchResults(); @SuppressWarnings("unchecked") diff --git a/src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json b/src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json index 5d5751683..e9d0c6e2f 100644 --- a/src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json +++ b/src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json @@ -3,9 +3,11 @@ "response_processors": [ { "rerank": { - "cross-encoder": { - "model_id": "%s", - "rerank_context_field": "text_representation" + "text_similarity": { + "model_id": "%s" + }, + "context": { + "document_fields": ["text_representation"] } } } From 2976807aa5847604227ea529e59f108581405339 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 5 Dec 2023 11:46:11 -0800 Subject: [PATCH 10/27] rename CrossEncoder to TextSimilarity Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 8 +++---- ...ava => TextSimilarityRerankProcessor.java} | 4 ++-- .../factory/RerankProcessorFactoryTests.java | 24 +++++++++---------- ...a => TextSimilarityRerankProcessorIT.java} | 12 +++++----- ...> TextSimilarityRerankProcessorTests.java} | 8 +++---- ...imilarityRerankPipelineConfiguration.json} | 0 ...UploadTextSimilarityModelRequestBody.json} | 0 7 files changed, 28 insertions(+), 28 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/rerank/{CrossEncoderRerankProcessor.java => TextSimilarityRerankProcessor.java} (95%) rename src/test/java/org/opensearch/neuralsearch/processor/rerank/{CrossEncoderRerankProcessorIT.java => TextSimilarityRerankProcessorIT.java} (93%) rename src/test/java/org/opensearch/neuralsearch/processor/rerank/{CrossEncoderRerankProcessorTests.java => TextSimilarityRerankProcessorTests.java} (98%) rename src/test/resources/processor/{CrossEncoderRerankPipelineConfiguration.json => TextSimilarityRerankPipelineConfiguration.json} (100%) rename src/test/resources/processor/{UploadCrossEncoderModelRequestBody.json => UploadTextSimilarityModelRequestBody.json} (100%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 448a8e600..99bb24331 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -28,10 +28,10 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.ContextSourceFetcher; -import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.QueryContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.RerankType; +import org.opensearch.neuralsearch.processor.rerank.TextSimilarityRerankProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; @@ -62,13 +62,13 @@ public SearchResponseProcessor create( case TEXT_SIMILARITY: @SuppressWarnings("unchecked") Map rerankerConfig = (Map) config.remove(type.getLabel()); - String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD); + String modelId = rerankerConfig.get(TextSimilarityRerankProcessor.MODEL_ID_FIELD); if (modelId == null) { throw new IllegalArgumentException( - String.format(Locale.ROOT, "%s must be specified", CrossEncoderRerankProcessor.MODEL_ID_FIELD) + String.format(Locale.ROOT, "%s must be specified", TextSimilarityRerankProcessor.MODEL_ID_FIELD) ); } - return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); + return new TextSimilarityRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); default: throw new IllegalArgumentException( String.format(Locale.ROOT, "could not find constructor for reranker type %s", type.getLabel()) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java similarity index 95% rename from src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java rename to src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java index 8e38e2aaa..f95a794ba 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java @@ -27,7 +27,7 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; -public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { +public class TextSimilarityRerankProcessor extends RescoringRerankProcessor { public static final String MODEL_ID_FIELD = "model_id"; @@ -35,7 +35,7 @@ public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { protected final MLCommonsClientAccessor mlCommonsClientAccessor; - public CrossEncoderRerankProcessor( + public TextSimilarityRerankProcessor( String description, String tag, boolean ignoreFailure, diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index 54bd37b69..13c27a6a4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -30,10 +30,10 @@ import org.junit.Before; import org.mockito.Mock; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankType; +import org.opensearch.neuralsearch.processor.rerank.TextSimilarityRerankProcessor; import org.opensearch.search.pipeline.Processor.PipelineContext; import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.test.OpenSearchTestCase; @@ -70,7 +70,7 @@ public void testRerankProcessorFactory_EmptyConfig_ThenFail() { public void testRerankProcessorFactory_NonExistentType_ThenFail() { Map config = new HashMap<>( - Map.of("jpeo rvgh we iorgn", Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")) + Map.of("jpeo rvgh we iorgn", Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")) ); assertThrows( "no rerank type found", @@ -83,14 +83,14 @@ public void testRerankProcessorFactory_CrossEncoder_HappyPath() { Map config = new HashMap<>( Map.of( RerankType.TEXT_SIMILARITY.getLabel(), - new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) ) ); SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); assert (processor instanceof RerankProcessor); - assert (processor instanceof CrossEncoderRerankProcessor); + assert (processor instanceof TextSimilarityRerankProcessor); assert (processor.getType().equals(RerankProcessor.TYPE)); } @@ -100,14 +100,14 @@ public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { "poafn aorr;anv", Map.of(";oawhls", "aowirhg "), RerankType.TEXT_SIMILARITY.getLabel(), - new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id", "pqiohg rpowierhg", "pw;oith4pt3ih go")), + new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id", "pqiohg rpowierhg", "pw;oith4pt3ih go")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) ) ); SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); assert (processor instanceof RerankProcessor); - assert (processor instanceof CrossEncoderRerankProcessor); + assert (processor instanceof TextSimilarityRerankProcessor); assert (processor.getType().equals(RerankProcessor.TYPE)); } @@ -115,7 +115,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() { Map config = new HashMap<>( Map.of( RerankType.TEXT_SIMILARITY.getLabel(), - new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>( Map.of( @@ -137,7 +137,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { Map config = new HashMap<>(Map.of(RerankType.TEXT_SIMILARITY.getLabel(), Map.of())); assertThrows( - CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified", + TextSimilarityRerankProcessor.MODEL_ID_FIELD + " must be specified", IllegalArgumentException.class, () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); @@ -145,7 +145,7 @@ public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { Map config = new HashMap<>( - Map.of(RerankType.TEXT_SIMILARITY.getLabel(), new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id"))) + Map.of(RerankType.TEXT_SIMILARITY.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id"))) ); assertThrows( String.format(Locale.ROOT, "%s field must be provided", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), @@ -164,7 +164,7 @@ public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { ) ); assertThrows( - CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified", + TextSimilarityRerankProcessor.MODEL_ID_FIELD + " must be specified", IllegalArgumentException.class, () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); @@ -174,7 +174,7 @@ public void testRerankProcessorFactory_CrossEncoder_BadContextDocField_ThenFail( Map config = new HashMap<>( Map.of( RerankType.TEXT_SIMILARITY.getLabel(), - new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, "text_representation")) ) @@ -190,7 +190,7 @@ public void testRerankProcessorFactory_CrossEncoder_EmptyContextDocField_ThenFai Map config = new HashMap<>( Map.of( RerankType.TEXT_SIMILARITY.getLabel(), - new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>())) ) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java similarity index 93% rename from src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorIT.java rename to src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java index 3ec9f0c18..9cf62f929 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java @@ -39,9 +39,9 @@ import com.google.common.collect.ImmutableList; @Log4j2 -public class CrossEncoderRerankProcessorIT extends BaseNeuralSearchIT { +public class TextSimilarityRerankProcessorIT extends BaseNeuralSearchIT { - final static String PIPELINE_NAME = "rerank-ce-pipeline"; + final static String PIPELINE_NAME = "rerank-ts-pipeline"; final static String INDEX_NAME = "rerank-test"; final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy"; final static String TEXT_REP_2 = "Fish like to eat plankton"; @@ -61,16 +61,16 @@ public void tearDown() { } public void testCrossEncoderRerankProcessor() throws Exception { - String modelId = uploadCrossEncoderModel(); + String modelId = uploadTextSimilarityModel(); loadModel(modelId); - createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/CrossEncoderRerankPipelineConfiguration.json"); + createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/TextSimilarityRerankPipelineConfiguration.json"); setupIndex(); runQueries(); } - private String uploadCrossEncoderModel() throws Exception { + private String uploadTextSimilarityModel() throws Exception { String requestBody = Files.readString( - Path.of(classLoader.getResource("processor/UploadCrossEncoderModelRequestBody.json").toURI()) + Path.of(classLoader.getResource("processor/UploadTextSimilarityModelRequestBody.json").toURI()) ); return uploadModel(requestBody); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java similarity index 98% rename from src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java index a6585b719..a1ef46b91 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java @@ -61,7 +61,7 @@ import org.opensearch.test.OpenSearchTestCase; @Log4j2 -public class CrossEncoderRerankProcessorTests extends OpenSearchTestCase { +public class TextSimilarityRerankProcessorTests extends OpenSearchTestCase { @Mock SearchRequest request; @@ -76,7 +76,7 @@ public class CrossEncoderRerankProcessorTests extends OpenSearchTestCase { RerankProcessorFactory factory; - CrossEncoderRerankProcessor processor; + TextSimilarityRerankProcessor processor; @Before public void setup() { @@ -85,12 +85,12 @@ public void setup() { Map config = new HashMap<>( Map.of( RerankType.TEXT_SIMILARITY.getLabel(), - new HashMap<>(Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) ) ); - processor = (CrossEncoderRerankProcessor) factory.create( + processor = (TextSimilarityRerankProcessor) factory.create( Map.of(), "rerank processor", "processor for reranking with a cross encoder", diff --git a/src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json b/src/test/resources/processor/TextSimilarityRerankPipelineConfiguration.json similarity index 100% rename from src/test/resources/processor/CrossEncoderRerankPipelineConfiguration.json rename to src/test/resources/processor/TextSimilarityRerankPipelineConfiguration.json diff --git a/src/test/resources/processor/UploadCrossEncoderModelRequestBody.json b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json similarity index 100% rename from src/test/resources/processor/UploadCrossEncoderModelRequestBody.json rename to src/test/resources/processor/UploadTextSimilarityModelRequestBody.json From 5332fee05841facade638c2cef6fcb96c3fc9892 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 5 Dec 2023 12:55:39 -0800 Subject: [PATCH 11/27] add query_context layer to search ext Signed-off-by: HenryL27 --- .../rerank/QueryContextSourceFetcher.java | 19 ++++++++++++++----- .../TextSimilarityRerankProcessorIT.java | 2 +- .../TextSimilarityRerankProcessorTests.java | 4 +++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java index a4a7a26ba..45ba1efa1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java @@ -48,15 +48,24 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo List exts = searchRequest.source().ext(); Map params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams(); Map scoringContext = new HashMap<>(); - if (params.containsKey(QUERY_TEXT_FIELD)) { - if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { + if (!params.containsKey(NAME)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "must specify %s", NAME)); + } + Object ctxObj = params.remove(NAME); + if (!(ctxObj instanceof Map)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a map", NAME)); + } + @SuppressWarnings("unchecked") + Map ctxMap = (Map) ctxObj; + if (ctxMap.containsKey(QUERY_TEXT_FIELD)) { + if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) ); } - scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); - } else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { - String path = (String) params.get(QUERY_TEXT_PATH_FIELD); + scoringContext.put(QUERY_TEXT_FIELD, (String) ctxMap.get(QUERY_TEXT_FIELD)); + } else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { + String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD); // Convert query to a map with io/xcontent shenanigans ByteArrayOutputStream baos = new ByteArrayOutputStream(); XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java index 9cf62f929..ea16c614e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java @@ -126,7 +126,7 @@ private void runQueries() throws Exception { } private Map search(String queryText) throws Exception { - String jsonQueryFrame = "{\"query\":{\"match_all\":{}},\"ext\":{\"rerank\":{\"query_text\":\"%s\"}}}"; + String jsonQueryFrame = "{\"query\":{\"match_all\":{}},\"ext\":{\"rerank\":{\"query_context\": {\"query_text\":\"%s\"}}}}"; String jsonQuery = String.format(LOCALE, jsonQueryFrame, queryText); log.info(jsonQuery); Request request = new Request("POST", "/" + INDEX_NAME + "/_search"); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java index a1ef46b91..85c176a57 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java @@ -105,7 +105,9 @@ private void setupParams(Map params) { NeuralQueryBuilder nqb = new NeuralQueryBuilder(); nqb.fieldName("embedding").k(3).modelId("embedding_id").queryText("Question about dolphins"); ssb.query(nqb); - List exts = List.of(new RerankSearchExtBuilder(new HashMap<>(params))); + List exts = List.of( + new RerankSearchExtBuilder(new HashMap<>(Map.of(QueryContextSourceFetcher.NAME, new HashMap<>(params)))) + ); ssb.ext(exts); doReturn(ssb).when(request).source(); } From aa1d5245a6b3f6e4a84f809b4065fdd8af56d82b Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 5 Dec 2023 13:42:46 -0800 Subject: [PATCH 12/27] add javadocs Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 22 +++++++++++++++++-- .../rerank/ContextSourceFetcher.java | 9 ++++++++ .../rerank/DocumentContextSourceFetcher.java | 7 ++++-- .../rerank/QueryContextSourceFetcher.java | 4 ++++ .../processor/rerank/RerankProcessor.java | 5 +++-- .../rerank/RescoringRerankProcessor.java | 14 +++++++++--- .../rerank/TextSimilarityRerankProcessor.java | 12 ++++++++++ 7 files changed, 64 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 99bb24331..5f55b15a8 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -24,7 +24,6 @@ import java.util.stream.Collectors; import lombok.AllArgsConstructor; -import lombok.extern.log4j.Log4j2; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.ContextSourceFetcher; @@ -37,7 +36,11 @@ import com.google.common.annotations.VisibleForTesting; -@Log4j2 +/** + * Factory for rerank processors. Must: + * - Instantiate the right kind of rerank processor + * - Instantiate the appropriate context source fetchers + */ @AllArgsConstructor public class RerankProcessorFactory implements Processor.Factory { @@ -90,8 +93,17 @@ RerankType findRerankType(final Map config) throws IllegalArgume throw new IllegalArgumentException("no rerank type found"); } + /** + * Factory class for context fetchers. Constructs a list of context fetchers + * specified in the pipeline config (and maybe the query context fetcher) + */ protected static class ContextFetcherFactory { + /** + * Map rerank types to whether they should include the query context source fetcher + * @param type the constructing RerankType + * @return does this RerankType depend on the QueryContextSourceFetcher? + */ public static boolean shouldIncludeQueryContextFetcher(RerankType type) { switch (type) { case TEXT_SIMILARITY: @@ -101,6 +113,12 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) { } } + /** + * Create necessary queryContextFetchers for this processor + * @param config processor config object. Look for "context" field to find fetchers + * @param includeQueryContextFetcher should I include the queryContextFetcher? + * @return list of contextFetchers for the processor to use + */ public static List createFetchers(Map config, boolean includeQueryContextFetcher) { List fetchers = new ArrayList<>(); @SuppressWarnings("unchecked") diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java index 5732d8b2f..d81608fe5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java @@ -23,6 +23,10 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; +/** + * Interface that gets context from some source and puts it in a map + * for a reranking processor to use + */ public interface ContextSourceFetcher { /** @@ -35,5 +39,10 @@ public interface ContextSourceFetcher { */ public void fetchContext(SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener); + /** + * Get the name of the contextSourceFetcher. This will be used as the field + * name in the context config for the pipeline + * @return Name of the fetcher + */ public String getName(); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java index 3e2ae636b..e25683085 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java @@ -23,14 +23,15 @@ import java.util.Map; import lombok.AllArgsConstructor; -import lombok.extern.log4j.Log4j2; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.search.SearchHit; -@Log4j2 +/** + * Context Source Fetcher that gets context from the search results (documents) + */ @AllArgsConstructor public class DocumentContextSourceFetcher implements ContextSourceFetcher { @@ -47,6 +48,7 @@ public class DocumentContextSourceFetcher implements ContextSourceFetcher { * @param searchResponse the search results, in case they're relevant * @param listener be async */ + @Override public void fetchContext(SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener) { List contexts = new ArrayList<>(); for (SearchHit hit : searchResponse.getHits()) { @@ -69,6 +71,7 @@ private String contextFromSearchHit(final SearchHit hit, final String field) { } } + @Override public String getName() { return NAME; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java index 45ba1efa1..42d5ceba5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java @@ -36,6 +36,9 @@ import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; import org.opensearch.search.SearchExtBuilder; +/** + * Context Source Fetcher that gets context from the rerank query ext. + */ public class QueryContextSourceFetcher implements ContextSourceFetcher { public static final String NAME = "query_context"; @@ -93,6 +96,7 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo } } + @Override public String getName() { return NAME; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 02f471451..47638920c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -24,14 +24,15 @@ import lombok.AllArgsConstructor; import lombok.Getter; -import lombok.extern.log4j.Log4j2; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.search.pipeline.SearchResponseProcessor; -@Log4j2 +/** + * Abstract base class for reranking processors + */ @AllArgsConstructor public abstract class RerankProcessor implements SearchResponseProcessor { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index ca3b84749..b6e21a136 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -23,8 +23,6 @@ import java.util.List; import java.util.Map; -import lombok.extern.log4j.Log4j2; - import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.core.action.ActionListener; @@ -32,9 +30,19 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.profile.SearchProfileShardResults; -@Log4j2 +/** + * RerankProcessor that rescores all the documents and re-sorts them using the new scores + */ public abstract class RescoringRerankProcessor extends RerankProcessor { + /** + * Constructor. pass through to RerankProcessor ctor. + * @param type + * @param description + * @param tag + * @param ignoreFailure + * @param contextSourceFetchers + */ public RescoringRerankProcessor( RerankType type, String description, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java index f95a794ba..1cba4866e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java @@ -27,6 +27,9 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; +/** + * Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore + */ public class TextSimilarityRerankProcessor extends RescoringRerankProcessor { public static final String MODEL_ID_FIELD = "model_id"; @@ -35,6 +38,15 @@ public class TextSimilarityRerankProcessor extends RescoringRerankProcessor { protected final MLCommonsClientAccessor mlCommonsClientAccessor; + /** + * Constructor + * @param description + * @param tag + * @param ignoreFailure + * @param modelId id of TEXT_SIMILARITY model + * @param contextSourceFetchers + * @param mlCommonsClientAccessor + */ public TextSimilarityRerankProcessor( String description, String tag, From 77301d954bea10b2172c1291751ab5fa1d50b9e4 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 11 Dec 2023 12:22:12 -0800 Subject: [PATCH 13/27] update to new asyncProcessResponse api Signed-off-by: HenryL27 --- .../neuralsearch/processor/rerank/RerankProcessor.java | 8 +++++++- .../rerank/TextSimilarityRerankProcessorTests.java | 6 +++++- .../processor/UploadTextSimilarityModelRequestBody.json | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 47638920c..82e7dd268 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -28,6 +28,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.pipeline.SearchResponseProcessor; /** @@ -95,7 +96,12 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp } @Override - public void processResponseAsync(SearchRequest request, SearchResponse response, ActionListener responseListener) { + public void processResponseAsync( + SearchRequest request, + SearchResponse response, + PipelineProcessingContext ctx, + ActionListener responseListener + ) { try { generateRerankingContext( request, diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java index 85c176a57..af951d507 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java @@ -57,6 +57,7 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.pipeline.Processor.PipelineContext; import org.opensearch.test.OpenSearchTestCase; @@ -74,6 +75,9 @@ public class TextSimilarityRerankProcessorTests extends OpenSearchTestCase { @Mock PipelineContext pipelineContext; + @Mock + PipelineProcessingContext ppctx; + RerankProcessorFactory factory; TextSimilarityRerankProcessor processor; @@ -336,7 +340,7 @@ public void testProcessResponseAsync() throws IOException { setupSearchResults(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - processor.processResponseAsync(request, response, listener); + processor.processResponseAsync(request, response, ppctx, listener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); verify(listener, times(1)).onResponse(argCaptor.capture()); SearchResponse rsp = argCaptor.getValue(); diff --git a/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json index 897354616..9c28a10f4 100644 --- a/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json +++ b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json @@ -12,5 +12,5 @@ "framework_type": "huggingface_transformers", "all_config": "nobody will read this" }, - "url": "https://github.com/HenryL27/ml-commons/blob/cross-encoder/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE.zip?raw=true" + "url": "https://github.com/opensearch-project/ml-commons/blob/main/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-torch_script.zip?raw=true" } \ No newline at end of file From e8de4122b028b426299409c1a2f2d8f7709507ee Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 18 Dec 2023 08:39:09 -0800 Subject: [PATCH 14/27] rename reranktype to ML_OPENSEARCH Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 4 ++-- .../processor/rerank/RerankType.java | 2 +- .../rerank/TextSimilarityRerankProcessor.java | 2 +- .../factory/RerankProcessorFactoryTests.java | 19 ++++++++----------- .../TextSimilarityRerankProcessorTests.java | 2 +- ...SimilarityRerankPipelineConfiguration.json | 2 +- 6 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 5f55b15a8..342aa25b8 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -62,7 +62,7 @@ public SearchResponseProcessor create( boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type); List contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher); switch (type) { - case TEXT_SIMILARITY: + case ML_OPENSEARCH: @SuppressWarnings("unchecked") Map rerankerConfig = (Map) config.remove(type.getLabel()); String modelId = rerankerConfig.get(TextSimilarityRerankProcessor.MODEL_ID_FIELD); @@ -106,7 +106,7 @@ protected static class ContextFetcherFactory { */ public static boolean shouldIncludeQueryContextFetcher(RerankType type) { switch (type) { - case TEXT_SIMILARITY: + case ML_OPENSEARCH: return true; default: return false; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index 45221a2c5..b21c97aac 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -28,7 +28,7 @@ */ public enum RerankType { - TEXT_SIMILARITY("text_similarity"); + ML_OPENSEARCH("ml_opensearch"); @Getter private final String label; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java index 1cba4866e..2c8f5f186 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java @@ -55,7 +55,7 @@ public TextSimilarityRerankProcessor( List contextSourceFetchers, MLCommonsClientAccessor mlCommonsClientAccessor ) { - super(RerankType.TEXT_SIMILARITY, description, tag, ignoreFailure, contextSourceFetchers); + super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers); this.modelId = modelId; this.mlCommonsClientAccessor = mlCommonsClientAccessor; } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index 13c27a6a4..4b070e757 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -25,8 +25,6 @@ import java.util.Locale; import java.util.Map; -import lombok.extern.log4j.Log4j2; - import org.junit.Before; import org.mockito.Mock; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -38,7 +36,6 @@ import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.test.OpenSearchTestCase; -@Log4j2 public class RerankProcessorFactoryTests extends OpenSearchTestCase { final String TAG = "default-tag"; @@ -82,7 +79,7 @@ public void testRerankProcessorFactory_NonExistentType_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_HappyPath() { Map config = new HashMap<>( Map.of( - RerankType.TEXT_SIMILARITY.getLabel(), + RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) @@ -99,7 +96,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { Map.of( "poafn aorr;anv", Map.of(";oawhls", "aowirhg "), - RerankType.TEXT_SIMILARITY.getLabel(), + RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id", "pqiohg rpowierhg", "pw;oith4pt3ih go")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) @@ -114,7 +111,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() { Map config = new HashMap<>( Map.of( - RerankType.TEXT_SIMILARITY.getLabel(), + RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>( @@ -135,7 +132,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() { } public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { - Map config = new HashMap<>(Map.of(RerankType.TEXT_SIMILARITY.getLabel(), Map.of())); + Map config = new HashMap<>(Map.of(RerankType.ML_OPENSEARCH.getLabel(), Map.of())); assertThrows( TextSimilarityRerankProcessor.MODEL_ID_FIELD + " must be specified", IllegalArgumentException.class, @@ -145,7 +142,7 @@ public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { Map config = new HashMap<>( - Map.of(RerankType.TEXT_SIMILARITY.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id"))) + Map.of(RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id"))) ); assertThrows( String.format(Locale.ROOT, "%s field must be provided", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), @@ -157,7 +154,7 @@ public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { Map config = new HashMap<>( Map.of( - RerankType.TEXT_SIMILARITY.getLabel(), + RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) @@ -173,7 +170,7 @@ public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_BadContextDocField_ThenFail() { Map config = new HashMap<>( Map.of( - RerankType.TEXT_SIMILARITY.getLabel(), + RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, "text_representation")) @@ -189,7 +186,7 @@ public void testRerankProcessorFactory_CrossEncoder_BadContextDocField_ThenFail( public void testRerankProcessorFactory_CrossEncoder_EmptyContextDocField_ThenFail() { Map config = new HashMap<>( Map.of( - RerankType.TEXT_SIMILARITY.getLabel(), + RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>())) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java index af951d507..b06abfc84 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java @@ -88,7 +88,7 @@ public void setup() { factory = new RerankProcessorFactory(mlCommonsClientAccessor); Map config = new HashMap<>( Map.of( - RerankType.TEXT_SIMILARITY.getLabel(), + RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) diff --git a/src/test/resources/processor/TextSimilarityRerankPipelineConfiguration.json b/src/test/resources/processor/TextSimilarityRerankPipelineConfiguration.json index e9d0c6e2f..89682e23c 100644 --- a/src/test/resources/processor/TextSimilarityRerankPipelineConfiguration.json +++ b/src/test/resources/processor/TextSimilarityRerankPipelineConfiguration.json @@ -3,7 +3,7 @@ "response_processors": [ { "rerank": { - "text_similarity": { + "ml_opensearch": { "model_id": "%s" }, "context": { From a7090b2c72a3dd8ba5b426f4a2b4ce66b3580920 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 18 Dec 2023 09:38:06 -0800 Subject: [PATCH 15/27] improve error messages for bad rerank type config Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 28 +++++++++++++------ .../processor/rerank/RerankType.java | 26 +++++++++++++---- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 342aa25b8..7886357e5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -21,6 +21,8 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; +import java.util.StringJoiner; import java.util.stream.Collectors; import lombok.AllArgsConstructor; @@ -35,6 +37,7 @@ import org.opensearch.search.pipeline.SearchResponseProcessor; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Sets; /** * Factory for rerank processors. Must: @@ -81,16 +84,25 @@ public SearchResponseProcessor create( @VisibleForTesting RerankType findRerankType(final Map config) throws IllegalArgumentException { - for (String key : config.keySet()) { - try { - RerankType attempt = RerankType.from(key); - return attempt; - } catch (IllegalArgumentException e) { - // Assume it's just a different field in the config, so don't do anything. - // If we get to the end and there were no valid RerankTypes, then we can panic. + // Set of rerank type labels in the config + Set rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet()); + // A rerank type must be provided + if (rerankTypes.size() == 0) { + StringJoiner msgBuilder = new StringJoiner(", ", "No rerank type found. Possible rerank types are: [", "]"); + for (RerankType t : RerankType.values()) { + msgBuilder.add(t.getLabel()); } + throw new IllegalArgumentException(msgBuilder.toString()); } - throw new IllegalArgumentException("no rerank type found"); + // Only one rerank type may be provided + if (rerankTypes.size() > 1) { + StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted."); + for (String rt : rerankTypes) { + msgBuilder.add(rt); + } + throw new IllegalArgumentException(msgBuilder.toString()); + } + return RerankType.from(rerankTypes.iterator().next()); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index b21c97aac..c74d49cc5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -17,9 +17,10 @@ */ package org.opensearch.neuralsearch.processor.rerank; -import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.Locale; -import java.util.Optional; +import java.util.Map; import lombok.Getter; @@ -37,17 +38,30 @@ private RerankType(String label) { this.label = label; } + private static final Map LABEL_MAP; + static { + Map labelMap = new HashMap<>(); + for (RerankType type : RerankType.values()) { + labelMap.put(type.getLabel(), type); + } + LABEL_MAP = Collections.unmodifiableMap(labelMap); + } + /** * Construct a RerankType from the label * @param label label of a RerankType * @return RerankType represented by the label */ public static RerankType from(String label) { - Optional typeMaybe = Arrays.stream(RerankType.values()).filter(rrt -> rrt.label.equals(label)).findFirst(); - if (typeMaybe.isPresent()) { - return typeMaybe.get(); - } else { + RerankType ans = LABEL_MAP.get(label); + if (ans == null) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Wrong rerank type name: %s", label)); + } else { + return ans; } } + + public static Map labelMap() { + return LABEL_MAP; + } } From 797eaf6b00bea0e343c61ffd315b133f7b438f0d Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 18 Dec 2023 11:08:52 -0800 Subject: [PATCH 16/27] simplify configuration/factory logic Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 51 +++++++------------ .../rerank/ContextSourceFetcher.java | 1 + .../rerank/DocumentContextSourceFetcher.java | 19 +++++++ .../factory/RerankProcessorFactoryTests.java | 13 ++--- 4 files changed, 46 insertions(+), 38 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 7886357e5..a73e277f3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -23,10 +23,10 @@ import java.util.Map; import java.util.Set; import java.util.StringJoiner; -import java.util.stream.Collectors; import lombok.AllArgsConstructor; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.ContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; @@ -63,22 +63,19 @@ public SearchResponseProcessor create( ) { RerankType type = findRerankType(config); boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type); - List contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher); + List contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag); switch (type) { case ML_OPENSEARCH: - @SuppressWarnings("unchecked") - Map rerankerConfig = (Map) config.remove(type.getLabel()); - String modelId = rerankerConfig.get(TextSimilarityRerankProcessor.MODEL_ID_FIELD); - if (modelId == null) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "%s must be specified", TextSimilarityRerankProcessor.MODEL_ID_FIELD) - ); - } + Map rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel()); + String modelId = ConfigurationUtils.readStringProperty( + RERANK_PROCESSOR_TYPE, + tag, + rerankerConfig, + TextSimilarityRerankProcessor.MODEL_ID_FIELD + ); return new TextSimilarityRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); default: - throw new IllegalArgumentException( - String.format(Locale.ROOT, "could not find constructor for reranker type %s", type.getLabel()) - ); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel())); } } @@ -97,9 +94,7 @@ RerankType findRerankType(final Map config) throws IllegalArgume // Only one rerank type may be provided if (rerankTypes.size() > 1) { StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted."); - for (String rt : rerankTypes) { - msgBuilder.add(rt); - } + rerankTypes.forEach(rt -> msgBuilder.add(rt)); throw new IllegalArgumentException(msgBuilder.toString()); } return RerankType.from(rerankTypes.iterator().next()); @@ -131,26 +126,18 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) { * @param includeQueryContextFetcher should I include the queryContextFetcher? * @return list of contextFetchers for the processor to use */ - public static List createFetchers(Map config, boolean includeQueryContextFetcher) { + public static List createFetchers( + Map config, + boolean includeQueryContextFetcher, + String tag + ) { + Map contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD); List fetchers = new ArrayList<>(); - @SuppressWarnings("unchecked") - Map contextConfig = (Map) config.remove(CONTEXT_CONFIG_FIELD); - if (contextConfig == null) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field must be provided", CONTEXT_CONFIG_FIELD)); - } for (String key : contextConfig.keySet()) { + Object cfg = contextConfig.get(key); switch (key) { case DocumentContextSourceFetcher.NAME: - Object cfg = contextConfig.get(key); - if (!(cfg instanceof List)) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of strings", key)); - } - List fields = (List) contextConfig.get(key); - if (fields.size() == 0) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", key)); - } - List strfields = fields.stream().map(field -> (String) field).collect(Collectors.toList()); - fetchers.add(new DocumentContextSourceFetcher(strfields)); + fetchers.add(DocumentContextSourceFetcher.create(cfg)); break; default: throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key)); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java index d81608fe5..b3d6e372c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java @@ -45,4 +45,5 @@ public interface ContextSourceFetcher { * @return Name of the fetcher */ public String getName(); + } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java index e25683085..7fdfabe12 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java @@ -20,7 +20,9 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.stream.Collectors; import lombok.AllArgsConstructor; @@ -75,4 +77,21 @@ private String contextFromSearchHit(final SearchHit hit, final String field) { public String getName() { return NAME; } + + /** + * Create a document context source fetcher from list of field names provided by configuration + * @param config configuration object grabbed from parsed API request. Should be a list of strings + * @return a new DocumentContextSourceFetcher or throws IllegalArgumentException if config is malformed + */ + public static DocumentContextSourceFetcher create(Object config) { + if (!(config instanceof List)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a list of field names", NAME)); + } + List fields = (List) config; + if (fields.size() == 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", NAME)); + } + List strfields = fields.stream().map(field -> (String) field).collect(Collectors.toList()); + return new DocumentContextSourceFetcher(strfields); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index 4b070e757..d768b6f35 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -27,6 +27,7 @@ import org.junit.Before; import org.mockito.Mock; +import org.opensearch.OpenSearchParseException; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; @@ -134,8 +135,8 @@ public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { Map config = new HashMap<>(Map.of(RerankType.ML_OPENSEARCH.getLabel(), Map.of())); assertThrows( - TextSimilarityRerankProcessor.MODEL_ID_FIELD + " must be specified", - IllegalArgumentException.class, + String.format(Locale.ROOT, "[%s] required property is missing", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), + OpenSearchParseException.class, () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); } @@ -145,8 +146,8 @@ public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { Map.of(RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id"))) ); assertThrows( - String.format(Locale.ROOT, "%s field must be provided", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), - IllegalArgumentException.class, + String.format(Locale.ROOT, "[%s] required property is missing", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), + OpenSearchParseException.class, () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); } @@ -161,8 +162,8 @@ public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { ) ); assertThrows( - TextSimilarityRerankProcessor.MODEL_ID_FIELD + " must be specified", - IllegalArgumentException.class, + String.format(Locale.ROOT, "[%s] required property is missing", TextSimilarityRerankProcessor.MODEL_ID_FIELD), + OpenSearchParseException.class, () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); } From ddf286612a41fa380a6758ca9d1adab15103f5b8 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 18 Dec 2023 12:01:58 -0800 Subject: [PATCH 17/27] improve handling for non-flat-string context fields Signed-off-by: HenryL27 --- .../processor/factory/RerankProcessorFactory.java | 4 +--- .../rerank/DocumentContextSourceFetcher.java | 11 +++++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index a73e277f3..d1c7042b6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -36,7 +36,6 @@ import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Sets; /** @@ -79,8 +78,7 @@ public SearchResponseProcessor create( } } - @VisibleForTesting - RerankType findRerankType(final Map config) throws IllegalArgumentException { + private RerankType findRerankType(final Map config) throws IllegalArgumentException { // Set of rerank type labels in the config Set rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet()); // A rerank type must be provided diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java index 7fdfabe12..d5a1fef29 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java @@ -29,6 +29,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.ObjectPath; import org.opensearch.search.SearchHit; /** @@ -65,9 +66,11 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo private String contextFromSearchHit(final SearchHit hit, final String field) { if (hit.getFields().containsKey(field)) { - return (String) hit.field(field).getValue(); + Object fieldValue = hit.field(field).getValue(); + return String.valueOf(fieldValue); } else if (hit.hasSource() && hit.getSourceAsMap().containsKey(field)) { - return (String) hit.getSourceAsMap().get(field); + Object sourceValue = ObjectPath.eval(field, hit.getSourceAsMap()); + return String.valueOf(sourceValue); } else { return ""; } @@ -91,7 +94,7 @@ public static DocumentContextSourceFetcher create(Object config) { if (fields.size() == 0) { throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be nonempty", NAME)); } - List strfields = fields.stream().map(field -> (String) field).collect(Collectors.toList()); - return new DocumentContextSourceFetcher(strfields); + List fieldsAsStrings = fields.stream().map(field -> (String) field).collect(Collectors.toList()); + return new DocumentContextSourceFetcher(fieldsAsStrings); } } From 14c8f893624de8816738e89548edb4600c882d17 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 18 Dec 2023 12:11:02 -0800 Subject: [PATCH 18/27] rename TextSimilarity files to MLOpenSearch files Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 6 ++--- ....java => MLOpenSearchRerankProcessor.java} | 4 ++-- .../factory/RerankProcessorFactoryTests.java | 22 +++++++++---------- ...ava => MLOpenSearchRerankProcessorIT.java} | 6 ++--- ... => MLOpenSearchRerankProcessorTests.java} | 8 +++---- ...ankMLOpenSearchPipelineConfiguration.json} | 0 6 files changed, 23 insertions(+), 23 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/rerank/{TextSimilarityRerankProcessor.java => MLOpenSearchRerankProcessor.java} (96%) rename src/test/java/org/opensearch/neuralsearch/processor/rerank/{TextSimilarityRerankProcessorIT.java => MLOpenSearchRerankProcessorIT.java} (96%) rename src/test/java/org/opensearch/neuralsearch/processor/rerank/{TextSimilarityRerankProcessorTests.java => MLOpenSearchRerankProcessorTests.java} (98%) rename src/test/resources/processor/{TextSimilarityRerankPipelineConfiguration.json => RerankMLOpenSearchPipelineConfiguration.json} (100%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index d1c7042b6..eb41789be 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -30,9 +30,9 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.ContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.QueryContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.RerankType; -import org.opensearch.neuralsearch.processor.rerank.TextSimilarityRerankProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; @@ -70,9 +70,9 @@ public SearchResponseProcessor create( RERANK_PROCESSOR_TYPE, tag, rerankerConfig, - TextSimilarityRerankProcessor.MODEL_ID_FIELD + MLOpenSearchRerankProcessor.MODEL_ID_FIELD ); - return new TextSimilarityRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); + return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); default: throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel())); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java similarity index 96% rename from src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java rename to src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java index 2c8f5f186..70ba9a050 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -30,7 +30,7 @@ /** * Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore */ -public class TextSimilarityRerankProcessor extends RescoringRerankProcessor { +public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor { public static final String MODEL_ID_FIELD = "model_id"; @@ -47,7 +47,7 @@ public class TextSimilarityRerankProcessor extends RescoringRerankProcessor { * @param contextSourceFetchers * @param mlCommonsClientAccessor */ - public TextSimilarityRerankProcessor( + public MLOpenSearchRerankProcessor( String description, String tag, boolean ignoreFailure, diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index d768b6f35..d8d372cbc 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -30,9 +30,9 @@ import org.opensearch.OpenSearchParseException; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankType; -import org.opensearch.neuralsearch.processor.rerank.TextSimilarityRerankProcessor; import org.opensearch.search.pipeline.Processor.PipelineContext; import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.test.OpenSearchTestCase; @@ -68,7 +68,7 @@ public void testRerankProcessorFactory_EmptyConfig_ThenFail() { public void testRerankProcessorFactory_NonExistentType_ThenFail() { Map config = new HashMap<>( - Map.of("jpeo rvgh we iorgn", Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")) + Map.of("jpeo rvgh we iorgn", Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")) ); assertThrows( "no rerank type found", @@ -81,14 +81,14 @@ public void testRerankProcessorFactory_CrossEncoder_HappyPath() { Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), - new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) ) ); SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); assert (processor instanceof RerankProcessor); - assert (processor instanceof TextSimilarityRerankProcessor); + assert (processor instanceof MLOpenSearchRerankProcessor); assert (processor.getType().equals(RerankProcessor.TYPE)); } @@ -98,14 +98,14 @@ public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { "poafn aorr;anv", Map.of(";oawhls", "aowirhg "), RerankType.ML_OPENSEARCH.getLabel(), - new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id", "pqiohg rpowierhg", "pw;oith4pt3ih go")), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id", "pqiohg rpowierhg", "pw;oith4pt3ih go")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) ) ); SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); assert (processor instanceof RerankProcessor); - assert (processor instanceof TextSimilarityRerankProcessor); + assert (processor instanceof MLOpenSearchRerankProcessor); assert (processor.getType().equals(RerankProcessor.TYPE)); } @@ -113,7 +113,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() { Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), - new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>( Map.of( @@ -143,7 +143,7 @@ public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { Map config = new HashMap<>( - Map.of(RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id"))) + Map.of(RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id"))) ); assertThrows( String.format(Locale.ROOT, "[%s] required property is missing", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), @@ -162,7 +162,7 @@ public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { ) ); assertThrows( - String.format(Locale.ROOT, "[%s] required property is missing", TextSimilarityRerankProcessor.MODEL_ID_FIELD), + String.format(Locale.ROOT, "[%s] required property is missing", MLOpenSearchRerankProcessor.MODEL_ID_FIELD), OpenSearchParseException.class, () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); @@ -172,7 +172,7 @@ public void testRerankProcessorFactory_CrossEncoder_BadContextDocField_ThenFail( Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), - new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, "text_representation")) ) @@ -188,7 +188,7 @@ public void testRerankProcessorFactory_CrossEncoder_EmptyContextDocField_ThenFai Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), - new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>())) ) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java similarity index 96% rename from src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java rename to src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java index ea16c614e..3c2a37057 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java @@ -39,9 +39,9 @@ import com.google.common.collect.ImmutableList; @Log4j2 -public class TextSimilarityRerankProcessorIT extends BaseNeuralSearchIT { +public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT { - final static String PIPELINE_NAME = "rerank-ts-pipeline"; + final static String PIPELINE_NAME = "rerank-mlos-pipeline"; final static String INDEX_NAME = "rerank-test"; final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy"; final static String TEXT_REP_2 = "Fish like to eat plankton"; @@ -63,7 +63,7 @@ public void tearDown() { public void testCrossEncoderRerankProcessor() throws Exception { String modelId = uploadTextSimilarityModel(); loadModel(modelId); - createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/TextSimilarityRerankPipelineConfiguration.json"); + createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/RerankMLOpenSearchPipelineConfiguration.json"); setupIndex(); runQueries(); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java similarity index 98% rename from src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index b06abfc84..bb589a600 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/TextSimilarityRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -62,7 +62,7 @@ import org.opensearch.test.OpenSearchTestCase; @Log4j2 -public class TextSimilarityRerankProcessorTests extends OpenSearchTestCase { +public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase { @Mock SearchRequest request; @@ -80,7 +80,7 @@ public class TextSimilarityRerankProcessorTests extends OpenSearchTestCase { RerankProcessorFactory factory; - TextSimilarityRerankProcessor processor; + MLOpenSearchRerankProcessor processor; @Before public void setup() { @@ -89,12 +89,12 @@ public void setup() { Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), - new HashMap<>(Map.of(TextSimilarityRerankProcessor.MODEL_ID_FIELD, "model-id")), + new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")), RerankProcessorFactory.CONTEXT_CONFIG_FIELD, new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) ) ); - processor = (TextSimilarityRerankProcessor) factory.create( + processor = (MLOpenSearchRerankProcessor) factory.create( Map.of(), "rerank processor", "processor for reranking with a cross encoder", diff --git a/src/test/resources/processor/TextSimilarityRerankPipelineConfiguration.json b/src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json similarity index 100% rename from src/test/resources/processor/TextSimilarityRerankPipelineConfiguration.json rename to src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json From 577f855ed1e631c6fc7d764deea83616a8ddb691 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 18 Dec 2023 17:19:08 -0800 Subject: [PATCH 19/27] apply spotless after rebase Signed-off-by: HenryL27 --- .../neuralsearch/ml/MLCommonsClientAccessor.java | 4 ---- .../processor/factory/RerankProcessorFactory.java | 4 ++-- .../processor/rerank/DocumentContextSourceFetcher.java | 4 ++-- .../neuralsearch/processor/rerank/RerankProcessor.java | 10 ++++++---- .../neuralsearch/query/ext/RerankSearchExtBuilder.java | 8 ++++---- .../rerank/MLOpenSearchRerankProcessorIT.java | 6 +++--- .../rerank/MLOpenSearchRerankProcessorTests.java | 4 ++-- .../query/ext/RerankSearchExtBuilderTests.java | 4 ++-- 8 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index a6170d308..f9ddf73a9 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -13,10 +13,6 @@ import java.util.Map; import java.util.stream.Collectors; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; -import lombok.extern.log4j.Log4j2; - import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index eb41789be..0c52a3588 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -24,8 +24,6 @@ import java.util.Set; import java.util.StringJoiner; -import lombok.AllArgsConstructor; - import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.rerank.ContextSourceFetcher; @@ -38,6 +36,8 @@ import com.google.common.collect.Sets; +import lombok.AllArgsConstructor; + /** * Factory for rerank processors. Must: * - Instantiate the right kind of rerank processor diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java index d5a1fef29..fb8650e51 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java @@ -24,14 +24,14 @@ import java.util.Map; import java.util.stream.Collectors; -import lombok.AllArgsConstructor; - import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.ObjectPath; import org.opensearch.search.SearchHit; +import lombok.AllArgsConstructor; + /** * Context Source Fetcher that gets context from the search results (documents) */ diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 82e7dd268..2ca3586e8 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -22,15 +22,15 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; -import lombok.AllArgsConstructor; -import lombok.Getter; - import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.pipeline.SearchResponseProcessor; +import lombok.AllArgsConstructor; +import lombok.Getter; + /** * Abstract base class for reranking processors */ @@ -106,7 +106,9 @@ public void processResponseAsync( generateRerankingContext( request, response, - ActionListener.wrap(context -> { rerank(response, context, responseListener); }, e -> { responseListener.onFailure(e); }) + ActionListener.wrap(context -> { rerank(response, context, responseListener); }, e -> { + responseListener.onFailure(e); + }) ); } catch (Exception e) { responseListener.onFailure(e); diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java index 915b2c858..4f8f3cb94 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -23,16 +23,16 @@ import java.util.Objects; import java.util.Optional; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.extern.log4j.Log4j2; - import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchExtBuilder; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + @Log4j2 @AllArgsConstructor public class RerankSearchExtBuilder extends SearchExtBuilder { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java index 3c2a37057..7d0972dcb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java @@ -22,9 +22,6 @@ import java.util.List; import java.util.Map; -import lombok.SneakyThrows; -import lombok.extern.log4j.Log4j2; - import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; @@ -38,6 +35,9 @@ import com.google.common.collect.ImmutableList; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; + @Log4j2 public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index bb589a600..1c6446650 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -33,8 +33,6 @@ import java.util.Locale; import java.util.Map; -import lombok.extern.log4j.Log4j2; - import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -61,6 +59,8 @@ import org.opensearch.search.pipeline.Processor.PipelineContext; import org.opensearch.test.OpenSearchTestCase; +import lombok.extern.log4j.Log4j2; + @Log4j2 public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase { diff --git a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java index 8c24a5a8d..13637b2ca 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java @@ -23,8 +23,6 @@ import java.util.List; import java.util.Map; -import lombok.extern.log4j.Log4j2; - import org.junit.Before; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesReference; @@ -35,6 +33,8 @@ import org.opensearch.search.SearchExtBuilder; import org.opensearch.test.OpenSearchTestCase; +import lombok.extern.log4j.Log4j2; + @Log4j2 public class RerankSearchExtBuilderTests extends OpenSearchTestCase { From e3cf21898492d2401794ea7ce3446a10e899aabe Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 21 Dec 2023 10:59:26 -0800 Subject: [PATCH 20/27] update changelog Signed-off-by: HenryL27 --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ae6baab7..284cf64da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features +- Add rerank processor interface and ml-commons reranker ([#494](https://github.com/opensearch-project/neural-search/pull/494)) ### Enhancements ### Bug Fixes - Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438)) From 7a6595f6f5b7f423639141241ec28e1a27d477ba Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 9 Jan 2024 15:55:29 -0800 Subject: [PATCH 21/27] after rebase Signed-off-by: HenryL27 --- .../processor/factory/RerankProcessorFactory.java | 13 ------------- .../processor/rerank/ContextSourceFetcher.java | 13 ------------- .../rerank/DocumentContextSourceFetcher.java | 13 ------------- .../rerank/MLOpenSearchRerankProcessor.java | 13 ------------- .../rerank/QueryContextSourceFetcher.java | 13 ------------- .../processor/rerank/RerankProcessor.java | 13 ------------- .../neuralsearch/processor/rerank/RerankType.java | 13 ------------- .../rerank/RescoringRerankProcessor.java | 13 ------------- .../query/ext/RerankSearchExtBuilder.java | 13 ------------- .../factory/RerankProcessorFactoryTests.java | 13 ------------- .../rerank/MLOpenSearchRerankProcessorIT.java | 15 +-------------- .../rerank/MLOpenSearchRerankProcessorTests.java | 13 ------------- .../query/ext/RerankSearchExtBuilderTests.java | 13 ------------- .../RerankMLOpenSearchPipelineConfiguration.json | 2 +- .../UploadTextSimilarityModelRequestBody.json | 2 +- 15 files changed, 3 insertions(+), 172 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 0c52a3588..5779bbf2c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.factory; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java index b3d6e372c..d8a576937 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.rerank; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java index fb8650e51..7dc577502 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.rerank; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java index 70ba9a050..c94ddf6ca 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.rerank; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java index 42d5ceba5..5000dd756 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.rerank; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 2ca3586e8..6104990ab 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.rerank; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index c74d49cc5..c1d64c5ae 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.rerank; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index b6e21a136..c52a5223d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.rerank; diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java index 4f8f3cb94..4c0026a89 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.query.ext; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index d8d372cbc..7488e2607 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.factory; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java index 7d0972dcb..1fa6ac629 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.rerank; @@ -31,7 +18,7 @@ import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.rest.RestStatus; -import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; import com.google.common.collect.ImmutableList; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index 1c6446650..53019d7ce 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.processor.rerank; diff --git a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java index 13637b2ca..e1724b014 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java @@ -1,19 +1,6 @@ /* - * Copyright 2023 Aryn * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package org.opensearch.neuralsearch.query.ext; diff --git a/src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json b/src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json index 89682e23c..fc6cfd124 100644 --- a/src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json +++ b/src/test/resources/processor/RerankMLOpenSearchPipelineConfiguration.json @@ -12,4 +12,4 @@ } } ] -} \ No newline at end of file +} diff --git a/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json index 9c28a10f4..82529202d 100644 --- a/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json +++ b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json @@ -13,4 +13,4 @@ "all_config": "nobody will read this" }, "url": "https://github.com/opensearch-project/ml-commons/blob/main/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_similarity/TinyBERT-CE-torch_script.zip?raw=true" -} \ No newline at end of file +} From 708fb66e58f7dcfa3c394ffee3ea73a4b112ab9c Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 10 Jan 2024 10:15:26 -0800 Subject: [PATCH 22/27] Address pr comments and fix XContent in search ext Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 2 +- .../rerank/DocumentContextSourceFetcher.java | 10 ++++ .../rerank/QueryContextSourceFetcher.java | 28 +++++++--- .../rerank/RescoringRerankProcessor.java | 3 +- .../query/ext/RerankSearchExtBuilder.java | 34 ++++++++++-- .../factory/RerankProcessorFactoryTests.java | 6 +-- .../rerank/MLOpenSearchRerankProcessorIT.java | 20 +++---- .../MLOpenSearchRerankProcessorTests.java | 17 +++--- .../ext/RerankSearchExtBuilderTests.java | 52 ++++++++++++++----- .../UploadTextSimilarityModelRequestBody.json | 2 +- 10 files changed, 124 insertions(+), 50 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 5779bbf2c..953b91766 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -89,7 +89,7 @@ private RerankType findRerankType(final Map config) throws Illeg * Factory class for context fetchers. Constructs a list of context fetchers * specified in the pipeline config (and maybe the query context fetcher) */ - protected static class ContextFetcherFactory { + private static class ContextFetcherFactory { /** * Map rerank types to whether they should include the query context source fetcher diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java index 7dc577502..34fd42d86 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java @@ -18,10 +18,12 @@ import org.opensearch.search.SearchHit; import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; /** * Context Source Fetcher that gets context from the search results (documents) */ +@Log4j2 @AllArgsConstructor public class DocumentContextSourceFetcher implements ContextSourceFetcher { @@ -59,6 +61,14 @@ private String contextFromSearchHit(final SearchHit hit, final String field) { Object sourceValue = ObjectPath.eval(field, hit.getSourceAsMap()); return String.valueOf(sourceValue); } else { + log.warn( + String.format( + Locale.ROOT, + "Could not find field %s in document %s for reranking! Using the empty string instead.", + field, + hit.getId() + ) + ); return ""; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java index 5000dd756..b027e3f6f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java @@ -6,6 +6,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -56,14 +57,7 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo scoringContext.put(QUERY_TEXT_FIELD, (String) ctxMap.get(QUERY_TEXT_FIELD)); } else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD); - // Convert query to a map with io/xcontent shenanigans - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); - searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS); - builder.close(); - ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); - XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais); - Map map = parser.map(); + Map map = requestToMap(searchRequest); // Get the text at the path Object queryText = ObjectPath.eval(path, map); if (!(queryText instanceof String)) { @@ -87,4 +81,22 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo public String getName() { return NAME; } + + /** + * Convert a search request to a general map by streaming out as XContent and then back in, + * with the intention of representing the query as a user would see it + * @param request Search request to turn into xcontent + * @return Map representing the XContent-ified search request + * @throws IOException + */ + private static Map requestToMap(SearchRequest request) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); + request.source().toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.close(); + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais); + Map map = parser.map(); + return map; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index c52a5223d..43efb795d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -59,8 +59,9 @@ public void rerank(SearchResponse searchResponse, Map rerankingC // Assign new scores SearchHit[] hits = searchResponse.getHits().getHits(); if (hits.length != scores.size()) { - throw new Exception("scores and hits are not the same length"); + throw new RuntimeException("scores and hits are not the same length"); } + // NOTE: Assumes that the new scores came back in the same order for (int i = 0; i < hits.length; i++) { hits[i].score(scores.get(i)); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java index 4c0026a89..3909c7499 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -18,9 +18,34 @@ import lombok.AllArgsConstructor; import lombok.Getter; -import lombok.extern.log4j.Log4j2; -@Log4j2 +/** + * Holds ext data from the query for reranking processors. Since + * there can be multiple kinds of rerank processors with different + * contexts, all we can assume is that there's keys and objects. + * e.g. ext might look like + * { + * "query": {blah}, + * "ext": { + * "rerank": { + * "query_context": { + * "query_text": "some question to rerank about" + * } + * } + * } + * } + * or + * { + * "query": {blah}, + * "ext": { + * "rerank": { + * "query_context": { + * "query_path": "query.neural.embedding.query_text" + * } + * } + * } + * } + */ @AllArgsConstructor public class RerankSearchExtBuilder extends SearchExtBuilder { @@ -44,7 +69,10 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.field(PARAM_FIELD_NAME, this.params); + for (String key : this.params.keySet()) { + builder.field(key, this.params.get(key)); + } + return builder; } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index 7488e2607..fa15eda46 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -29,13 +29,13 @@ public class RerankProcessorFactoryTests extends OpenSearchTestCase { final String TAG = "default-tag"; final String DESC = "processor description"; - RerankProcessorFactory factory; + private RerankProcessorFactory factory; @Mock - MLCommonsClientAccessor clientAccessor; + private MLCommonsClientAccessor clientAccessor; @Mock - PipelineContext pipelineContext; + private PipelineContext pipelineContext; @Before public void setup() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java index 1fa6ac629..c5bdb77f4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java @@ -28,11 +28,12 @@ @Log4j2 public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT { - final static String PIPELINE_NAME = "rerank-mlos-pipeline"; - final static String INDEX_NAME = "rerank-test"; - final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy"; - final static String TEXT_REP_2 = "Fish like to eat plankton"; - final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}"; + private final static String PIPELINE_NAME = "rerank-mlos-pipeline"; + private final static String INDEX_NAME = "rerank-test"; + private final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy"; + private final static String TEXT_REP_2 = "Fish like to eat plankton"; + private final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}"; + private String modelId; @After @SneakyThrows @@ -42,13 +43,14 @@ public void tearDown() { * this happens in case we leave model from previous test case. We use new model for every test, and old model * can be undeployed and deleted to free resources after each test case execution. */ + deleteModel(modelId); deleteSearchPipeline(PIPELINE_NAME); - findDeployedModels().forEach(this::deleteModel); deleteIndex(INDEX_NAME); } - public void testCrossEncoderRerankProcessor() throws Exception { - String modelId = uploadTextSimilarityModel(); + @SneakyThrows + public void testCrossEncoderRerankProcessor() { + modelId = uploadTextSimilarityModel(); loadModel(modelId); createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/RerankMLOpenSearchPipelineConfiguration.json"); setupIndex(); @@ -59,7 +61,7 @@ private String uploadTextSimilarityModel() throws Exception { String requestBody = Files.readString( Path.of(classLoader.getResource("processor/UploadTextSimilarityModelRequestBody.json").toURI()) ); - return uploadModel(requestBody); + return registerModelGroupAndUploadModel(requestBody); } private void setupIndex() throws Exception { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index 53019d7ce..018677b60 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -46,28 +46,25 @@ import org.opensearch.search.pipeline.Processor.PipelineContext; import org.opensearch.test.OpenSearchTestCase; -import lombok.extern.log4j.Log4j2; - -@Log4j2 public class MLOpenSearchRerankProcessorTests extends OpenSearchTestCase { @Mock - SearchRequest request; + private SearchRequest request; - SearchResponse response; + private SearchResponse response; @Mock - MLCommonsClientAccessor mlCommonsClientAccessor; + private MLCommonsClientAccessor mlCommonsClientAccessor; @Mock - PipelineContext pipelineContext; + private PipelineContext pipelineContext; @Mock - PipelineProcessingContext ppctx; + private PipelineProcessingContext ppctx; - RerankProcessorFactory factory; + private RerankProcessorFactory factory; - MLOpenSearchRerankProcessor processor; + private MLOpenSearchRerankProcessor processor; @Before public void setup() { diff --git a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java index e1724b014..ea0af1eb5 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java @@ -7,16 +7,23 @@ import static org.mockito.Mockito.mock; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.junit.Before; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.ParseField; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.BytesStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchExtBuilder; import org.opensearch.test.OpenSearchTestCase; @@ -29,7 +36,20 @@ public class RerankSearchExtBuilderTests extends OpenSearchTestCase { @Before public void setup() { - params = Map.of("query_text", "question about the meaning of life, the universe, and everything"); + params = Map.of("query_context", Map.of("query_text", "question about the meaning of life, the universe, and everything")); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry( + SearchExtBuilder.class, + new ParseField(RerankSearchExtBuilder.PARAM_FIELD_NAME), + parser -> RerankSearchExtBuilder.parse(parser) + ) + ) + ); } public void testStreaming() throws IOException { @@ -43,19 +63,23 @@ public void testStreaming() throws IOException { assert (b1.equals(b2)); } - // public void testToXContent() throws IOException { - // RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(new HashMap<>(params)); - // XContentBuilder builder = XContentType.JSON.contentBuilder(); - // builder.startObject(); - // b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS); - // builder.endObject(); - // String extString = builder.toString(); - // log.info(extString); - // XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString); - // RerankSearchExtBuilder b2 = RerankSearchExtBuilder.parse(parser); - // assert (b2.getParams().equals(params)); - // assert (b1.equals(b2)); - // } + public void testToXContent() throws IOException { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(new HashMap<>(params)); + XContentBuilder builder = XContentType.JSON.contentBuilder(); + builder.startObject(); + b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + builder.endObject(); + String extString = builder.toString(); + log.info(extString); + XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString); + SearchExtBuilder b2 = parser.namedObject(SearchExtBuilder.class, RerankSearchExtBuilder.PARAM_FIELD_NAME, parser); + assert (b2 instanceof RerankSearchExtBuilder); + RerankSearchExtBuilder b3 = (RerankSearchExtBuilder) b2; + log.info(b1.getParams().toString()); + log.info(b3.getParams().toString()); + assert (b3.getParams().equals(params)); + assert (b1.equals(b3)); + } public void testPullFromListOfExtBuilders() { RerankSearchExtBuilder builder = new RerankSearchExtBuilder(params); diff --git a/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json index 82529202d..3c23f6f21 100644 --- a/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json +++ b/src/test/resources/processor/UploadTextSimilarityModelRequestBody.json @@ -4,7 +4,7 @@ "function_name": "TEXT_SIMILARITY", "description": "test model", "model_format": "TORCH_SCRIPT", - "model_group_id": "", + "model_group_id": "%s", "model_content_hash_value": "90e39a926101d1a4e542aade0794319404689b12acfd5d7e65c03d91c668b5cf", "model_config": { "model_type": "bert", From 2d04075d4a583f8837da45f665fff8372c9f75fd Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 10 Jan 2024 10:22:58 -0800 Subject: [PATCH 23/27] move contextSourceFetchers to their own subdirectory Signed-off-by: HenryL27 --- .../processor/factory/RerankProcessorFactory.java | 6 +++--- .../processor/rerank/MLOpenSearchRerankProcessor.java | 3 +++ .../neuralsearch/processor/rerank/RerankProcessor.java | 1 + .../processor/rerank/RescoringRerankProcessor.java | 1 + .../rerank/{ => context}/ContextSourceFetcher.java | 2 +- .../rerank/{ => context}/DocumentContextSourceFetcher.java | 2 +- .../rerank/{ => context}/QueryContextSourceFetcher.java | 2 +- .../processor/factory/RerankProcessorFactoryTests.java | 2 +- .../processor/rerank/MLOpenSearchRerankProcessorTests.java | 2 ++ 9 files changed, 14 insertions(+), 7 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/rerank/{ => context}/ContextSourceFetcher.java (94%) rename src/main/java/org/opensearch/neuralsearch/processor/rerank/{ => context}/DocumentContextSourceFetcher.java (98%) rename src/main/java/org/opensearch/neuralsearch/processor/rerank/{ => context}/QueryContextSourceFetcher.java (98%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 953b91766..e551909aa 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -13,11 +13,11 @@ import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.rerank.ContextSourceFetcher; -import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; -import org.opensearch.neuralsearch.processor.rerank.QueryContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.RerankType; +import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java index c94ddf6ca..db6f43f95 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -13,6 +13,9 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; +import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; /** * Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 6104990ab..8b75068f6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -12,6 +12,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.pipeline.SearchResponseProcessor; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index 43efb795d..41494c631 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -13,6 +13,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.core.action.ActionListener; +import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.profile.SearchProfileShardResults; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java similarity index 94% rename from src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java rename to src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java index d8a576937..6a98b6561 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor.rerank; +package org.opensearch.neuralsearch.processor.rerank.context; import java.util.Map; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java similarity index 98% rename from src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java rename to src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java index 34fd42d86..2ed5b0713 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor.rerank; +package org.opensearch.neuralsearch.processor.rerank.context; import java.util.ArrayList; import java.util.HashMap; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java similarity index 98% rename from src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java rename to src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java index b027e3f6f..18e5be20b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor.rerank; +package org.opensearch.neuralsearch.processor.rerank.context; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index fa15eda46..ec74f831a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -16,10 +16,10 @@ import org.mockito.Mock; import org.opensearch.OpenSearchParseException; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.rerank.DocumentContextSourceFetcher; import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankType; +import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; import org.opensearch.search.pipeline.Processor.PipelineContext; import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.test.OpenSearchTestCase; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index 018677b60..80297e5c7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -36,6 +36,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; +import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; +import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; import org.opensearch.search.SearchExtBuilder; From a39428b57828b8b0799bbc32963b9a8d54c94e49 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 11 Jan 2024 09:22:56 -0800 Subject: [PATCH 24/27] Apply suggestions from code review Co-authored-by: Martin Gaievski Signed-off-by: HenryL27 --- .../processor/factory/RerankProcessorFactory.java | 2 +- .../opensearch/neuralsearch/processor/rerank/RerankType.java | 5 ++++- .../processor/rerank/RescoringRerankProcessor.java | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index e551909aa..992e70eb0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -97,7 +97,7 @@ private static class ContextFetcherFactory { * @return does this RerankType depend on the QueryContextSourceFetcher? */ public static boolean shouldIncludeQueryContextFetcher(RerankType type) { - switch (type) { + return ML_OPENSEARCH == type; case ML_OPENSEARCH: return true; default: diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index c1d64c5ae..2a19aeae6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -43,7 +43,10 @@ public static RerankType from(String label) { RerankType ans = LABEL_MAP.get(label); if (ans == null) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Wrong rerank type name: %s", label)); - } else { + if (ans == null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Wrong rerank type name: %s", label)); + } + return ans; return ans; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index 41494c631..d6fdd3878 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -24,7 +24,7 @@ public abstract class RescoringRerankProcessor extends RerankProcessor { /** - * Constructor. pass through to RerankProcessor ctor. + * Constructor. pass through to RerankProcessor constructor. * @param type * @param description * @param tag @@ -60,7 +60,7 @@ public void rerank(SearchResponse searchResponse, Map rerankingC // Assign new scores SearchHit[] hits = searchResponse.getHits().getHits(); if (hits.length != scores.size()) { - throw new RuntimeException("scores and hits are not the same length"); + throw new IllegalStateException("scores and hits are not the same length"); } // NOTE: Assumes that the new scores came back in the same order for (int i = 0; i < hits.length; i++) { From f46296546ac0299c9946274f68b63d778faae9e0 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 11 Jan 2024 09:24:09 -0800 Subject: [PATCH 25/27] CR changes Signed-off-by: HenryL27 --- CHANGELOG.md | 2 +- .../factory/RerankProcessorFactory.java | 7 +----- .../rerank/MLOpenSearchRerankProcessor.java | 14 ++++++------ .../processor/rerank/RerankProcessor.java | 22 +++++++++---------- .../processor/rerank/RerankType.java | 6 +---- .../rerank/RescoringRerankProcessor.java | 18 +++++++-------- .../rerank/context/ContextSourceFetcher.java | 2 +- .../context/DocumentContextSourceFetcher.java | 4 ++-- .../context/QueryContextSourceFetcher.java | 16 +++++++++----- 9 files changed, 43 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 284cf64da..b3c69c9fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features -- Add rerank processor interface and ml-commons reranker ([#494](https://github.com/opensearch-project/neural-search/pull/494)) ### Enhancements ### Bug Fixes - Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438)) @@ -19,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.11...2.x) ### Features +- Add rerank processor interface and ml-commons reranker ([#494](https://github.com/opensearch-project/neural-search/pull/494)) ### Enhancements ### Bug Fixes - Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524)) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 992e70eb0..b02666855 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -97,12 +97,7 @@ private static class ContextFetcherFactory { * @return does this RerankType depend on the QueryContextSourceFetcher? */ public static boolean shouldIncludeQueryContextFetcher(RerankType type) { - return ML_OPENSEARCH == type; - case ML_OPENSEARCH: - return true; - default: - return false; - } + return type == RerankType.ML_OPENSEARCH; } /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java index db6f43f95..4749a3bb5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -38,12 +38,12 @@ public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor { * @param mlCommonsClientAccessor */ public MLOpenSearchRerankProcessor( - String description, - String tag, - boolean ignoreFailure, - String modelId, - List contextSourceFetchers, - MLCommonsClientAccessor mlCommonsClientAccessor + final String description, + final String tag, + final boolean ignoreFailure, + final String modelId, + final List contextSourceFetchers, + final MLCommonsClientAccessor mlCommonsClientAccessor ) { super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers); this.modelId = modelId; @@ -51,7 +51,7 @@ public MLOpenSearchRerankProcessor( } @Override - public void rescoreSearchResponse(SearchResponse response, Map rerankingContext, ActionListener> listener) { + public void rescoreSearchResponse(final SearchResponse response, final Map rerankingContext, final ActionListener> listener) { Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD); if (!(ctxObj instanceof List)) { listener.onFailure( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 8b75068f6..93a2c8416 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -44,9 +44,9 @@ public abstract class RerankProcessor implements SearchResponseProcessor { * @param listener be async */ public void generateRerankingContext( - SearchRequest searchRequest, - SearchResponse searchResponse, - ActionListener> listener + final SearchRequest searchRequest, + final SearchResponse searchResponse, + final ActionListener> listener ) { Map overallContext = new ConcurrentHashMap<>(); AtomicInteger successfulContexts = new AtomicInteger(contextSourceFetchers.size()); @@ -73,22 +73,22 @@ public String getType() { * @param listener be async */ public abstract void rerank( - SearchResponse searchResponse, - Map rerankingContext, - ActionListener listener + final SearchResponse searchResponse, + final Map rerankingContext, + final ActionListener listener ); @Override - public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + public SearchResponse processResponse(final SearchRequest request, final SearchResponse response) throws Exception { throw new UnsupportedOperationException("Use asyncProcessResponse unless you can guarantee to not deadlock yourself"); } @Override public void processResponseAsync( - SearchRequest request, - SearchResponse response, - PipelineProcessingContext ctx, - ActionListener responseListener + final SearchRequest request, + final SearchResponse response, + final PipelineProcessingContext ctx, + final ActionListener responseListener ) { try { generateRerankingContext( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index 2a19aeae6..2063242dd 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -39,16 +39,12 @@ private RerankType(String label) { * @param label label of a RerankType * @return RerankType represented by the label */ - public static RerankType from(String label) { + public static RerankType from(final String label) { RerankType ans = LABEL_MAP.get(label); if (ans == null) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Wrong rerank type name: %s", label)); - if (ans == null) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "Wrong rerank type name: %s", label)); } return ans; - return ans; - } } public static Map labelMap() { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index d6fdd3878..c4f69ac89 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -32,11 +32,11 @@ public abstract class RescoringRerankProcessor extends RerankProcessor { * @param contextSourceFetchers */ public RescoringRerankProcessor( - RerankType type, - String description, - String tag, - boolean ignoreFailure, - List contextSourceFetchers + final RerankType type, + final String description, + final String tag, + final boolean ignoreFailure, + final List contextSourceFetchers ) { super(type, description, tag, ignoreFailure, contextSourceFetchers); } @@ -48,13 +48,13 @@ public RescoringRerankProcessor( * @param listener be async. recieves the list of new scores */ public abstract void rescoreSearchResponse( - SearchResponse response, - Map rerankingContext, - ActionListener> listener + final SearchResponse response, + final Map rerankingContext, + final ActionListener> listener ); @Override - public void rerank(SearchResponse searchResponse, Map rerankingContext, ActionListener listener) { + public void rerank(final SearchResponse searchResponse, final Map rerankingContext, final ActionListener listener) { try { rescoreSearchResponse(searchResponse, rerankingContext, ActionListener.wrap(scores -> { // Assign new scores diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java index 6a98b6561..dd297aee6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java @@ -24,7 +24,7 @@ public interface ContextSourceFetcher { * @param searchResponse the search results, in case they're relevant * @param listener be async */ - public void fetchContext(SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener); + public void fetchContext(final SearchRequest searchRequest, final SearchResponse searchResponse, final ActionListener> listener); /** * Get the name of the contextSourceFetcher. This will be used as the field diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java index 2ed5b0713..2b1c09545 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java @@ -30,7 +30,7 @@ public class DocumentContextSourceFetcher implements ContextSourceFetcher { public static final String NAME = "document_fields"; public static final String DOCUMENT_CONTEXT_LIST_FIELD = "document_context_list"; - List contextFields; + private final List contextFields; /** * Fetch the information needed in order to rerank. @@ -41,7 +41,7 @@ public class DocumentContextSourceFetcher implements ContextSourceFetcher { * @param listener be async */ @Override - public void fetchContext(SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener) { + public void fetchContext(final SearchRequest searchRequest, final SearchResponse searchResponse, final ActionListener> listener) { List contexts = new ArrayList<>(); for (SearchHit hit : searchResponse.getHits()) { StringBuilder ctx = new StringBuilder(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java index 18e5be20b..1702a9914 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java @@ -34,11 +34,12 @@ public class QueryContextSourceFetcher implements ContextSourceFetcher { public static final String QUERY_TEXT_PATH_FIELD = "query_text_path"; @Override - public void fetchContext(SearchRequest searchRequest, SearchResponse searchResponse, ActionListener> listener) { + public void fetchContext(final SearchRequest searchRequest, final SearchResponse searchResponse, final ActionListener> listener) { try { + // Get RerankSearchExt query-specific context map List exts = searchRequest.source().ext(); Map params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams(); - Map scoringContext = new HashMap<>(); + Map rerankContext = new HashMap<>(); if (!params.containsKey(NAME)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "must specify %s", NAME)); } @@ -46,16 +47,19 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo if (!(ctxObj instanceof Map)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a map", NAME)); } + // Put query context into reranking context @SuppressWarnings("unchecked") Map ctxMap = (Map) ctxObj; if (ctxMap.containsKey(QUERY_TEXT_FIELD)) { + // Case "query_text": "" if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) ); } - scoringContext.put(QUERY_TEXT_FIELD, (String) ctxMap.get(QUERY_TEXT_FIELD)); + rerankContext.put(QUERY_TEXT_FIELD, (String) ctxMap.get(QUERY_TEXT_FIELD)); } else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) { + // Case "query_text_path": ser/de the query into a map and then find the text at the path specified String path = (String) ctxMap.get(QUERY_TEXT_PATH_FIELD); Map map = requestToMap(searchRequest); // Get the text at the path @@ -65,13 +69,13 @@ public void fetchContext(SearchRequest searchRequest, SearchResponse searchRespo String.format(Locale.ROOT, "%s must point to a string field", QUERY_TEXT_PATH_FIELD) ); } - scoringContext.put(QUERY_TEXT_FIELD, (String) queryText); + rerankContext.put(QUERY_TEXT_FIELD, (String) queryText); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Must specify either \"%s\" or \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD) ); } - listener.onResponse(scoringContext); + listener.onResponse(rerankContext); } catch (Exception e) { listener.onFailure(e); } @@ -89,7 +93,7 @@ public String getName() { * @return Map representing the XContent-ified search request * @throws IOException */ - private static Map requestToMap(SearchRequest request) throws IOException { + private static Map requestToMap(final SearchRequest request) throws IOException { ByteArrayOutputStream baos = new ByteArrayOutputStream(); XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); request.source().toXContent(builder, ToXContent.EMPTY_PARAMS); From db8bec1babfcaec9ecf6e8247ab9bd1a2ff0c5d3 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 11 Jan 2024 10:11:31 -0800 Subject: [PATCH 26/27] finish CR comments and fix broken unittest Signed-off-by: HenryL27 --- .../rerank/MLOpenSearchRerankProcessor.java | 6 ++++- .../rerank/RescoringRerankProcessor.java | 24 +++++++++++-------- .../rerank/context/ContextSourceFetcher.java | 8 +++++-- .../context/DocumentContextSourceFetcher.java | 6 ++++- .../context/QueryContextSourceFetcher.java | 6 ++++- .../factory/RerankProcessorFactoryTests.java | 20 ++++++++-------- .../rerank/MLOpenSearchRerankProcessorIT.java | 8 ++++++- .../MLOpenSearchRerankProcessorTests.java | 4 +++- 8 files changed, 55 insertions(+), 27 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java index 4749a3bb5..d8d9e8ec3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java @@ -51,7 +51,11 @@ public MLOpenSearchRerankProcessor( } @Override - public void rescoreSearchResponse(final SearchResponse response, final Map rerankingContext, final ActionListener> listener) { + public void rescoreSearchResponse( + final SearchResponse response, + final Map rerankingContext, + final ActionListener> listener + ) { Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD); if (!(ctxObj instanceof List)) { listener.onFailure( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index c4f69ac89..27ccf51e6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -6,7 +6,6 @@ import java.util.Arrays; import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Map; @@ -54,11 +53,22 @@ public abstract void rescoreSearchResponse( ); @Override - public void rerank(final SearchResponse searchResponse, final Map rerankingContext, final ActionListener listener) { + public void rerank( + final SearchResponse searchResponse, + final Map rerankingContext, + final ActionListener listener + ) { try { + if (searchResponse.getHits().getTotalHits().value == 0) { + listener.onResponse(searchResponse); + return; + } rescoreSearchResponse(searchResponse, rerankingContext, ActionListener.wrap(scores -> { // Assign new scores SearchHit[] hits = searchResponse.getHits().getHits(); + if (scores == null) { + throw new IllegalStateException("scores cannot be null"); + } if (hits.length != scores.size()) { throw new IllegalStateException("scores and hits are not the same length"); } @@ -66,14 +76,8 @@ public void rerank(final SearchResponse searchResponse, final Map() { - @Override - public int compare(SearchHit hit1, SearchHit hit2) { - // backwards to sort DESC - return Float.compare(hit2.getScore(), hit1.getScore()); - } - }); + // Re-sort by the new scores. Backwards comparison for desc ordering + Collections.sort(Arrays.asList(hits), (hit1, hit2) -> Float.compare(hit2.getScore(), hit1.getScore())); // Reconstruct the search response, replacing the max score SearchHits newHits = new SearchHits( hits, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java index dd297aee6..ddb4a08fb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/ContextSourceFetcher.java @@ -24,13 +24,17 @@ public interface ContextSourceFetcher { * @param searchResponse the search results, in case they're relevant * @param listener be async */ - public void fetchContext(final SearchRequest searchRequest, final SearchResponse searchResponse, final ActionListener> listener); + void fetchContext( + final SearchRequest searchRequest, + final SearchResponse searchResponse, + final ActionListener> listener + ); /** * Get the name of the contextSourceFetcher. This will be used as the field * name in the context config for the pipeline * @return Name of the fetcher */ - public String getName(); + String getName(); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java index 2b1c09545..857c1dd46 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/DocumentContextSourceFetcher.java @@ -41,7 +41,11 @@ public class DocumentContextSourceFetcher implements ContextSourceFetcher { * @param listener be async */ @Override - public void fetchContext(final SearchRequest searchRequest, final SearchResponse searchResponse, final ActionListener> listener) { + public void fetchContext( + final SearchRequest searchRequest, + final SearchResponse searchResponse, + final ActionListener> listener + ) { List contexts = new ArrayList<>(); for (SearchHit hit : searchResponse.getHits()) { StringBuilder ctx = new StringBuilder(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java index 1702a9914..d7463bcd1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/context/QueryContextSourceFetcher.java @@ -34,7 +34,11 @@ public class QueryContextSourceFetcher implements ContextSourceFetcher { public static final String QUERY_TEXT_PATH_FIELD = "query_text_path"; @Override - public void fetchContext(final SearchRequest searchRequest, final SearchResponse searchResponse, final ActionListener> listener) { + public void fetchContext( + final SearchRequest searchRequest, + final SearchResponse searchResponse, + final ActionListener> listener + ) { try { // Get RerankSearchExt query-specific context map List exts = searchRequest.source().ext(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index ec74f831a..ea37b2afb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -44,7 +44,7 @@ public void setup() { factory = new RerankProcessorFactory(clientAccessor); } - public void testRerankProcessorFactory_EmptyConfig_ThenFail() { + public void testRerankProcessorFactory_whenEmptyConfig_thenFail() { Map config = new HashMap<>(Map.of()); assertThrows( "no rerank type found", @@ -53,7 +53,7 @@ public void testRerankProcessorFactory_EmptyConfig_ThenFail() { ); } - public void testRerankProcessorFactory_NonExistentType_ThenFail() { + public void testRerankProcessorFactory_whenNonExistentType_thenFail() { Map config = new HashMap<>( Map.of("jpeo rvgh we iorgn", Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id")) ); @@ -64,7 +64,7 @@ public void testRerankProcessorFactory_NonExistentType_ThenFail() { ); } - public void testRerankProcessorFactory_CrossEncoder_HappyPath() { + public void testCrossEncoder_whenCorrectParams_thenSuccessful() { Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), @@ -79,7 +79,7 @@ public void testRerankProcessorFactory_CrossEncoder_HappyPath() { assert (processor.getType().equals(RerankProcessor.TYPE)); } - public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { + public void testCrossEncoder_whenMessyConfig_thenSuccessful() { Map config = new HashMap<>( Map.of( "poafn aorr;anv", @@ -96,7 +96,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { assert (processor.getType().equals(RerankProcessor.TYPE)); } - public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() { + public void testCrossEncoder_whenMessyContext_thenFail() { Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), @@ -119,7 +119,7 @@ public void testRerankProcessorFactory_CrossEncoder_MessyContext_ThenFail() { ); } - public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { + public void testCrossEncoder_whenEmptySubConfig_thenFail() { Map config = new HashMap<>(Map.of(RerankType.ML_OPENSEARCH.getLabel(), Map.of())); assertThrows( String.format(Locale.ROOT, "[%s] required property is missing", RerankProcessorFactory.CONTEXT_CONFIG_FIELD), @@ -128,7 +128,7 @@ public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { ); } - public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { + public void testCrossEncoder_whenNoContextField_thenFail() { Map config = new HashMap<>( Map.of(RerankType.ML_OPENSEARCH.getLabel(), new HashMap<>(Map.of(MLOpenSearchRerankProcessor.MODEL_ID_FIELD, "model-id"))) ); @@ -139,7 +139,7 @@ public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { ); } - public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { + public void testCrossEncoder_whenNoModelId_thenFail() { Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), @@ -155,7 +155,7 @@ public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { ); } - public void testRerankProcessorFactory_CrossEncoder_BadContextDocField_ThenFail() { + public void testCrossEncoder_whenBadContextDocField_thenFail() { Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), @@ -171,7 +171,7 @@ public void testRerankProcessorFactory_CrossEncoder_BadContextDocField_ThenFail( ); } - public void testRerankProcessorFactory_CrossEncoder_EmptyContextDocField_ThenFail() { + public void testCrossEncoder_whenEmptyContextDocField_thenFail() { Map config = new HashMap<>( Map.of( RerankType.ML_OPENSEARCH.getLabel(), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java index c5bdb77f4..7bde28f7b 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java @@ -13,6 +13,7 @@ import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; import org.junit.After; +import org.junit.Before; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; @@ -48,10 +49,15 @@ public void tearDown() { deleteIndex(INDEX_NAME); } + @Before @SneakyThrows - public void testCrossEncoderRerankProcessor() { + public void setup() { modelId = uploadTextSimilarityModel(); loadModel(modelId); + } + + @SneakyThrows + public void testCrossEncoderRerankProcessor() { createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/RerankMLOpenSearchPipelineConfiguration.json"); setupIndex(); runQueries(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index 80297e5c7..ca31839d2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -20,6 +20,7 @@ import java.util.Locale; import java.util.Map; +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -128,8 +129,9 @@ private void setupSearchResults() throws IOException { nullHit.score(0f); SearchHit[] hitArray = new SearchHit[] { fieldHit, sourceHit, nullHit }; + TotalHits totalHits = new TotalHits(3, TotalHits.Relation.EQUAL_TO); - SearchHits searchHits = new SearchHits(hitArray, null, 1.0f); + SearchHits searchHits = new SearchHits(hitArray, totalHits, 1.0f); SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new Clusters(1, 1, 0), null); } From 7962ffab42d13f7ad39b717d6249f3b58b29d8e3 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 11 Jan 2024 10:24:29 -0800 Subject: [PATCH 27/27] fix unittest names Signed-off-by: HenryL27 --- .../rerank/MLOpenSearchRerankProcessorTests.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java index ca31839d2..50d0cf2bc 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorTests.java @@ -136,7 +136,7 @@ private void setupSearchResults() throws IOException { response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new Clusters(1, 1, 0), null); } - public void testScoringContext_QueryText_ThenSucceed() throws IOException { + public void testRerankContext_whenQueryText_thenSucceed() throws IOException { setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_FIELD, "query text")); setupSearchResults(); @SuppressWarnings("unchecked") @@ -149,7 +149,7 @@ public void testScoringContext_QueryText_ThenSucceed() throws IOException { assert (argCaptor.getValue().get(QueryContextSourceFetcher.QUERY_TEXT_FIELD).equals("query text")); } - public void testScoringContext_QueryTextPath_ThenSucceed() throws IOException { + public void testRerankContext_whenQueryTextPath_thenSucceed() throws IOException { setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "query.neural.embedding.query_text")); setupSearchResults(); @SuppressWarnings("unchecked") @@ -162,7 +162,7 @@ public void testScoringContext_QueryTextPath_ThenSucceed() throws IOException { assert (argCaptor.getValue().get(QueryContextSourceFetcher.QUERY_TEXT_FIELD).equals("Question about dolphins")); } - public void testScoringContext_QueryTextAndPath_ThenFail() throws IOException { + public void testRerankContext_whenQueryTextAndPath_thenFail() throws IOException { setupParams( Map.of( QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, @@ -189,7 +189,7 @@ public void testScoringContext_QueryTextAndPath_ThenFail() throws IOException { )); } - public void testScoringContext_NoQueryInfo_ThenFail() throws IOException { + public void testRerankContext_whenNoQueryInfo_thenFail() throws IOException { setupParams(Map.of()); setupSearchResults(); @SuppressWarnings("unchecked") @@ -209,7 +209,7 @@ public void testScoringContext_NoQueryInfo_ThenFail() throws IOException { )); } - public void testScoringContext_QueryTextPath_BadPointer_ThenFail() throws IOException { + public void testRerankContext_whenQueryTextPathIsBadPointer_thenFail() throws IOException { setupParams(Map.of(QueryContextSourceFetcher.QUERY_TEXT_PATH_FIELD, "query.neural.embedding")); setupSearchResults(); @SuppressWarnings("unchecked") @@ -244,7 +244,7 @@ public void testRescoreSearchResponse_HappyPath() throws IOException { assert (argCaptor.getValue().get(2) == 3f); } - public void testRescoreSearchResponse_NoContextList_ThenFail() throws IOException { + public void testRescoreSearchResponse_whenNoContextList_thenFail() throws IOException { setupSimilarityRescoring(); setupSearchResults(); @SuppressWarnings("unchecked") @@ -289,7 +289,7 @@ public void testRerank_HappyPath() throws IOException { assert (rsp.getHits().getAt(2).getScore() == 1f); } - public void testRerank_ScoresAndHitsHaveDiffLengths() throws IOException { + public void testRerank_whenScoresAndHitsHaveDiffLengths_thenFail() throws IOException { doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); List scores = List.of(1f, 2f);