Skip to content

Commit

Permalink
Replace bridge methods with filtered methods in Painless (#88100)
Browse files Browse the repository at this point in the history
The invokedynamic instruction does not perfectly follow the Painless casting model opting to add 
bridge methods where necessary to ensure symmetric behavior between compile-time and run-time 
casting using boxed types. This change replaces the specialized class loader and bridge methods using 
filtered method handles instead. This reduces the overall complexity of runtime casting.
  • Loading branch information
jdconrad authored Jul 5, 2022
1 parent 56befd0 commit 7234730
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.lang.invoke.MethodType;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -146,6 +147,8 @@ private ArrayLengthHelper() {}
/** factory for arraylength MethodHandle (intrinsic) from Java 9 (pkg-private for tests) */
static final MethodHandle JAVA9_ARRAY_LENGTH_MH_FACTORY;

public static final Map<Class<?>, MethodHandle> DEF_TO_BOXED_TYPE_IMPLICIT_CAST;

static {
final MethodHandles.Lookup methodHandlesLookup = MethodHandles.publicLookup();

Expand Down Expand Up @@ -182,6 +185,43 @@ private ArrayLengthHelper() {}
arrayLengthMHFactory = null;
}
JAVA9_ARRAY_LENGTH_MH_FACTORY = arrayLengthMHFactory;

Map<Class<?>, MethodHandle> defToBoxedTypeImplicitCast = new HashMap<>();

try {
defToBoxedTypeImplicitCast.put(
Byte.class,
methodHandlesLookup.findStatic(Def.class, "defToByteImplicit", MethodType.methodType(Byte.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Short.class,
methodHandlesLookup.findStatic(Def.class, "defToShortImplicit", MethodType.methodType(Short.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Character.class,
methodHandlesLookup.findStatic(Def.class, "defToCharacterImplicit", MethodType.methodType(Character.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Integer.class,
methodHandlesLookup.findStatic(Def.class, "defToIntegerImplicit", MethodType.methodType(Integer.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Long.class,
methodHandlesLookup.findStatic(Def.class, "defToLongImplicit", MethodType.methodType(Long.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Float.class,
methodHandlesLookup.findStatic(Def.class, "defToFloatImplicit", MethodType.methodType(Float.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Double.class,
methodHandlesLookup.findStatic(Def.class, "defToDoubleImplicit", MethodType.methodType(Double.class, Object.class))
);
} catch (NoSuchMethodException | IllegalAccessException exception) {
throw new IllegalStateException(exception);
}

DEF_TO_BOXED_TYPE_IMPLICIT_CAST = Collections.unmodifiableMap(defToBoxedTypeImplicitCast);
}

/** Hack to rethrow unknown Exceptions from {@link MethodHandle#invokeExact}: */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@

package org.elasticsearch.painless.lookup;

import org.elasticsearch.bootstrap.BootstrapInfo;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.painless.Def;
import org.elasticsearch.painless.MethodWriter;
import org.elasticsearch.painless.WriterConstants;
import org.elasticsearch.painless.spi.Whitelist;
import org.elasticsearch.painless.spi.WhitelistClass;
import org.elasticsearch.painless.spi.WhitelistClassBinding;
Expand All @@ -25,9 +22,6 @@
import org.elasticsearch.painless.spi.annotation.CompileTimeOnlyAnnotation;
import org.elasticsearch.painless.spi.annotation.InjectConstantAnnotation;
import org.elasticsearch.painless.spi.annotation.NoImportAnnotation;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.commons.GeneratorAdapter;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
Expand All @@ -37,13 +31,6 @@
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.AccessController;
import java.security.CodeSource;
import java.security.PrivilegedAction;
import java.security.SecureClassLoader;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -56,15 +43,6 @@
import java.util.function.Supplier;
import java.util.regex.Pattern;

import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_BYTE_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_CHARACTER_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_DOUBLE_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_FLOAT_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_INTEGER_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_LONG_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_SHORT_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_UTIL_TYPE;
import static org.elasticsearch.painless.WriterConstants.OBJECT_TYPE;
import static org.elasticsearch.painless.lookup.PainlessLookupUtility.DEF_CLASS_NAME;
import static org.elasticsearch.painless.lookup.PainlessLookupUtility.buildPainlessConstructorKey;
import static org.elasticsearch.painless.lookup.PainlessLookupUtility.buildPainlessFieldKey;
Expand All @@ -75,42 +53,17 @@

public final class PainlessLookupBuilder {

private static final class BridgeLoader extends SecureClassLoader {
BridgeLoader(ClassLoader parent) {
super(parent);
}

@Override
public Class<?> findClass(String name) throws ClassNotFoundException {
return Def.class.getName().equals(name) ? Def.class : super.findClass(name);
}

Class<?> defineBridge(String name, byte[] bytes) {
return defineClass(name, bytes, 0, bytes.length, CODESOURCE);
}
}

private static final CodeSource CODESOURCE;

private static final Map<PainlessConstructor, PainlessConstructor> painlessConstructorCache = new HashMap<>();
private static final Map<PainlessMethod, PainlessMethod> painlessMethodCache = new HashMap<>();
private static final Map<PainlessField, PainlessField> painlessFieldCache = new HashMap<>();
private static final Map<PainlessClassBinding, PainlessClassBinding> painlessClassBindingCache = new HashMap<>();
private static final Map<PainlessInstanceBinding, PainlessInstanceBinding> painlessInstanceBindingCache = new HashMap<>();
private static final Map<PainlessMethod, PainlessMethod> painlessBridgeCache = new HashMap<>();
private static final Map<PainlessMethod, PainlessMethod> painlessFilteredCache = new HashMap<>();

private static final Pattern CLASS_NAME_PATTERN = Pattern.compile("^[_a-zA-Z][._a-zA-Z0-9]*$");
private static final Pattern METHOD_NAME_PATTERN = Pattern.compile("^[_a-zA-Z][_a-zA-Z0-9]*$");
private static final Pattern FIELD_NAME_PATTERN = Pattern.compile("^[_a-zA-Z][_a-zA-Z0-9]*$");

static {
try {
CODESOURCE = new CodeSource(new URL("file:" + BootstrapInfo.UNTRUSTED_CODEBASE), (Certificate[]) null);
} catch (MalformedURLException mue) {
throw new RuntimeException(mue);
}
}

public static PainlessLookup buildFromWhitelists(List<Whitelist> whitelists) {
PainlessLookupBuilder painlessLookupBuilder = new PainlessLookupBuilder();
String origin = "internal error";
Expand Down Expand Up @@ -2216,7 +2169,9 @@ private void setFunctionalInterfaceMethod(Class<?> targetClass, PainlessClassBui
* run-time resulting from calls with a def type value target.
*/
private void generateRuntimeMethods() {
for (PainlessClassBuilder painlessClassBuilder : classesToPainlessClassBuilders.values()) {
for (Map.Entry<Class<?>, PainlessClassBuilder> painlessClassBuilderEntry : classesToPainlessClassBuilders.entrySet()) {
Class<?> targetClass = painlessClassBuilderEntry.getKey();
PainlessClassBuilder painlessClassBuilder = painlessClassBuilderEntry.getValue();
painlessClassBuilder.runtimeMethods.putAll(painlessClassBuilder.methods);

for (PainlessMethod painlessMethod : painlessClassBuilder.runtimeMethods.values()) {
Expand All @@ -2228,63 +2183,25 @@ private void generateRuntimeMethods() {
|| typeParameter == Long.class
|| typeParameter == Float.class
|| typeParameter == Double.class) {
generateBridgeMethod(painlessClassBuilder, painlessMethod);
generateFilteredMethod(targetClass, painlessClassBuilder, painlessMethod);
}
}
}
}
}

private void generateBridgeMethod(PainlessClassBuilder painlessClassBuilder, PainlessMethod painlessMethod) {
private void generateFilteredMethod(Class<?> targetClass, PainlessClassBuilder painlessClassBuilder, PainlessMethod painlessMethod) {
String painlessMethodKey = buildPainlessMethodKey(painlessMethod.javaMethod().getName(), painlessMethod.typeParameters().size());
PainlessMethod bridgePainlessMethod = painlessBridgeCache.get(painlessMethod);
PainlessMethod filteredPainlessMethod = painlessFilteredCache.get(painlessMethod);

if (bridgePainlessMethod == null) {
if (filteredPainlessMethod == null) {
Method javaMethod = painlessMethod.javaMethod();
boolean isStatic = Modifier.isStatic(painlessMethod.javaMethod().getModifiers());

int bridgeClassFrames = ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS;
int bridgeClassAccess = Opcodes.ACC_PUBLIC | Opcodes.ACC_SUPER | Opcodes.ACC_FINAL;
String bridgeClassName = "org/elasticsearch/painless/Bridge$"
+ javaMethod.getDeclaringClass().getSimpleName()
+ "$"
+ javaMethod.getName();
ClassWriter bridgeClassWriter = new ClassWriter(bridgeClassFrames);
bridgeClassWriter.visit(
WriterConstants.CLASS_VERSION,
bridgeClassAccess,
bridgeClassName,
null,
OBJECT_TYPE.getInternalName(),
null
);

org.objectweb.asm.commons.Method bridgeConstructorType = new org.objectweb.asm.commons.Method(
"<init>",
MethodType.methodType(void.class).toMethodDescriptorString()
);
GeneratorAdapter bridgeConstructorWriter = new GeneratorAdapter(
Opcodes.ASM5,
bridgeConstructorType,
bridgeClassWriter.visitMethod(
Opcodes.ACC_PRIVATE,
bridgeConstructorType.getName(),
bridgeConstructorType.getDescriptor(),
null,
null
)
);
bridgeConstructorWriter.visitCode();
bridgeConstructorWriter.loadThis();
bridgeConstructorWriter.invokeConstructor(OBJECT_TYPE, bridgeConstructorType);
bridgeConstructorWriter.returnValue();
bridgeConstructorWriter.endMethod();

int bridgeTypeParameterOffset = isStatic ? 0 : 1;
List<Class<?>> bridgeTypeParameters = new ArrayList<>(javaMethod.getParameterTypes().length + bridgeTypeParameterOffset);
int filteredTypeParameterOffset = isStatic ? 0 : 1;
List<Class<?>> filteredTypeParameters = new ArrayList<>(javaMethod.getParameterTypes().length + filteredTypeParameterOffset);

if (isStatic == false) {
bridgeTypeParameters.add(javaMethod.getDeclaringClass());
filteredTypeParameters.add(javaMethod.getDeclaringClass());
}

for (Class<?> typeParameter : javaMethod.getParameterTypes()) {
Expand All @@ -2295,78 +2212,48 @@ private void generateBridgeMethod(PainlessClassBuilder painlessClassBuilder, Pai
|| typeParameter == Long.class
|| typeParameter == Float.class
|| typeParameter == Double.class) {
bridgeTypeParameters.add(Object.class);
filteredTypeParameters.add(Object.class);
} else {
bridgeTypeParameters.add(typeParameter);
filteredTypeParameters.add(typeParameter);
}
}

MethodType bridgeMethodType = MethodType.methodType(painlessMethod.returnType(), bridgeTypeParameters);
MethodWriter bridgeMethodWriter = new MethodWriter(
Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC,
new org.objectweb.asm.commons.Method(painlessMethod.javaMethod().getName(), bridgeMethodType.toMethodDescriptorString()),
bridgeClassWriter,
null,
null
);
bridgeMethodWriter.visitCode();

if (isStatic == false) {
bridgeMethodWriter.loadArg(0);
}

for (int typeParameterCount = 0; typeParameterCount < javaMethod.getParameterTypes().length; ++typeParameterCount) {
bridgeMethodWriter.loadArg(typeParameterCount + bridgeTypeParameterOffset);
Class<?> typeParameter = javaMethod.getParameterTypes()[typeParameterCount];

if (typeParameter == Byte.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_BYTE_IMPLICIT);
else if (typeParameter == Short.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_SHORT_IMPLICIT);
else if (typeParameter == Character.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_CHARACTER_IMPLICIT);
else if (typeParameter == Integer.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_INTEGER_IMPLICIT);
else if (typeParameter == Long.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_LONG_IMPLICIT);
else if (typeParameter == Float.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_FLOAT_IMPLICIT);
else if (typeParameter == Double.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_DOUBLE_IMPLICIT);
}

bridgeMethodWriter.invokeMethodCall(painlessMethod);
bridgeMethodWriter.returnValue();
bridgeMethodWriter.endMethod();

bridgeClassWriter.visitEnd();
MethodType filteredMethodType = MethodType.methodType(painlessMethod.returnType(), filteredTypeParameters);
MethodHandle filteredMethodHandle = painlessMethod.methodHandle();

try {
BridgeLoader bridgeLoader = AccessController.doPrivileged(new PrivilegedAction<BridgeLoader>() {
@Override
public BridgeLoader run() {
return new BridgeLoader(javaMethod.getDeclaringClass().getClassLoader());
for (int typeParameterCount = 0; typeParameterCount < javaMethod.getParameterTypes().length; ++typeParameterCount) {
Class<?> typeParameter = javaMethod.getParameterTypes()[typeParameterCount];
MethodHandle castMethodHandle = Def.DEF_TO_BOXED_TYPE_IMPLICIT_CAST.get(typeParameter);

if (castMethodHandle != null) {
filteredMethodHandle = MethodHandles.filterArguments(
filteredMethodHandle,
typeParameterCount + filteredTypeParameterOffset,
castMethodHandle
);
}
});
}

Class<?> bridgeClass = bridgeLoader.defineBridge(bridgeClassName.replace('/', '.'), bridgeClassWriter.toByteArray());
Method bridgeMethod = bridgeClass.getMethod(
painlessMethod.javaMethod().getName(),
bridgeTypeParameters.toArray(new Class<?>[0])
);
MethodHandle bridgeHandle = lookup(bridgeClass).unreflect(bridgeClass.getMethods()[0]);
bridgePainlessMethod = new PainlessMethod(
bridgeMethod,
bridgeClass,
filteredPainlessMethod = new PainlessMethod(
painlessMethod.javaMethod(),
targetClass,
painlessMethod.returnType(),
bridgeTypeParameters,
bridgeHandle,
bridgeMethodType,
filteredTypeParameters,
filteredMethodHandle,
filteredMethodType,
Collections.emptyMap()
);
painlessClassBuilder.runtimeMethods.put(painlessMethodKey.intern(), bridgePainlessMethod);
painlessBridgeCache.put(painlessMethod, bridgePainlessMethod);
painlessClassBuilder.runtimeMethods.put(painlessMethodKey.intern(), filteredPainlessMethod);
painlessFilteredCache.put(painlessMethod, filteredPainlessMethod);
} catch (Exception exception) {
throw new IllegalStateException(
"internal error occurred attempting to generate a bridge method [" + bridgeClassName + "]",
"internal error occurred attempting to generate a runtime method [" + painlessMethodKey + "]",
exception
);
}
} else {
painlessClassBuilder.runtimeMethods.put(painlessMethodKey.intern(), bridgePainlessMethod);
painlessClassBuilder.runtimeMethods.put(painlessMethodKey.intern(), filteredPainlessMethod);
}
}

Expand Down

0 comments on commit 7234730

Please sign in to comment.