Skip to content

Commit

Permalink
Forbid expensive query parts in ranking evaluation (#30151)
Browse files Browse the repository at this point in the history
Currently the ranking evaluation API accepts the full query syntax for
the queries specified in the evaluation set and executes them via multi
search. This potentially runs costly aggregations and suggestions too.
This change adds checks that forbid using aggregations, suggesters, 
highlighters and the explain and profile options in the queries that are 
run as part of the ranking evaluation since they are irrelevent in the 
context of this API.
  • Loading branch information
Christoph Büscher authored May 14, 2018
1 parent 41148e4 commit cc93131
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public class RankEvalSpec implements Writeable, ToXContentObject {
/** Default max number of requests. */
private static final int MAX_CONCURRENT_SEARCHES = 10;
/** optional: Templates to base test requests on */
private Map<String, Script> templates = new HashMap<>();
private final Map<String, Script> templates = new HashMap<>();

public RankEvalSpec(List<RatedRequest> ratedRequests, EvaluationMetric metric, Collection<ScriptWithId> templates) {
this.metric = Objects.requireNonNull(metric, "Cannot evaluate ranking if no evaluation metric is provided.");
Expand All @@ -68,8 +68,8 @@ public RankEvalSpec(List<RatedRequest> ratedRequests, EvaluationMetric metric, C
this.ratedRequests = ratedRequests;
if (templates == null || templates.isEmpty()) {
for (RatedRequest request : ratedRequests) {
if (request.getTestRequest() == null) {
throw new IllegalStateException("Cannot evaluate ranking if neither template nor test request is "
if (request.getEvaluationRequest() == null) {
throw new IllegalStateException("Cannot evaluate ranking if neither template nor evaluation request is "
+ "provided. Seen for request id: " + request.getId());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ public class RatedRequest implements Writeable, ToXContentObject {
private final String id;
private final List<String> summaryFields;
private final List<RatedDocument> ratedDocs;
// Search request to execute for this rated request. This can be null if template and corresponding parameters are supplied.
/**
* Search request to execute for this rated request. This can be null in
* case the query is supplied as a template with corresponding parameters
*/
@Nullable
private SearchSourceBuilder testRequest;
private final SearchSourceBuilder evaluationRequest;
/**
* Map of parameters to use for filling a query template, can be used
* instead of providing testRequest.
Expand All @@ -86,27 +89,49 @@ public class RatedRequest implements Writeable, ToXContentObject {
@Nullable
private String templateId;

private RatedRequest(String id, List<RatedDocument> ratedDocs, SearchSourceBuilder testRequest,
/**
* Create a rated request with template ids and parameters.
*
* @param id a unique name for this rated request
* @param ratedDocs a list of document ratings
* @param params template parameters
* @param templateId a templare id
*/
public RatedRequest(String id, List<RatedDocument> ratedDocs, Map<String, Object> params,
String templateId) {
this(id, ratedDocs, null, params, templateId);
}

/**
* Create a rated request using a {@link SearchSourceBuilder} to define the
* evaluated query.
*
* @param id a unique name for this rated request
* @param ratedDocs a list of document ratings
* @param evaluatedQuery the query that is evaluated
*/
public RatedRequest(String id, List<RatedDocument> ratedDocs, SearchSourceBuilder evaluatedQuery) {
this(id, ratedDocs, evaluatedQuery, new HashMap<>(), null);
}

private RatedRequest(String id, List<RatedDocument> ratedDocs, SearchSourceBuilder evaluatedQuery,
Map<String, Object> params, String templateId) {
if (params != null && (params.size() > 0 && testRequest != null)) {
if (params != null && (params.size() > 0 && evaluatedQuery != null)) {
throw new IllegalArgumentException(
"Ambiguous rated request: Set both, verbatim test request and test request "
+ "template parameters.");
"Ambiguous rated request: Set both, verbatim test request and test request " + "template parameters.");
}
if (templateId != null && testRequest != null) {
if (templateId != null && evaluatedQuery != null) {
throw new IllegalArgumentException(
"Ambiguous rated request: Set both, verbatim test request and test request "
+ "template parameters.");
"Ambiguous rated request: Set both, verbatim test request and test request " + "template parameters.");
}
if ((params == null || params.size() < 1) && testRequest == null) {
throw new IllegalArgumentException(
"Need to set at least test request or test request template parameters.");
if ((params == null || params.size() < 1) && evaluatedQuery == null) {
throw new IllegalArgumentException("Need to set at least test request or test request template parameters.");
}
if ((params != null && params.size() > 0) && templateId == null) {
throw new IllegalArgumentException(
"If template parameters are supplied need to set id of template to apply "
+ "them to too.");
throw new IllegalArgumentException("If template parameters are supplied need to set id of template to apply " + "them to too.");
}
validateEvaluatedQuery(evaluatedQuery);

// check that not two documents with same _index/id are specified
Set<DocumentKey> docKeys = new HashSet<>();
for (RatedDocument doc : ratedDocs) {
Expand All @@ -118,7 +143,7 @@ private RatedRequest(String id, List<RatedDocument> ratedDocs, SearchSourceBuild
}

this.id = id;
this.testRequest = testRequest;
this.evaluationRequest = evaluatedQuery;
this.ratedDocs = new ArrayList<>(ratedDocs);
if (params != null) {
this.params = new HashMap<>(params);
Expand All @@ -129,18 +154,30 @@ private RatedRequest(String id, List<RatedDocument> ratedDocs, SearchSourceBuild
this.summaryFields = new ArrayList<>();
}

public RatedRequest(String id, List<RatedDocument> ratedDocs, Map<String, Object> params,
String templateId) {
this(id, ratedDocs, null, params, templateId);
}

public RatedRequest(String id, List<RatedDocument> ratedDocs, SearchSourceBuilder testRequest) {
this(id, ratedDocs, testRequest, new HashMap<>(), null);
static void validateEvaluatedQuery(SearchSourceBuilder evaluationRequest) {
// ensure that testRequest, if set, does not contain aggregation, suggest or highlighting section
if (evaluationRequest != null) {
if (evaluationRequest.suggest() != null) {
throw new IllegalArgumentException("Query in rated requests should not contain a suggest section.");
}
if (evaluationRequest.aggregations() != null) {
throw new IllegalArgumentException("Query in rated requests should not contain aggregations.");
}
if (evaluationRequest.highlighter() != null) {
throw new IllegalArgumentException("Query in rated requests should not contain a highlighter section.");
}
if (evaluationRequest.explain() != null && evaluationRequest.explain()) {
throw new IllegalArgumentException("Query in rated requests should not use explain.");
}
if (evaluationRequest.profile()) {
throw new IllegalArgumentException("Query in rated requests should not use profile.");
}
}
}

public RatedRequest(StreamInput in) throws IOException {
RatedRequest(StreamInput in) throws IOException {
this.id = in.readString();
testRequest = in.readOptionalWriteable(SearchSourceBuilder::new);
evaluationRequest = in.readOptionalWriteable(SearchSourceBuilder::new);

int intentSize = in.readInt();
ratedDocs = new ArrayList<>(intentSize);
Expand All @@ -159,7 +196,7 @@ public RatedRequest(StreamInput in) throws IOException {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(id);
out.writeOptionalWriteable(testRequest);
out.writeOptionalWriteable(evaluationRequest);

out.writeInt(ratedDocs.size());
for (RatedDocument ratedDoc : ratedDocs) {
Expand All @@ -173,8 +210,8 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(this.templateId);
}

public SearchSourceBuilder getTestRequest() {
return testRequest;
public SearchSourceBuilder getEvaluationRequest() {
return evaluationRequest;
}

/** return the user supplied request id */
Expand Down Expand Up @@ -240,8 +277,8 @@ public static RatedRequest fromXContent(XContentParser parser) {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ID_FIELD.getPreferredName(), this.id);
if (testRequest != null) {
builder.field(REQUEST_FIELD.getPreferredName(), this.testRequest);
if (evaluationRequest != null) {
builder.field(REQUEST_FIELD.getPreferredName(), this.evaluationRequest);
}
builder.startArray(RATINGS_FIELD.getPreferredName());
for (RatedDocument doc : this.ratedDocs) {
Expand Down Expand Up @@ -285,7 +322,7 @@ public final boolean equals(Object obj) {

RatedRequest other = (RatedRequest) obj;

return Objects.equals(id, other.id) && Objects.equals(testRequest, other.testRequest)
return Objects.equals(id, other.id) && Objects.equals(evaluationRequest, other.evaluationRequest)
&& Objects.equals(summaryFields, other.summaryFields)
&& Objects.equals(ratedDocs, other.ratedDocs)
&& Objects.equals(params, other.params)
Expand All @@ -294,7 +331,7 @@ public final boolean equals(Object obj) {

@Override
public final int hashCode() {
return Objects.hash(id, testRequest, summaryFields, ratedDocs, params,
return Objects.hash(id, evaluationRequest, summaryFields, ratedDocs, params,
templateId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import java.util.concurrent.ConcurrentHashMap;

import static org.elasticsearch.common.xcontent.XContentHelper.createParser;
import static org.elasticsearch.index.rankeval.RatedRequest.validateEvaluatedQuery;

/**
* Instances of this class execute a collection of search intents (read: user
Expand Down Expand Up @@ -99,15 +100,17 @@ protected void doExecute(RankEvalRequest request, ActionListener<RankEvalRespons
msearchRequest.maxConcurrentSearchRequests(evaluationSpecification.getMaxConcurrentSearches());
List<RatedRequest> ratedRequestsInSearch = new ArrayList<>();
for (RatedRequest ratedRequest : ratedRequests) {
SearchSourceBuilder ratedSearchSource = ratedRequest.getTestRequest();
if (ratedSearchSource == null) {
SearchSourceBuilder evaluationRequest = ratedRequest.getEvaluationRequest();
if (evaluationRequest == null) {
Map<String, Object> params = ratedRequest.getParams();
String templateId = ratedRequest.getTemplateId();
TemplateScript.Factory templateScript = scriptsWithoutParams.get(templateId);
String resolvedRequest = templateScript.newInstance(params).execute();
try (XContentParser subParser = createParser(namedXContentRegistry,
LoggingDeprecationHandler.INSTANCE, new BytesArray(resolvedRequest), XContentType.JSON)) {
ratedSearchSource = SearchSourceBuilder.fromXContent(subParser, false);
evaluationRequest = SearchSourceBuilder.fromXContent(subParser, false);
// check for parts that should not be part of a ranking evaluation request
validateEvaluatedQuery(evaluationRequest);
} catch (IOException e) {
// if we fail parsing, put the exception into the errors map and continue
errors.put(ratedRequest.getId(), e);
Expand All @@ -116,17 +119,17 @@ LoggingDeprecationHandler.INSTANCE, new BytesArray(resolvedRequest), XContentTyp
}

if (metric.forcedSearchSize().isPresent()) {
ratedSearchSource.size(metric.forcedSearchSize().get());
evaluationRequest.size(metric.forcedSearchSize().get());
}

ratedRequestsInSearch.add(ratedRequest);
List<String> summaryFields = ratedRequest.getSummaryFields();
if (summaryFields.isEmpty()) {
ratedSearchSource.fetchSource(false);
evaluationRequest.fetchSource(false);
} else {
ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]);
evaluationRequest.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]);
}
SearchRequest searchRequest = new SearchRequest(request.indices(), ratedSearchSource);
SearchRequest searchRequest = new SearchRequest(request.indices(), evaluationRequest);
searchRequest.indicesOptions(request.indicesOptions());
msearchRequest.add(searchRequest);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.suggest.SuggestBuilder;
import org.elasticsearch.search.suggest.SuggestBuilders;
import org.elasticsearch.test.ESTestCase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
Expand Down Expand Up @@ -165,7 +169,7 @@ public void testEqualsAndHash() throws IOException {

private static RatedRequest mutateTestItem(RatedRequest original) {
String id = original.getId();
SearchSourceBuilder testRequest = original.getTestRequest();
SearchSourceBuilder evaluationRequest = original.getEvaluationRequest();
List<RatedDocument> ratedDocs = original.getRatedDocs();
Map<String, Object> params = original.getParams();
List<String> summaryFields = original.getSummaryFields();
Expand All @@ -177,11 +181,11 @@ private static RatedRequest mutateTestItem(RatedRequest original) {
id = randomValueOtherThan(id, () -> randomAlphaOfLength(10));
break;
case 1:
if (testRequest != null) {
int size = randomValueOtherThan(testRequest.size(), () -> randomInt(Integer.MAX_VALUE));
testRequest = new SearchSourceBuilder();
testRequest.size(size);
testRequest.query(new MatchAllQueryBuilder());
if (evaluationRequest != null) {
int size = randomValueOtherThan(evaluationRequest.size(), () -> randomInt(Integer.MAX_VALUE));
evaluationRequest = new SearchSourceBuilder();
evaluationRequest.size(size);
evaluationRequest.query(new MatchAllQueryBuilder());
} else {
if (randomBoolean()) {
Map<String, Object> mutated = new HashMap<>();
Expand All @@ -204,10 +208,10 @@ private static RatedRequest mutateTestItem(RatedRequest original) {
}

RatedRequest ratedRequest;
if (testRequest == null) {
if (evaluationRequest == null) {
ratedRequest = new RatedRequest(id, ratedDocs, params, templateId);
} else {
ratedRequest = new RatedRequest(id, ratedDocs, testRequest);
ratedRequest = new RatedRequest(id, ratedDocs, evaluationRequest);
}
ratedRequest.addSummaryFields(summaryFields);

Expand Down Expand Up @@ -258,6 +262,44 @@ public void testSettingTemplateIdNoParamsThrows() {
expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, null, "templateId"));
}

public void testAggsNotAllowed() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
SearchSourceBuilder query = new SearchSourceBuilder();
query.aggregation(AggregationBuilders.terms("fieldName"));
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, query));
assertEquals("Query in rated requests should not contain aggregations.", e.getMessage());
}

public void testSuggestionsNotAllowed() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
SearchSourceBuilder query = new SearchSourceBuilder();
query.suggest(new SuggestBuilder().addSuggestion("id", SuggestBuilders.completionSuggestion("fieldname")));
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, query));
assertEquals("Query in rated requests should not contain a suggest section.", e.getMessage());
}

public void testHighlighterNotAllowed() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
SearchSourceBuilder query = new SearchSourceBuilder();
query.highlighter(new HighlightBuilder().field("field"));
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new RatedRequest("id", ratedDocs, query));
assertEquals("Query in rated requests should not contain a highlighter section.", e.getMessage());
}

public void testExplainNotAllowed() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder().explain(true)));
assertEquals("Query in rated requests should not use explain.", e.getMessage());
}

public void testProfileNotAllowed() {
List<RatedDocument> ratedDocs = Arrays.asList(new RatedDocument("index1", "id1", 1));
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> new RatedRequest("id", ratedDocs, new SearchSourceBuilder().profile(true)));
assertEquals("Query in rated requests should not use profile.", e.getMessage());
}

/**
* test that modifying the order of index/docId to make sure it doesn't
* matter for parsing xContent
Expand Down Expand Up @@ -287,7 +329,7 @@ public void testParseFromXContent() throws IOException {
try (XContentParser parser = createParser(JsonXContent.jsonXContent, querySpecString)) {
RatedRequest specification = RatedRequest.fromXContent(parser);
assertEquals("my_qa_query", specification.getId());
assertNotNull(specification.getTestRequest());
assertNotNull(specification.getEvaluationRequest());
List<RatedDocument> ratedDocs = specification.getRatedDocs();
assertEquals(3, ratedDocs.size());
for (int i = 0; i < 3; i++) {
Expand Down
Loading

0 comments on commit cc93131

Please sign in to comment.