From d48765c3e135caeeef7160121229820103233bb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Fri, 15 Jun 2018 11:56:16 +0200 Subject: [PATCH] Add details section for dcg ranking metric (#31177) While the other two ranking evaluation metrics (precicion and reciprocal rank) already provide a more detailed output for how their score is calculated, the discounted cumulative gain metric (dcg) and its normalized variant are lacking this until now. Its not really clear which level of detail might be useful for debugging and understanding the final metric calculation, but this change adds a `metric_details` section to REST output that contains some information about the evaluation details. --- .../client/RestHighLevelClientTests.java | 6 +- .../rankeval/DiscountedCumulativeGain.java | 142 ++++++++++++++++-- .../RankEvalNamedXContentProvider.java | 2 + .../index/rankeval/RankEvalPlugin.java | 5 +- .../DiscountedCumulativeGainTests.java | 26 +++- .../index/rankeval/EvalQueryQualityTests.java | 14 +- 6 files changed, 174 insertions(+), 21 deletions(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index ea8f9df665e81..2d8ca045e4638 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.client; import com.fasterxml.jackson.core.JsonParseException; + import org.apache.http.Header; import org.apache.http.HttpEntity; import org.apache.http.HttpHost; @@ -608,7 +609,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(7, namedXContents.size()); + assertEquals(8, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -626,9 +627,10 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); - assertEquals(Integer.valueOf(2), categories.get(MetricDetail.class)); + assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class)); assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); + assertTrue(names.contains(DiscountedCumulativeGain.NAME)); } private static class TrackingActionListener implements ActionListener { diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java index 13926d7d362ff..01a6e35299b29 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java @@ -36,6 +36,7 @@ import java.util.Optional; import java.util.stream.Collectors; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; @@ -129,26 +130,31 @@ public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, .collect(Collectors.toList()); List ratedHits = joinHitsWithRatings(hits, ratedDocs); List ratingsInSearchHits = new ArrayList<>(ratedHits.size()); + int unratedResults = 0; for (RatedSearchHit hit : ratedHits) { - // unknownDocRating might be null, which means it will be unrated docs are - // ignored in the dcg calculation - // we still need to add them as a placeholder so the rank of the subsequent - // ratings is correct + // unknownDocRating might be null, in which case unrated docs will be ignored in the dcg calculation. + // we still need to add them as a placeholder so the rank of the subsequent ratings is correct ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating)); + if (hit.getRating().isPresent() == false) { + unratedResults++; + } } - double dcg = computeDCG(ratingsInSearchHits); + final double dcg = computeDCG(ratingsInSearchHits); + double result = dcg; + double idcg = 0; if (normalize) { Collections.sort(allRatings, Comparator.nullsLast(Collections.reverseOrder())); - double idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size()))); - if (idcg > 0) { - dcg = dcg / idcg; + idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size()))); + if (idcg != 0) { + result = dcg / idcg; } else { - dcg = 0; + result = 0; } } - EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, dcg); + EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, result); evalQueryQuality.addHitsAndRatings(ratedHits); + evalQueryQuality.setMetricDetails(new Detail(dcg, idcg, unratedResults)); return evalQueryQuality; } @@ -167,7 +173,7 @@ private static double computeDCG(List ratings) { private static final ParseField K_FIELD = new ParseField("k"); private static final ParseField NORMALIZE_FIELD = new ParseField("normalize"); private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating"); - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("dcg_at", false, + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("dcg", false, args -> { Boolean normalized = (Boolean) args[0]; Integer optK = (Integer) args[2]; @@ -217,4 +223,118 @@ public final boolean equals(Object obj) { public final int hashCode() { return Objects.hash(normalize, unknownDocRating, k); } + + public static final class Detail implements MetricDetail { + + private static ParseField DCG_FIELD = new ParseField("dcg"); + private static ParseField IDCG_FIELD = new ParseField("ideal_dcg"); + private static ParseField NDCG_FIELD = new ParseField("normalized_dcg"); + private static ParseField UNRATED_FIELD = new ParseField("unrated_docs"); + private final double dcg; + private final double idcg; + private final int unratedDocs; + + Detail(double dcg, double idcg, int unratedDocs) { + this.dcg = dcg; + this.idcg = idcg; + this.unratedDocs = unratedDocs; + } + + Detail(StreamInput in) throws IOException { + this.dcg = in.readDouble(); + this.idcg = in.readDouble(); + this.unratedDocs = in.readVInt(); + } + + @Override + public + String getMetricName() { + return NAME; + } + + @Override + public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(DCG_FIELD.getPreferredName(), this.dcg); + if (this.idcg != 0) { + builder.field(IDCG_FIELD.getPreferredName(), this.idcg); + builder.field(NDCG_FIELD.getPreferredName(), this.dcg / this.idcg); + } + builder.field(UNRATED_FIELD.getPreferredName(), this.unratedDocs); + return builder; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, args -> { + return new Detail((Double) args[0], (Double) args[1] != null ? (Double) args[1] : 0.0d, (Integer) args[2]); + }); + + static { + PARSER.declareDouble(constructorArg(), DCG_FIELD); + PARSER.declareDouble(optionalConstructorArg(), IDCG_FIELD); + PARSER.declareInt(constructorArg(), UNRATED_FIELD); + } + + public static Detail fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(this.dcg); + out.writeDouble(this.idcg); + out.writeVInt(this.unratedDocs); + } + + @Override + public String getWriteableName() { + return NAME; + } + + /** + * @return the discounted cumulative gain + */ + public double getDCG() { + return this.dcg; + } + + /** + * @return the ideal discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required + */ + public double getIDCG() { + return this.idcg; + } + + /** + * @return the normalized discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required + */ + public double getNDCG() { + return (this.idcg != 0) ? this.dcg / this.idcg : 0; + } + + /** + * @return the number of unrated documents in the search results + */ + public Object getUnratedDocs() { + return this.unratedDocs; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + DiscountedCumulativeGain.Detail other = (DiscountedCumulativeGain.Detail) obj; + return (this.dcg == other.dcg && + this.idcg == other.idcg && + this.unratedDocs == other.unratedDocs); + } + + @Override + public int hashCode() { + return Objects.hash(this.dcg, this.idcg, this.unratedDocs); + } + } } + diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java index c5785ca3847d4..f2176113cdf9d 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java @@ -41,6 +41,8 @@ public List getNamedXContentParsers() { PrecisionAtK.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME), MeanReciprocalRank.Detail::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME), + DiscountedCumulativeGain.Detail::fromXContent)); return namedXContent; } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java index 884cf3bafdcda..8ac2b7fbee528 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java @@ -61,8 +61,9 @@ public List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new)); - namedWriteables - .add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new)); return namedWriteables; } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java index 64337786b1eb6..24ac600a11398 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.text.Text; @@ -254,9 +255,8 @@ private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRati public static DiscountedCumulativeGain createTestItem() { boolean normalize = randomBoolean(); - Integer unknownDocRating = Integer.valueOf(randomIntBetween(0, 1000)); - - return new DiscountedCumulativeGain(normalize, unknownDocRating, 10); + Integer unknownDocRating = frequently() ? Integer.valueOf(randomIntBetween(0, 1000)) : null; + return new DiscountedCumulativeGain(normalize, unknownDocRating, randomIntBetween(1, 10)); } public void testXContentRoundtrip() throws IOException { @@ -283,7 +283,25 @@ public void testXContentParsingIsNotLenient() throws IOException { parser.nextToken(); XContentParseException exception = expectThrows(XContentParseException.class, () -> DiscountedCumulativeGain.fromXContent(parser)); - assertThat(exception.getMessage(), containsString("[dcg_at] unknown field")); + assertThat(exception.getMessage(), containsString("[dcg] unknown field")); + } + } + + public void testMetricDetails() { + double dcg = randomDoubleBetween(0, 1, true); + double idcg = randomBoolean() ? 0.0 : randomDoubleBetween(0, 1, true); + double expectedNdcg = idcg != 0 ? dcg / idcg : 0.0; + int unratedDocs = randomIntBetween(0, 100); + DiscountedCumulativeGain.Detail detail = new DiscountedCumulativeGain.Detail(dcg, idcg, unratedDocs); + assertEquals(dcg, detail.getDCG(), 0.0); + assertEquals(idcg, detail.getIDCG(), 0.0); + assertEquals(expectedNdcg, detail.getNDCG(), 0.0); + assertEquals(unratedDocs, detail.getUnratedDocs()); + if (idcg != 0) { + assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"ideal_dcg\":" + idcg + ",\"normalized_dcg\":" + expectedNdcg + + ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail)); + } else { + assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail)); } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java index 112cf4eaaf72e..e9fae6b5c63ee 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java @@ -68,10 +68,20 @@ public static EvalQueryQuality randomEvalQueryQuality() { EvalQueryQuality evalQueryQuality = new EvalQueryQuality(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, true)); if (randomBoolean()) { - if (randomBoolean()) { + int metricDetail = randomIntBetween(0, 2); + switch (metricDetail) { + case 0: evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(randomIntBetween(0, 1000), randomIntBetween(0, 1000))); - } else { + break; + case 1: evalQueryQuality.setMetricDetails(new MeanReciprocalRank.Detail(randomIntBetween(0, 1000))); + break; + case 2: + evalQueryQuality.setMetricDetails(new DiscountedCumulativeGain.Detail(randomDoubleBetween(0, 1, true), + randomBoolean() ? randomDoubleBetween(0, 1, true) : 0, randomInt())); + break; + default: + throw new IllegalArgumentException("illegal randomized value in test"); } } evalQueryQuality.addHitsAndRatings(ratedHits);