Skip to content

Commit

Permalink
Drop boost from runtime distance feature query (elastic#63949)
Browse files Browse the repository at this point in the history
This drops the `boost` parameter of the `distance_feature` query builder
internally, relying on our query building infrastructure to wrap the
query in a `boosting` query.

Relates to elastic#63767
  • Loading branch information
nik9000 authored Oct 21, 2020
1 parent e566730 commit f2bcc77
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,11 @@ public static long parseToLong(
}

@Override
public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
public Query distanceFeatureQuery(Object origin, String pivot, QueryShardContext context) {
long originLong = parseToLong(origin, true, null, null, context::nowInMillis);
TimeValue pivotTime = TimeValue.parseTimeValue(pivot, "distance_feature.pivot");
return resolution.distanceFeatureQuery(name(), boost, originLong, pivotTime);
// As we already apply boost in AbstractQueryBuilder::toQuery, we always passing a boost of 1.0 to distanceFeatureQuery
return resolution.distanceFeatureQuery(name(), 1.0f, originLong, pivotTime);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S
}

@Override
public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
public Query distanceFeatureQuery(Object origin, String pivot, QueryShardContext context) {
GeoPoint originGeoPoint;
if (origin instanceof GeoPoint) {
originGeoPoint = (GeoPoint) origin;
Expand All @@ -205,7 +205,8 @@ public Query distanceFeatureQuery(Object origin, String pivot, float boost, Quer
"Must be of type [geo_point] or [string] for geo_point fields!");
}
double pivotDouble = DistanceUnit.DEFAULT.parse(pivot, DistanceUnit.DEFAULT);
return LatLonPoint.newDistanceFeatureQuery(name(), boost, originGeoPoint.lat(), originGeoPoint.lon(), pivotDouble);
// As we already apply boost in AbstractQueryBuilder::toQuery, we always passing a boost of 1.0 to distanceFeatureQuery
return LatLonPoint.newDistanceFeatureQuery(name(), 1.0f, originGeoPoint.lat(), originGeoPoint.lon(), pivotDouble);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ public SpanQuery spanPrefixQuery(String value, SpanMultiTermQueryWrapper.SpanRew
+ "] which is of type [" + typeName() + "]");
}

public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
public Query distanceFeatureQuery(Object origin, String pivot, QueryShardContext context) {
throw new IllegalArgumentException("Illegal data type of [" + typeName() + "]!"+
"[" + DistanceFeatureQueryBuilder.NAME + "] query can only be run on a date, date_nanos or geo_point field type!");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
if (fieldType == null) {
return Queries.newMatchNoDocsQuery("Can't run [" + NAME + "] query on unmapped fields!");
}
// As we already apply boost in AbstractQueryBuilder::toQuery, we always passing a boost of 1.0 to distanceFeatureQuery
return fieldType.distanceFeatureQuery(origin.origin(), pivot, 1.0f, context);
return fieldType.distanceFeatureQuery(origin.origin(), pivot, context);
}

String fieldName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public DateScriptFieldData.Builder fielddataBuilder(String fullyQualifiedIndexNa
}

@Override
public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
public Query distanceFeatureQuery(Object origin, String pivot, QueryShardContext context) {
checkAllowExpensiveQueries(context);
return DateFieldType.handleNow(context, now -> {
long originLong = DateFieldType.parseToLong(
Expand All @@ -98,8 +98,7 @@ public Query distanceFeatureQuery(Object origin, String pivot, float boost, Quer
leafFactory(context)::newInstance,
name(),
originLong,
pivotTime.getMillis(),
boost
pivotTime.getMillis()
);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,17 @@
public final class LongScriptFieldDistanceFeatureQuery extends AbstractScriptFieldQuery<AbstractLongFieldScript> {
private final long origin;
private final long pivot;
private final float boost;

public LongScriptFieldDistanceFeatureQuery(
Script script,
Function<LeafReaderContext, AbstractLongFieldScript> leafFactory,
String fieldName,
long origin,
long pivot,
float boost
long pivot
) {
super(script, fieldName, leafFactory);
this.origin = origin;
this.pivot = pivot;
this.boost = boost;
}

@Override
Expand Down Expand Up @@ -70,12 +67,11 @@ public Explanation explain(LeafReaderContext context, int doc) {
AbstractLongFieldScript script = scriptContextFunction().apply(context);
script.runForDoc(doc);
long value = valueWithMinAbsoluteDistance(script);
float weight = LongScriptFieldDistanceFeatureQuery.this.boost * boost;
float score = score(weight, distanceFor(value));
float score = score(boost, distanceFor(value));
return Explanation.match(
score,
"Distance score, computed as weight * pivot / (pivot + abs(value - origin)) from:",
Explanation.match(weight, "weight"),
Explanation.match(boost, "weight"),
Explanation.match(pivot, "pivot"),
Explanation.match(origin, "origin"),
Explanation.match(value, "current value")
Expand Down Expand Up @@ -105,7 +101,7 @@ public float matchCost() {
}
};
disi = TwoPhaseIterator.asDocIdSetIterator(twoPhase);
this.weight = LongScriptFieldDistanceFeatureQuery.this.boost * boost;
this.weight = boost;
}

@Override
Expand Down Expand Up @@ -179,15 +175,14 @@ public String toString(String field) {
}
b.append(getClass().getSimpleName());
b.append("(origin=").append(origin);
b.append(",pivot=").append(pivot);
b.append(",boost=").append(boost).append(")");
b.append(",pivot=").append(pivot).append(")");
return b.toString();

}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), origin, pivot, boost);
return Objects.hash(super.hashCode(), origin, pivot);
}

@Override
Expand All @@ -196,7 +191,7 @@ public boolean equals(Object obj) {
return false;
}
LongScriptFieldDistanceFeatureQuery other = (LongScriptFieldDistanceFeatureQuery) obj;
return origin == other.origin && pivot == other.pivot && boost == other.boost;
return origin == other.origin && pivot == other.pivot;
}

@Override
Expand All @@ -214,8 +209,4 @@ long origin() {
long pivot() {
return pivot;
}

float boost() {
return boost;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ public void testDistanceFeatureQuery() throws IOException {
);
try (DirectoryReader reader = iw.getReader()) {
IndexSearcher searcher = newSearcher(reader);
Query query = simpleMappedFieldType().distanceFeatureQuery(1595432181354L, "1ms", 1, mockContext());
Query query = simpleMappedFieldType().distanceFeatureQuery(1595432181354L, "1ms", mockContext());
TopDocs docs = searcher.search(query, 4);
assertThat(docs.scoreDocs, arrayWithSize(3));
assertThat(readSource(reader, docs.scoreDocs[0].doc), equalTo("{\"timestamp\": [1595432181354]}"));
Expand Down Expand Up @@ -228,7 +228,7 @@ public void testDistanceFeatureQueryInLoop() throws IOException {
}

private Query randomDistanceFeatureQuery(MappedFieldType ft, QueryShardContext ctx) {
return ft.distanceFeatureQuery(randomDate(), randomTimeValue(), randomFloat(), ctx);
return ft.distanceFeatureQuery(randomDate(), randomTimeValue(), ctx);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,12 @@ public class LongScriptFieldDistanceFeatureQueryTests extends AbstractScriptFiel
protected LongScriptFieldDistanceFeatureQuery createTestInstance() {
long origin = randomLong();
long pivot = randomValueOtherThan(origin, ESTestCase::randomLong);
return new LongScriptFieldDistanceFeatureQuery(randomScript(), leafFactory, randomAlphaOfLength(5), origin, pivot, randomFloat());
return new LongScriptFieldDistanceFeatureQuery(randomScript(), leafFactory, randomAlphaOfLength(5), origin, pivot);
}

@Override
protected LongScriptFieldDistanceFeatureQuery copy(LongScriptFieldDistanceFeatureQuery orig) {
return new LongScriptFieldDistanceFeatureQuery(
orig.script(),
leafFactory,
orig.fieldName(),
orig.origin(),
orig.pivot(),
orig.boost()
);
return new LongScriptFieldDistanceFeatureQuery(orig.script(), leafFactory, orig.fieldName(), orig.origin(), orig.pivot());
}

@Override
Expand All @@ -56,8 +49,7 @@ protected LongScriptFieldDistanceFeatureQuery mutate(LongScriptFieldDistanceFeat
String fieldName = orig.fieldName();
long origin = orig.origin();
long pivot = orig.pivot();
float boost = orig.boost();
switch (randomInt(4)) {
switch (randomInt(3)) {
case 0:
script = randomValueOtherThan(script, this::randomScript);
break;
Expand All @@ -70,13 +62,10 @@ protected LongScriptFieldDistanceFeatureQuery mutate(LongScriptFieldDistanceFeat
case 3:
pivot = randomValueOtherThan(origin, () -> randomValueOtherThan(orig.pivot(), ESTestCase::randomLong));
break;
case 4:
boost = randomValueOtherThan(boost, ESTestCase::randomFloat);
break;
default:
fail();
}
return new LongScriptFieldDistanceFeatureQuery(script, leafFactory, fieldName, origin, pivot, boost);
return new LongScriptFieldDistanceFeatureQuery(script, leafFactory, fieldName, origin, pivot);
}

@Override
Expand Down Expand Up @@ -105,12 +94,13 @@ public void execute() {
leafFactory,
"test",
1595432181351L,
6L,
between(1, 100)
3L
);
TopDocs td = searcher.search(query, 1);
assertThat(td.scoreDocs[0].score, equalTo(query.boost()));
TopDocs td = searcher.search(query, 2);
assertThat(td.scoreDocs[0].score, equalTo(1.0f));
assertThat(td.scoreDocs[0].doc, equalTo(1));
assertThat(td.scoreDocs[1].score, equalTo(.5f));
assertThat(td.scoreDocs[1].doc, equalTo(0));
}
}
}
Expand All @@ -124,7 +114,7 @@ public void testMaxScore() throws IOException {
float boost = randomFloat();
assertThat(
query.createWeight(searcher, ScoreMode.COMPLETE, boost).scorer(reader.leaves().get(0)).getMaxScore(randomInt()),
equalTo(query.boost() * boost)
equalTo(boost)
);
}
}
Expand All @@ -134,9 +124,7 @@ public void testMaxScore() throws IOException {
protected void assertToString(LongScriptFieldDistanceFeatureQuery query) {
assertThat(
query.toString(query.fieldName()),
equalTo(
"LongScriptFieldDistanceFeatureQuery(origin=" + query.origin() + ",pivot=" + query.pivot() + ",boost=" + query.boost() + ")"
)
equalTo("LongScriptFieldDistanceFeatureQuery(origin=" + query.origin() + ",pivot=" + query.pivot() + ")")
);
}

Expand Down

0 comments on commit f2bcc77

Please sign in to comment.