Skip to content

Commit

Permalink
Add optional extra logging for native scripts (#268)
Browse files Browse the repository at this point in the history
* Add optional extra logging for native scripts

* Added Extra Logging section to docs
  • Loading branch information
aprudhomme authored and nomoa committed Jan 28, 2020
1 parent 337f92b commit 0078300
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 15 deletions.
45 changes: 45 additions & 0 deletions docs/advanced-functionality.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,48 @@ Characteristics of the internal cache can be controlled with these node settings
ltr.caches.expire_after_write: 10m
# Evict cache entries 10 minutes after access (defaults to 1hour, set to 0 to disable)
ltr.caches.expire_after_read: 10m

=============================
Extra Logging
=============================

As described in :doc:`logging-features`, it is possible to use the logging extension to return the feature values with each document. For native scripts, it is also possible to return extra arbitrary information with the logged features.

For native scripts, the parameter :code:`extra_logging` is injected into the script parameters. The parameter value is a `Supplier <https://docs.oracle.com/javase/8/docs/api/java/util/function/Supplier.html>`_ <`Map <https://docs.oracle.com/javase/8/docs/api/java/util/Map.html>`_>, which provides a non-null :code:`Map<String,Object>` **only** during the logging fetch phase. Any values added to this Map will be returned with the logged features::

@Override
public double runAsDouble() {
...
Map<String,Object> extraLoggingMap = ((Supplier<Map<String,Object>>) getParams().get("extra_logging")).get();
if (extraLoggingMap != null) {
extraLoggingMap.put("extra_float", 10.0f);
extraLoggingMap.put("extra_string", "additional_info");
}
...
}

If (and only if) the extra logging Map is accessed, it will be returned as an additional entry with the logged features::

{
"log_entry1": [
{
"name": "title_query"
"value": 9.510193
},
{
"name": "body_query"
"value": 10.7808075
},
{
"name": "user_rating",
"value": 7.8
},
{
"name": "extra_logging",
"value": {
"extra_float": 10.0,
"extra_string": "additional_info"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.o19s.es.ltr.feature.store;

import java.util.Map;
import java.util.function.Supplier;


public class ExtraLoggingSupplier implements Supplier<Map<String,Object>> {
protected Supplier<Map<String,Object>> supplier;

public void setSupplier(Supplier<Map<String,Object>> supplier) {
this.supplier = supplier;
}

/**
* Return a Map to add additional information to be returned when logging feature values.
*
* This Map will only be non-null during the LoggingFetchSubPhase.
*/
@Override
public Map<String, Object> get() {
if (supplier != null) {
return supplier.get();
}
return null;
}
}
24 changes: 18 additions & 6 deletions src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import com.o19s.es.ltr.feature.Feature;
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.query.LtrRewritableQuery;
import com.o19s.es.ltr.ranker.LtrRanker;
import com.o19s.es.ltr.query.LtrRewriteContext;
import com.o19s.es.ltr.ranker.LogLtrRanker;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -31,12 +32,12 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class ScriptFeature implements Feature {
public static final String TEMPLATE_LANGUAGE = "script_feature";
public static final String FEATURE_VECTOR = "feature_vector";
public static final String EXTRA_LOGGING = "extra_logging";
public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params";

private final String name;
Expand Down Expand Up @@ -109,11 +110,13 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin


FeatureSupplier supplier = new FeatureSupplier(featureSet);
ExtraLoggingSupplier extraLoggingSupplier = new ExtraLoggingSupplier();
Map<String, Object> nparams = new HashMap<>();
nparams.putAll(baseScriptParams);
nparams.putAll(queryTimeParams);
nparams.putAll(extraQueryTimeParams);
nparams.put(FEATURE_VECTOR, supplier);
nparams.put(EXTRA_LOGGING, extraLoggingSupplier);
Script script = new Script(this.script.getType(), this.script.getLang(),
this.script.getIdOrCode(), this.script.getOptions(), nparams);
ScoreScript.Factory factoryFactory = context.getQueryShardContext().getScriptService().compile(script, ScoreScript.CONTEXT);
Expand All @@ -122,15 +125,17 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<Strin
context.getQueryShardContext().index().getName(),
context.getQueryShardContext().getShardId(),
context.getQueryShardContext().indexVersionCreated());
return new LtrScript(function, supplier);
return new LtrScript(function, supplier, extraLoggingSupplier);
}

static class LtrScript extends Query implements LtrRewritableQuery {
private final ScriptScoreFunction function;
private final FeatureSupplier supplier;
LtrScript(ScriptScoreFunction function, FeatureSupplier supplier) {
private final ExtraLoggingSupplier extraLoggingSupplier;
LtrScript(ScriptScoreFunction function, FeatureSupplier supplier, ExtraLoggingSupplier extraLoggingSupplier) {
this.function = function;
this.supplier = supplier;
this.extraLoggingSupplier = extraLoggingSupplier;
}

@SuppressWarnings("EqualsWhichDoesntCheckParameterClass")
Expand Down Expand Up @@ -161,8 +166,15 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
}

@Override
public Query ltrRewrite(Supplier<LtrRanker.FeatureVector> vectorSupplier) throws IOException {
supplier.set(vectorSupplier);
public Query ltrRewrite(LtrRewriteContext context) throws IOException {
supplier.set(context.getFeatureVectorSupplier());

LogLtrRanker.LogConsumer consumer = context.getLogConsumer();
if (consumer != null) {
extraLoggingSupplier.setSupplier(consumer::getExtraLoggingMap);
} else {
extraLoggingSupplier.setSupplier(() -> null);
}
return this;
}
}
Expand Down
24 changes: 23 additions & 1 deletion src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ private Tuple<RankerQuery, HitLogConsumer> toLogger(LoggingSearchExtBuilder.LogS

static class HitLogConsumer implements LogLtrRanker.LogConsumer {
private static final String FIELD_NAME = "_ltrlog";
private static final String EXTRA_LOGGING_NAME = "extra_logging";
private final String name;
private final FeatureSet set;
private final boolean missingAsZero;
Expand All @@ -192,6 +193,7 @@ static class HitLogConsumer implements LogLtrRanker.LogConsumer {
// ]
private List<Map<String, Object>> currentLog;
private SearchHit currentHit;
private Map<String,Object> extraLogging;


HitLogConsumer(String name, FeatureSet set, boolean missingAsZero) {
Expand All @@ -201,7 +203,9 @@ static class HitLogConsumer implements LogLtrRanker.LogConsumer {
}

private void rebuild() {
List<Map<String, Object>> ini = new ArrayList<>(set.size());
// Allocate one Map per feature, plus one placeholder for an extra logging Map
// that will only be added if used.
List<Map<String, Object>> ini = new ArrayList<>(set.size() + 1);

for (int i = 0; i < set.size(); i++) {
Map<String, Object> defaultKeyVal = new HashMap<>();
Expand All @@ -212,6 +216,7 @@ private void rebuild() {
ini.add(i, defaultKeyVal);
}
currentLog = ini;
extraLogging = null;
}

@Override
Expand All @@ -221,6 +226,23 @@ public void accept(int featureOrdinal, float score) {
currentLog.get(featureOrdinal).put("value", score);
}

/**
* Return Map to store additional logging information returned with the feature values.
*
* The Map is created on first access.
*/
@Override
public Map<String,Object> getExtraLoggingMap() {
if (extraLogging == null) {
extraLogging = new HashMap<>();
Map<String,Object> logEntry = new HashMap<>();
logEntry.put("name", EXTRA_LOGGING_NAME);
logEntry.put("value", extraLogging);
currentLog.add(logEntry);
}
return extraLogging;
}

void nextDoc(SearchHit hit) {
if (hit.fieldsOrNull() == null) {
hit.fields(new HashMap<>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ public boolean equals(Object obj) {
}

@Override
public Query ltrRewrite(Supplier<LtrRanker.FeatureVector> vectorSuppler) {
return new FVDerivedExpressionQuery(this, vectorSuppler);
public Query ltrRewrite(LtrRewriteContext context) {
return new FVDerivedExpressionQuery(this, context.getFeatureVectorSupplier());
}

@Override
Expand Down
6 changes: 2 additions & 4 deletions src/main/java/com/o19s/es/ltr/query/LtrRewritableQuery.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package com.o19s.es.ltr.query;

import com.o19s.es.ltr.ranker.LtrRanker;
import org.apache.lucene.search.Query;

import java.io.IOException;
import java.util.function.Supplier;

public interface LtrRewritableQuery {
/**
* Rewrite the query so that it holds the vectorSupplier
* Rewrite the query so that it holds the vectorSupplier and provide extra logging support
*/
Query ltrRewrite(Supplier<LtrRanker.FeatureVector> vectorSuppler) throws IOException;
Query ltrRewrite(LtrRewriteContext context) throws IOException;
}
35 changes: 35 additions & 0 deletions src/main/java/com/o19s/es/ltr/query/LtrRewriteContext.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.o19s.es.ltr.query;

import com.o19s.es.ltr.ranker.LogLtrRanker;
import com.o19s.es.ltr.ranker.LtrRanker;

import java.util.function.Supplier;

/**
* Contains context needed to rewrite queries to holds the vectorSupplier and provide extra logging support
*/
public class LtrRewriteContext {
private final Supplier<LtrRanker.FeatureVector> vectorSupplier;
private final LtrRanker ranker;

public LtrRewriteContext(LtrRanker ranker, Supplier<LtrRanker.FeatureVector> vectorSupplier) {
this.ranker = ranker;
this.vectorSupplier = vectorSupplier;
}

public Supplier<LtrRanker.FeatureVector> getFeatureVectorSupplier() {
return vectorSupplier;
}

/**
* Get LogConsumer used during the LoggingFetchSubPhase
*
* The returned consumer will only be non-null during the logging fetch phase
*/
public LogLtrRanker.LogConsumer getLogConsumer() {
if (ranker instanceof LogLtrRanker) {
return ((LogLtrRanker)ranker).getLogConsumer();
}
return null;
}
}
3 changes: 2 additions & 1 deletion src/main/java/com/o19s/es/ltr/query/RankerQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,10 @@ public boolean isCacheable(LeafReaderContext ctx) {
// Hopefully elastic never runs
MutableSupplier<LtrRanker.FeatureVector> vectorSupplier = new Suppliers.FeatureVectorSupplier();
FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(ranker, vectorSupplier);
LtrRewriteContext context = new LtrRewriteContext(ranker, vectorSupplier);
for (Query q : queries) {
if (q instanceof LtrRewritableQuery) {
q = ((LtrRewritableQuery)q).ltrRewrite(vectorSupplier);
q = ((LtrRewritableQuery)q).ltrRewrite(context);
}
weights.add(searcher.createWeight(q, ScoreMode.COMPLETE, boost));
}
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.o19s.es.ltr.ranker;

import java.util.Map;

public class LogLtrRanker implements LtrRanker {
private final LogConsumer logger;
private final LtrRanker ranker;
Expand Down Expand Up @@ -79,9 +81,14 @@ void reset(LtrRanker ranker) {
}
}

public LogConsumer getLogConsumer() {
return logger;
}

@FunctionalInterface
public interface LogConsumer {
void accept(int featureOrdinal, float score);
default Map<String,Object> getExtraLoggingMap() {return null;}
default void reset() {}
}
}
42 changes: 42 additions & 0 deletions src/test/java/com/o19s/es/ltr/action/BaseIntegrationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;

import static com.o19s.es.ltr.feature.store.ScriptFeature.EXTRA_LOGGING;
import static com.o19s.es.ltr.feature.store.ScriptFeature.FEATURE_VECTOR;

public abstract class BaseIntegrationTest extends ESSingleNodeTestCase {
Expand Down Expand Up @@ -168,6 +170,9 @@ public <FactoryType> FactoryType compile(String scriptName, String scriptSource,
if (!p.containsKey(FEATURE_VECTOR)) {
throw new IllegalArgumentException("Missing parameter [" + FEATURE_VECTOR + "]");
}
if (!p.containsKey(EXTRA_LOGGING)) {
throw new IllegalArgumentException("Missing parameter [" + EXTRA_LOGGING + "]");
}
if (!p.containsKey(DEPDENDENT_FEATURE)) {
throw new IllegalArgumentException("Missing parameter [depdendent_feature ]");
}
Expand Down Expand Up @@ -198,6 +203,43 @@ public boolean needs_score() {

return context.factoryClazz.cast(factory);
}
else if (scriptSource.equals(FEATURE_EXTRACTOR + "_extra_logging")) {
ScoreScript.Factory factory = (p, lookup) ->
new ScoreScript.LeafFactory() {
{
if (!p.containsKey(FEATURE_VECTOR)) {
throw new IllegalArgumentException("Missing parameter [" + FEATURE_VECTOR + "]");
}
if (!p.containsKey(EXTRA_LOGGING)) {
throw new IllegalArgumentException("Missing parameter [" + EXTRA_LOGGING + "]");
}
}

@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return new ScoreScript(p, lookup, ctx) {

@Override
public double execute(ExplanationHolder explanation) {
Map<String,Object> extraLoggingMap = ((Supplier<Map<String,Object>>) getParams()
.get(EXTRA_LOGGING)).get();
if (extraLoggingMap != null) {
extraLoggingMap.put("extra_float", 10.0f);
extraLoggingMap.put("extra_string", "additional_info");
}
return 1.0d;
}
};
}

@Override
public boolean needs_score() {
return false;
}
};

return context.factoryClazz.cast(factory);
}
throw new IllegalArgumentException("Unknown script name " + scriptSource);
}
};
Expand Down
Loading

0 comments on commit 0078300

Please sign in to comment.