From 0078300dfcda5d7e9c567070bb833ea74e7f8cb0 Mon Sep 17 00:00:00 2001 From: aprudhomme Date: Tue, 28 Jan 2020 08:04:23 -0800 Subject: [PATCH] Add optional extra logging for native scripts (#268) * Add optional extra logging for native scripts * Added Extra Logging section to docs --- docs/advanced-functionality.rst | 45 +++++ .../feature/store/ExtraLoggingSupplier.java | 26 +++ .../es/ltr/feature/store/ScriptFeature.java | 24 ++- .../es/ltr/logging/LoggingFetchSubPhase.java | 24 ++- .../es/ltr/query/DerivedExpressionQuery.java | 4 +- .../o19s/es/ltr/query/LtrRewritableQuery.java | 6 +- .../o19s/es/ltr/query/LtrRewriteContext.java | 35 ++++ .../com/o19s/es/ltr/query/RankerQuery.java | 3 +- .../com/o19s/es/ltr/ranker/LogLtrRanker.java | 7 + .../es/ltr/action/BaseIntegrationTest.java | 42 ++++ .../store/ExtraLoggingSupplierTests.java | 44 +++++ .../com/o19s/es/ltr/logging/LoggingIT.java | 179 +++++++++++++++++- 12 files changed, 424 insertions(+), 15 deletions(-) create mode 100644 src/main/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplier.java create mode 100644 src/main/java/com/o19s/es/ltr/query/LtrRewriteContext.java create mode 100644 src/test/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplierTests.java diff --git a/docs/advanced-functionality.rst b/docs/advanced-functionality.rst index 313672b5..2155aacf 100644 --- a/docs/advanced-functionality.rst +++ b/docs/advanced-functionality.rst @@ -167,3 +167,48 @@ Characteristics of the internal cache can be controlled with these node settings ltr.caches.expire_after_write: 10m # Evict cache entries 10 minutes after access (defaults to 1hour, set to 0 to disable) ltr.caches.expire_after_read: 10m + +============================= +Extra Logging +============================= + +As described in :doc:`logging-features`, it is possible to use the logging extension to return the feature values with each document. For native scripts, it is also possible to return extra arbitrary information with the logged features. + +For native scripts, the parameter :code:`extra_logging` is injected into the script parameters. The parameter value is a `Supplier `_ <`Map `_>, which provides a non-null :code:`Map` **only** during the logging fetch phase. Any values added to this Map will be returned with the logged features:: + + @Override + public double runAsDouble() { + ... + Map extraLoggingMap = ((Supplier>) getParams().get("extra_logging")).get(); + if (extraLoggingMap != null) { + extraLoggingMap.put("extra_float", 10.0f); + extraLoggingMap.put("extra_string", "additional_info"); + } + ... + } + +If (and only if) the extra logging Map is accessed, it will be returned as an additional entry with the logged features:: + + { + "log_entry1": [ + { + "name": "title_query" + "value": 9.510193 + }, + { + "name": "body_query" + "value": 10.7808075 + }, + { + "name": "user_rating", + "value": 7.8 + }, + { + "name": "extra_logging", + "value": { + "extra_float": 10.0, + "extra_string": "additional_info" + } + } + ] + } \ No newline at end of file diff --git a/src/main/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplier.java b/src/main/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplier.java new file mode 100644 index 00000000..de897fef --- /dev/null +++ b/src/main/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplier.java @@ -0,0 +1,26 @@ +package com.o19s.es.ltr.feature.store; + +import java.util.Map; +import java.util.function.Supplier; + + +public class ExtraLoggingSupplier implements Supplier> { + protected Supplier> supplier; + + public void setSupplier(Supplier> supplier) { + this.supplier = supplier; + } + + /** + * Return a Map to add additional information to be returned when logging feature values. + * + * This Map will only be non-null during the LoggingFetchSubPhase. + */ + @Override + public Map get() { + if (supplier != null) { + return supplier.get(); + } + return null; + } +} \ No newline at end of file diff --git a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java index 728b29e2..e3f3f44a 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java @@ -4,7 +4,8 @@ import com.o19s.es.ltr.feature.Feature; import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.query.LtrRewritableQuery; -import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.query.LtrRewriteContext; +import com.o19s.es.ltr.ranker.LogLtrRanker; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; @@ -31,12 +32,12 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.function.Supplier; import java.util.stream.Collectors; public class ScriptFeature implements Feature { public static final String TEMPLATE_LANGUAGE = "script_feature"; public static final String FEATURE_VECTOR = "feature_vector"; + public static final String EXTRA_LOGGING = "extra_logging"; public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params"; private final String name; @@ -109,11 +110,13 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map nparams = new HashMap<>(); nparams.putAll(baseScriptParams); nparams.putAll(queryTimeParams); nparams.putAll(extraQueryTimeParams); nparams.put(FEATURE_VECTOR, supplier); + nparams.put(EXTRA_LOGGING, extraLoggingSupplier); Script script = new Script(this.script.getType(), this.script.getLang(), this.script.getIdOrCode(), this.script.getOptions(), nparams); ScoreScript.Factory factoryFactory = context.getQueryShardContext().getScriptService().compile(script, ScoreScript.CONTEXT); @@ -122,15 +125,17 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map vectorSupplier) throws IOException { - supplier.set(vectorSupplier); + public Query ltrRewrite(LtrRewriteContext context) throws IOException { + supplier.set(context.getFeatureVectorSupplier()); + + LogLtrRanker.LogConsumer consumer = context.getLogConsumer(); + if (consumer != null) { + extraLoggingSupplier.setSupplier(consumer::getExtraLoggingMap); + } else { + extraLoggingSupplier.setSupplier(() -> null); + } return this; } } diff --git a/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java b/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java index f9a24568..190eab44 100644 --- a/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java +++ b/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java @@ -177,6 +177,7 @@ private Tuple toLogger(LoggingSearchExtBuilder.LogS static class HitLogConsumer implements LogLtrRanker.LogConsumer { private static final String FIELD_NAME = "_ltrlog"; + private static final String EXTRA_LOGGING_NAME = "extra_logging"; private final String name; private final FeatureSet set; private final boolean missingAsZero; @@ -192,6 +193,7 @@ static class HitLogConsumer implements LogLtrRanker.LogConsumer { // ] private List> currentLog; private SearchHit currentHit; + private Map extraLogging; HitLogConsumer(String name, FeatureSet set, boolean missingAsZero) { @@ -201,7 +203,9 @@ static class HitLogConsumer implements LogLtrRanker.LogConsumer { } private void rebuild() { - List> ini = new ArrayList<>(set.size()); + // Allocate one Map per feature, plus one placeholder for an extra logging Map + // that will only be added if used. + List> ini = new ArrayList<>(set.size() + 1); for (int i = 0; i < set.size(); i++) { Map defaultKeyVal = new HashMap<>(); @@ -212,6 +216,7 @@ private void rebuild() { ini.add(i, defaultKeyVal); } currentLog = ini; + extraLogging = null; } @Override @@ -221,6 +226,23 @@ public void accept(int featureOrdinal, float score) { currentLog.get(featureOrdinal).put("value", score); } + /** + * Return Map to store additional logging information returned with the feature values. + * + * The Map is created on first access. + */ + @Override + public Map getExtraLoggingMap() { + if (extraLogging == null) { + extraLogging = new HashMap<>(); + Map logEntry = new HashMap<>(); + logEntry.put("name", EXTRA_LOGGING_NAME); + logEntry.put("value", extraLogging); + currentLog.add(logEntry); + } + return extraLogging; + } + void nextDoc(SearchHit hit) { if (hit.fieldsOrNull() == null) { hit.fields(new HashMap<>()); diff --git a/src/main/java/com/o19s/es/ltr/query/DerivedExpressionQuery.java b/src/main/java/com/o19s/es/ltr/query/DerivedExpressionQuery.java index f834b2c4..d6ebab5b 100644 --- a/src/main/java/com/o19s/es/ltr/query/DerivedExpressionQuery.java +++ b/src/main/java/com/o19s/es/ltr/query/DerivedExpressionQuery.java @@ -66,8 +66,8 @@ public boolean equals(Object obj) { } @Override - public Query ltrRewrite(Supplier vectorSuppler) { - return new FVDerivedExpressionQuery(this, vectorSuppler); + public Query ltrRewrite(LtrRewriteContext context) { + return new FVDerivedExpressionQuery(this, context.getFeatureVectorSupplier()); } @Override diff --git a/src/main/java/com/o19s/es/ltr/query/LtrRewritableQuery.java b/src/main/java/com/o19s/es/ltr/query/LtrRewritableQuery.java index b4059ce6..0bceca5b 100644 --- a/src/main/java/com/o19s/es/ltr/query/LtrRewritableQuery.java +++ b/src/main/java/com/o19s/es/ltr/query/LtrRewritableQuery.java @@ -1,14 +1,12 @@ package com.o19s.es.ltr.query; -import com.o19s.es.ltr.ranker.LtrRanker; import org.apache.lucene.search.Query; import java.io.IOException; -import java.util.function.Supplier; public interface LtrRewritableQuery { /** - * Rewrite the query so that it holds the vectorSupplier + * Rewrite the query so that it holds the vectorSupplier and provide extra logging support */ - Query ltrRewrite(Supplier vectorSuppler) throws IOException; + Query ltrRewrite(LtrRewriteContext context) throws IOException; } diff --git a/src/main/java/com/o19s/es/ltr/query/LtrRewriteContext.java b/src/main/java/com/o19s/es/ltr/query/LtrRewriteContext.java new file mode 100644 index 00000000..d44bcded --- /dev/null +++ b/src/main/java/com/o19s/es/ltr/query/LtrRewriteContext.java @@ -0,0 +1,35 @@ +package com.o19s.es.ltr.query; + +import com.o19s.es.ltr.ranker.LogLtrRanker; +import com.o19s.es.ltr.ranker.LtrRanker; + +import java.util.function.Supplier; + +/** + * Contains context needed to rewrite queries to holds the vectorSupplier and provide extra logging support + */ +public class LtrRewriteContext { + private final Supplier vectorSupplier; + private final LtrRanker ranker; + + public LtrRewriteContext(LtrRanker ranker, Supplier vectorSupplier) { + this.ranker = ranker; + this.vectorSupplier = vectorSupplier; + } + + public Supplier getFeatureVectorSupplier() { + return vectorSupplier; + } + + /** + * Get LogConsumer used during the LoggingFetchSubPhase + * + * The returned consumer will only be non-null during the logging fetch phase + */ + public LogLtrRanker.LogConsumer getLogConsumer() { + if (ranker instanceof LogLtrRanker) { + return ((LogLtrRanker)ranker).getLogConsumer(); + } + return null; + } +} diff --git a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java index 8f834f22..db942923 100644 --- a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java +++ b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java @@ -194,9 +194,10 @@ public boolean isCacheable(LeafReaderContext ctx) { // Hopefully elastic never runs MutableSupplier vectorSupplier = new Suppliers.FeatureVectorSupplier(); FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker, vectorSupplier); + LtrRewriteContext context = new LtrRewriteContext(ranker, vectorSupplier); for (Query q : queries) { if (q instanceof LtrRewritableQuery) { - q = ((LtrRewritableQuery)q).ltrRewrite(vectorSupplier); + q = ((LtrRewritableQuery)q).ltrRewrite(context); } weights.add(searcher.createWeight(q, ScoreMode.COMPLETE, boost)); } diff --git a/src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java b/src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java index d885e4f9..912828c2 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java +++ b/src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java @@ -16,6 +16,8 @@ package com.o19s.es.ltr.ranker; +import java.util.Map; + public class LogLtrRanker implements LtrRanker { private final LogConsumer logger; private final LtrRanker ranker; @@ -79,9 +81,14 @@ void reset(LtrRanker ranker) { } } + public LogConsumer getLogConsumer() { + return logger; + } + @FunctionalInterface public interface LogConsumer { void accept(int featureOrdinal, float score); + default Map getExtraLoggingMap() {return null;} default void reset() {} } } diff --git a/src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java b/src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java index 90ca1183..ecbc3c82 100644 --- a/src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java +++ b/src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java @@ -43,7 +43,9 @@ import java.util.Collection; import java.util.Map; import java.util.concurrent.ExecutionException; +import java.util.function.Supplier; +import static com.o19s.es.ltr.feature.store.ScriptFeature.EXTRA_LOGGING; import static com.o19s.es.ltr.feature.store.ScriptFeature.FEATURE_VECTOR; public abstract class BaseIntegrationTest extends ESSingleNodeTestCase { @@ -168,6 +170,9 @@ public FactoryType compile(String scriptName, String scriptSource, if (!p.containsKey(FEATURE_VECTOR)) { throw new IllegalArgumentException("Missing parameter [" + FEATURE_VECTOR + "]"); } + if (!p.containsKey(EXTRA_LOGGING)) { + throw new IllegalArgumentException("Missing parameter [" + EXTRA_LOGGING + "]"); + } if (!p.containsKey(DEPDENDENT_FEATURE)) { throw new IllegalArgumentException("Missing parameter [depdendent_feature ]"); } @@ -198,6 +203,43 @@ public boolean needs_score() { return context.factoryClazz.cast(factory); } + else if (scriptSource.equals(FEATURE_EXTRACTOR + "_extra_logging")) { + ScoreScript.Factory factory = (p, lookup) -> + new ScoreScript.LeafFactory() { + { + if (!p.containsKey(FEATURE_VECTOR)) { + throw new IllegalArgumentException("Missing parameter [" + FEATURE_VECTOR + "]"); + } + if (!p.containsKey(EXTRA_LOGGING)) { + throw new IllegalArgumentException("Missing parameter [" + EXTRA_LOGGING + "]"); + } + } + + @Override + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { + return new ScoreScript(p, lookup, ctx) { + + @Override + public double execute(ExplanationHolder explanation) { + Map extraLoggingMap = ((Supplier>) getParams() + .get(EXTRA_LOGGING)).get(); + if (extraLoggingMap != null) { + extraLoggingMap.put("extra_float", 10.0f); + extraLoggingMap.put("extra_string", "additional_info"); + } + return 1.0d; + } + }; + } + + @Override + public boolean needs_score() { + return false; + } + }; + + return context.factoryClazz.cast(factory); + } throw new IllegalArgumentException("Unknown script name " + scriptSource); } }; diff --git a/src/test/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplierTests.java b/src/test/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplierTests.java new file mode 100644 index 00000000..7e44025e --- /dev/null +++ b/src/test/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplierTests.java @@ -0,0 +1,44 @@ +package com.o19s.es.ltr.feature.store; + +import com.o19s.es.ltr.ranker.LogLtrRanker; +import org.apache.lucene.util.LuceneTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class ExtraLoggingSupplierTests extends LuceneTestCase { + public void testGetWithConsumerNotSet() { + ExtraLoggingSupplier supplier = new ExtraLoggingSupplier(); + assertNull(supplier.get()); + } + + public void testGetWillNullConsumerSet() { + ExtraLoggingSupplier supplier = new ExtraLoggingSupplier(); + supplier.setSupplier(null); + assertNull(supplier.get()); + } + + public void testGetWithSuppliedNull() { + ExtraLoggingSupplier supplier = new ExtraLoggingSupplier(); + supplier.setSupplier(() -> null); + assertNull(supplier.get()); + } + + public void testGetWithSuppliedMap() { + Map extraLoggingMap = new HashMap<>(); + + LogLtrRanker.LogConsumer consumer = new LogLtrRanker.LogConsumer() { + @Override + public void accept(int featureOrdinal, float score) {} + + @Override + public Map getExtraLoggingMap() { + return extraLoggingMap; + } + }; + + ExtraLoggingSupplier supplier = new ExtraLoggingSupplier(); + supplier.setSupplier(consumer::getExtraLoggingMap); + assertTrue(supplier.get() == extraLoggingMap); + } +} diff --git a/src/test/java/com/o19s/es/ltr/logging/LoggingIT.java b/src/test/java/com/o19s/es/ltr/logging/LoggingIT.java index 6da2a4eb..c2ce262d 100644 --- a/src/test/java/com/o19s/es/ltr/logging/LoggingIT.java +++ b/src/test/java/com/o19s/es/ltr/logging/LoggingIT.java @@ -18,6 +18,7 @@ import com.o19s.es.ltr.LtrTestUtils; import com.o19s.es.ltr.action.BaseIntegrationTest; +import com.o19s.es.ltr.feature.store.ScriptFeature; import com.o19s.es.ltr.feature.store.StoredFeature; import com.o19s.es.ltr.feature.store.StoredFeatureSet; import com.o19s.es.ltr.feature.store.StoredLtrModel; @@ -42,6 +43,7 @@ import org.elasticsearch.search.rescore.QueryRescorerBuilder; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -75,6 +77,29 @@ public void prepareModels() throws Exception { LinearRankerParserTests.generateRandomModelString(set), true)); addElement(model); } + public void prepareModelsExtraLogging() throws Exception { + List features = new ArrayList<>(3); + features.add(new StoredFeature("text_feature1", Collections.singletonList("query"), "mustache", + QueryBuilders.matchQuery("field1", "{{query}}").toString())); + features.add(new StoredFeature("text_feature2", Collections.singletonList("query"), "mustache", + QueryBuilders.matchQuery("field2", "{{query}}").toString())); + features.add(new StoredFeature("numeric_feature1", Collections.singletonList("query"), "mustache", + new FunctionScoreQueryBuilder(QueryBuilders.matchAllQuery(), new FieldValueFactorFunctionBuilder("scorefield1") + .factor(FACTOR) + .modifier(FieldValueFactorFunction.Modifier.LN2P) + .missing(0F)).scoreMode(FunctionScoreQuery.ScoreMode.MULTIPLY).toString())); + features.add(new StoredFeature("derived_feature", Collections.singletonList("query"), "derived_expression", + "100")); + features.add(new StoredFeature("extra_logging_feature", Arrays.asList("query"), ScriptFeature.TEMPLATE_LANGUAGE, + "{\"lang\": \"native\", \"source\": \"feature_extractor_extra_logging\", \"params\": {}}")); + + StoredFeatureSet set = new StoredFeatureSet("my_set", features); + addElement(set); + StoredLtrModel model = new StoredLtrModel("my_model", set, + new StoredLtrModel.LtrModelDefinition("model/linear", + LinearRankerParserTests.generateRandomModelString(set), true)); + addElement(model); + } public void testFailures() throws Exception { prepareModels(); buildIndex(); @@ -213,6 +238,82 @@ public void testLog() throws Exception { assertSearchHits(docs, resp3); } + public void testLogExtraLogging() throws Exception { + prepareModelsExtraLogging(); + Map docs = buildIndex(); + + Map params = new HashMap<>(); + params.put("query", "found"); + List idsColl = new ArrayList<>(docs.keySet()); + Collections.shuffle(idsColl, random()); + String[] ids = idsColl.subList(0, TestUtil.nextInt(random(), 5, 15)).toArray(new String[0]); + StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) + .featureSetName("my_set") + .params(params) + .queryName("test") + .boost(random().nextInt(3)); + + StoredLtrQueryBuilder sbuilder_rescore = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) + .featureSetName("my_set") + .params(params) + .queryName("test_rescore") + .boost(random().nextInt(3)); + + QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) + .filter(QueryBuilders.idsQuery("test").addIds(ids)); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext(Collections.singletonList( + new LoggingSearchExtBuilder() + .addQueryLogging("first_log", "test", false) + .addRescoreLogging("second_log", 0, true))); + + SearchResponse resp = client().prepareSearch("test_index").setTypes("test").setSource(sourceBuilder).get(); + assertSearchHitsExtraLogging(docs, resp); + sbuilder.featureSetName(null); + sbuilder.modelName("my_model"); + sbuilder.boost(random().nextInt(3)); + sbuilder_rescore.featureSetName(null); + sbuilder_rescore.modelName("my_model"); + sbuilder_rescore.boost(random().nextInt(3)); + + query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) + .filter(QueryBuilders.idsQuery("test").addIds(ids)); + sourceBuilder = new SearchSourceBuilder().query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext(Collections.singletonList( + new LoggingSearchExtBuilder() + .addQueryLogging("first_log", "test", false) + .addRescoreLogging("second_log", 0, true))); + + SearchResponse resp2 = client().prepareSearch("test_index").setTypes("test").setSource(sourceBuilder).get(); + assertSearchHitsExtraLogging(docs, resp2); + + query = QueryBuilders.boolQuery() + .must(new WrapperQueryBuilder(sbuilder.toString())) + .must( + QueryBuilders.nestedQuery( + "nesteddocs1", + QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("nesteddocs1.field1", "nestedvalue")), + ScoreMode.None + ).innerHit(new InnerHitBuilder()) + ); + sourceBuilder = new SearchSourceBuilder().query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext(Collections.singletonList( + new LoggingSearchExtBuilder() + .addQueryLogging("first_log", "test", false) + .addRescoreLogging("second_log", 0, true))); + SearchResponse resp3 = client().prepareSearch("test_index").setTypes("test").setSource(sourceBuilder).get(); + assertSearchHitsExtraLogging(docs, resp3); + } + protected void assertSearchHits(Map docs, SearchResponse resp) { for (SearchHit hit: resp.getHits()) { assertTrue(hit.getFields().containsKey("_ltrlog")); @@ -223,6 +324,64 @@ protected void assertSearchHits(Map docs, SearchResponse resp) { List> log1 = logs.get("first_log"); List> log2 = logs.get("second_log"); Doc d = docs.get(hit.getId()); + + assertEquals(4, log1.size()); + assertEquals(4, log2.size()); + if (d.field1.equals("found")) { + assertEquals(log1.get(0).get("name"), "text_feature1"); + assertEquals(log2.get(0).get("name"), "text_feature1"); + + assertTrue((Float)log1.get(0).get("value") > 0F); + assertTrue((Float)log2.get(0).get("value") > 0F); + + assertEquals(log1.get(1).get("name"), "text_feature2"); + assertFalse(log1.get(1).containsKey("value")); + + assertEquals(log2.get(1).get("name"), "text_feature2"); + assertEquals(log2.get(1).get("value"), 0F); + + } else { + assertEquals(log1.get(0).get("name"), "text_feature1"); + assertEquals(log2.get(0).get("name"), "text_feature1"); + + assertTrue((Float)log1.get(1).get("value") > 0F); + assertTrue((Float)log2.get(1).get("value") > 0F); + + assertEquals(log1.get(0).get("name"), "text_feature1"); + assertEquals(log2.get(0).get("name"), "text_feature1"); + + assertEquals(0F, (Float)log2.get(0).get("value"), 0F); + } + float score = (float) Math.log1p((d.scorefield1 * FACTOR) + 1); + assertEquals(log1.get(2).get("name"), "numeric_feature1"); + assertEquals(log2.get(2).get("name"), "numeric_feature1"); + + assertEquals(score, (Float)log1.get(2).get("value"), Math.ulp(score)); + assertEquals(score, (Float)log2.get(2).get("value"), Math.ulp(score)); + + assertEquals(log1.get(3).get("name"), "derived_feature"); + assertEquals(log2.get(3).get("name"), "derived_feature"); + + assertEquals(100.0, (Float) log1.get(3).get("value"), Math.ulp(100.0)); + assertEquals(100.0, (Float) log2.get(3).get("value"), Math.ulp(100.0)); + + } + } + + @SuppressWarnings("unchecked") + protected void assertSearchHitsExtraLogging(Map docs, SearchResponse resp) { + for (SearchHit hit: resp.getHits()) { + assertTrue(hit.getFields().containsKey("_ltrlog")); + Map>> logs = hit.getFields().get("_ltrlog").getValue(); + assertTrue(logs.containsKey("first_log")); + assertTrue(logs.containsKey("second_log")); + + List> log1 = logs.get("first_log"); + List> log2 = logs.get("second_log"); + Doc d = docs.get(hit.getId()); + + assertEquals(6, log1.size()); + assertEquals(6, log2.size()); if (d.field1.equals("found")) { assertEquals(log1.get(0).get("name"), "text_feature1"); assertEquals(log2.get(0).get("name"), "text_feature1"); @@ -258,9 +417,27 @@ protected void assertSearchHits(Map docs, SearchResponse resp) { assertEquals(log1.get(3).get("name"), "derived_feature"); assertEquals(log2.get(3).get("name"), "derived_feature"); - assertEquals(100.0, (Float)log1.get(3).get("value"), Math.ulp(100.0)); + assertEquals(100.0, (Float) log1.get(3).get("value"), Math.ulp(100.0)); assertEquals(100.0, (Float) log2.get(3).get("value"), Math.ulp(100.0)); + assertEquals(log1.get(4).get("name"), "extra_logging_feature"); + assertEquals(log2.get(4).get("name"), "extra_logging_feature"); + + assertEquals(1.0, (Float) log1.get(4).get("value"), Math.ulp(1.0)); + assertEquals(1.0, (Float) log2.get(4).get("value"), Math.ulp(1.0)); + + assertEquals(log1.get(5).get("name"), "extra_logging"); + assertEquals(log2.get(5).get("name"), "extra_logging"); + + Map extraMap1 = (Map) log1.get(5).get("value"); + Map extraMap2 = (Map) log2.get(5).get("value"); + + assertEquals(2, extraMap1.size()); + assertEquals(2, extraMap2.size()); + assertEquals(10.0f, extraMap1.get("extra_float")); + assertEquals(10.0f, extraMap2.get("extra_float")); + assertEquals("additional_info", extraMap1.get("extra_string")); + assertEquals("additional_info", extraMap2.get("extra_string")); } }