diff --git a/x-pack/plugin/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/mapper/ScriptDateMappedFieldType.java b/x-pack/plugin/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/mapper/ScriptDateMappedFieldType.java index b369c0f44df2d..f3f561f55e587 100644 --- a/x-pack/plugin/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/mapper/ScriptDateMappedFieldType.java +++ b/x-pack/plugin/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/mapper/ScriptDateMappedFieldType.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.common.time.DateFormatter; import org.elasticsearch.common.time.DateMathParser; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.mapper.DateFieldMapper; import org.elasticsearch.index.mapper.DateFieldMapper.DateFieldType; import org.elasticsearch.index.mapper.DateFieldMapper.Resolution; @@ -23,6 +24,7 @@ import org.elasticsearch.search.lookup.SearchLookup; import org.elasticsearch.xpack.runtimefields.DateScriptFieldScript; import org.elasticsearch.xpack.runtimefields.fielddata.ScriptDateFieldData; +import org.elasticsearch.xpack.runtimefields.query.LongScriptFieldDistanceFeatureQuery; import org.elasticsearch.xpack.runtimefields.query.LongScriptFieldExistsQuery; import org.elasticsearch.xpack.runtimefields.query.LongScriptFieldRangeQuery; import org.elasticsearch.xpack.runtimefields.query.LongScriptFieldTermQuery; @@ -86,6 +88,30 @@ private DateScriptFieldScript.LeafFactory leafFactory(SearchLookup lookup) { return scriptFactory.newFactory(script.getParams(), lookup); } + @Override + public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) { + checkAllowExpensiveQueries(context); + return DateFieldType.handleNow(context, now -> { + long originLong = DateFieldType.parseToLong( + origin, + true, + null, + dateTimeFormatter.toDateMathParser(), + now, + DateFieldMapper.Resolution.MILLISECONDS + ); + TimeValue pivotTime = TimeValue.parseTimeValue(pivot, "distance_feature.pivot"); + return new LongScriptFieldDistanceFeatureQuery( + script, + leafFactory(context.lookup())::newInstance, + name(), + originLong, + pivotTime.getMillis(), + boost + ); + }); + } + @Override public Query existsQuery(QueryShardContext context) { checkAllowExpensiveQueries(context); diff --git a/x-pack/plugin/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/query/LongScriptFieldDistanceFeatureQuery.java b/x-pack/plugin/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/query/LongScriptFieldDistanceFeatureQuery.java new file mode 100644 index 0000000000000..d91d6b98240c2 --- /dev/null +++ b/x-pack/plugin/runtime-fields/src/main/java/org/elasticsearch/xpack/runtimefields/query/LongScriptFieldDistanceFeatureQuery.java @@ -0,0 +1,218 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.runtimefields.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; +import org.apache.lucene.search.Weight; +import org.elasticsearch.common.CheckedFunction; +import org.elasticsearch.script.Script; +import org.elasticsearch.xpack.runtimefields.AbstractLongScriptFieldScript; + +import java.io.IOException; +import java.util.Objects; +import java.util.Set; + +public final class LongScriptFieldDistanceFeatureQuery extends AbstractScriptFieldQuery { + private final CheckedFunction leafFactory; + private final long origin; + private final long pivot; + private final float boost; + + public LongScriptFieldDistanceFeatureQuery( + Script script, + CheckedFunction leafFactory, + String fieldName, + long origin, + long pivot, + float boost + ) { + super(script, fieldName); + this.leafFactory = leafFactory; + this.origin = origin; + this.pivot = pivot; + this.boost = boost; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new Weight(this) { + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + + @Override + public void extractTerms(Set terms) {} + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + return new DistanceScorer(this, leafFactory.apply(context), context.reader().maxDoc(), boost); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + AbstractLongScriptFieldScript script = leafFactory.apply(context); + script.runForDoc(doc); + long value = valueWithMinAbsoluteDistance(script); + float weight = LongScriptFieldDistanceFeatureQuery.this.boost * boost; + float score = score(weight, distanceFor(value)); + return Explanation.match( + score, + "Distance score, computed as weight * pivot / (pivot + abs(value - origin)) from:", + Explanation.match(weight, "weight"), + Explanation.match(pivot, "pivot"), + Explanation.match(origin, "origin"), + Explanation.match(value, "current value") + ); + } + }; + } + + private class DistanceScorer extends Scorer { + private final AbstractLongScriptFieldScript script; + private final TwoPhaseIterator twoPhase; + private final DocIdSetIterator disi; + private final float weight; + + protected DistanceScorer(Weight weight, AbstractLongScriptFieldScript script, int maxDoc, float boost) { + super(weight); + this.script = script; + twoPhase = new TwoPhaseIterator(DocIdSetIterator.all(maxDoc)) { + @Override + public boolean matches() throws IOException { + script.runForDoc(approximation().docID()); + return script.count() > 0; + } + + @Override + public float matchCost() { + return MATCH_COST; + } + }; + disi = TwoPhaseIterator.asDocIdSetIterator(twoPhase); + this.weight = LongScriptFieldDistanceFeatureQuery.this.boost * boost; + } + + @Override + public int docID() { + return disi.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return disi; + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return twoPhase; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return weight; + } + + @Override + public float score() throws IOException { + if (script.count() == 0) { + return 0; + } + return LongScriptFieldDistanceFeatureQuery.this.score(weight, (double) minAbsoluteDistance(script)); + } + } + + long minAbsoluteDistance(AbstractLongScriptFieldScript script) { + long minDistance = Long.MAX_VALUE; + for (int i = 0; i < script.count(); i++) { + minDistance = Math.min(minDistance, distanceFor(script.values()[i])); + } + return minDistance; + } + + long valueWithMinAbsoluteDistance(AbstractLongScriptFieldScript script) { + long minDistance = Long.MAX_VALUE; + long minDistanceValue = Long.MAX_VALUE; + for (int i = 0; i < script.count(); i++) { + long distance = distanceFor(script.values()[i]); + if (distance < minDistance) { + minDistance = distance; + minDistanceValue = script.values()[i]; + } + } + return minDistanceValue; + } + + long distanceFor(long value) { + long distance = Math.max(value, origin) - Math.min(value, origin); + if (distance < 0) { + // The distance doesn't fit into signed long so clamp it to MAX_VALUE + return Long.MAX_VALUE; + } + return distance; + } + + float score(float weight, double distance) { + return (float) (weight * (pivot / (pivot + distance))); + } + + @Override + public String toString(String field) { + StringBuilder b = new StringBuilder(); + if (false == fieldName().equals(field)) { + b.append(fieldName()).append(":"); + } + b.append(getClass().getSimpleName()); + b.append("(origin=").append(origin); + b.append(",pivot=").append(pivot); + b.append(",boost=").append(boost).append(")"); + return b.toString(); + + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), origin, pivot); + } + + @Override + public boolean equals(Object obj) { + if (false == super.equals(obj)) { + return false; + } + LongScriptFieldDistanceFeatureQuery other = (LongScriptFieldDistanceFeatureQuery) obj; + return origin == other.origin && pivot == other.pivot; + } + + @Override + public void visit(QueryVisitor visitor) { + // No subclasses contain any Terms because those have to be strings. + if (visitor.acceptField(fieldName())) { + visitor.visitLeaf(this); + } + } + + long origin() { + return origin; + } + + long pivot() { + return pivot; + } + + float boost() { + return boost; + } +} diff --git a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/mapper/AbstractScriptMappedFieldTypeTestCase.java b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/mapper/AbstractScriptMappedFieldTypeTestCase.java index 684057c9c90af..a44a6b5b5ae28 100644 --- a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/mapper/AbstractScriptMappedFieldTypeTestCase.java +++ b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/mapper/AbstractScriptMappedFieldTypeTestCase.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.runtimefields.mapper; +import org.apache.lucene.index.IndexReader; import org.elasticsearch.common.geo.ShapeRelation; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.query.QueryShardContext; @@ -124,4 +125,8 @@ private void assertQueryOnlyOnText(String queryName, ThrowingRunnable buildQuery ) ); } + + protected String readSource(IndexReader reader, int docId) throws IOException { + return reader.document(docId).getBinaryValue("_source").utf8ToString(); + } } diff --git a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/mapper/ScriptDateMappedFieldTypeTests.java b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/mapper/ScriptDateMappedFieldTypeTests.java index cebe31f7a5c25..168aee9e6e666 100644 --- a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/mapper/ScriptDateMappedFieldTypeTests.java +++ b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/mapper/ScriptDateMappedFieldTypeTests.java @@ -12,13 +12,16 @@ import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; @@ -59,6 +62,8 @@ import java.util.function.BiConsumer; import static java.util.Collections.emptyMap; +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -146,18 +151,9 @@ public void testSort() throws IOException { ScriptDateFieldData ifd = simpleMappedFieldType().fielddataBuilder("test", mockContext()::lookup).build(null, null, null); SortField sf = ifd.sortField(null, MultiValueMode.MIN, null, false); TopFieldDocs docs = searcher.search(new MatchAllDocsQuery(), 3, new Sort(sf)); - assertThat( - reader.document(docs.scoreDocs[0].doc).getBinaryValue("_source").utf8ToString(), - equalTo("{\"timestamp\": [1595432181351]}") - ); - assertThat( - reader.document(docs.scoreDocs[1].doc).getBinaryValue("_source").utf8ToString(), - equalTo("{\"timestamp\": [1595432181354]}") - ); - assertThat( - reader.document(docs.scoreDocs[2].doc).getBinaryValue("_source").utf8ToString(), - equalTo("{\"timestamp\": [1595432181356]}") - ); + assertThat(readSource(reader, docs.scoreDocs[0].doc), equalTo("{\"timestamp\": [1595432181351]}")); + assertThat(readSource(reader, docs.scoreDocs[1].doc), equalTo("{\"timestamp\": [1595432181354]}")); + assertThat(readSource(reader, docs.scoreDocs[2].doc), equalTo("{\"timestamp\": [1595432181356]}")); } } } @@ -192,6 +188,42 @@ public double execute(ExplanationHolder explanation) { } } + public void testDistanceFeatureQuery() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + iw.addDocuments( + List.of( + List.of(new StoredField("_source", new BytesRef("{\"timestamp\": [1595432181354]}"))), + List.of(new StoredField("_source", new BytesRef("{\"timestamp\": [1595432181351]}"))), + List.of(new StoredField("_source", new BytesRef("{\"timestamp\": [1595432181356, 1]}"))), + List.of(new StoredField("_source", new BytesRef("{\"timestamp\": []}"))) + ) + ); + try (DirectoryReader reader = iw.getReader()) { + IndexSearcher searcher = newSearcher(reader); + Query query = simpleMappedFieldType().distanceFeatureQuery(1595432181354L, "1ms", 1, mockContext()); + TopDocs docs = searcher.search(query, 4); + assertThat(docs.scoreDocs, arrayWithSize(3)); + assertThat(readSource(reader, docs.scoreDocs[0].doc), equalTo("{\"timestamp\": [1595432181354]}")); + assertThat(docs.scoreDocs[0].score, equalTo(1.0F)); + assertThat(readSource(reader, docs.scoreDocs[1].doc), equalTo("{\"timestamp\": [1595432181356, 1]}")); + assertThat((double) docs.scoreDocs[1].score, closeTo(.333, .001)); + assertThat(readSource(reader, docs.scoreDocs[2].doc), equalTo("{\"timestamp\": [1595432181351]}")); + assertThat((double) docs.scoreDocs[2].score, closeTo(.250, .001)); + Explanation explanation = query.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0F) + .explain(reader.leaves().get(0), docs.scoreDocs[0].doc); + assertThat(explanation.toString(), containsString("1.0 = Distance score, computed as weight * pivot / (pivot")); + assertThat(explanation.toString(), containsString("1.0 = weight")); + assertThat(explanation.toString(), containsString("1 = pivot")); + assertThat(explanation.toString(), containsString("1595432181354 = origin")); + assertThat(explanation.toString(), containsString("1595432181354 = current value")); + } + } + } + + public void testDistanceFeatureQueryIsExpensive() throws IOException { + checkExpensiveQuery((ft, ctx) -> ft.distanceFeatureQuery(randomLong(), randomAlphaOfLength(5), randomFloat(), ctx)); + } + @Override public void testExistsQuery() throws IOException { try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { @@ -409,7 +441,7 @@ private DateScriptFieldScript.Factory factory(String code) { @Override public void execute() { for (Object timestamp : (List) getSource().get("timestamp")) { - new DateScriptFieldScript.Millis(this).millis((Long) timestamp); + new DateScriptFieldScript.Millis(this).millis(((Number) timestamp).longValue()); } } }; diff --git a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractDoubleScriptFieldQueryTestCase.java b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractDoubleScriptFieldQueryTestCase.java index 59bd856c3cfd6..ce05013318771 100644 --- a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractDoubleScriptFieldQueryTestCase.java +++ b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractDoubleScriptFieldQueryTestCase.java @@ -6,17 +6,8 @@ package org.elasticsearch.xpack.runtimefields.query; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.util.automaton.ByteRunAutomaton; import org.elasticsearch.xpack.runtimefields.DoubleScriptFieldScript; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Supplier; - -import static org.hamcrest.Matchers.equalTo; import static org.mockito.Mockito.mock; public abstract class AbstractDoubleScriptFieldQueryTestCase extends @@ -26,24 +17,6 @@ public abstract class AbstractDoubleScriptFieldQueryTestCase leavesVisited = new ArrayList<>(); - query.visit(new QueryVisitor() { - @Override - public void consumeTerms(Query query, Term... terms) { - fail(); - } - - @Override - public void consumeTermsMatching(Query query, String field, Supplier automaton) { - fail(); - } - - @Override - public void visitLeaf(Query query) { - leavesVisited.add(query); - } - }); - assertThat(leavesVisited, equalTo(List.of(query))); + assertEmptyVisit(); } } diff --git a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractIpScriptFieldQueryTestCase.java b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractIpScriptFieldQueryTestCase.java index 1f2ec87024896..d57b8820752d8 100644 --- a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractIpScriptFieldQueryTestCase.java +++ b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractIpScriptFieldQueryTestCase.java @@ -7,19 +7,11 @@ package org.elasticsearch.xpack.runtimefields.query; import org.apache.lucene.document.InetAddressPoint; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.automaton.ByteRunAutomaton; import org.elasticsearch.xpack.runtimefields.IpScriptFieldScript; import java.net.InetAddress; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Supplier; -import static org.hamcrest.Matchers.equalTo; import static org.mockito.Mockito.mock; public abstract class AbstractIpScriptFieldQueryTestCase extends AbstractScriptFieldQueryTestCase { @@ -28,25 +20,7 @@ public abstract class AbstractIpScriptFieldQueryTestCase leavesVisited = new ArrayList<>(); - query.visit(new QueryVisitor() { - @Override - public void consumeTerms(Query query, Term... terms) { - fail(); - } - - @Override - public void consumeTermsMatching(Query query, String field, Supplier automaton) { - fail(); - } - - @Override - public void visitLeaf(Query query) { - leavesVisited.add(query); - } - }); - assertThat(leavesVisited, equalTo(List.of(query))); + assertEmptyVisit(); } protected static BytesRef encode(InetAddress addr) { diff --git a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractLongScriptFieldQueryTestCase.java b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractLongScriptFieldQueryTestCase.java index 82a3543321fc8..7958d906a8594 100644 --- a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractLongScriptFieldQueryTestCase.java +++ b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractLongScriptFieldQueryTestCase.java @@ -7,19 +7,10 @@ package org.elasticsearch.xpack.runtimefields.query; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.util.automaton.ByteRunAutomaton; import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.xpack.runtimefields.AbstractLongScriptFieldScript; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Supplier; - -import static org.hamcrest.Matchers.equalTo; public abstract class AbstractLongScriptFieldQueryTestCase extends AbstractScriptFieldQueryTestCase< T> { @@ -27,24 +18,6 @@ public abstract class AbstractLongScriptFieldQueryTestCase leavesVisited = new ArrayList<>(); - query.visit(new QueryVisitor() { - @Override - public void consumeTerms(Query query, Term... terms) { - fail(); - } - - @Override - public void consumeTermsMatching(Query query, String field, Supplier automaton) { - fail(); - } - - @Override - public void visitLeaf(Query query) { - leavesVisited.add(query); - } - }); - assertThat(leavesVisited, equalTo(List.of(query))); + assertEmptyVisit(); } } diff --git a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractScriptFieldQueryTestCase.java b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractScriptFieldQueryTestCase.java index 743450cd5f933..b8ad9406b8d00 100644 --- a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractScriptFieldQueryTestCase.java +++ b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/AbstractScriptFieldQueryTestCase.java @@ -6,10 +6,19 @@ package org.elasticsearch.xpack.runtimefields.query; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.util.automaton.ByteRunAutomaton; import org.elasticsearch.script.Script; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.EqualsHashCodeTestUtils; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + import static org.hamcrest.Matchers.equalTo; public abstract class AbstractScriptFieldQueryTestCase extends ESTestCase { @@ -27,7 +36,7 @@ public final void testEqualsAndHashCode() { EqualsHashCodeTestUtils.checkEqualsAndHashCode(createTestInstance(), this::copy, this::mutate); } - public abstract void testMatches(); + public abstract void testMatches() throws IOException; public final void testToString() { T query = createTestInstance(); @@ -38,4 +47,26 @@ public final void testToString() { protected abstract void assertToString(T query); public abstract void testVisit(); + + protected final void assertEmptyVisit() { + T query = createTestInstance(); + List leavesVisited = new ArrayList<>(); + query.visit(new QueryVisitor() { + @Override + public void consumeTerms(Query query, Term... terms) { + fail(); + } + + @Override + public void consumeTermsMatching(Query query, String field, Supplier automaton) { + fail(); + } + + @Override + public void visitLeaf(Query query) { + leavesVisited.add(query); + } + }); + assertThat(leavesVisited, equalTo(List.of(query))); + } } diff --git a/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/LongScriptFieldDistanceFeatureQueryTests.java b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/LongScriptFieldDistanceFeatureQueryTests.java new file mode 100644 index 0000000000000..160305b400813 --- /dev/null +++ b/x-pack/plugin/runtime-fields/src/test/java/org/elasticsearch/xpack/runtimefields/query/LongScriptFieldDistanceFeatureQueryTests.java @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.runtimefields.query; + +import org.apache.lucene.document.StoredField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.CheckedFunction; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.lookup.SearchLookup; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.runtimefields.AbstractLongScriptFieldScript; +import org.elasticsearch.xpack.runtimefields.DateScriptFieldScript; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; + +public class LongScriptFieldDistanceFeatureQueryTests extends AbstractScriptFieldQueryTestCase { + private final CheckedFunction leafFactory = ctx -> null; + + @Override + protected LongScriptFieldDistanceFeatureQuery createTestInstance() { + long origin = randomLong(); + long pivot = randomValueOtherThan(origin, ESTestCase::randomLong); + return new LongScriptFieldDistanceFeatureQuery(randomScript(), leafFactory, randomAlphaOfLength(5), origin, pivot, randomFloat()); + } + + @Override + protected LongScriptFieldDistanceFeatureQuery copy(LongScriptFieldDistanceFeatureQuery orig) { + return new LongScriptFieldDistanceFeatureQuery( + orig.script(), + leafFactory, + orig.fieldName(), + orig.origin(), + orig.pivot(), + orig.boost() + ); + } + + @Override + protected LongScriptFieldDistanceFeatureQuery mutate(LongScriptFieldDistanceFeatureQuery orig) { + Script script = orig.script(); + String fieldName = orig.fieldName(); + long origin = orig.origin(); + long pivot = orig.pivot(); + float boost = orig.boost(); + switch (randomInt(4)) { + case 0: + script = randomValueOtherThan(script, this::randomScript); + break; + case 1: + fieldName += "modified"; + break; + case 2: + origin = randomValueOtherThan(origin, () -> randomValueOtherThan(orig.pivot(), ESTestCase::randomLong)); + break; + case 3: + pivot = randomValueOtherThan(origin, () -> randomValueOtherThan(orig.pivot(), ESTestCase::randomLong)); + break; + case 4: + boost = randomValueOtherThan(boost, ESTestCase::randomFloat); + break; + default: + fail(); + } + return new LongScriptFieldDistanceFeatureQuery(script, leafFactory, fieldName, origin, pivot, boost); + } + + @Override + public void testMatches() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + iw.addDocument(List.of(new StoredField("_source", new BytesRef("{\"timestamp\": [1595432181354]}")))); + iw.addDocument(List.of(new StoredField("_source", new BytesRef("{\"timestamp\": [1595432181351]}")))); + try (DirectoryReader reader = iw.getReader()) { + IndexSearcher searcher = newSearcher(reader); + CheckedFunction leafFactory = + ctx -> new DateScriptFieldScript(Map.of(), new SearchLookup(null, null), ctx) { + @Override + public void execute() { + for (Object timestamp : (List) getSource().get("timestamp")) { + new DateScriptFieldScript.Millis(this).millis(((Number) timestamp).longValue()); + } + } + }; + LongScriptFieldDistanceFeatureQuery query = new LongScriptFieldDistanceFeatureQuery( + randomScript(), + leafFactory, + "test", + 1595432181351L, + 6L, + between(1, 100) + ); + TopDocs td = searcher.search(query, 1); + assertThat(td.scoreDocs[0].score, equalTo(query.boost())); + assertThat(td.scoreDocs[0].doc, equalTo(1)); + } + } + } + + public void testMaxScore() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + iw.addDocument(List.of()); + try (DirectoryReader reader = iw.getReader()) { + IndexSearcher searcher = newSearcher(reader); + LongScriptFieldDistanceFeatureQuery query = createTestInstance(); + float boost = randomFloat(); + assertThat( + query.createWeight(searcher, ScoreMode.COMPLETE, boost).scorer(reader.leaves().get(0)).getMaxScore(randomInt()), + equalTo(query.boost() * boost) + ); + } + } + } + + @Override + protected void assertToString(LongScriptFieldDistanceFeatureQuery query) { + assertThat( + query.toString(query.fieldName()), + equalTo( + "LongScriptFieldDistanceFeatureQuery(origin=" + query.origin() + ",pivot=" + query.pivot() + ",boost=" + query.boost() + ")" + ) + ); + } + + @Override + public final void testVisit() { + assertEmptyVisit(); + } +}