Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support JDK serialization/deserialization features #2730

Merged
merged 1 commit into from
Dec 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ public static void register(boolean finalIsWritable, Field... fields) {
ImageSingletons.lookup(RuntimeReflectionSupport.class).register(finalIsWritable, false, fields);
}

/**
* Makes the provided fields available for reflection at run time. The fields will be returned
* by {@link Class#getField}, {@link Class#getFields},and all the other methods on {@link Class}
* that return a single or a list of fields.
*
* @param finalIsWritable for all of the passed fields which are marked {@code final}, indicates
* whether it should be possible to change their value using reflection.
* @param allowUnsafeAccess for all of the passed fields, indicates whether it should be
* possible to access by unsafe operations.
* @since 21.0
*/
public static void register(boolean finalIsWritable, boolean allowUnsafeAccess, Field... fields) {
olpaw marked this conversation as resolved.
Show resolved Hide resolved
ImageSingletons.lookup(RuntimeReflectionSupport.class).register(finalIsWritable, allowUnsafeAccess, fields);
}

/**
* Makes the provided classes available for reflective instantiation by
* {@link Class#newInstance}. This is equivalent to registering the nullary constructors of the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import static com.oracle.svm.core.util.VMError.guarantee;
import static com.oracle.svm.jni.JNIObjectHandles.nullHandle;
import static com.oracle.svm.jvmtiagentbase.Support.callObjectMethod;
import static com.oracle.svm.jvmtiagentbase.Support.check;
import static com.oracle.svm.jvmtiagentbase.Support.checkJni;
import static com.oracle.svm.jvmtiagentbase.Support.checkNoException;
Expand All @@ -39,11 +40,13 @@
import static com.oracle.svm.jvmtiagentbase.Support.getClassNameOr;
import static com.oracle.svm.jvmtiagentbase.Support.getClassNameOrNull;
import static com.oracle.svm.jvmtiagentbase.Support.getDirectCallerClass;
import static com.oracle.svm.jvmtiagentbase.Support.getObjectField;
import static com.oracle.svm.jvmtiagentbase.Support.getMethodDeclaringClass;
import static com.oracle.svm.jvmtiagentbase.Support.getObjectArgument;
import static com.oracle.svm.jvmtiagentbase.Support.jniFunctions;
import static com.oracle.svm.jvmtiagentbase.Support.jvmtiEnv;
import static com.oracle.svm.jvmtiagentbase.Support.jvmtiFunctions;
import static com.oracle.svm.jvmtiagentbase.Support.newObjectL;
import static com.oracle.svm.jvmtiagentbase.Support.testException;
import static com.oracle.svm.jvmtiagentbase.Support.toCString;
import static com.oracle.svm.jvmtiagentbase.jvmti.JvmtiEvent.JVMTI_EVENT_BREAKPOINT;
Expand All @@ -61,6 +64,8 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.locks.ReentrantLock;

import com.oracle.svm.util.SerializationChecksumCalculator;
import com.oracle.svm.jni.nativeapi.JNIFieldId;
import org.graalvm.compiler.core.common.NumUtil;
import org.graalvm.nativeimage.StackValue;
import org.graalvm.nativeimage.UnmanagedMemory;
Expand All @@ -76,6 +81,7 @@
import org.graalvm.nativeimage.c.type.CTypeConversion;
import org.graalvm.nativeimage.c.type.CTypeConversion.CCharPointerHolder;
import org.graalvm.nativeimage.c.type.WordPointer;
import org.graalvm.word.WordBase;
import org.graalvm.word.WordFactory;

import com.oracle.svm.core.c.function.CEntryPointOptions;
Expand Down Expand Up @@ -831,6 +837,121 @@ private static boolean resolveMemberName(JNIEnvironment jni, Breakpoint bp) {
return true;
}

static class CheckSumCalculator extends SerializationChecksumCalculator.JVMCIAgentCalculator {
private JNIEnvironment jni;
private Breakpoint bp;

CheckSumCalculator(JNIEnvironment jni, Breakpoint bp) {
this.jni = jni;
this.bp = bp;
}

@Override
protected WordBase getSuperClass(WordBase clazz) {
return jniFunctions().getGetSuperclass().invoke(jni, (JNIObjectHandle) clazz);
}

@Override
public Long calculateFromComputeDefaultSUID(WordBase clazz) {
JNIMethodId computeDefaultSUIDMId = agent.handles().getJavaIoObjectStreamClassComputeDefaultSUID(jni, bp.clazz);
JNIValue args = StackValue.get(1, JNIValue.class);
args.setObject((JNIObjectHandle) clazz);
return jniFunctions().getCallStaticLongMethodA().invoke(jni, bp.clazz, computeDefaultSUIDMId, args);
}

@Override
protected boolean isClassAbstract(WordBase clazz) {
CIntPointer modifiers = StackValue.get(CIntPointer.class);
if (jvmtiFunctions().GetClassModifiers().invoke(jvmtiEnv(), (JNIObjectHandle) clazz, modifiers) != JvmtiError.JVMTI_ERROR_NONE) {
return false;
}
// Checkstyle: allow reflection
return (modifiers.read() & java.lang.reflect.Modifier.ABSTRACT) != 0;
}

@Override
public String getClassName(WordBase clazz) {
return getClassNameOrNull(jni, (JNIObjectHandle) clazz);
}
}

private static boolean objectStreamClassConstructor(JNIEnvironment jni, Breakpoint bp) {
JNIObjectHandle serializeTargetClass = getObjectArgument(1);
String serializeTargetClassName = getClassNameOrNull(jni, serializeTargetClass);
long checksum = 0;
List<SerializationInfo> traceCandidates = new ArrayList<>();
CheckSumCalculator checkSumCalculator = new CheckSumCalculator(jni, bp);
JNIObjectHandle objectStreamClassInstance = newObjectL(jni, bp.clazz, bp.method, serializeTargetClass);
Object result = nullHandle().notEqual(objectStreamClassInstance);
if (clearException(jni)) {
result = false;
}
// Skip Lambda class serialization
if (serializeTargetClassName.contains("$$Lambda$")) {
return true;
}
if (result.equals(true)) {
checksum = checkSumCalculator.calculateChecksum(getConsClassName(jni, bp.clazz, objectStreamClassInstance), serializeTargetClassName, serializeTargetClass);
}
traceCandidates.add(new SerializationInfo(serializeTargetClassName, checksum));

/**
* When the ObjectStreamClass instance is created for the given serializeTargetClass, some
* additional ObjectStreamClass instances (usually the super classes) are created
* recursively. Call ObjectStreamClass.getClassDataLayout0() can get all of them.
*/
JNIMethodId getClassDataLayout0MId = agent.handles().getJavaIoObjectStreamClassGetClassDataLayout0(jni, bp.clazz);
JNIObjectHandle dataLayoutArray = callObjectMethod(jni, objectStreamClassInstance, getClassDataLayout0MId);
if (!clearException(jni) && nullHandle().notEqual(dataLayoutArray)) {
int length = jniFunctions().getGetArrayLength().invoke(jni, dataLayoutArray);
// If only 1 element is got from getClassDataLayout0(). it is base ObjectStreamClass
// instance itself.
if (!clearException(jni) && length > 1) {
JNIFieldId hasDataFId = agent.handles().getJavaIOObjectStreamClassClassDataSlotHasData(jni);
JNIFieldId descFId = agent.handles().getJavaIOObjectStreamClassClassDataSlotDesc(jni);
JNIMethodId javaIoObjectStreamClassForClassMId = agent.handles().getJavaIoObjectStreamClassForClass(jni, bp.clazz);
for (int i = 0; i < length; i++) {
JNIObjectHandle classDataSlot = jniFunctions().getGetObjectArrayElement().invoke(jni, dataLayoutArray, i);
boolean hasData = jniFunctions().getGetBooleanField().invoke(jni, classDataSlot, hasDataFId);
if (hasData) {
JNIObjectHandle oscInstanceInSlot = jniFunctions().getGetObjectField().invoke(jni, classDataSlot, descFId);
if (!jniFunctions().getIsSameObject().invoke(jni, oscInstanceInSlot, objectStreamClassInstance)) {
JNIObjectHandle oscClazz = callObjectMethod(jni, oscInstanceInSlot, javaIoObjectStreamClassForClassMId);
String oscClassName = getClassNameOrNull(jni, oscClazz);
traceCandidates.add(new SerializationInfo(oscClassName,
checkSumCalculator.calculateChecksum(getConsClassName(jni,
bp.clazz, oscInstanceInSlot), oscClassName, oscClazz)));
}
}
}
}
}
for (SerializationInfo serializationInfo : traceCandidates) {
if (traceWriter != null) {
traceWriter.traceCall("serialization",
"ObjectStreamClass.<init>",
null,
null,
null,
result,
// serializeTargetClassName, checksum);
serializationInfo.className, serializationInfo.checksum);
guarantee(!testException(jni));
}
}
return true;
}

private static String getConsClassName(JNIEnvironment jni, JNIObjectHandle objectStreamClassClazz, JNIObjectHandle objectStreamClassInstance) {
JNIObjectHandle cons = getObjectField(jni, objectStreamClassClazz, objectStreamClassInstance, "cons", "Ljava/lang/reflect/Constructor;");
String targetConstructorClassName = "";
if (nullHandle().notEqual(cons)) {
// Compute hashcode from the first unserializable superclass
targetConstructorClassName = getClassNameOrNull(jni, callObjectMethod(jni, cons, agent.handles().javaLangReflectMemberGetDeclaringClass));
}
return targetConstructorClassName;
}

@CEntryPoint
@CEntryPointOptions(prologue = AgentIsolate.Prologue.class)
private static void onBreakpoint(@SuppressWarnings("unused") JvmtiEnv jvmti, JNIEnvironment jni,
Expand Down Expand Up @@ -1135,6 +1256,7 @@ private interface BreakpointHandler {
brk("java/lang/reflect/Proxy", "newProxyInstance",
"(Ljava/lang/ClassLoader;[Ljava/lang/Class;Ljava/lang/reflect/InvocationHandler;)Ljava/lang/Object;", BreakpointInterceptor::newProxyInstance),

brk("java/io/ObjectStreamClass", "<init>", "(Ljava/lang/Class;)V", BreakpointInterceptor::objectStreamClassConstructor),
optionalBrk("java/util/ResourceBundle",
"getBundleImpl",
"(Ljava/lang/String;Ljava/util/Locale;Ljava/lang/ClassLoader;Ljava/util/ResourceBundle$Control;)Ljava/util/ResourceBundle;",
Expand Down Expand Up @@ -1264,6 +1386,16 @@ public int hashCode() {
}
}

private static final class SerializationInfo {
private String className;
private long checksum;

SerializationInfo(String serializeTargetClassName, long checksum) {
this.className = serializeTargetClassName;
this.checksum = checksum;
}
}

private BreakpointInterceptor() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ protected int onLoadCallback(JNIJavaVM vm, JvmtiEnv jvmti, JvmtiEventCallbacks c
// They should use the same filter sets, however.
AccessAdvisor advisor = createAccessAdvisor(builtinHeuristicFilter, callerFilter, accessFilter);
TraceProcessor processor = new TraceProcessor(advisor, mergeConfigs.loadJniConfig(handler), mergeConfigs.loadReflectConfig(handler),
mergeConfigs.loadProxyConfig(handler), mergeConfigs.loadResourceConfig(handler));
mergeConfigs.loadProxyConfig(handler), mergeConfigs.loadResourceConfig(handler), mergeConfigs.loadSerializationConfig(handler));
traceWriter = new TraceProcessorWriterAdapter(processor);
} catch (Throwable t) {
System.err.println(MESSAGE_PREFIX + t);
Expand Down Expand Up @@ -424,6 +424,7 @@ private void writeConfigurationFiles() {
allConfigFiles.put(ConfigurationFiles.JNI_NAME, p.getJniConfiguration());
allConfigFiles.put(ConfigurationFiles.DYNAMIC_PROXY_NAME, p.getProxyConfiguration());
allConfigFiles.put(ConfigurationFiles.RESOURCES_NAME, p.getResourceConfiguration());
allConfigFiles.put(ConfigurationFiles.SERIALIZATION_NAME, p.getSerializationConfiguration());

for (Map.Entry<String, JsonPrintable> configFile : allConfigFiles.entrySet()) {
Path tempPath = tempDirectory.resolve(configFile.getKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
*/
package com.oracle.svm.agent;

import static com.oracle.svm.jni.JNIObjectHandles.nullHandle;

import com.oracle.svm.jni.nativeapi.JNIFieldId;
import com.oracle.svm.jni.nativeapi.JNIEnvironment;
import com.oracle.svm.jni.nativeapi.JNIMethodId;
import com.oracle.svm.jni.nativeapi.JNIObjectHandle;
Expand All @@ -33,6 +36,7 @@ public class NativeImageAgentJNIHandleSet extends JNIHandleSet {

final JNIObjectHandle javaLangClass;
final JNIMethodId javaLangClassForName3;
final JNIMethodId javaUtilEnumerationNextElement;
final JNIMethodId javaLangClassGetDeclaredMethod;
final JNIMethodId javaLangClassGetDeclaredConstructor;
final JNIMethodId javaLangClassGetDeclaredField;
Expand All @@ -53,6 +57,15 @@ public class NativeImageAgentJNIHandleSet extends JNIHandleSet {
final JNIMethodId javaLangInvokeMemberNameIsField;

private JNIMethodId javaUtilResourceBundleGetBundleImplSLCC;

// Lazily look for serialization classes
private JNIMethodId javaIoObjectStreamClassComputeDefaultSUID;
private JNIMethodId javaIoObjectStreamClassForClass;
private JNIMethodId javaIoObjectStreamClassGetClassDataLayout0;
private JNIObjectHandle javaIOObjectStreamClassClassDataSlot;
private JNIFieldId javaIOObjectStreamClassClassDataSlotDesc;
private JNIFieldId javaIOObjectStreamClassClassDataSlotHasData;

private boolean queriedJavaUtilResourceBundleGetBundleImplSLCC;

NativeImageAgentJNIHandleSet(JNIEnvironment env) {
Expand All @@ -70,6 +83,7 @@ public class NativeImageAgentJNIHandleSet extends JNIHandleSet {

JNIObjectHandle javaUtilEnumeration = findClass(env, "java/util/Enumeration");
javaUtilEnumerationHasMoreElements = getMethodId(env, javaUtilEnumeration, "hasMoreElements", "()Z", false);
javaUtilEnumerationNextElement = getMethodId(env, javaUtilEnumeration, "nextElement", "()Ljava/lang/Object;", false);

javaLangClassLoader = newClassGlobalRef(env, "java/lang/ClassLoader");

Expand All @@ -91,4 +105,46 @@ JNIMethodId tryGetJavaUtilResourceBundleGetBundleImplSLCC(JNIEnvironment env) {
}
return javaUtilResourceBundleGetBundleImplSLCC;
}

JNIMethodId getJavaIoObjectStreamClassComputeDefaultSUID(JNIEnvironment env, JNIObjectHandle javaIoObjectStreamClass) {
if (javaIoObjectStreamClassComputeDefaultSUID.equal(nullHandle())) {
javaIoObjectStreamClassComputeDefaultSUID = getMethodId(env, javaIoObjectStreamClass, "computeDefaultSUID", "(Ljava/lang/Class;)J", true);
}
return javaIoObjectStreamClassComputeDefaultSUID;
}

JNIMethodId getJavaIoObjectStreamClassForClass(JNIEnvironment env, JNIObjectHandle javaIoObjectStreamClass) {
if (javaIoObjectStreamClassForClass.equal(nullHandle())) {
javaIoObjectStreamClassForClass = getMethodId(env, javaIoObjectStreamClass, "forClass", "()Ljava/lang/Class;", false);
}
return javaIoObjectStreamClassForClass;
}

JNIMethodId getJavaIoObjectStreamClassGetClassDataLayout0(JNIEnvironment env, JNIObjectHandle javaIoObjectStreamClass) {
if (javaIoObjectStreamClassGetClassDataLayout0.equal(nullHandle())) {
javaIoObjectStreamClassGetClassDataLayout0 = getMethodId(env, javaIoObjectStreamClass, "getClassDataLayout0", "()[Ljava/io/ObjectStreamClass$ClassDataSlot;", false);
}
return javaIoObjectStreamClassGetClassDataLayout0;
}

JNIObjectHandle getJavaIOObjectStreamClassClassDataSlot(JNIEnvironment env) {
if (javaIOObjectStreamClassClassDataSlot.equal(nullHandle())) {
javaIOObjectStreamClassClassDataSlot = newClassGlobalRef(env, "java/io/ObjectStreamClass$ClassDataSlot");
}
return javaIOObjectStreamClassClassDataSlot;
}

JNIFieldId getJavaIOObjectStreamClassClassDataSlotDesc(JNIEnvironment env) {
if (javaIOObjectStreamClassClassDataSlotDesc.equal(nullHandle())) {
javaIOObjectStreamClassClassDataSlotDesc = getFieldId(env, getJavaIOObjectStreamClassClassDataSlot(env), "desc", "Ljava/io/ObjectStreamClass;", false);
}
return javaIOObjectStreamClassClassDataSlotDesc;
}

JNIFieldId getJavaIOObjectStreamClassClassDataSlotHasData(JNIEnvironment env) {
if (javaIOObjectStreamClassClassDataSlotHasData.equal(nullHandle())) {
javaIOObjectStreamClassClassDataSlotHasData = getFieldId(env, getJavaIOObjectStreamClassClassDataSlot(env), "hasData", "Z", false);
}
return javaIOObjectStreamClassClassDataSlotHasData;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ private static void generate(Iterator<String> argsIter, boolean acceptTraceFileA
set.getResourceConfigPaths().add(requirePathUri(current, value));
break;

case "--serialization-input":
set = inputSet; // fall through
case "--serialization-output":
set.getSerializationConfigPaths().add(requirePathUri(current, value));
break;

case "--trace-input":
traceInputs.add(requirePathUri(current, value));
break;
Expand Down Expand Up @@ -249,7 +255,8 @@ private static void generate(Iterator<String> argsIter, boolean acceptTraceFileA
TraceProcessor p;
try {
p = new TraceProcessor(advisor, inputSet.loadJniConfig(ConfigurationSet.FAIL_ON_EXCEPTION), inputSet.loadReflectConfig(ConfigurationSet.FAIL_ON_EXCEPTION),
inputSet.loadProxyConfig(ConfigurationSet.FAIL_ON_EXCEPTION), inputSet.loadResourceConfig(ConfigurationSet.FAIL_ON_EXCEPTION));
inputSet.loadProxyConfig(ConfigurationSet.FAIL_ON_EXCEPTION), inputSet.loadResourceConfig(ConfigurationSet.FAIL_ON_EXCEPTION),
inputSet.loadSerializationConfig(ConfigurationSet.FAIL_ON_EXCEPTION));
} catch (IOException e) {
throw e;
} catch (Throwable t) {
Expand Down Expand Up @@ -287,6 +294,11 @@ private static void generate(Iterator<String> argsIter, boolean acceptTraceFileA
p.getResourceConfiguration().printJson(writer);
}
}
for (URI uri : outputSet.getSerializationConfigPaths()) {
try (JsonWriter writer = new JsonWriter(Paths.get(uri))) {
p.getSerializationConfiguration().printJson(writer);
}
}
}

private static void generateFilterRules(Iterator<String> argsIter) throws IOException {
Expand Down
Loading