Skip to content

Commit

Permalink
Add dynamic (duck) type resolution to Painless static types (#78575)
Browse files Browse the repository at this point in the history
This change adds dynamic (duck) type resolution to Painless static types using an annotation ( at 
dynamic_type ) to control which static types are allowed to be dynamically invoked. This annotation 
does not chain so any sub classes that also require dynamic type resolution must be annotated as 
well.
  • Loading branch information
jdconrad authored Oct 4, 2021
1 parent a019d41 commit f6fbeb8
Show file tree
Hide file tree
Showing 11 changed files with 331 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.painless.spi.annotation;

public class DynamicTypeAnnotation {

public static final String NAME = "dynamic_type";

public static final DynamicTypeAnnotation INSTANCE = new DynamicTypeAnnotation();

private DynamicTypeAnnotation() {

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.painless.spi.annotation;

import java.util.Map;

public class DynamicTypeAnnotationParser implements WhitelistAnnotationParser {

public static final DynamicTypeAnnotationParser INSTANCE = new DynamicTypeAnnotationParser();

private DynamicTypeAnnotationParser() {}

@Override
public Object parse(Map<String, String> arguments) {
if (arguments.isEmpty() == false) {
throw new IllegalArgumentException(
"unexpected parameters for [@" + DynamicTypeAnnotation.NAME + "] annotation, found " + arguments
);
}

return DynamicTypeAnnotation.INSTANCE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ public interface WhitelistAnnotationParser {
new AbstractMap.SimpleEntry<>(NonDeterministicAnnotation.NAME, NonDeterministicAnnotationParser.INSTANCE),
new AbstractMap.SimpleEntry<>(InjectConstantAnnotation.NAME, InjectConstantAnnotationParser.INSTANCE),
new AbstractMap.SimpleEntry<>(CompileTimeOnlyAnnotation.NAME, CompileTimeOnlyAnnotationParser.INSTANCE),
new AbstractMap.SimpleEntry<>(AugmentedAnnotation.NAME, AugmentedAnnotationParser.INSTANCE)
new AbstractMap.SimpleEntry<>(AugmentedAnnotation.NAME, AugmentedAnnotationParser.INSTANCE),
new AbstractMap.SimpleEntry<>(DynamicTypeAnnotation.NAME, DynamicTypeAnnotationParser.INSTANCE)
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public final class PainlessClass {
public final Map<String, PainlessField> staticFields;
public final Map<String, PainlessField> fields;
public final PainlessMethod functionalInterfaceMethod;
public final Map<Class<?>, Object> annotations;

public final Map<String, PainlessMethod> runtimeMethods;
public final Map<String, MethodHandle> getterMethodHandles;
Expand All @@ -29,6 +30,7 @@ public final class PainlessClass {
Map<String, PainlessMethod> staticMethods, Map<String, PainlessMethod> methods,
Map<String, PainlessField> staticFields, Map<String, PainlessField> fields,
PainlessMethod functionalInterfaceMethod,
Map<Class<?>, Object> annotations,
Map<String, PainlessMethod> runtimeMethods,
Map<String, MethodHandle> getterMethodHandles, Map<String, MethodHandle> setterMethodHandles) {

Expand All @@ -38,6 +40,7 @@ public final class PainlessClass {
this.staticFields = Map.copyOf(staticFields);
this.fields = Map.copyOf(fields);
this.functionalInterfaceMethod = functionalInterfaceMethod;
this.annotations = annotations;

this.getterMethodHandles = Map.copyOf(getterMethodHandles);
this.setterMethodHandles = Map.copyOf(setterMethodHandles);
Expand All @@ -61,11 +64,12 @@ public boolean equals(Object object) {
Objects.equals(methods, that.methods) &&
Objects.equals(staticFields, that.staticFields) &&
Objects.equals(fields, that.fields) &&
Objects.equals(functionalInterfaceMethod, that.functionalInterfaceMethod);
Objects.equals(functionalInterfaceMethod, that.functionalInterfaceMethod) &&
Objects.equals(annotations, that.annotations);
}

@Override
public int hashCode() {
return Objects.hash(constructors, staticMethods, methods, staticFields, fields, functionalInterfaceMethod);
return Objects.hash(constructors, staticMethods, methods, staticFields, fields, functionalInterfaceMethod, annotations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ final class PainlessClassBuilder {
final Map<String, PainlessField> staticFields;
final Map<String, PainlessField> fields;
PainlessMethod functionalInterfaceMethod;
final Map<Class<?>, Object> annotations;

final Map<String, PainlessMethod> runtimeMethods;
final Map<String, MethodHandle> getterMethodHandles;
Expand All @@ -33,14 +34,15 @@ final class PainlessClassBuilder {
staticFields = new HashMap<>();
fields = new HashMap<>();
functionalInterfaceMethod = null;
annotations = new HashMap<>();

runtimeMethods = new HashMap<>();
getterMethodHandles = new HashMap<>();
setterMethodHandles = new HashMap<>();
}

PainlessClass build() {
return new PainlessClass(constructors, staticMethods, methods, staticFields, fields, functionalInterfaceMethod,
return new PainlessClass(constructors, staticMethods, methods, staticFields, fields, functionalInterfaceMethod, annotations,
runtimeMethods, getterMethodHandles, setterMethodHandles);
}

Expand All @@ -61,11 +63,12 @@ public boolean equals(Object object) {
Objects.equals(methods, that.methods) &&
Objects.equals(staticFields, that.staticFields) &&
Objects.equals(fields, that.fields) &&
Objects.equals(functionalInterfaceMethod, that.functionalInterfaceMethod);
Objects.equals(functionalInterfaceMethod, that.functionalInterfaceMethod) &&
Objects.equals(annotations, that.annotations);
}

@Override
public int hashCode() {
return Objects.hash(constructors, staticMethods, methods, staticFields, fields, functionalInterfaceMethod);
return Objects.hash(constructors, staticMethods, methods, staticFields, fields, functionalInterfaceMethod, annotations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public static PainlessLookup buildFromWhitelists(List<Whitelist> whitelists) {
origin = whitelistClass.origin;
painlessLookupBuilder.addPainlessClass(
whitelist.classLoader, whitelistClass.javaClassName,
whitelistClass.painlessAnnotations.containsKey(NoImportAnnotation.class) == false);
whitelistClass.painlessAnnotations);
}
}

Expand Down Expand Up @@ -236,7 +236,8 @@ private Class<?> loadClass(ClassLoader classLoader, String javaClassName, Suppli
}
}

public void addPainlessClass(ClassLoader classLoader, String javaClassName, boolean importClassName) {
public void addPainlessClass(ClassLoader classLoader, String javaClassName, Map<Class<?>, Object> annotations) {

Objects.requireNonNull(classLoader);
Objects.requireNonNull(javaClassName);

Expand All @@ -255,12 +256,12 @@ public void addPainlessClass(ClassLoader classLoader, String javaClassName, bool
clazz = loadClass(classLoader, javaClassName, () -> "class [" + javaClassName + "] not found");
}

addPainlessClass(clazz, importClassName);
addPainlessClass(clazz, annotations);
}

public void addPainlessClass(Class<?> clazz, boolean importClassName) {
public void addPainlessClass(Class<?> clazz, Map<Class<?>, Object> annotations) {
Objects.requireNonNull(clazz);
//Matcher m = new Matcher();
Objects.requireNonNull(annotations);

if (clazz == def.class) {
throw new IllegalArgumentException("cannot add reserved class [" + DEF_CLASS_NAME + "]");
Expand Down Expand Up @@ -296,13 +297,15 @@ public void addPainlessClass(Class<?> clazz, boolean importClassName) {

if (existingPainlessClassBuilder == null) {
PainlessClassBuilder painlessClassBuilder = new PainlessClassBuilder();
painlessClassBuilder.annotations.putAll(annotations);

canonicalClassNamesToClasses.put(canonicalClassName.intern(), clazz);
classesToPainlessClassBuilders.put(clazz, painlessClassBuilder);
}

String javaClassName = clazz.getName();
String importedCanonicalClassName = javaClassName.substring(javaClassName.lastIndexOf('.') + 1).replace('$', '.');
boolean importClassName = annotations.containsKey(NoImportAnnotation.class) == false;

if (canonicalClassName.equals(importedCanonicalClassName)) {
if (importClassName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.painless.lookup.PainlessConstructor;
import org.elasticsearch.painless.lookup.PainlessField;
import org.elasticsearch.painless.lookup.PainlessInstanceBinding;
import org.elasticsearch.painless.lookup.PainlessLookup;
import org.elasticsearch.painless.lookup.PainlessLookupUtility;
import org.elasticsearch.painless.lookup.PainlessMethod;
import org.elasticsearch.painless.lookup.def;
Expand Down Expand Up @@ -69,6 +70,7 @@
import org.elasticsearch.painless.node.SThrow;
import org.elasticsearch.painless.node.STry;
import org.elasticsearch.painless.node.SWhile;
import org.elasticsearch.painless.spi.annotation.DynamicTypeAnnotation;
import org.elasticsearch.painless.spi.annotation.NonDeterministicAnnotation;
import org.elasticsearch.painless.symbol.Decorations;
import org.elasticsearch.painless.symbol.Decorations.AllEscape;
Expand All @@ -83,6 +85,7 @@
import org.elasticsearch.painless.symbol.Decorations.ContinuousLoop;
import org.elasticsearch.painless.symbol.Decorations.DefOptimized;
import org.elasticsearch.painless.symbol.Decorations.DowncastPainlessCast;
import org.elasticsearch.painless.symbol.Decorations.DynamicInvocation;
import org.elasticsearch.painless.symbol.Decorations.EncodingDecoration;
import org.elasticsearch.painless.symbol.Decorations.Explicit;
import org.elasticsearch.painless.symbol.Decorations.ExpressionPainlessCast;
Expand Down Expand Up @@ -140,6 +143,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;

Expand Down Expand Up @@ -2525,9 +2529,8 @@ public void visitDot(EDot userDotNode, SemanticScope semanticScope) {
}
} else if (prefixValueType != null && prefixValueType.getValueType() == def.class) {
TargetType targetType = userDotNode.isNullSafe() ? null : semanticScope.getDecoration(userDotNode, TargetType.class);
// TODO: remove ZonedDateTime exception when JodaCompatibleDateTime is removed
valueType = targetType == null || targetType.getTargetType() == ZonedDateTime.class ||
semanticScope.getCondition(userDotNode, Explicit.class) ? def.class : targetType.getTargetType();
valueType = targetType == null || semanticScope.getCondition(userDotNode, Explicit.class) ?
def.class : targetType.getTargetType();

if (write) {
semanticScope.setCondition(userDotNode, DefOptimized.class);
Expand Down Expand Up @@ -2888,9 +2891,45 @@ public void visitCall(ECall userCallNode, SemanticScope semanticScope) {
"[" + semanticScope.getDecoration(userPrefixNode, PartialCanonicalTypeName.class).getPartialCanonicalTypeName() + "]"));
}

boolean dynamic = false;
PainlessMethod method = null;

if (prefixValueType != null) {
Class<?> type = prefixValueType.getValueType();
PainlessLookup lookup = semanticScope.getScriptScope().getPainlessLookup();

if (prefixValueType.getValueType() == def.class) {
dynamic = true;
} else {
method = lookup.lookupPainlessMethod(type, false, methodName, userArgumentsSize);

if (method == null) {
dynamic = lookup.lookupPainlessClass(type).annotations.containsKey(DynamicTypeAnnotation.class) &&
lookup.lookupPainlessSubClassesMethod(type, methodName, userArgumentsSize) != null;

if (dynamic == false) {
throw userCallNode.createError(new IllegalArgumentException("member method " +
"[" + prefixValueType.getValueCanonicalTypeName() + ", " + methodName + "/" + userArgumentsSize + "] " +
"not found"));
}
}
}
} else if (prefixStaticType != null) {
method = semanticScope.getScriptScope().getPainlessLookup().lookupPainlessMethod(
prefixStaticType.getStaticType(), true, methodName, userArgumentsSize);

if (method == null) {
throw userCallNode.createError(new IllegalArgumentException("static method " +
"[" + prefixStaticType.getStaticCanonicalTypeName() + ", " + methodName + "/" + userArgumentsSize + "] " +
"not found"));
}
} else {
throw userCallNode.createError(new IllegalStateException("value required: instead found no value"));
}

Class<?> valueType;

if (prefixValueType != null && prefixValueType.getValueType() == def.class) {
if (dynamic) {
for (AExpression userArgumentNode : userArgumentNodes) {
semanticScope.setCondition(userArgumentNode, Read.class);
semanticScope.setCondition(userArgumentNode, Internal.class);
Expand All @@ -2904,34 +2943,12 @@ public void visitCall(ECall userCallNode, SemanticScope semanticScope) {
}

TargetType targetType = userCallNode.isNullSafe() ? null : semanticScope.getDecoration(userCallNode, TargetType.class);
// TODO: remove ZonedDateTime exception when JodaCompatibleDateTime is removed
valueType = targetType == null || targetType.getTargetType() == ZonedDateTime.class ||
semanticScope.getCondition(userCallNode, Explicit.class) ? def.class : targetType.getTargetType();
} else {
PainlessMethod method;

if (prefixValueType != null) {
method = semanticScope.getScriptScope().getPainlessLookup().lookupPainlessMethod(
prefixValueType.getValueType(), false, methodName, userArgumentsSize);

if (method == null) {
throw userCallNode.createError(new IllegalArgumentException("member method " +
"[" + prefixValueType.getValueCanonicalTypeName() + ", " + methodName + "/" + userArgumentsSize + "] " +
"not found"));
}
} else if (prefixStaticType != null) {
method = semanticScope.getScriptScope().getPainlessLookup().lookupPainlessMethod(
prefixStaticType.getStaticType(), true, methodName, userArgumentsSize);

if (method == null) {
throw userCallNode.createError(new IllegalArgumentException("static method " +
"[" + prefixStaticType.getStaticCanonicalTypeName() + ", " + methodName + "/" + userArgumentsSize + "] " +
"not found"));
}
} else {
throw userCallNode.createError(new IllegalStateException("value required: instead found no value"));
}
valueType = targetType == null || semanticScope.getCondition(userCallNode, Explicit.class) ?
def.class : targetType.getTargetType();

semanticScope.setCondition(userCallNode, DynamicInvocation.class);
} else {
Objects.requireNonNull(method);
semanticScope.getScriptScope().markNonDeterministic(method.annotations.containsKey(NonDeterministicAnnotation.class));

for (int argument = 0; argument < userArgumentsSize; ++argument) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
import org.elasticsearch.painless.symbol.Decorations.CompoundType;
import org.elasticsearch.painless.symbol.Decorations.ContinuousLoop;
import org.elasticsearch.painless.symbol.Decorations.DowncastPainlessCast;
import org.elasticsearch.painless.symbol.Decorations.DynamicInvocation;
import org.elasticsearch.painless.symbol.Decorations.EncodingDecoration;
import org.elasticsearch.painless.symbol.Decorations.Explicit;
import org.elasticsearch.painless.symbol.Decorations.ExpressionPainlessCast;
Expand Down Expand Up @@ -1802,7 +1803,7 @@ public void visitCall(ECall userCallNode, ScriptScope scriptScope) {
ValueType prefixValueType = scriptScope.getDecoration(userCallNode.getPrefixNode(), ValueType.class);
Class<?> valueType = scriptScope.getDecoration(userCallNode, ValueType.class).getValueType();

if (prefixValueType != null && prefixValueType.getValueType() == def.class) {
if (scriptScope.getCondition(userCallNode, DynamicInvocation.class)) {
InvokeCallDefNode irCallSubDefNode = new InvokeCallDefNode(userCallNode.getLocation());

for (AExpression userArgumentNode : userCallNode.getArgumentNodes()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ public PainlessMethod getStandardPainlessMethod() {
}
}

public interface DynamicInvocation extends Condition {

}

public static class GetterPainlessMethod implements Decoration {

private final PainlessMethod getterPainlessMethod;
Expand Down
Loading

0 comments on commit f6fbeb8

Please sign in to comment.