From a4d9c1bec8df85cea8955a7123c07839728da3c2 Mon Sep 17 00:00:00 2001 From: Jack Conradson Date: Thu, 26 Aug 2021 12:09:43 -0700 Subject: [PATCH] Add a direct sub classes data structure to the Painless lookup (#76955) This change has two main components. The first is to have method/field resolution for compile-time and run-time use the same code path for now. This removes copying of member methods between super and sub classes and instead does a resolution through the class hierarchy. This allows us to correctly implement the next change. The second is a data structure that allows for the lookup of direct sub classes for all allow listed classes/interfaces within Painless. --- .../painless/lookup/PainlessLookup.java | 86 ++++++++----- .../lookup/PainlessLookupBuilder.java | 106 ++++++++-------- .../elasticsearch/painless/LookupTests.java | 116 ++++++++++++++++++ .../org.elasticsearch.painless.lookup | 35 ++++++ 4 files changed, 258 insertions(+), 85 deletions(-) create mode 100644 modules/lang-painless/src/test/java/org/elasticsearch/painless/LookupTests.java create mode 100644 modules/lang-painless/src/test/resources/org/elasticsearch/painless/org.elasticsearch.painless.lookup diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookup.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookup.java index d94185d8014ae..95b0735524b8c 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookup.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookup.java @@ -11,6 +11,10 @@ import org.elasticsearch.common.util.CollectionUtils; import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -27,6 +31,7 @@ public final class PainlessLookup { private final Map> javaClassNamesToClasses; private final Map> canonicalClassNamesToClasses; private final Map, PainlessClass> classesToPainlessClasses; + private final Map, Set>> classesToDirectSubClasses; private final Map painlessMethodKeysToImportedPainlessMethods; private final Map painlessMethodKeysToPainlessClassBindings; @@ -36,6 +41,7 @@ public final class PainlessLookup { Map> javaClassNamesToClasses, Map> canonicalClassNamesToClasses, Map, PainlessClass> classesToPainlessClasses, + Map, Set>> classesToDirectSubClasses, Map painlessMethodKeysToImportedPainlessMethods, Map painlessMethodKeysToPainlessClassBindings, Map painlessMethodKeysToPainlessInstanceBindings) { @@ -43,6 +49,7 @@ public final class PainlessLookup { Objects.requireNonNull(javaClassNamesToClasses); Objects.requireNonNull(canonicalClassNamesToClasses); Objects.requireNonNull(classesToPainlessClasses); + Objects.requireNonNull(classesToDirectSubClasses); Objects.requireNonNull(painlessMethodKeysToImportedPainlessMethods); Objects.requireNonNull(painlessMethodKeysToPainlessClassBindings); @@ -51,6 +58,7 @@ public final class PainlessLookup { this.javaClassNamesToClasses = javaClassNamesToClasses; this.canonicalClassNamesToClasses = CollectionUtils.copyMap(canonicalClassNamesToClasses); this.classesToPainlessClasses = CollectionUtils.copyMap(classesToPainlessClasses); + this.classesToDirectSubClasses = CollectionUtils.copyMap(classesToDirectSubClasses); this.painlessMethodKeysToImportedPainlessMethods = CollectionUtils.copyMap(painlessMethodKeysToImportedPainlessMethods); this.painlessMethodKeysToPainlessClassBindings = CollectionUtils.copyMap(painlessMethodKeysToPainlessClassBindings); @@ -77,6 +85,10 @@ public Set> getClasses() { return classesToPainlessClasses.keySet(); } + public Set> getDirectSubClasses(Class superClass) { + return classesToDirectSubClasses.get(superClass); + } + public Set getImportedPainlessMethodsKeys() { return painlessMethodKeysToImportedPainlessMethods.keySet(); } @@ -144,16 +156,12 @@ public PainlessMethod lookupPainlessMethod(Class targetClass, boolean isStati targetClass = typeToBoxedType(targetClass); } - PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetClass); String painlessMethodKey = buildPainlessMethodKey(methodName, methodArity); + Function objectLookup = isStatic ? + targetPainlessClass -> targetPainlessClass.staticMethods.get(painlessMethodKey) : + targetPainlessClass -> targetPainlessClass.methods.get(painlessMethodKey); - if (targetPainlessClass == null) { - return null; - } - - return isStatic ? - targetPainlessClass.staticMethods.get(painlessMethodKey) : - targetPainlessClass.methods.get(painlessMethodKey); + return lookupPainlessObject(targetClass, objectLookup); } public PainlessField lookupPainlessField(String targetCanonicalClassName, boolean isStatic, String fieldName) { @@ -172,22 +180,12 @@ public PainlessField lookupPainlessField(Class targetClass, boolean isStatic, Objects.requireNonNull(targetClass); Objects.requireNonNull(fieldName); - PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetClass); String painlessFieldKey = buildPainlessFieldKey(fieldName); + Function objectLookup = isStatic ? + targetPainlessClass -> targetPainlessClass.staticFields.get(painlessFieldKey) : + targetPainlessClass -> targetPainlessClass.fields.get(painlessFieldKey); - if (targetPainlessClass == null) { - return null; - } - - PainlessField painlessField = isStatic ? - targetPainlessClass.staticFields.get(painlessFieldKey) : - targetPainlessClass.fields.get(painlessFieldKey); - - if (painlessField == null) { - return null; - } - - return painlessField; + return lookupPainlessObject(targetClass, objectLookup); } public PainlessMethod lookupImportedPainlessMethod(String methodName, int arity) { @@ -232,7 +230,7 @@ public PainlessMethod lookupRuntimePainlessMethod(Class originalTargetClass, Function objectLookup = targetPainlessClass -> targetPainlessClass.runtimeMethods.get(painlessMethodKey); - return lookupRuntimePainlessObject(originalTargetClass, objectLookup); + return lookupPainlessObject(originalTargetClass, objectLookup); } public MethodHandle lookupRuntimeGetterMethodHandle(Class originalTargetClass, String getterName) { @@ -241,7 +239,7 @@ public MethodHandle lookupRuntimeGetterMethodHandle(Class originalTargetClass Function objectLookup = targetPainlessClass -> targetPainlessClass.getterMethodHandles.get(getterName); - return lookupRuntimePainlessObject(originalTargetClass, objectLookup); + return lookupPainlessObject(originalTargetClass, objectLookup); } public MethodHandle lookupRuntimeSetterMethodHandle(Class originalTargetClass, String setterName) { @@ -250,10 +248,13 @@ public MethodHandle lookupRuntimeSetterMethodHandle(Class originalTargetClass Function objectLookup = targetPainlessClass -> targetPainlessClass.setterMethodHandles.get(setterName); - return lookupRuntimePainlessObject(originalTargetClass, objectLookup); + return lookupPainlessObject(originalTargetClass, objectLookup); } - private T lookupRuntimePainlessObject(Class originalTargetClass, Function objectLookup) { + private T lookupPainlessObject(Class originalTargetClass, Function objectLookup) { + Objects.requireNonNull(originalTargetClass); + Objects.requireNonNull(objectLookup); + Class currentTargetClass = originalTargetClass; while (currentTargetClass != null) { @@ -270,17 +271,38 @@ private T lookupRuntimePainlessObject(Class originalTargetClass, Function currentTargetClass = currentTargetClass.getSuperclass(); } + if (originalTargetClass.isInterface()) { + PainlessClass targetPainlessClass = classesToPainlessClasses.get(Object.class); + + if (targetPainlessClass != null) { + T painlessObject = objectLookup.apply(targetPainlessClass); + + if (painlessObject != null) { + return painlessObject; + } + } + } + currentTargetClass = originalTargetClass; + Set> resolvedInterfaces = new HashSet<>(); while (currentTargetClass != null) { - for (Class targetInterface : currentTargetClass.getInterfaces()) { - PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetInterface); + List> targetInterfaces = new ArrayList<>(Arrays.asList(currentTargetClass.getInterfaces())); + + while (targetInterfaces.isEmpty() == false) { + Class targetInterface = targetInterfaces.remove(0); + + if (resolvedInterfaces.add(targetInterface)) { + PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetInterface); + + if (targetPainlessClass != null) { + T painlessObject = objectLookup.apply(targetPainlessClass); - if (targetPainlessClass != null) { - T painlessObject = objectLookup.apply(targetPainlessClass); + if (painlessObject != null) { + return painlessObject; + } - if (painlessObject != null) { - return painlessObject; + targetInterfaces.addAll(Arrays.asList(targetInterface.getInterfaces())); } } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookupBuilder.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookupBuilder.java index a7390b5415870..b819b1e134048 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookupBuilder.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookupBuilder.java @@ -42,11 +42,14 @@ import java.security.SecureClassLoader; import java.security.cert.Certificate; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.Supplier; import java.util.regex.Pattern; @@ -189,6 +192,7 @@ public static PainlessLookup buildFromWhitelists(List whitelists) { // of the values of javaClassNamesToClasses. private final Map> canonicalClassNamesToClasses; private final Map, PainlessClassBuilder> classesToPainlessClassBuilders; + private final Map, Set>> classesToDirectSubClasses; private final Map painlessMethodKeysToImportedPainlessMethods; private final Map painlessMethodKeysToPainlessClassBindings; @@ -198,6 +202,7 @@ public PainlessLookupBuilder() { javaClassNamesToClasses = new HashMap<>(); canonicalClassNamesToClasses = new HashMap<>(); classesToPainlessClassBuilders = new HashMap<>(); + classesToDirectSubClasses = new HashMap<>(); painlessMethodKeysToImportedPainlessMethods = new HashMap<>(); painlessMethodKeysToPainlessClassBindings = new HashMap<>(); @@ -1255,7 +1260,7 @@ public void addPainlessInstanceBinding( } public PainlessLookup build() { - copyPainlessClassMembers(); + buildPainlessClassHierarchy(); setFunctionalInterfaceMethods(); generateRuntimeMethods(); cacheRuntimeHandles(); @@ -1286,71 +1291,66 @@ public PainlessLookup build() { javaClassNamesToClasses, canonicalClassNamesToClasses, classesToPainlessClasses, + classesToDirectSubClasses, painlessMethodKeysToImportedPainlessMethods, painlessMethodKeysToPainlessClassBindings, painlessMethodKeysToPainlessInstanceBindings); } - private void copyPainlessClassMembers() { - for (Class parentClass : classesToPainlessClassBuilders.keySet()) { - copyPainlessInterfaceMembers(parentClass, parentClass); - - Class childClass = parentClass.getSuperclass(); - - while (childClass != null) { - if (classesToPainlessClassBuilders.containsKey(childClass)) { - copyPainlessClassMembers(childClass, parentClass); - } - - copyPainlessInterfaceMembers(childClass, parentClass); - childClass = childClass.getSuperclass(); - } - } - - for (Class javaClass : classesToPainlessClassBuilders.keySet()) { - if (javaClass.isInterface()) { - copyPainlessClassMembers(Object.class, javaClass); - } - } - } - - private void copyPainlessInterfaceMembers(Class parentClass, Class targetClass) { - for (Class childClass : parentClass.getInterfaces()) { - if (classesToPainlessClassBuilders.containsKey(childClass)) { - copyPainlessClassMembers(childClass, targetClass); - } - - copyPainlessInterfaceMembers(childClass, targetClass); + private void buildPainlessClassHierarchy() { + for (Class targetClass : classesToPainlessClassBuilders.keySet()) { + classesToDirectSubClasses.put(targetClass, new HashSet<>()); } - } - private void copyPainlessClassMembers(Class originalClass, Class targetClass) { - PainlessClassBuilder originalPainlessClassBuilder = classesToPainlessClassBuilders.get(originalClass); - PainlessClassBuilder targetPainlessClassBuilder = classesToPainlessClassBuilders.get(targetClass); + for (Class subClass : classesToPainlessClassBuilders.keySet()) { + List> superInterfaces = new ArrayList<>(Arrays.asList(subClass.getInterfaces())); - Objects.requireNonNull(originalPainlessClassBuilder); - Objects.requireNonNull(targetPainlessClassBuilder); + // we check for Object.class as part of the allow listed classes because + // it is possible for the compiler to work without Object + if (subClass.isInterface() && superInterfaces.isEmpty() && classesToPainlessClassBuilders.containsKey(Object.class)) { + classesToDirectSubClasses.get(Object.class).add(subClass); + } else { + Class superClass = subClass.getSuperclass(); + + // this finds the nearest super class for a given sub class + // because the allow list may have gaps between classes + // example: + // class A {} // allowed + // class B extends A // not allowed + // class C extends B // allowed + // in this case C is considered a direct sub class of A + while (superClass != null) { + if (classesToPainlessClassBuilders.containsKey(superClass)) { + break; + } else { + // this ensures all interfaces from a sub class that + // is not allow listed are checked if they are + // considered a direct super class of the sub class + // because these interfaces may still be allow listed + // even if their sub class is not + superInterfaces.addAll(Arrays.asList(superClass.getInterfaces())); + } - for (Map.Entry painlessMethodEntry : originalPainlessClassBuilder.methods.entrySet()) { - String painlessMethodKey = painlessMethodEntry.getKey(); - PainlessMethod newPainlessMethod = painlessMethodEntry.getValue(); - PainlessMethod existingPainlessMethod = targetPainlessClassBuilder.methods.get(painlessMethodKey); + superClass = superClass.getSuperclass(); + } - if (existingPainlessMethod == null || existingPainlessMethod.targetClass != newPainlessMethod.targetClass && - existingPainlessMethod.targetClass.isAssignableFrom(newPainlessMethod.targetClass)) { - targetPainlessClassBuilder.methods.put(painlessMethodKey.intern(), newPainlessMethod); + if (superClass != null) { + classesToDirectSubClasses.get(superClass).add(subClass); + } } - } - for (Map.Entry painlessFieldEntry : originalPainlessClassBuilder.fields.entrySet()) { - String painlessFieldKey = painlessFieldEntry.getKey(); - PainlessField newPainlessField = painlessFieldEntry.getValue(); - PainlessField existingPainlessField = targetPainlessClassBuilder.fields.get(painlessFieldKey); + Set> resolvedInterfaces = new HashSet<>(); + + while (superInterfaces.isEmpty() == false) { + Class superInterface = superInterfaces.remove(0); - if (existingPainlessField == null || - existingPainlessField.javaField.getDeclaringClass() != newPainlessField.javaField.getDeclaringClass() && - existingPainlessField.javaField.getDeclaringClass().isAssignableFrom(newPainlessField.javaField.getDeclaringClass())) { - targetPainlessClassBuilder.fields.put(painlessFieldKey.intern(), newPainlessField); + if (resolvedInterfaces.add(superInterface)) { + if (classesToPainlessClassBuilders.containsKey(superInterface)) { + classesToDirectSubClasses.get(superInterface).add(subClass); + } else { + superInterfaces.addAll(Arrays.asList(superInterface.getInterfaces())); + } + } } } } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/LookupTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/LookupTests.java new file mode 100644 index 0000000000000..a0ebbb5b25024 --- /dev/null +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/LookupTests.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.painless; + +import org.elasticsearch.painless.lookup.PainlessLookup; +import org.elasticsearch.painless.lookup.PainlessLookupBuilder; +import org.elasticsearch.painless.spi.WhitelistLoader; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +import java.util.Collections; +import java.util.Set; + +public class LookupTests extends ESTestCase { + + protected PainlessLookup painlessLookup; + + @Before + public void setup() { + painlessLookup = PainlessLookupBuilder.buildFromWhitelists(Collections.singletonList( + WhitelistLoader.loadFromResourceFiles(PainlessPlugin.class, "org.elasticsearch.painless.lookup") + )); + } + + public static class A { } // in whitelist + public static class B extends A { } // not in whitelist + public static class C extends B { } // in whitelist + public static class D extends B { } // in whitelist + + public interface Z { } // in whitelist + public interface Y { } // not in whitelist + public interface X extends Y, Z { } // not in whitelist + public interface V extends Y, Z { } // in whitelist + public interface U extends X { } // in whitelist + public interface T extends V { } // in whitelist + public interface S extends U, X { } // in whitelist + + public static class AA implements X { } // in whitelist + public static class AB extends AA implements S { } // not in whitelist + public static class AC extends AB implements V { } // in whitelist + public static class AD implements X, S, T { } // in whitelist + + public void testDirectSubClasses() { + Set> directSubClasses = painlessLookup.getDirectSubClasses(Object.class); + assertEquals(4, directSubClasses.size()); + assertTrue(directSubClasses.contains(A.class)); + assertTrue(directSubClasses.contains(Z.class)); + assertTrue(directSubClasses.contains(AA.class)); + assertTrue(directSubClasses.contains(AD.class)); + + directSubClasses = painlessLookup.getDirectSubClasses(A.class); + assertEquals(2, directSubClasses.size()); + assertTrue(directSubClasses.contains(D.class)); + assertTrue(directSubClasses.contains(C.class)); + + directSubClasses = painlessLookup.getDirectSubClasses(B.class); + assertNull(directSubClasses); + + directSubClasses = painlessLookup.getDirectSubClasses(C.class); + assertTrue(directSubClasses.isEmpty()); + + directSubClasses = painlessLookup.getDirectSubClasses(D.class); + assertTrue(directSubClasses.isEmpty()); + + directSubClasses = painlessLookup.getDirectSubClasses(Z.class); + assertEquals(5, directSubClasses.size()); + assertTrue(directSubClasses.contains(V.class)); + assertTrue(directSubClasses.contains(U.class)); + assertTrue(directSubClasses.contains(S.class)); + assertTrue(directSubClasses.contains(AA.class)); + assertTrue(directSubClasses.contains(AD.class)); + + directSubClasses = painlessLookup.getDirectSubClasses(Y.class); + assertNull(directSubClasses); + + directSubClasses = painlessLookup.getDirectSubClasses(X.class); + assertNull(directSubClasses); + + directSubClasses = painlessLookup.getDirectSubClasses(V.class); + assertEquals(2, directSubClasses.size()); + assertTrue(directSubClasses.contains(T.class)); + assertTrue(directSubClasses.contains(AC.class)); + + directSubClasses = painlessLookup.getDirectSubClasses(U.class); + assertEquals(1, directSubClasses.size()); + assertTrue(directSubClasses.contains(S.class)); + + directSubClasses = painlessLookup.getDirectSubClasses(T.class); + assertEquals(1, directSubClasses.size()); + assertTrue(directSubClasses.contains(AD.class)); + + directSubClasses = painlessLookup.getDirectSubClasses(S.class); + assertEquals(2, directSubClasses.size()); + assertTrue(directSubClasses.contains(AC.class)); + assertTrue(directSubClasses.contains(AD.class)); + + directSubClasses = painlessLookup.getDirectSubClasses(AA.class); + assertEquals(1, directSubClasses.size()); + assertTrue(directSubClasses.contains(AC.class)); + + directSubClasses = painlessLookup.getDirectSubClasses(AB.class); + assertNull(directSubClasses); + + directSubClasses = painlessLookup.getDirectSubClasses(AC.class); + assertTrue(directSubClasses.isEmpty()); + + directSubClasses = painlessLookup.getDirectSubClasses(AD.class); + assertTrue(directSubClasses.isEmpty()); + } +} diff --git a/modules/lang-painless/src/test/resources/org/elasticsearch/painless/org.elasticsearch.painless.lookup b/modules/lang-painless/src/test/resources/org/elasticsearch/painless/org.elasticsearch.painless.lookup new file mode 100644 index 0000000000000..b6a5adc6208b7 --- /dev/null +++ b/modules/lang-painless/src/test/resources/org/elasticsearch/painless/org.elasticsearch.painless.lookup @@ -0,0 +1,35 @@ +class java.lang.Object { +} + +class org.elasticsearch.painless.LookupTests$A { +} + +class org.elasticsearch.painless.LookupTests$C { +} + +class org.elasticsearch.painless.LookupTests$D { +} + +class org.elasticsearch.painless.LookupTests$Z { +} + +class org.elasticsearch.painless.LookupTests$V { +} + +class org.elasticsearch.painless.LookupTests$U { +} + +class org.elasticsearch.painless.LookupTests$T { +} + +class org.elasticsearch.painless.LookupTests$S { +} + +class org.elasticsearch.painless.LookupTests$AA { +} + +class org.elasticsearch.painless.LookupTests$AC { +} + +class org.elasticsearch.painless.LookupTests$AD { +} \ No newline at end of file