Skip to content

Commit

Permalink
Adjust the format of final message
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 1, 2024
1 parent 7cba653 commit 9340557
Show file tree
Hide file tree
Showing 21 changed files with 272 additions and 155 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ProcessorExplainPublisher;
import org.opensearch.neuralsearch.processor.ExplainResponseProcessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
Expand Down Expand Up @@ -82,7 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();
public static final String PROCESSOR_EXPLAIN = "processor_explain";
public static final String EXPLAIN_RESPONSE_KEY = "explain_response";

@Override
public Collection<Object> createComponents(
Expand Down Expand Up @@ -185,7 +185,7 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRespon
return Map.of(
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService()),
ProcessorExplainPublisher.TYPE,
ExplainResponseProcessor.TYPE,
new ProcessorExplainPublisherFactory()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import java.util.Map;
import java.util.Objects;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.PROCESSOR_EXPLAIN;
import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR;

@Getter
@AllArgsConstructor
public class ProcessorExplainPublisher implements SearchResponseProcessor {
public class ExplainResponseProcessor implements SearchResponseProcessor {

public static final String TYPE = "processor_explain_publisher";
public static final String TYPE = "explain_response_processor";

private final String description;
private final String tag;
Expand All @@ -43,10 +43,10 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) {
if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(PROCESSOR_EXPLAIN)))) {
if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) {
return response;
}
ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(PROCESSOR_EXPLAIN);
ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ProcessorExplainDto.ExplanationType, Object> explainPayload = processorExplainDto.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
Explanation processorExplanation = processorExplainDto.getExplanation();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.PROCESSOR_EXPLAIN;
import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.topLevelExpalantionForCombinedScore;
import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria;
Expand Down Expand Up @@ -152,7 +152,7 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
.build();
// store explain object to pipeline context
PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
pipelineProcessingContext.setAttribute(PROCESSOR_EXPLAIN, processorExplainDto);
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, processorExplainDto);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import lombok.ToString;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique;

/**
Expand All @@ -20,7 +21,6 @@
public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import lombok.ToString;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique;

/**
Expand All @@ -20,7 +21,6 @@
public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "geometric_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import lombok.ToString;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.describeCombinationTechnique;

/**
Expand All @@ -20,7 +21,6 @@
public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "harmonic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
* Collection of utility methods for score combination technique classes
*/
@Log4j2
class ScoreCombinationUtil {
private static final String PARAM_NAME_WEIGHTS = "weights";
public class ScoreCombinationUtil {
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
package org.opensearch.neuralsearch.processor.combination;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.Objects;
Expand All @@ -31,6 +29,8 @@
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;

import static org.opensearch.neuralsearch.processor.explain.ExplainUtils.getScoreCombinationExplainDetailsForDocument;

/**
* Abstracts combination of scores in query search results.
*/
Expand Down Expand Up @@ -64,7 +64,7 @@ public class ScoreCombiner {
* Different score combination techniques are different in step 2, where we create map of "doc id" - "combined score",
* other steps are same for all techniques.
*
* @param combineScoresDTO contains details of query top docs, score combination technique and sort is enabled or disabled.
* @param combineScoresDTO contains details of query top docs, score combination technique and sort is enabled or disabled.
*/
public void combineScores(final CombineScoresDto combineScoresDTO) {
// iterate over results from each shard. Every CompoundTopDocs object has results from
Expand Down Expand Up @@ -106,7 +106,6 @@ private void combineShardScores(
updateQueryTopDocsWithCombinedScores(
compoundQueryTopDocs,
topDocsPerSubQuery,
normalizedScoresPerDoc,
combinedNormalizedScoresByDocId,
sortedDocsIds,
getDocIdSortFieldsMap(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sort),
Expand All @@ -129,7 +128,7 @@ private boolean isSortOrderByScore(Sort sort) {
}

/**
* @param sort sort criteria
* @param sort sort criteria
* @param topDocsPerSubQuery top docs per subquery
* @return list of top field docs which is deduced by typcasting top docs to top field docs.
*/
Expand All @@ -149,9 +148,9 @@ private List<TopFieldDocs> getTopFieldDocs(final Sort sort, final List<TopDocs>
}

/**
* @param compoundTopDocs top docs that represent on shard
* @param compoundTopDocs top docs that represent on shard
* @param combinedNormalizedScoresByDocId docId to normalized scores map
* @param sort sort criteria
* @param sort sort criteria
* @return map of docId and sort fields if sorting is enabled.
*/
private Map<Integer, Object[]> getDocIdSortFieldsMap(
Expand Down Expand Up @@ -290,7 +289,6 @@ private Map<Integer, Float> combineScoresAndGetCombinedNormalizedScoresPerDocume
private void updateQueryTopDocsWithCombinedScores(
final CompoundTopDocs compoundQueryTopDocs,
final List<TopDocs> topDocsPerSubQuery,
Map<Integer, float[]> normalizedScoresPerDoc,
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final Collection<Integer> sortedScores,
Map<Integer, Object[]> docIdSortFieldMap,
Expand Down Expand Up @@ -322,21 +320,21 @@ private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, final lon

public Map<SearchShard, List<ExplainDetails>> explain(
final List<CompoundTopDocs> queryTopDocs,
ScoreCombinationTechnique combinationTechnique,
Sort sort
final ScoreCombinationTechnique combinationTechnique,
final Sort sort
) {
// In case of duplicate keys, keep the first value
HashMap<SearchShard, List<ExplainDetails>> map = new HashMap<>();
HashMap<SearchShard, List<ExplainDetails>> explanations = new HashMap<>();
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
for (Map.Entry<SearchShard, List<ExplainDetails>> docIdAtSearchShardExplainDetailsEntry : explainByShard(
combinationTechnique,
compoundQueryTopDocs,
sort
).entrySet()) {
map.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue());
explanations.putIfAbsent(docIdAtSearchShardExplainDetailsEntry.getKey(), docIdAtSearchShardExplainDetailsEntry.getValue());
}
}
return map;
return explanations;
}

private Map<SearchShard, List<ExplainDetails>> explainByShard(
Expand All @@ -349,31 +347,20 @@ private Map<SearchShard, List<ExplainDetails>> explainByShard(
}
// - create map of normalized scores results returned from the single shard
Map<Integer, float[]> normalizedScoresPerDoc = getNormalizedScoresPerDocument(compoundQueryTopDocs.getTopDocs());
Map<SearchShard, List<ExplainDetails>> explainsForShard = new HashMap<>();
Map<Integer, Float> combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue())));
Collection<Integer> sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId);
SearchShard searchShard = compoundQueryTopDocs.getSearchShard();
explainsForShard.put(searchShard, new ArrayList<>());
for (Integer docId : sortedDocsIds) {
float combinedScore = combinedNormalizedScoresByDocId.get(docId);
explainsForShard.get(searchShard)
.add(
new ExplainDetails(
combinedScore,
String.format(
Locale.ROOT,
"source scores: %s, combined score %s",
Arrays.toString(normalizedScoresPerDoc.get(docId)),
combinedScore
),
docId
)
);
}

return explainsForShard;
List<ExplainDetails> listOfExplainsForShard = sortedDocsIds.stream()
.map(
docId -> getScoreCombinationExplainDetailsForDocument(
docId,
combinedNormalizedScoresByDocId,
normalizedScoresPerDoc.get(docId)
)
)
.toList();
return Map.of(compoundQueryTopDocs.getSearchShard(), listOfExplainsForShard);
}

private Collection<Integer> getSortedDocsIds(
Expand Down
Loading

0 comments on commit 9340557

Please sign in to comment.