Skip to content

Commit

Permalink
Add a direct sub classes data structure to the Painless lookup (#76955)
Browse files Browse the repository at this point in the history
This change has two main components.

The first is to have method/field resolution for compile-time and run-time use the same code path for
now. This removes copying of member methods between super and sub classes and instead does a
resolution through the class hierarchy. This allows us to correctly implement the next change.

The second is a data structure that allows for the lookup of direct sub classes for all allow listed
classes/interfaces within Painless.
  • Loading branch information
jdconrad committed Aug 26, 2021
1 parent 71a5982 commit a4d9c1b
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import org.elasticsearch.common.util.CollectionUtils;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
Expand All @@ -27,6 +31,7 @@ public final class PainlessLookup {
private final Map<String, Class<?>> javaClassNamesToClasses;
private final Map<String, Class<?>> canonicalClassNamesToClasses;
private final Map<Class<?>, PainlessClass> classesToPainlessClasses;
private final Map<Class<?>, Set<Class<?>>> classesToDirectSubClasses;

private final Map<String, PainlessMethod> painlessMethodKeysToImportedPainlessMethods;
private final Map<String, PainlessClassBinding> painlessMethodKeysToPainlessClassBindings;
Expand All @@ -36,13 +41,15 @@ public final class PainlessLookup {
Map<String, Class<?>> javaClassNamesToClasses,
Map<String, Class<?>> canonicalClassNamesToClasses,
Map<Class<?>, PainlessClass> classesToPainlessClasses,
Map<Class<?>, Set<Class<?>>> classesToDirectSubClasses,
Map<String, PainlessMethod> painlessMethodKeysToImportedPainlessMethods,
Map<String, PainlessClassBinding> painlessMethodKeysToPainlessClassBindings,
Map<String, PainlessInstanceBinding> painlessMethodKeysToPainlessInstanceBindings) {

Objects.requireNonNull(javaClassNamesToClasses);
Objects.requireNonNull(canonicalClassNamesToClasses);
Objects.requireNonNull(classesToPainlessClasses);
Objects.requireNonNull(classesToDirectSubClasses);

Objects.requireNonNull(painlessMethodKeysToImportedPainlessMethods);
Objects.requireNonNull(painlessMethodKeysToPainlessClassBindings);
Expand All @@ -51,6 +58,7 @@ public final class PainlessLookup {
this.javaClassNamesToClasses = javaClassNamesToClasses;
this.canonicalClassNamesToClasses = CollectionUtils.copyMap(canonicalClassNamesToClasses);
this.classesToPainlessClasses = CollectionUtils.copyMap(classesToPainlessClasses);
this.classesToDirectSubClasses = CollectionUtils.copyMap(classesToDirectSubClasses);

this.painlessMethodKeysToImportedPainlessMethods = CollectionUtils.copyMap(painlessMethodKeysToImportedPainlessMethods);
this.painlessMethodKeysToPainlessClassBindings = CollectionUtils.copyMap(painlessMethodKeysToPainlessClassBindings);
Expand All @@ -77,6 +85,10 @@ public Set<Class<?>> getClasses() {
return classesToPainlessClasses.keySet();
}

public Set<Class<?>> getDirectSubClasses(Class<?> superClass) {
return classesToDirectSubClasses.get(superClass);
}

public Set<String> getImportedPainlessMethodsKeys() {
return painlessMethodKeysToImportedPainlessMethods.keySet();
}
Expand Down Expand Up @@ -144,16 +156,12 @@ public PainlessMethod lookupPainlessMethod(Class<?> targetClass, boolean isStati
targetClass = typeToBoxedType(targetClass);
}

PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetClass);
String painlessMethodKey = buildPainlessMethodKey(methodName, methodArity);
Function<PainlessClass, PainlessMethod> objectLookup = isStatic ?
targetPainlessClass -> targetPainlessClass.staticMethods.get(painlessMethodKey) :
targetPainlessClass -> targetPainlessClass.methods.get(painlessMethodKey);

if (targetPainlessClass == null) {
return null;
}

return isStatic ?
targetPainlessClass.staticMethods.get(painlessMethodKey) :
targetPainlessClass.methods.get(painlessMethodKey);
return lookupPainlessObject(targetClass, objectLookup);
}

public PainlessField lookupPainlessField(String targetCanonicalClassName, boolean isStatic, String fieldName) {
Expand All @@ -172,22 +180,12 @@ public PainlessField lookupPainlessField(Class<?> targetClass, boolean isStatic,
Objects.requireNonNull(targetClass);
Objects.requireNonNull(fieldName);

PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetClass);
String painlessFieldKey = buildPainlessFieldKey(fieldName);
Function<PainlessClass, PainlessField> objectLookup = isStatic ?
targetPainlessClass -> targetPainlessClass.staticFields.get(painlessFieldKey) :
targetPainlessClass -> targetPainlessClass.fields.get(painlessFieldKey);

if (targetPainlessClass == null) {
return null;
}

PainlessField painlessField = isStatic ?
targetPainlessClass.staticFields.get(painlessFieldKey) :
targetPainlessClass.fields.get(painlessFieldKey);

if (painlessField == null) {
return null;
}

return painlessField;
return lookupPainlessObject(targetClass, objectLookup);
}

public PainlessMethod lookupImportedPainlessMethod(String methodName, int arity) {
Expand Down Expand Up @@ -232,7 +230,7 @@ public PainlessMethod lookupRuntimePainlessMethod(Class<?> originalTargetClass,
Function<PainlessClass, PainlessMethod> objectLookup =
targetPainlessClass -> targetPainlessClass.runtimeMethods.get(painlessMethodKey);

return lookupRuntimePainlessObject(originalTargetClass, objectLookup);
return lookupPainlessObject(originalTargetClass, objectLookup);
}

public MethodHandle lookupRuntimeGetterMethodHandle(Class<?> originalTargetClass, String getterName) {
Expand All @@ -241,7 +239,7 @@ public MethodHandle lookupRuntimeGetterMethodHandle(Class<?> originalTargetClass

Function<PainlessClass, MethodHandle> objectLookup = targetPainlessClass -> targetPainlessClass.getterMethodHandles.get(getterName);

return lookupRuntimePainlessObject(originalTargetClass, objectLookup);
return lookupPainlessObject(originalTargetClass, objectLookup);
}

public MethodHandle lookupRuntimeSetterMethodHandle(Class<?> originalTargetClass, String setterName) {
Expand All @@ -250,10 +248,13 @@ public MethodHandle lookupRuntimeSetterMethodHandle(Class<?> originalTargetClass

Function<PainlessClass, MethodHandle> objectLookup = targetPainlessClass -> targetPainlessClass.setterMethodHandles.get(setterName);

return lookupRuntimePainlessObject(originalTargetClass, objectLookup);
return lookupPainlessObject(originalTargetClass, objectLookup);
}

private <T> T lookupRuntimePainlessObject(Class<?> originalTargetClass, Function<PainlessClass, T> objectLookup) {
private <T> T lookupPainlessObject(Class<?> originalTargetClass, Function<PainlessClass, T> objectLookup) {
Objects.requireNonNull(originalTargetClass);
Objects.requireNonNull(objectLookup);

Class<?> currentTargetClass = originalTargetClass;

while (currentTargetClass != null) {
Expand All @@ -270,17 +271,38 @@ private <T> T lookupRuntimePainlessObject(Class<?> originalTargetClass, Function
currentTargetClass = currentTargetClass.getSuperclass();
}

if (originalTargetClass.isInterface()) {
PainlessClass targetPainlessClass = classesToPainlessClasses.get(Object.class);

if (targetPainlessClass != null) {
T painlessObject = objectLookup.apply(targetPainlessClass);

if (painlessObject != null) {
return painlessObject;
}
}
}

currentTargetClass = originalTargetClass;
Set<Class<?>> resolvedInterfaces = new HashSet<>();

while (currentTargetClass != null) {
for (Class<?> targetInterface : currentTargetClass.getInterfaces()) {
PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetInterface);
List<Class<?>> targetInterfaces = new ArrayList<>(Arrays.asList(currentTargetClass.getInterfaces()));

while (targetInterfaces.isEmpty() == false) {
Class<?> targetInterface = targetInterfaces.remove(0);

if (resolvedInterfaces.add(targetInterface)) {
PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetInterface);

if (targetPainlessClass != null) {
T painlessObject = objectLookup.apply(targetPainlessClass);

if (targetPainlessClass != null) {
T painlessObject = objectLookup.apply(targetPainlessClass);
if (painlessObject != null) {
return painlessObject;
}

if (painlessObject != null) {
return painlessObject;
targetInterfaces.addAll(Arrays.asList(targetInterface.getInterfaces()));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@
import java.security.SecureClassLoader;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -189,6 +192,7 @@ public static PainlessLookup buildFromWhitelists(List<Whitelist> whitelists) {
// of the values of javaClassNamesToClasses.
private final Map<String, Class<?>> canonicalClassNamesToClasses;
private final Map<Class<?>, PainlessClassBuilder> classesToPainlessClassBuilders;
private final Map<Class<?>, Set<Class<?>>> classesToDirectSubClasses;

private final Map<String, PainlessMethod> painlessMethodKeysToImportedPainlessMethods;
private final Map<String, PainlessClassBinding> painlessMethodKeysToPainlessClassBindings;
Expand All @@ -198,6 +202,7 @@ public PainlessLookupBuilder() {
javaClassNamesToClasses = new HashMap<>();
canonicalClassNamesToClasses = new HashMap<>();
classesToPainlessClassBuilders = new HashMap<>();
classesToDirectSubClasses = new HashMap<>();

painlessMethodKeysToImportedPainlessMethods = new HashMap<>();
painlessMethodKeysToPainlessClassBindings = new HashMap<>();
Expand Down Expand Up @@ -1255,7 +1260,7 @@ public void addPainlessInstanceBinding(
}

public PainlessLookup build() {
copyPainlessClassMembers();
buildPainlessClassHierarchy();
setFunctionalInterfaceMethods();
generateRuntimeMethods();
cacheRuntimeHandles();
Expand Down Expand Up @@ -1286,71 +1291,66 @@ public PainlessLookup build() {
javaClassNamesToClasses,
canonicalClassNamesToClasses,
classesToPainlessClasses,
classesToDirectSubClasses,
painlessMethodKeysToImportedPainlessMethods,
painlessMethodKeysToPainlessClassBindings,
painlessMethodKeysToPainlessInstanceBindings);
}

private void copyPainlessClassMembers() {
for (Class<?> parentClass : classesToPainlessClassBuilders.keySet()) {
copyPainlessInterfaceMembers(parentClass, parentClass);

Class<?> childClass = parentClass.getSuperclass();

while (childClass != null) {
if (classesToPainlessClassBuilders.containsKey(childClass)) {
copyPainlessClassMembers(childClass, parentClass);
}

copyPainlessInterfaceMembers(childClass, parentClass);
childClass = childClass.getSuperclass();
}
}

for (Class<?> javaClass : classesToPainlessClassBuilders.keySet()) {
if (javaClass.isInterface()) {
copyPainlessClassMembers(Object.class, javaClass);
}
}
}

private void copyPainlessInterfaceMembers(Class<?> parentClass, Class<?> targetClass) {
for (Class<?> childClass : parentClass.getInterfaces()) {
if (classesToPainlessClassBuilders.containsKey(childClass)) {
copyPainlessClassMembers(childClass, targetClass);
}

copyPainlessInterfaceMembers(childClass, targetClass);
private void buildPainlessClassHierarchy() {
for (Class<?> targetClass : classesToPainlessClassBuilders.keySet()) {
classesToDirectSubClasses.put(targetClass, new HashSet<>());
}
}

private void copyPainlessClassMembers(Class<?> originalClass, Class<?> targetClass) {
PainlessClassBuilder originalPainlessClassBuilder = classesToPainlessClassBuilders.get(originalClass);
PainlessClassBuilder targetPainlessClassBuilder = classesToPainlessClassBuilders.get(targetClass);
for (Class<?> subClass : classesToPainlessClassBuilders.keySet()) {
List<Class<?>> superInterfaces = new ArrayList<>(Arrays.asList(subClass.getInterfaces()));

Objects.requireNonNull(originalPainlessClassBuilder);
Objects.requireNonNull(targetPainlessClassBuilder);
// we check for Object.class as part of the allow listed classes because
// it is possible for the compiler to work without Object
if (subClass.isInterface() && superInterfaces.isEmpty() && classesToPainlessClassBuilders.containsKey(Object.class)) {
classesToDirectSubClasses.get(Object.class).add(subClass);
} else {
Class<?> superClass = subClass.getSuperclass();

// this finds the nearest super class for a given sub class
// because the allow list may have gaps between classes
// example:
// class A {} // allowed
// class B extends A // not allowed
// class C extends B // allowed
// in this case C is considered a direct sub class of A
while (superClass != null) {
if (classesToPainlessClassBuilders.containsKey(superClass)) {
break;
} else {
// this ensures all interfaces from a sub class that
// is not allow listed are checked if they are
// considered a direct super class of the sub class
// because these interfaces may still be allow listed
// even if their sub class is not
superInterfaces.addAll(Arrays.asList(superClass.getInterfaces()));
}

for (Map.Entry<String, PainlessMethod> painlessMethodEntry : originalPainlessClassBuilder.methods.entrySet()) {
String painlessMethodKey = painlessMethodEntry.getKey();
PainlessMethod newPainlessMethod = painlessMethodEntry.getValue();
PainlessMethod existingPainlessMethod = targetPainlessClassBuilder.methods.get(painlessMethodKey);
superClass = superClass.getSuperclass();
}

if (existingPainlessMethod == null || existingPainlessMethod.targetClass != newPainlessMethod.targetClass &&
existingPainlessMethod.targetClass.isAssignableFrom(newPainlessMethod.targetClass)) {
targetPainlessClassBuilder.methods.put(painlessMethodKey.intern(), newPainlessMethod);
if (superClass != null) {
classesToDirectSubClasses.get(superClass).add(subClass);
}
}
}

for (Map.Entry<String, PainlessField> painlessFieldEntry : originalPainlessClassBuilder.fields.entrySet()) {
String painlessFieldKey = painlessFieldEntry.getKey();
PainlessField newPainlessField = painlessFieldEntry.getValue();
PainlessField existingPainlessField = targetPainlessClassBuilder.fields.get(painlessFieldKey);
Set<Class<?>> resolvedInterfaces = new HashSet<>();

while (superInterfaces.isEmpty() == false) {
Class<?> superInterface = superInterfaces.remove(0);

if (existingPainlessField == null ||
existingPainlessField.javaField.getDeclaringClass() != newPainlessField.javaField.getDeclaringClass() &&
existingPainlessField.javaField.getDeclaringClass().isAssignableFrom(newPainlessField.javaField.getDeclaringClass())) {
targetPainlessClassBuilder.fields.put(painlessFieldKey.intern(), newPainlessField);
if (resolvedInterfaces.add(superInterface)) {
if (classesToPainlessClassBuilders.containsKey(superInterface)) {
classesToDirectSubClasses.get(superInterface).add(subClass);
} else {
superInterfaces.addAll(Arrays.asList(superInterface.getInterfaces()));
}
}
}
}
}
Expand Down
Loading

0 comments on commit a4d9c1b

Please sign in to comment.