Skip to content

Commit

Permalink
Updating error handling for compound retrievers (elastic#115277)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmpailis authored and jfreden committed Nov 4, 2024
1 parent 9282bd1 commit 71c24ba
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.search.MultiSearchRequest;
Expand All @@ -20,6 +22,7 @@
import org.elasticsearch.action.search.TransportMultiSearchAction;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
Expand Down Expand Up @@ -121,10 +124,17 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
public void onResponse(MultiSearchResponse items) {
List<ScoreDoc[]> topDocs = new ArrayList<>();
List<Exception> failures = new ArrayList<>();
// capture the max status code returned by any of the responses
int statusCode = RestStatus.OK.getStatus();
List<String> retrieversWithFailures = new ArrayList<>();
for (int i = 0; i < items.getResponses().length; i++) {
var item = items.getResponses()[i];
if (item.isFailure()) {
failures.add(item.getFailure());
retrieversWithFailures.add(innerRetrievers.get(i).retriever().getName());
if (ExceptionsHelper.status(item.getFailure()).getStatus() > statusCode) {
statusCode = ExceptionsHelper.status(item.getFailure()).getStatus();
}
} else {
assert item.getResponse() != null;
var rankDocs = getRankDocs(item.getResponse());
Expand All @@ -133,7 +143,14 @@ public void onResponse(MultiSearchResponse items) {
}
}
if (false == failures.isEmpty()) {
IllegalStateException ex = new IllegalStateException("Search failed - some nested retrievers returned errors.");
assert statusCode != RestStatus.OK.getStatus();
final String errMessage = "["
+ getName()
+ "] search failed - retrievers '"
+ retrieversWithFailures
+ "' returned errors. "
+ "All failures are attached as suppressed exceptions.";
Exception ex = new ElasticsearchStatusException(errMessage, RestStatus.fromCode(statusCode));
failures.forEach(ex::addSuppressed);
listener.onFailure(ex);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
package org.elasticsearch.xpack.rank.rrf;

import org.apache.lucene.search.TotalHits;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequestBuilder;
Expand All @@ -18,6 +20,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -47,7 +50,6 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

Expand Down Expand Up @@ -589,11 +591,11 @@ public void testRRFExplainWithAnotherNestedRRF() {
});
}

public void testRRFInnerRetrieverSearchError() {
public void testRRFInnerRetrieverAll4xxSearchErrors() {
final int rankWindowSize = 100;
final int rankConstant = 10;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will throw an error during evaluation
// this will throw a 4xx error during evaluation
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
);
Expand All @@ -615,10 +617,57 @@ public void testRRFInnerRetrieverSearchError() {
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
Exception ex = expectThrows(IllegalStateException.class, req::get);
assertThat(ex, instanceOf(IllegalStateException.class));
assertThat(ex.getMessage(), containsString("Search failed - some nested retrievers returned errors"));
assertThat(ex.getSuppressed().length, greaterThan(0));
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
assertThat(
ex.getMessage(),
containsString(
"[rrf] search failed - retrievers '[standard]' returned errors. All failures are attached as suppressed exceptions."
)
);
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST));
assertThat(ex.getSuppressed().length, equalTo(1));
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
}

public void testRRFInnerRetrieverMultipleErrorsOne5xx() {
final int rankWindowSize = 100;
final int rankConstant = 10;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will throw a 4xx error during evaluation
StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(
QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10))
);
// this will throw a 5xx error
TestRetrieverBuilder testRetrieverBuilder = new TestRetrieverBuilder("val") {
@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
searchSourceBuilder.aggregation(AggregationBuilders.avg("some_invalid_param"));
}
};
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standard0, null),
new CompoundRetrieverBuilder.RetrieverSource(testRetrieverBuilder, null)
),
rankWindowSize,
rankConstant
)
);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
Exception ex = expectThrows(ElasticsearchStatusException.class, req::get);
assertThat(ex, instanceOf(ElasticsearchStatusException.class));
assertThat(
ex.getMessage(),
containsString(
"[rrf] search failed - retrievers '[standard, test]' returned errors. All failures are attached as suppressed exceptions."
)
);
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(ex.getSuppressed().length, equalTo(2));
assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class));
assertThat(ex.getSuppressed()[1].getCause().getCause(), instanceOf(IllegalStateException.class));
}

public void testRRFInnerRetrieverErrorWhenExtractingToSource() {
Expand Down

0 comments on commit 71c24ba

Please sign in to comment.