Skip to content

Commit

Permalink
Fix generic parameter data flow validation in NativeAOT (#81532)
Browse files Browse the repository at this point in the history
[This is a revert of a revert of #80956 with additional fixes for #81358)

This reworks how generic parameter data flow validation is done in the NativeAOT compiler.

Previously generic data flow was done from generic dictionary nodes. Problem with that approach is that there's no origin information at that point. The warnings can't point to the place where the problematic instantiation is in the code - we only know that it exists.
Aside from it being unfriendly for the users, it means any RUC or suppressions don't work on these warnings the same way they do in linker/analyzer.

This change modifies the logic to tag the method as "needs data flow" whenever we spot an instantiation of an annotated generic in it somewhere.
Then the actual validation/marking is done from data flow using the trim analysis patterns.

The only exception to this is generic data flow for base types and interface implementations, that one is done on the EEType nodes.

Note that AOT implements a much more precise version of the generic data flow validation as compared to linker/analyzer. See the big comment at the beginning of `GenericParameterWarningLocation.cs` for how that works.

Due to an issue with DependencyInjection, this change also implements a behavior where if a method or field is reflection accessible, the compiler will perform generic argument data flow on all types in the signature of the method/field (which it normally wouldn't do). See #81358 for details about the issue and discussions on the fix approach.

Test changes:
Adds the two tests from linker which cover this functionality.

Change the test infra to use token to compare message origins for expected warnings. Consistently converting generic types/methods into strings across two type systems is just very difficult - the tokens are simple and reliable.

Changes the tests to avoid expecting specific generic types/methods formatting in the messages - again, it's too hard to make this consistent without lot of effort. And the tests don't really need it.

Adds a smoke test which has a simplified version of the DI problem from #81358.
  • Loading branch information
vitek-karas authored Feb 7, 2023
1 parent 55b35db commit e71a4fb
Show file tree
Hide file tree
Showing 28 changed files with 2,942 additions and 177 deletions.
5 changes: 4 additions & 1 deletion src/coreclr/tools/Common/Compiler/DisplayNameHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ public static string GetDisplayName(this MethodDesc method)
if (method.Signature.Length > 0)
{
for (int i = 0; i < method.Signature.Length - 1; i++)
sb.Append(method.Signature[i].GetDisplayNameWithoutNamespace()).Append(',');
{
TypeDesc instantiatedType = method.Signature[i].InstantiateSignature(method.OwningType.Instantiation, method.Instantiation);
sb.Append(instantiatedType.GetDisplayNameWithoutNamespace()).Append(',');
}

sb.Append(method.Signature[method.Signature.Length - 1].GetDisplayNameWithoutNamespace());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ public bool RequiresDataflowAnalysis(MethodDesc method)
try
{
method = method.GetTypicalMethodDefinition();
return GetAnnotations(method.OwningType).TryGetAnnotation(method, out var methodAnnotations)
&& (methodAnnotations.ReturnParameterAnnotation != DynamicallyAccessedMemberTypes.None || methodAnnotations.ParameterAnnotations != null);
TypeAnnotations typeAnnotations = GetAnnotations(method.OwningType);
return typeAnnotations.HasGenericParameterAnnotation() || typeAnnotations.TryGetAnnotation(method, out _);
}
catch (TypeSystemException)
{
Expand All @@ -73,7 +73,8 @@ public bool RequiresDataflowAnalysis(FieldDesc field)
try
{
field = field.GetTypicalFieldDefinition();
return GetAnnotations(field.OwningType).TryGetAnnotation(field, out _);
TypeAnnotations typeAnnotations = GetAnnotations(field.OwningType);
return typeAnnotations.HasGenericParameterAnnotation() || typeAnnotations.TryGetAnnotation(field, out _);
}
catch (TypeSystemException)
{
Expand Down Expand Up @@ -105,6 +106,31 @@ public bool HasAnyAnnotations(TypeDesc type)
}
}

public bool HasGenericParameterAnnotation(TypeDesc type)
{
try
{
return GetAnnotations(type.GetTypeDefinition()).HasGenericParameterAnnotation();
}
catch (TypeSystemException)
{
return false;
}
}

public bool HasGenericParameterAnnotation(MethodDesc method)
{
try
{
method = method.GetTypicalMethodDefinition();
return GetAnnotations(method.OwningType).TryGetAnnotation(method, out var annotation) && annotation.GenericParameterAnnotations != null;
}
catch (TypeSystemException)
{
return false;
}
}

internal DynamicallyAccessedMemberTypes GetParameterAnnotation(ParameterProxy param)
{
MethodDesc method = param.Method.Method.GetTypicalMethodDefinition();
Expand Down Expand Up @@ -884,6 +910,8 @@ public bool TryGetAnnotation(GenericParameterDesc genericParameter, out Dynamica

return false;
}

public bool HasGenericParameterAnnotation() => _genericParameterAnnotations != null;
}

private readonly struct MethodAnnotations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,83 @@

namespace ILCompiler.Dataflow
{
public readonly struct GenericArgumentDataFlow
public static class GenericArgumentDataFlow
{
private readonly Logger _logger;
private readonly NodeFactory _factory;
private readonly FlowAnnotations _annotations;
private readonly MessageOrigin _origin;
public static void ProcessGenericArgumentDataFlow(ref DependencyList dependencies, NodeFactory factory, in MessageOrigin origin, TypeDesc type, TypeDesc contextType)
{
ProcessGenericArgumentDataFlow(ref dependencies, factory, origin, type, contextType.Instantiation, Instantiation.Empty);
}

public GenericArgumentDataFlow(Logger logger, NodeFactory factory, FlowAnnotations annotations, in MessageOrigin origin)
public static void ProcessGenericArgumentDataFlow(ref DependencyList dependencies, NodeFactory factory, in MessageOrigin origin, TypeDesc type, MethodDesc contextMethod)
{
_logger = logger;
_factory = factory;
_annotations = annotations;
_origin = origin;
ProcessGenericArgumentDataFlow(ref dependencies, factory, origin, type, contextMethod.OwningType.Instantiation, contextMethod.Instantiation);
}

public DependencyList ProcessGenericArgumentDataFlow(GenericParameterDesc genericParameter, TypeDesc genericArgument)
private static void ProcessGenericArgumentDataFlow(ref DependencyList dependencies, NodeFactory factory, in MessageOrigin origin, TypeDesc type, Instantiation typeContext, Instantiation methodContext)
{
var genericParameterValue = _annotations.GetGenericParameterValue(genericParameter);
Debug.Assert(genericParameterValue.DynamicallyAccessedMemberTypes != DynamicallyAccessedMemberTypes.None);
if (!type.HasInstantiation)
return;

MultiValue genericArgumentValue = _annotations.GetTypeValueFromGenericArgument(genericArgument);
TypeDesc instantiatedType = type.InstantiateSignature(typeContext, methodContext);

var mdManager = (UsageBasedMetadataManager)factory.MetadataManager;

var diagnosticContext = new DiagnosticContext(
_origin,
_logger.ShouldSuppressAnalysisWarningsForRequires(_origin.MemberDefinition, DiagnosticUtilities.RequiresUnreferencedCodeAttribute),
_logger);
return RequireDynamicallyAccessedMembers(diagnosticContext, genericArgumentValue, genericParameterValue, genericParameter.GetDisplayName());
origin,
!mdManager.Logger.ShouldSuppressAnalysisWarningsForRequires(origin.MemberDefinition, DiagnosticUtilities.RequiresUnreferencedCodeAttribute),
mdManager.Logger);
var reflectionMarker = new ReflectionMarker(mdManager.Logger, factory, mdManager.FlowAnnotations, typeHierarchyDataFlowOrigin: null, enabled: true);

ProcessGenericArgumentDataFlow(diagnosticContext, reflectionMarker, instantiatedType);

if (reflectionMarker.Dependencies.Count > 0)
{
if (dependencies == null)
dependencies = reflectionMarker.Dependencies;
else
dependencies.AddRange(reflectionMarker.Dependencies);
}
}

public static void ProcessGenericArgumentDataFlow(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, TypeDesc type)
{
TypeDesc typeDefinition = type.GetTypeDefinition();
if (typeDefinition != type)
{
ProcessGenericInstantiation(diagnosticContext, reflectionMarker, type.Instantiation, typeDefinition.Instantiation);
}
}

public static void ProcessGenericArgumentDataFlow(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, MethodDesc method)
{
MethodDesc typicalMethod = method.GetTypicalMethodDefinition();
if (typicalMethod != method)
{
ProcessGenericInstantiation(diagnosticContext, reflectionMarker, method.Instantiation, typicalMethod.Instantiation);
}

ProcessGenericArgumentDataFlow(diagnosticContext, reflectionMarker, method.OwningType);
}

public static void ProcessGenericArgumentDataFlow(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, FieldDesc field)
{
ProcessGenericArgumentDataFlow(diagnosticContext, reflectionMarker, field.OwningType);
}

private DependencyList RequireDynamicallyAccessedMembers(
in DiagnosticContext diagnosticContext,
in MultiValue value,
ValueWithDynamicallyAccessedMembers targetValue,
string reason)
private static void ProcessGenericInstantiation(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, Instantiation instantiation, Instantiation typicalInstantiation)
{
var reflectionMarker = new ReflectionMarker(_logger, _factory, _annotations, typeHierarchyDataFlowOrigin: null, enabled: true);
var requireDynamicallyAccessedMembersAction = new RequireDynamicallyAccessedMembersAction(reflectionMarker, diagnosticContext, reason);
requireDynamicallyAccessedMembersAction.Invoke(value, targetValue);
return reflectionMarker.Dependencies;
for (int i = 0; i < instantiation.Length; i++)
{
var genericParameter = (GenericParameterDesc)typicalInstantiation[i];
if (reflectionMarker.Annotations.GetGenericParameterAnnotation(genericParameter) != default)
{
var genericParameterValue = reflectionMarker.Annotations.GetGenericParameterValue(genericParameter);
Debug.Assert(genericParameterValue.DynamicallyAccessedMemberTypes != DynamicallyAccessedMemberTypes.None);
MultiValue genericArgumentValue = reflectionMarker.Annotations.GetTypeValueFromGenericArgument(instantiation[i]);
var requireDynamicallyAccessedMembersAction = new RequireDynamicallyAccessedMembersAction(reflectionMarker, diagnosticContext, genericParameter.GetDisplayName());
requireDynamicallyAccessedMembersAction.Invoke(genericArgumentValue, genericParameterValue);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ protected virtual void Scan(MethodIL methodBody, ref InterproceduralState interp
StackSlot retValue = PopUnknown(currentStack, 1, methodBody, offset);
// If the return value is a reference, treat it as the value itself for now
// We can handle ref return values better later
ReturnValue = MultiValueLattice.Meet(ReturnValue, DereferenceValue(retValue.Value, locals, ref interproceduralState));
ReturnValue = MultiValueLattice.Meet(ReturnValue, DereferenceValue(methodBody, offset, retValue.Value, locals, ref interproceduralState));
ValidateNoReferenceToReference(locals, methodBody, offset);
}
ClearStack(ref currentStack);
Expand Down Expand Up @@ -947,23 +947,24 @@ private void ScanLdtoken(MethodIL methodBody, int offset, object operand, Stack<
var nullableDam = new RuntimeTypeHandleForNullableValueWithDynamicallyAccessedMembers(new TypeProxy(type),
new RuntimeTypeHandleForGenericParameterValue(genericParam));
currentStack.Push(new StackSlot(nullableDam));
return;
break;
case MetadataType underlyingType:
var nullableType = new RuntimeTypeHandleForNullableSystemTypeValue(new TypeProxy(type), new SystemTypeValue(underlyingType));
currentStack.Push(new StackSlot(nullableType));
return;
break;
default:
PushUnknown(currentStack);
return;
break;
}
}
else
{
var typeHandle = new RuntimeTypeHandleValue(new TypeProxy(type));
currentStack.Push(new StackSlot(typeHandle));
return;
}
}

HandleTypeReflectionAccess(methodBody, offset, type);
}
else if (operand is MethodDesc method)
{
Expand Down Expand Up @@ -1026,7 +1027,7 @@ protected void StoreInReference(MultiValue target, MultiValue source, MethodIL m
StoreMethodLocalValue(locals, source, localReference.LocalIndex, curBasicBlock);
break;
case FieldReferenceValue fieldReference
when GetFieldValue(fieldReference.FieldDefinition).AsSingleValue() is FieldValue fieldValue:
when HandleGetField(method, offset, fieldReference.FieldDefinition).AsSingleValue() is FieldValue fieldValue:
HandleStoreField(method, offset, fieldValue, source);
break;
case ParameterReferenceValue parameterReference
Expand All @@ -1038,7 +1039,7 @@ when GetMethodParameterValue(parameterReference.Parameter) is MethodParameterVal
HandleStoreMethodReturnValue(method, offset, methodReturnValue, source);
break;
case FieldValue fieldValue:
HandleStoreField(method, offset, fieldValue, DereferenceValue(source, locals, ref ipState));
HandleStoreField(method, offset, fieldValue, DereferenceValue(method, offset, source, locals, ref ipState));
break;
case IValueWithStaticType valueWithStaticType:
if (valueWithStaticType.StaticType is not null && FlowAnnotations.IsTypeInterestingForDataflow(valueWithStaticType.StaticType))
Expand All @@ -1057,7 +1058,25 @@ when GetMethodParameterValue(parameterReference.Parameter) is MethodParameterVal

}

protected abstract MultiValue GetFieldValue(FieldDesc field);
/// <summary>
/// HandleGetField is called every time the scanner needs to represent a value of the field
/// either as a source or target. It is not called when just a reference to field is created,
/// But if such reference is dereferenced then it will get called.
/// It is NOT called for hoisted locals.
/// </summary>
/// <remarks>
/// There should be no need to perform checks for hoisted locals. All of our reflection checks are based
/// on an assumption that problematic things happen because of running code. Doing things purely in the type system
/// (declaring new types which are never instantiated, declaring fields which are never assigned to, ...)
/// don't cause problems (or better way, they won't show observable behavioral differences).
/// Typically that would mean that accessing fields is also an uninteresting operation, unfortunately
/// static fields access can cause execution of static .cctor and that is running code -> possible problems.
/// So we have to track accesses in that case.
/// Hoisted locals are fields on closure classes/structs which should not have static .ctors, so we don't
/// need to track those. It makes the design a bit cleaner because hoisted locals are purely handled in here
/// and don't leak over to the reflection handling code in any way.
/// </remarks>
protected abstract MultiValue HandleGetField(MethodIL methodBody, int offset, FieldDesc field);

private void ScanLdfld(
MethodIL methodBody,
Expand All @@ -1083,7 +1102,7 @@ private void ScanLdfld(
}
else
{
value = GetFieldValue(field);
value = HandleGetField(methodBody, offset, field);
}
currentStack.Push(new StackSlot(value));
}
Expand Down Expand Up @@ -1119,15 +1138,15 @@ private void ScanStfld(
return;
}

foreach (var value in GetFieldValue(field))
foreach (var value in HandleGetField(methodBody, offset, field))
{
// GetFieldValue may return different node types, in which case they can't be stored to.
// At least not yet.
if (value is not FieldValue fieldValue)
continue;

// Incomplete handling of ref fields -- if we're storing a reference to a value, pretend it's just the value
MultiValue valueToStore = DereferenceValue(valueToStoreSlot.Value, locals, ref interproceduralState);
MultiValue valueToStore = DereferenceValue(methodBody, offset, valueToStoreSlot.Value, locals, ref interproceduralState);

HandleStoreField(methodBody, offset, fieldValue, valueToStore);
}
Expand Down Expand Up @@ -1163,7 +1182,12 @@ private ValueNodeList PopCallArguments(
return methodParams;
}

internal MultiValue DereferenceValue(MultiValue maybeReferenceValue, ValueBasicBlockPair?[] locals, ref InterproceduralState interproceduralState)
internal MultiValue DereferenceValue(
MethodIL methodBody,
int offset,
MultiValue maybeReferenceValue,
ValueBasicBlockPair?[] locals,
ref InterproceduralState interproceduralState)
{
MultiValue dereferencedValue = MultiValueLattice.Top;
foreach (var value in maybeReferenceValue)
Expand All @@ -1175,7 +1199,7 @@ internal MultiValue DereferenceValue(MultiValue maybeReferenceValue, ValueBasicB
dereferencedValue,
CompilerGeneratedState.IsHoistedLocal(fieldReferenceValue.FieldDefinition)
? interproceduralState.GetHoistedLocal(new HoistedLocalKey(fieldReferenceValue.FieldDefinition))
: GetFieldValue(fieldReferenceValue.FieldDefinition));
: HandleGetField(methodBody, offset, fieldReferenceValue.FieldDefinition));
break;
case ParameterReferenceValue parameterReferenceValue:
dereferencedValue = MultiValue.Meet(
Expand Down Expand Up @@ -1224,6 +1248,11 @@ protected void AssignRefAndOutParameters(
}
}

/// <summary>
/// Called when type is accessed directly (basically only ldtoken)
/// </summary>
protected abstract void HandleTypeReflectionAccess(MethodIL methodBody, int offset, TypeDesc accessedType);

/// <summary>
/// Called to handle reflection access to a method without any other specifics (ldtoken or ldftn for example)
/// </summary>
Expand Down Expand Up @@ -1260,7 +1289,7 @@ private void HandleCall(

var dereferencedMethodParams = new List<MultiValue>();
foreach (var argument in methodArguments)
dereferencedMethodParams.Add(DereferenceValue(argument, locals, ref interproceduralState));
dereferencedMethodParams.Add(DereferenceValue(callingMethodBody, offset, argument, locals, ref interproceduralState));
MultiValue methodReturnValue;
bool handledFunction = HandleCall(
callingMethodBody,
Expand Down
Loading

0 comments on commit e71a4fb

Please sign in to comment.