From 3c83ee54aec00dc6106214d8dc460c1f717663c3 Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Fri, 24 Nov 2023 16:59:35 +0100 Subject: [PATCH] Adding a new option to the MustacheScriptEngine to support missing parameter detection. --- .../mustache/CustomMustacheFactory.java | 36 +++++++++++++++---- .../CustomReflectionObjectHandler.java | 18 ++++++++++ .../script/mustache/MustacheScriptEngine.java | 33 ++++++++++++++--- .../mustache/MustacheScriptEngineTests.java | 33 ++++++++++++++++- .../java/org/elasticsearch/script/Script.java | 5 +++ .../ml/inference/ltr/LearnToRankService.java | 7 ++-- 6 files changed, 115 insertions(+), 17 deletions(-) diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomMustacheFactory.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomMustacheFactory.java index 73669ccacdbc6..21224b9a235f1 100644 --- a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomMustacheFactory.java +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomMustacheFactory.java @@ -63,16 +63,12 @@ public final class CustomMustacheFactory extends DefaultMustacheFactory { private final Encoder encoder; - public CustomMustacheFactory(String mediaType) { + private CustomMustacheFactory(String mediaType, boolean detectMissingParams) { super(); - setObjectHandler(new CustomReflectionObjectHandler()); + setObjectHandler(new CustomReflectionObjectHandler(detectMissingParams)); this.encoder = createEncoder(mediaType); } - public CustomMustacheFactory() { - this(DEFAULT_MEDIA_TYPE); - } - @Override public void encode(String value, Writer writer) { try { @@ -95,12 +91,15 @@ public MustacheVisitor createMustacheVisitor() { return new CustomMustacheVisitor(this); } + public static Builder builder() { + return new Builder(); + } + class CustomMustacheVisitor extends DefaultMustacheVisitor { CustomMustacheVisitor(DefaultMustacheFactory df) { super(df); } - @Override public void iterable(TemplateContext templateContext, String variable, Mustache mustache) { if (ToJsonCode.match(variable)) { @@ -360,4 +359,27 @@ public void encode(String s, Writer writer) throws IOException { writer.write(URLEncoder.encode(s, StandardCharsets.UTF_8)); } } + + /** + * Build a new {@link CustomMustacheFactory} object. + */ + static class Builder { + private String mediaType = DEFAULT_MEDIA_TYPE; + private boolean detectMissingParams = false; + + private Builder() {} + + public Builder mediaType(String mediaType) { + this.mediaType = mediaType; + return this; + } + + public Builder detectMissingParams(boolean detectMissingParams) { + this.detectMissingParams = detectMissingParams; + return this; + } + public CustomMustacheFactory build() { + return new CustomMustacheFactory(mediaType, detectMissingParams); + } + } } diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomReflectionObjectHandler.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomReflectionObjectHandler.java index c1e87fdc0970e..07fab2ea58420 100644 --- a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomReflectionObjectHandler.java +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomReflectionObjectHandler.java @@ -8,7 +8,9 @@ package org.elasticsearch.script.mustache; +import com.github.mustachejava.reflect.MissingWrapper; import com.github.mustachejava.reflect.ReflectionObjectHandler; +import com.github.mustachejava.util.Wrapper; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.util.Maps; @@ -19,10 +21,16 @@ import java.util.AbstractMap; import java.util.Collection; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Set; final class CustomReflectionObjectHandler extends ReflectionObjectHandler { + private final boolean detectMissingParams; + + CustomReflectionObjectHandler(boolean detectMissingParams) { + this.detectMissingParams = detectMissingParams; + } @Override public Object coerce(Object object) { @@ -41,6 +49,16 @@ public Object coerce(Object object) { } } + public Wrapper find(String name, List scopes) { + Wrapper wrapper = super.find(name, scopes); + + if (detectMissingParams && wrapper instanceof MissingWrapper) { + throw new MustacheScriptEngine.InvalidParameterException("Parameter [" + name + "] is missing"); + } + + return wrapper; + } + @Override @SuppressWarnings("rawtypes") protected AccessibleObject findMember(Class sClass, String name) { diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java index c6f60c48c4ab4..2fd28df3a0312 100644 --- a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java @@ -72,10 +72,20 @@ public Set> getSupportedContexts() { } private static CustomMustacheFactory createMustacheFactory(Map options) { - if (options == null || options.isEmpty() || options.containsKey(Script.CONTENT_TYPE_OPTION) == false) { - return new CustomMustacheFactory(); + CustomMustacheFactory.Builder builder = CustomMustacheFactory.builder(); + if (options == null || options.isEmpty()) { + return builder.build(); } - return new CustomMustacheFactory(options.get(Script.CONTENT_TYPE_OPTION)); + + if (options.containsKey(Script.CONTENT_TYPE_OPTION)) { + builder.mediaType(options.get(Script.CONTENT_TYPE_OPTION)); + } + + if (options.containsKey(Script.DETECT_MISSING_PARAMS_OPTION)) { + builder.detectMissingParams(Boolean.valueOf(options.get(Script.DETECT_MISSING_PARAMS_OPTION))); + } + + return builder.build(); } @Override @@ -107,10 +117,25 @@ public String execute() { try { template.execute(writer, params); } catch (Exception e) { - logger.error(() -> format("Error running %s", template), e); + if (logException(e)) { + logger.error(() -> format("Error running %s", template), e); + } throw new GeneralScriptException("Error running " + template, e); } return writer.toString(); } + + public boolean logException(Throwable e) { + if (e instanceof InvalidParameterException) { + return false; + } + return e.getCause() == null || logException(e.getCause()); + } + } + + static class InvalidParameterException extends MustacheException { + InvalidParameterException(String message) { + super(message, null, null); + } } } diff --git a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java index 0d3e881e54a56..144304d5d36ed 100644 --- a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java +++ b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java @@ -9,6 +9,7 @@ import com.github.mustachejava.MustacheFactory; +import org.elasticsearch.script.GeneralScriptException; import org.elasticsearch.script.Script; import org.elasticsearch.script.TemplateScript; import org.elasticsearch.test.ESTestCase; @@ -18,10 +19,13 @@ import java.io.IOException; import java.io.StringWriter; +import java.util.Collections; import java.util.List; import java.util.Map; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.startsWith; /** * Mustache based templating test @@ -33,7 +37,7 @@ public class MustacheScriptEngineTests extends ESTestCase { @Before public void setup() { qe = new MustacheScriptEngine(); - factory = new CustomMustacheFactory(); + factory = CustomMustacheFactory.builder().build(); } public void testSimpleParameterReplace() { @@ -196,6 +200,33 @@ public void testSimple() throws IOException { assertThat(TemplateScript.execute(), equalTo("{\"match_all\":{}}")); } + public void testDetectMissingParam() throws IOException { + Map scriptOptions = Map.ofEntries(Map.entry(Script.DETECT_MISSING_PARAMS_OPTION, "true")); + String source = "{\"match\": { \"field\": \"{{query_string}}\" }"; + TemplateScript.Factory compiled = qe.compile(null, source, TemplateScript.CONTEXT, scriptOptions); + + // fails when a param is missing and the DETECT_MISSING_PARAMS_OPTION option is set to true. + { + Map params = Collections.emptyMap(); + GeneralScriptException e = expectThrows(GeneralScriptException.class, () -> compiled.newInstance(params).execute()); + assertThat(e.getRootCause(), instanceOf(MustacheScriptEngine.InvalidParameterException.class)); + assertThat(e.getRootCause().getMessage(), startsWith("Parameter [query_string] is missing")); + } + + // fails when params is null and the DETECT_MISSING_PARAMS_OPTION option is set to true. + { + GeneralScriptException e = expectThrows(GeneralScriptException.class, () -> compiled.newInstance(null).execute()); + assertThat(e.getRootCause(), instanceOf(MustacheScriptEngine.InvalidParameterException.class)); + assertThat(e.getRootCause().getMessage(), startsWith("Parameter [query_string] is missing")); + } + + // works as expected when params are specified and the DETECT_MISSING_PARAMS_OPTION option is set to true + { + Map params = Map.ofEntries(Map.entry("query_string", "foo")); + assertThat(compiled.newInstance(params).execute(), equalTo("{\"match\": { \"field\": \"foo\" }")); + } + } + public void testParseTemplateAsSingleStringWithConditionalClause() throws IOException { String templateString = """ { diff --git a/server/src/main/java/org/elasticsearch/script/Script.java b/server/src/main/java/org/elasticsearch/script/Script.java index d21cdc50e00b5..8a9638c60711e 100644 --- a/server/src/main/java/org/elasticsearch/script/Script.java +++ b/server/src/main/java/org/elasticsearch/script/Script.java @@ -94,6 +94,11 @@ public final class Script implements ToXContentObject, Writeable { */ public static final String CONTENT_TYPE_OPTION = "content_type"; + /** + * Compiler option to enable missing parameters detection. + */ + public static final String DETECT_MISSING_PARAMS_OPTION = "detect_missing_params"; + /** * Standard {@link ParseField} for outer level of script queries. */ diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankService.java index 1443ccd687620..7ff9920d1e7c8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankService.java @@ -126,11 +126,6 @@ private LearnToRankConfig applyParams(LearnToRankConfig config, Map featureExtractorBuilders = new ArrayList<>(); for (LearnToRankFeatureExtractorBuilder featureExtractorBuilder : config.getFeatureExtractorBuilders()) { @@ -178,6 +173,8 @@ private QueryExtractorBuilder applyParams(QueryExtractorBuilder queryExtractorBu Script script = new Script(ScriptType.INLINE, DEFAULT_TEMPLATE_LANG, templateSource, Collections.emptyMap()); String parsedTemplate = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params).execute(); + System.out.println(templateSource); + System.out.println(parsedTemplate); // TODO: handle missing params. XContentParser parser = XContentType.JSON.xContent().createParser(parserConfiguration, parsedTemplate);