Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace bridge methods with filtered methods in Painless #88100

Merged
merged 4 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.lang.invoke.MethodType;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -146,6 +147,8 @@ private ArrayLengthHelper() {}
/** factory for arraylength MethodHandle (intrinsic) from Java 9 (pkg-private for tests) */
static final MethodHandle JAVA9_ARRAY_LENGTH_MH_FACTORY;

public static final Map<Class<?>, MethodHandle> DEF_TO_BOXED_TYPE_IMPLICIT_CAST;

static {
final MethodHandles.Lookup methodHandlesLookup = MethodHandles.publicLookup();

Expand Down Expand Up @@ -182,6 +185,43 @@ private ArrayLengthHelper() {}
arrayLengthMHFactory = null;
}
JAVA9_ARRAY_LENGTH_MH_FACTORY = arrayLengthMHFactory;

Map<Class<?>, MethodHandle> defToBoxedTypeImplicitCast = new HashMap<>();

try {
defToBoxedTypeImplicitCast.put(
Byte.class,
methodHandlesLookup.findStatic(Def.class, "defToByteImplicit", MethodType.methodType(Byte.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Short.class,
methodHandlesLookup.findStatic(Def.class, "defToShortImplicit", MethodType.methodType(Short.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Character.class,
methodHandlesLookup.findStatic(Def.class, "defToCharacterImplicit", MethodType.methodType(Character.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Integer.class,
methodHandlesLookup.findStatic(Def.class, "defToIntegerImplicit", MethodType.methodType(Integer.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Long.class,
methodHandlesLookup.findStatic(Def.class, "defToLongImplicit", MethodType.methodType(Long.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Float.class,
methodHandlesLookup.findStatic(Def.class, "defToFloatImplicit", MethodType.methodType(Float.class, Object.class))
);
defToBoxedTypeImplicitCast.put(
Double.class,
methodHandlesLookup.findStatic(Def.class, "defToDoubleImplicit", MethodType.methodType(Double.class, Object.class))
);
} catch (NoSuchMethodException | IllegalAccessException exception) {
throw new IllegalStateException(exception);
}

DEF_TO_BOXED_TYPE_IMPLICIT_CAST = Collections.unmodifiableMap(defToBoxedTypeImplicitCast);
}

/** Hack to rethrow unknown Exceptions from {@link MethodHandle#invokeExact}: */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@

package org.elasticsearch.painless.lookup;

import org.elasticsearch.bootstrap.BootstrapInfo;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.painless.Def;
import org.elasticsearch.painless.MethodWriter;
import org.elasticsearch.painless.WriterConstants;
import org.elasticsearch.painless.spi.Whitelist;
import org.elasticsearch.painless.spi.WhitelistClass;
import org.elasticsearch.painless.spi.WhitelistClassBinding;
Expand All @@ -25,9 +22,6 @@
import org.elasticsearch.painless.spi.annotation.CompileTimeOnlyAnnotation;
import org.elasticsearch.painless.spi.annotation.InjectConstantAnnotation;
import org.elasticsearch.painless.spi.annotation.NoImportAnnotation;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.commons.GeneratorAdapter;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
Expand All @@ -37,13 +31,6 @@
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.AccessController;
import java.security.CodeSource;
import java.security.PrivilegedAction;
import java.security.SecureClassLoader;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -56,15 +43,6 @@
import java.util.function.Supplier;
import java.util.regex.Pattern;

import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_BYTE_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_CHARACTER_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_DOUBLE_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_FLOAT_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_INTEGER_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_LONG_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_TO_B_SHORT_IMPLICIT;
import static org.elasticsearch.painless.WriterConstants.DEF_UTIL_TYPE;
import static org.elasticsearch.painless.WriterConstants.OBJECT_TYPE;
import static org.elasticsearch.painless.lookup.PainlessLookupUtility.DEF_CLASS_NAME;
import static org.elasticsearch.painless.lookup.PainlessLookupUtility.buildPainlessConstructorKey;
import static org.elasticsearch.painless.lookup.PainlessLookupUtility.buildPainlessFieldKey;
Expand All @@ -75,42 +53,17 @@

public final class PainlessLookupBuilder {

private static final class BridgeLoader extends SecureClassLoader {
BridgeLoader(ClassLoader parent) {
super(parent);
}

@Override
public Class<?> findClass(String name) throws ClassNotFoundException {
return Def.class.getName().equals(name) ? Def.class : super.findClass(name);
}

Class<?> defineBridge(String name, byte[] bytes) {
return defineClass(name, bytes, 0, bytes.length, CODESOURCE);
}
}

private static final CodeSource CODESOURCE;

private static final Map<PainlessConstructor, PainlessConstructor> painlessConstructorCache = new HashMap<>();
private static final Map<PainlessMethod, PainlessMethod> painlessMethodCache = new HashMap<>();
private static final Map<PainlessField, PainlessField> painlessFieldCache = new HashMap<>();
private static final Map<PainlessClassBinding, PainlessClassBinding> painlessClassBindingCache = new HashMap<>();
private static final Map<PainlessInstanceBinding, PainlessInstanceBinding> painlessInstanceBindingCache = new HashMap<>();
private static final Map<PainlessMethod, PainlessMethod> painlessBridgeCache = new HashMap<>();
private static final Map<PainlessMethod, PainlessMethod> painlessFilteredCache = new HashMap<>();

private static final Pattern CLASS_NAME_PATTERN = Pattern.compile("^[_a-zA-Z][._a-zA-Z0-9]*$");
private static final Pattern METHOD_NAME_PATTERN = Pattern.compile("^[_a-zA-Z][_a-zA-Z0-9]*$");
private static final Pattern FIELD_NAME_PATTERN = Pattern.compile("^[_a-zA-Z][_a-zA-Z0-9]*$");

static {
try {
CODESOURCE = new CodeSource(new URL("file:" + BootstrapInfo.UNTRUSTED_CODEBASE), (Certificate[]) null);
} catch (MalformedURLException mue) {
throw new RuntimeException(mue);
}
}

public static PainlessLookup buildFromWhitelists(List<Whitelist> whitelists) {
PainlessLookupBuilder painlessLookupBuilder = new PainlessLookupBuilder();
String origin = "internal error";
Expand Down Expand Up @@ -2216,7 +2169,9 @@ private void setFunctionalInterfaceMethod(Class<?> targetClass, PainlessClassBui
* run-time resulting from calls with a def type value target.
*/
private void generateRuntimeMethods() {
for (PainlessClassBuilder painlessClassBuilder : classesToPainlessClassBuilders.values()) {
for (Map.Entry<Class<?>, PainlessClassBuilder> painlessClassBuilderEntry : classesToPainlessClassBuilders.entrySet()) {
Class<?> targetClass = painlessClassBuilderEntry.getKey();
PainlessClassBuilder painlessClassBuilder = painlessClassBuilderEntry.getValue();
painlessClassBuilder.runtimeMethods.putAll(painlessClassBuilder.methods);

for (PainlessMethod painlessMethod : painlessClassBuilder.runtimeMethods.values()) {
Expand All @@ -2228,63 +2183,25 @@ private void generateRuntimeMethods() {
|| typeParameter == Long.class
|| typeParameter == Float.class
|| typeParameter == Double.class) {
generateBridgeMethod(painlessClassBuilder, painlessMethod);
generateFilteredMethod(targetClass, painlessClassBuilder, painlessMethod);
}
}
}
}
}

private void generateBridgeMethod(PainlessClassBuilder painlessClassBuilder, PainlessMethod painlessMethod) {
private void generateFilteredMethod(Class<?> targetClass, PainlessClassBuilder painlessClassBuilder, PainlessMethod painlessMethod) {
String painlessMethodKey = buildPainlessMethodKey(painlessMethod.javaMethod().getName(), painlessMethod.typeParameters().size());
PainlessMethod bridgePainlessMethod = painlessBridgeCache.get(painlessMethod);
PainlessMethod filteredPainlessMethod = painlessFilteredCache.get(painlessMethod);

if (bridgePainlessMethod == null) {
if (filteredPainlessMethod == null) {
Method javaMethod = painlessMethod.javaMethod();
boolean isStatic = Modifier.isStatic(painlessMethod.javaMethod().getModifiers());

int bridgeClassFrames = ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS;
int bridgeClassAccess = Opcodes.ACC_PUBLIC | Opcodes.ACC_SUPER | Opcodes.ACC_FINAL;
String bridgeClassName = "org/elasticsearch/painless/Bridge$"
+ javaMethod.getDeclaringClass().getSimpleName()
+ "$"
+ javaMethod.getName();
ClassWriter bridgeClassWriter = new ClassWriter(bridgeClassFrames);
bridgeClassWriter.visit(
WriterConstants.CLASS_VERSION,
bridgeClassAccess,
bridgeClassName,
null,
OBJECT_TYPE.getInternalName(),
null
);

org.objectweb.asm.commons.Method bridgeConstructorType = new org.objectweb.asm.commons.Method(
"<init>",
MethodType.methodType(void.class).toMethodDescriptorString()
);
GeneratorAdapter bridgeConstructorWriter = new GeneratorAdapter(
Opcodes.ASM5,
bridgeConstructorType,
bridgeClassWriter.visitMethod(
Opcodes.ACC_PRIVATE,
bridgeConstructorType.getName(),
bridgeConstructorType.getDescriptor(),
null,
null
)
);
bridgeConstructorWriter.visitCode();
bridgeConstructorWriter.loadThis();
bridgeConstructorWriter.invokeConstructor(OBJECT_TYPE, bridgeConstructorType);
bridgeConstructorWriter.returnValue();
bridgeConstructorWriter.endMethod();

int bridgeTypeParameterOffset = isStatic ? 0 : 1;
List<Class<?>> bridgeTypeParameters = new ArrayList<>(javaMethod.getParameterTypes().length + bridgeTypeParameterOffset);
int filteredTypeParameterOffset = isStatic ? 0 : 1;
List<Class<?>> filteredTypeParameters = new ArrayList<>(javaMethod.getParameterTypes().length + filteredTypeParameterOffset);

if (isStatic == false) {
bridgeTypeParameters.add(javaMethod.getDeclaringClass());
filteredTypeParameters.add(javaMethod.getDeclaringClass());
}

for (Class<?> typeParameter : javaMethod.getParameterTypes()) {
Expand All @@ -2295,78 +2212,48 @@ private void generateBridgeMethod(PainlessClassBuilder painlessClassBuilder, Pai
|| typeParameter == Long.class
|| typeParameter == Float.class
|| typeParameter == Double.class) {
bridgeTypeParameters.add(Object.class);
filteredTypeParameters.add(Object.class);
} else {
bridgeTypeParameters.add(typeParameter);
filteredTypeParameters.add(typeParameter);
}
}

MethodType bridgeMethodType = MethodType.methodType(painlessMethod.returnType(), bridgeTypeParameters);
MethodWriter bridgeMethodWriter = new MethodWriter(
Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC,
new org.objectweb.asm.commons.Method(painlessMethod.javaMethod().getName(), bridgeMethodType.toMethodDescriptorString()),
bridgeClassWriter,
null,
null
);
bridgeMethodWriter.visitCode();

if (isStatic == false) {
bridgeMethodWriter.loadArg(0);
}

for (int typeParameterCount = 0; typeParameterCount < javaMethod.getParameterTypes().length; ++typeParameterCount) {
bridgeMethodWriter.loadArg(typeParameterCount + bridgeTypeParameterOffset);
Class<?> typeParameter = javaMethod.getParameterTypes()[typeParameterCount];

if (typeParameter == Byte.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_BYTE_IMPLICIT);
else if (typeParameter == Short.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_SHORT_IMPLICIT);
else if (typeParameter == Character.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_CHARACTER_IMPLICIT);
else if (typeParameter == Integer.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_INTEGER_IMPLICIT);
else if (typeParameter == Long.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_LONG_IMPLICIT);
else if (typeParameter == Float.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_FLOAT_IMPLICIT);
else if (typeParameter == Double.class) bridgeMethodWriter.invokeStatic(DEF_UTIL_TYPE, DEF_TO_B_DOUBLE_IMPLICIT);
}

bridgeMethodWriter.invokeMethodCall(painlessMethod);
bridgeMethodWriter.returnValue();
bridgeMethodWriter.endMethod();

bridgeClassWriter.visitEnd();
MethodType filteredMethodType = MethodType.methodType(painlessMethod.returnType(), filteredTypeParameters);
MethodHandle filteredMethodHandle = painlessMethod.methodHandle();

try {
BridgeLoader bridgeLoader = AccessController.doPrivileged(new PrivilegedAction<BridgeLoader>() {
@Override
public BridgeLoader run() {
return new BridgeLoader(javaMethod.getDeclaringClass().getClassLoader());
for (int typeParameterCount = 0; typeParameterCount < javaMethod.getParameterTypes().length; ++typeParameterCount) {
Class<?> typeParameter = javaMethod.getParameterTypes()[typeParameterCount];
MethodHandle castMethodHandle = Def.DEF_TO_BOXED_TYPE_IMPLICIT_CAST.get(typeParameter);

if (castMethodHandle != null) {
filteredMethodHandle = MethodHandles.filterArguments(
filteredMethodHandle,
typeParameterCount + filteredTypeParameterOffset,
castMethodHandle
);
}
});
}

Class<?> bridgeClass = bridgeLoader.defineBridge(bridgeClassName.replace('/', '.'), bridgeClassWriter.toByteArray());
Method bridgeMethod = bridgeClass.getMethod(
painlessMethod.javaMethod().getName(),
bridgeTypeParameters.toArray(new Class<?>[0])
);
MethodHandle bridgeHandle = lookup(bridgeClass).unreflect(bridgeClass.getMethods()[0]);
bridgePainlessMethod = new PainlessMethod(
bridgeMethod,
bridgeClass,
filteredPainlessMethod = new PainlessMethod(
painlessMethod.javaMethod(),
targetClass,
painlessMethod.returnType(),
bridgeTypeParameters,
bridgeHandle,
bridgeMethodType,
filteredTypeParameters,
filteredMethodHandle,
filteredMethodType,
Collections.emptyMap()
);
painlessClassBuilder.runtimeMethods.put(painlessMethodKey.intern(), bridgePainlessMethod);
painlessBridgeCache.put(painlessMethod, bridgePainlessMethod);
painlessClassBuilder.runtimeMethods.put(painlessMethodKey.intern(), filteredPainlessMethod);
painlessFilteredCache.put(painlessMethod, filteredPainlessMethod);
} catch (Exception exception) {
throw new IllegalStateException(
"internal error occurred attempting to generate a bridge method [" + bridgeClassName + "]",
"internal error occurred attempting to generate a runtime method [" + painlessMethodKey + "]",
exception
);
}
} else {
painlessClassBuilder.runtimeMethods.put(painlessMethodKey.intern(), bridgePainlessMethod);
painlessClassBuilder.runtimeMethods.put(painlessMethodKey.intern(), filteredPainlessMethod);
}
}

Expand Down