From 17c2fd7bb5d04deeab0d1d027b2a3ec0b6ea4d6d Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Wed, 24 Jan 2024 17:09:11 -0800 Subject: [PATCH 01/19] Mark all base types and interfaces as RelevantToVariantCasting --- .../illink/src/linker/Linker/Annotations.cs | 22 ++++-- .../MostSpecificDefaultImplementationKept.cs | 68 +++++++++++++++++++ ...efaultInterfaceMethodOnDerivedInterface.cs | 68 +++++++++++++++++++ 3 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs create mode 100644 src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/StaticDefaultInterfaceMethodOnDerivedInterface.cs diff --git a/src/tools/illink/src/linker/Linker/Annotations.cs b/src/tools/illink/src/linker/Linker/Annotations.cs index ec507ea25c0a5e..cdf3785d39fb31 100644 --- a/src/tools/illink/src/linker/Linker/Annotations.cs +++ b/src/tools/illink/src/linker/Linker/Annotations.cs @@ -246,8 +246,22 @@ public bool IsInstantiated (TypeDefinition type) public void MarkRelevantToVariantCasting (TypeDefinition type) { - if (type != null) - types_relevant_to_variant_casting.Add (type); + if (type == null) + return; + + if (!types_relevant_to_variant_casting.Add (type)) + return; + + foreach (var baseType in type.Interfaces) { + var resolvedBaseType = context.Resolve (baseType.InterfaceType); + if (resolvedBaseType is null) + continue; + // Don't need to rercurse for interfaces - types implement all interfaces in the interface/base type hierarchy + types_relevant_to_variant_casting.Add (resolvedBaseType); + } + if (type.BaseType is not null && context.Resolve(type.BaseType) is {} baseTypeDef) { + MarkRelevantToVariantCasting (baseTypeDef); + } } public bool IsRelevantToVariantCasting (TypeDefinition type) @@ -454,9 +468,9 @@ public bool IsPublic (IMetadataTokenProvider provider) return TypeMapInfo.GetOverrides (method); } - public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface)>? GetDefaultInterfaceImplementations (MethodDefinition method) + public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface)> GetDefaultInterfaceImplementations (MethodDefinition method) { - return TypeMapInfo.GetDefaultInterfaceImplementations (method); + return TypeMapInfo.GetDefaultInterfaceImplementations (method) ?? []; } /// diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs new file mode 100644 index 00000000000000..b3a90330a9c8a2 --- /dev/null +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs @@ -0,0 +1,68 @@ + +using Mono.Linker.Tests.Cases.Expectations.Assertions; + +namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.DefaultInterfaceMethods +{ + [TestCaseRequirements (TestRunCharacteristics.SupportsDefaultInterfaceMethods, "Requires support for default interface methods")] + class MostSpecificDefaultImplementationKept + { + [Kept] + public static void Main () + { +#if SUPPORTS_DEFAULT_INTERFACE_METHODS + M(); +#endif + } + +#if SUPPORTS_DEFAULT_INTERFACE_METHODS + + [Kept] + static int M() where T : IBase { + return T.Value; + } + + [Kept] + interface IBase { + [Kept] + static virtual int Value + { + [Kept] + get=>0; + } + } + + [Kept] + [KeptInterface(typeof(IBase))] + interface IMiddle : IBase { + [Kept] // Should be removable -- Add link to bug before merge + static int IBase.Value + { + [Kept] // Should be removable -- Add link to bug before merge + get=>1; + } + } + + [Kept] + [KeptInterface(typeof(IBase))] + [KeptInterface(typeof(IMiddle))] + interface IDerived : IMiddle { + [Kept] + static int IBase.Value + { + [Kept] + get=>2; + } + } + + interface INotReferenced + {} + + [Kept] + [KeptInterface(typeof(IDerived))] + [KeptInterface(typeof(IMiddle))] + [KeptInterface(typeof(IBase))] + struct Instance : IDerived, INotReferenced { + } +#endif + } +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/StaticDefaultInterfaceMethodOnDerivedInterface.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/StaticDefaultInterfaceMethodOnDerivedInterface.cs new file mode 100644 index 00000000000000..eeb626cd5b95d7 --- /dev/null +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/StaticDefaultInterfaceMethodOnDerivedInterface.cs @@ -0,0 +1,68 @@ + +using Mono.Linker.Tests.Cases.Expectations.Assertions; + +namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.DefaultInterfaceMethods +{ + [TestCaseRequirements (TestRunCharacteristics.SupportsDefaultInterfaceMethods, "Requires support for default interface methods")] + class StaticDefaultInterfaceMethodOnDerivedInterface + { + [Kept] + public static void Main () + { +#if SUPPORTS_DEFAULT_INTERFACE_METHODS + M(); +#endif + } + +#if SUPPORTS_DEFAULT_INTERFACE_METHODS + + [Kept] + static int M() where T : IBase { + return T.Value; + } + + [Kept] + interface IBase { + [Kept] + static abstract int Value + { + [Kept] + get; + } + } + + [Kept] + [KeptInterface(typeof(IBase))] + interface IMiddle : IBase { + [Kept] // Should be removable -- Add link to bug before merge + static int IBase.Value + { + [Kept] // Should be removable -- Add link to bug before merge + get=>1; + } + } + + [Kept] + [KeptInterface(typeof(IBase))] + [KeptInterface(typeof(IMiddle))] + interface IDerived : IMiddle { + [Kept] + static int IBase.Value + { + [Kept] + get=>2; + } + } + + interface INotReferenced + {} + + [Kept] + [KeptInterface(typeof(IDerived))] + [KeptInterface(typeof(IMiddle))] + [KeptInterface(typeof(IBase))] + struct Instance : IDerived, INotReferenced { + } +#endif + } +} From 642385d0ae972d85b10f97a8fbc075ba138bb0cd Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Fri, 26 Jan 2024 17:14:41 -0800 Subject: [PATCH 02/19] Update Test infra --- .../src/linker/Linker.Steps/MarkStep.cs | 78 ++- .../illink/src/linker/Linker/Annotations.cs | 6 +- .../src/linker/Linker/DependencyInfo.cs | 2 + .../illink/src/linker/Linker/TypeMapInfo.cs | 18 +- .../MostSpecificDefaultImplementationKept.cs | 87 +++- .../TestCasesRunner/AssemblyChecker.cs | 478 ++++++++++-------- 6 files changed, 408 insertions(+), 261 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index da4dc455f0abca..5215eb5608eabc 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -215,6 +215,7 @@ internal DynamicallyAccessedMembersTypeHierarchy DynamicallyAccessedMembersTypeH DependencyKind.ReturnTypeMarshalSpec, DependencyKind.XmlDescriptor, DependencyKind.UnsafeAccessorTarget, + DependencyKind.DefaultImplementationForImplementingType, }; #endif @@ -732,6 +733,9 @@ bool ShouldMarkOverrideForBase (OverrideInformation overrideInformation) if (!Context.IsOptimizationEnabled (CodeOptimizations.OverrideRemoval, overrideInformation.Override)) return true; + if(Annotations.GetDefaultInterfaceImplementations(overrideInformation.Base).Where(dim => dim.DefaultInterfaceMethods == overrideInformation.Override).ToList() is [var dim]) + return Annotations.IsRelevantToVariantCasting(dim.InstanceType); + // In this context, an override needs to be kept if // a) it's an override on an instantiated type (of a marked base) or // b) it's an override of an abstract base (required for valid IL) @@ -2275,9 +2279,9 @@ void MarkTypeWithDebuggerDisplayAttribute (TypeDefinition type, CustomAttribute // Record a logical dependency on the attribute so that we can blame it for the kept members below. Tracer.AddDirectDependency (attribute, new DependencyInfo (DependencyKind.CustomAttribute, type), marked: false); - MarkTypeWithDebuggerDisplayAttributeValue(type, attribute, (string) attribute.ConstructorArguments[0].Value); + MarkTypeWithDebuggerDisplayAttributeValue (type, attribute, (string) attribute.ConstructorArguments[0].Value); if (attribute.HasProperties) { - foreach (var property in attribute.Properties) { + foreach (var property in attribute.Properties) { if (property.Name is "Name" or "Type") { MarkTypeWithDebuggerDisplayAttributeValue (type, attribute, (string) property.Argument.Value); } @@ -2578,7 +2582,7 @@ bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformat // If the method is static and the implementing type is relevant to variant casting, mark the implementation method. // A static method may only be called through a constrained call if the type is relevant to variant casting. if (@base.IsStatic) - return Annotations.IsRelevantToVariantCasting (method.DeclaringType) + return Annotations.IsRelevantToVariantCasting (method.DeclaringType) || method.DeclaringType.IsInterface || IgnoreScope (@base.DeclaringType.Scope); // If the implementing type is marked as instantiated, mark the implementation method. @@ -2814,7 +2818,51 @@ void MarkGenericArguments (IGenericInstance instance) if (parameter.HasDefaultConstructorConstraint) MarkDefaultConstructor (argumentTypeDef, new DependencyInfo (DependencyKind.DefaultCtorForNewConstrainedGenericArgument, instance)); - } + + // var interfaceConstraints = GetConstraintInterfaces (parameter); + // foreach (var constrainedInterfacemethod in interfaceConstraints.SelectMany (c => c.Methods).Where (m => m.IsStatic)) { + // foreach (var dim in Annotations.GetDefaultInterfaceImplementations (constrainedInterfacemethod)) { + // var asdf = new MessageOrigin(parameter); + // MarkMethod (dim.DefaultInterfaceMethods, new DependencyInfo (DependencyKind.DefaultImplementationForImplementingType, constrainedInterfacemethod), in asdf); + // } + // } + } + // IEnumerable GetConstraintInterfaces (GenericParameter gp) + // { + // foreach (var constraint in gp.Constraints) { + // switch (constraint.ConstraintType) { + // case GenericParameter gp2: + // foreach (var td1 in GetConstraintInterfaces (gp2)) { + // yield return td1; + // } + // break; + // case TypeSpecification when constraint.ConstraintType is not GenericInstanceType: + // // Not a resolvable type + // continue; + // default: + // if (Context.Resolve (constraint.ConstraintType) is TypeDefinition { IsInterface: true } td2) { + // yield return td2; + // } + // break; + // } + // } + // } + // MethodDefinition? def = Context.Resolve (method); + // if (def is not null) { + // for (int i = 0; i < gim.GenericArguments.Count; i++) { + // TypeReference arg = gim.GenericArguments[i]; + // GenericParameter param = def.GenericParameters[i]; + // var interfaceConstraints = GetConstraintInterfaces (param); + // foreach (var constrainedInterfacemethod in interfaceConstraints.SelectMany (c => c.Methods).Where (m => m.IsStatic)) { + // foreach (var dim in Annotations.GetDefaultInterfaceImplementations (constrainedInterfacemethod)) { + // var thisMethodOrigin = new MessageOrigin (def); + // using (ScopeStack.PushScope (thisMethodOrigin)) { + // MarkMethod (dim.DefaultInterfaceMethods, new DependencyInfo (DependencyKind.DefaultImplementationForImplementingType, method), in thisMethodOrigin); + // } + // } + // } + // } + // } } IGenericParameterProvider? GetGenericProviderFromInstance (IGenericInstance instance) @@ -3007,7 +3055,7 @@ void MarkMethodCollection (IList methods, in DependencyInfo re protected virtual MethodDefinition? MarkMethod (MethodReference reference, DependencyInfo reason, in MessageOrigin origin) { DependencyKind originalReasonKind = reason.Kind; - (reference, reason) = GetOriginalMethod (reference, reason); + (reference, reason) = GetOriginalMethod (reference, reason, in origin); if (reference.DeclaringType is ArrayType arrayType) { MarkType (reference.DeclaringType, new DependencyInfo (DependencyKind.DeclaringType, reference)); @@ -3178,14 +3226,15 @@ internal static void ReportRequiresUnreferencedCode (string displayName, Require diagnosticContext.AddDiagnostic (DiagnosticId.RequiresUnreferencedCode, displayName, arg1, arg2); } - protected (MethodReference, DependencyInfo) GetOriginalMethod (MethodReference method, DependencyInfo reason) + protected (MethodReference, DependencyInfo) GetOriginalMethod (MethodReference method, DependencyInfo reason, in MessageOrigin origin) { while (method is MethodSpecification specification) { // Blame the method reference (which isn't marked) on the original reason. Tracer.AddDirectDependency (specification, reason, marked: false); // Blame the outgoing element method on the specification. - if (method is GenericInstanceMethod gim) + if (method is GenericInstanceMethod gim) { MarkGenericArguments (gim); + } (method, reason) = (specification.ElementMethod, new DependencyInfo (DependencyKind.ElementMethod, specification)); Debug.Assert (!(method is MethodSpecification)); @@ -3231,7 +3280,7 @@ protected virtual void ProcessMethod (MethodDefinition method, in DependencyInfo } else if (method.TryGetProperty (out PropertyDefinition? property)) MarkProperty (property, new DependencyInfo (PropagateDependencyKindToAccessors (reason.Kind, DependencyKind.PropertyOfPropertyMethod), method)); else if (method.TryGetEvent (out EventDefinition? @event)) { - MarkEvent (@event, new DependencyInfo (PropagateDependencyKindToAccessors(reason.Kind, DependencyKind.EventOfEventMethod), method)); + MarkEvent (@event, new DependencyInfo (PropagateDependencyKindToAccessors (reason.Kind, DependencyKind.EventOfEventMethod), method)); } if (method.HasMetadataParameters ()) { @@ -3315,7 +3364,7 @@ protected virtual void DoAdditionalMethodProcessing (MethodDefinition method) { } - static DependencyKind PropagateDependencyKindToAccessors(DependencyKind parentDependencyKind, DependencyKind kind) + static DependencyKind PropagateDependencyKindToAccessors (DependencyKind parentDependencyKind, DependencyKind kind) { switch (parentDependencyKind) { // If the member is marked due to descriptor or similar, propagate the original reason to suppress some warnings correctly @@ -3335,11 +3384,11 @@ void MarkImplicitlyUsedFields (TypeDefinition type) return; // keep fields for types with explicit layout, for enums and for InlineArray types - if (!type.IsAutoLayout || type.IsEnum || TypeIsInlineArrayType(type)) + if (!type.IsAutoLayout || type.IsEnum || TypeIsInlineArrayType (type)) MarkFields (type, includeStatic: type.IsEnum, reason: new DependencyInfo (DependencyKind.MemberOfType, type)); } - static bool TypeIsInlineArrayType(TypeDefinition type) + static bool TypeIsInlineArrayType (TypeDefinition type) { if (!type.IsValueType) return false; @@ -3584,7 +3633,7 @@ protected internal virtual void MarkEvent (EventDefinition evt, in DependencyInf MarkCustomAttributes (evt, new DependencyInfo (DependencyKind.CustomAttribute, evt)); - DependencyKind dependencyKind = PropagateDependencyKindToAccessors(reason.Kind, DependencyKind.EventMethod); + DependencyKind dependencyKind = PropagateDependencyKindToAccessors (reason.Kind, DependencyKind.EventMethod); MarkMethodIfNotNull (evt.AddMethod, new DependencyInfo (dependencyKind, evt), ScopeStack.CurrentScope.Origin); MarkMethodIfNotNull (evt.InvokeMethod, new DependencyInfo (dependencyKind, evt), ScopeStack.CurrentScope.Origin); MarkMethodIfNotNull (evt.RemoveMethod, new DependencyInfo (dependencyKind, evt), ScopeStack.CurrentScope.Origin); @@ -3762,8 +3811,7 @@ protected virtual void MarkInstruction (Instruction instruction, MethodDefinitio ScopeStack.UpdateCurrentScopeInstructionOffset (instruction.Offset); if (markForReflectionAccess) { MarkMethodVisibleToReflection (methodReference, new DependencyInfo (dependencyKind, method), ScopeStack.CurrentScope.Origin); - } - else { + } else { MarkMethod (methodReference, new DependencyInfo (dependencyKind, method), ScopeStack.CurrentScope.Origin); } break; @@ -3828,6 +3876,8 @@ protected internal virtual void MarkInterfaceImplementation (InterfaceImplementa { if (Annotations.IsMarked (iface)) return; + if (iface.InterfaceType.Name == "IBase") + _ = 0; Annotations.MarkProcessed (iface, reason ?? new DependencyInfo (DependencyKind.InterfaceImplementationOnType, ScopeStack.CurrentScope.Origin.Provider)); using var localScope = origin.HasValue ? ScopeStack.PushScope (origin.Value) : null; diff --git a/src/tools/illink/src/linker/Linker/Annotations.cs b/src/tools/illink/src/linker/Linker/Annotations.cs index cdf3785d39fb31..0389a3b3e8c939 100644 --- a/src/tools/illink/src/linker/Linker/Annotations.cs +++ b/src/tools/illink/src/linker/Linker/Annotations.cs @@ -236,6 +236,8 @@ public bool IsReflectionUsed (IMemberDefinition method) public void MarkInstantiated (TypeDefinition type) { + if (type.Name == "NotUsedAsIBase") + _ = 0; marked_instantiated.Add (type); } @@ -256,7 +258,7 @@ public void MarkRelevantToVariantCasting (TypeDefinition type) var resolvedBaseType = context.Resolve (baseType.InterfaceType); if (resolvedBaseType is null) continue; - // Don't need to rercurse for interfaces - types implement all interfaces in the interface/base type hierarchy + // Don't need to recurse for interfaces - types implement all interfaces in the interface/base type hierarchy types_relevant_to_variant_casting.Add (resolvedBaseType); } if (type.BaseType is not null && context.Resolve(type.BaseType) is {} baseTypeDef) { @@ -468,7 +470,7 @@ public bool IsPublic (IMetadataTokenProvider provider) return TypeMapInfo.GetOverrides (method); } - public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface)> GetDefaultInterfaceImplementations (MethodDefinition method) + public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface, MethodDefinition DefaultInterfaceMethods)> GetDefaultInterfaceImplementations (MethodDefinition method) { return TypeMapInfo.GetDefaultInterfaceImplementations (method) ?? []; } diff --git a/src/tools/illink/src/linker/Linker/DependencyInfo.cs b/src/tools/illink/src/linker/Linker/DependencyInfo.cs index 7eb0681c5a9347..7476baa99757ff 100644 --- a/src/tools/illink/src/linker/Linker/DependencyInfo.cs +++ b/src/tools/illink/src/linker/Linker/DependencyInfo.cs @@ -145,6 +145,8 @@ public enum DependencyKind DynamicallyAccessedMemberOnType = 88, // type with DynamicallyAccessedMembers annotations (including those inherited from base types and interfaces) UnsafeAccessorTarget = 89, // the member is referenced via UnsafeAccessor attribute + + DefaultImplementationForImplementingType = 90 // The member is a default implementation of an interface method for a type that (may) need the method } public readonly struct DependencyInfo : IEquatable diff --git a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs index 6fc0232896a8c1..46d790d4d634d8 100644 --- a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs +++ b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs @@ -42,7 +42,7 @@ public class TypeMapInfo readonly LinkContext context; protected readonly Dictionary> base_methods = new Dictionary> (); protected readonly Dictionary> override_methods = new Dictionary> (); - protected readonly Dictionary> default_interface_implementations = new Dictionary> (); + protected readonly Dictionary> default_interface_implementations = new Dictionary> (); public TypeMapInfo (LinkContext context) { @@ -84,9 +84,9 @@ public void EnsureProcessed (AssemblyDefinition assembly) return bases; } - public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface)>? GetDefaultInterfaceImplementations (MethodDefinition method) + public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface, MethodDefinition DefaultImplementationMethod)>? GetDefaultInterfaceImplementations (MethodDefinition baseMethod) { - default_interface_implementations.TryGetValue (method, out var ret); + default_interface_implementations.TryGetValue (baseMethod, out var ret); return ret; } @@ -110,14 +110,14 @@ public void AddOverride (MethodDefinition @base, MethodDefinition @override, Int methods.Add (new OverrideInformation (@base, @override, context, matchingInterfaceImplementation)); } - public void AddDefaultInterfaceImplementation (MethodDefinition @base, TypeDefinition implementingType, InterfaceImplementation matchingInterfaceImplementation) + public void AddDefaultInterfaceImplementation (MethodDefinition @base, TypeDefinition implementingType, (InterfaceImplementation, MethodDefinition) matchingInterfaceImplementation) { if (!default_interface_implementations.TryGetValue (@base, out var implementations)) { - implementations = new List<(TypeDefinition, InterfaceImplementation)> (); + implementations = new List<(TypeDefinition, InterfaceImplementation, MethodDefinition)> (); default_interface_implementations.Add (@base, implementations); } - implementations.Add ((implementingType, matchingInterfaceImplementation)); + implementations.Add ((implementingType, matchingInterfaceImplementation.Item1, matchingInterfaceImplementation.Item2)); } protected virtual void MapType (TypeDefinition type) @@ -276,7 +276,7 @@ void AnnotateMethods (MethodDefinition @base, MethodDefinition @override, Interf // Note that this returns a list to potentially cover the diamond case (more than one // most specific implementation of the given interface methods). ILLink needs to preserve // all the implementations so that the proper exception can be thrown at runtime. - IEnumerable GetDefaultInterfaceImplementations (TypeDefinition type, MethodDefinition interfaceMethod) + IEnumerable<(InterfaceImplementation, MethodDefinition)> GetDefaultInterfaceImplementations (TypeDefinition type, MethodDefinition interfaceMethod) { // Go over all interfaces, trying to find a method that is an explicit MethodImpl of the // interface method in question. @@ -290,7 +290,7 @@ IEnumerable GetDefaultInterfaceImplementations (TypeDef foreach (var potentialImplMethod in potentialImplInterface.Methods) { if (potentialImplMethod == interfaceMethod && !potentialImplMethod.IsAbstract) { - yield return interfaceImpl; + yield return (interfaceImpl, potentialImplMethod); } if (!potentialImplMethod.HasOverrides) @@ -299,7 +299,7 @@ IEnumerable GetDefaultInterfaceImplementations (TypeDef // This method is an override of something. Let's see if it's the method we are looking for. foreach (var @override in potentialImplMethod.Overrides) { if (context.TryResolve (@override) == interfaceMethod) { - yield return interfaceImpl; + yield return (interfaceImpl, potentialImplMethod); foundImpl = true; break; } diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs index b3a90330a9c8a2..7b92c0948991b1 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs @@ -1,4 +1,3 @@ - using Mono.Linker.Tests.Cases.Expectations.Assertions; namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.DefaultInterfaceMethods @@ -10,58 +9,94 @@ class MostSpecificDefaultImplementationKept public static void Main () { #if SUPPORTS_DEFAULT_INTERFACE_METHODS - M(); + M (); + NotUsedAsIBase.Keep (); + GenericType.M (); + #endif } #if SUPPORTS_DEFAULT_INTERFACE_METHODS [Kept] - static int M() where T : IBase { + static int M () where T : IBase + { return T.Value; } [Kept] - interface IBase { + interface IBase + { [Kept] - static virtual int Value - { + static virtual int Value { [Kept] - get=>0; + get => 0; + } + + static virtual int Value2 { + get => 0; } } - [Kept] - [KeptInterface(typeof(IBase))] - interface IMiddle : IBase { - [Kept] // Should be removable -- Add link to bug before merge - static int IBase.Value - { - [Kept] // Should be removable -- Add link to bug before merge - get=>1; + // [Kept] + // [KeptInterface (typeof (IBase))] + interface IMiddle : IBase + { + // [Kept] // Should be removable -- Add link to bug before merge + static int IBase.Value { + // [Kept] // Should be removable -- Add link to bug before merge + get => 1; } } [Kept] - [KeptInterface(typeof(IBase))] - [KeptInterface(typeof(IMiddle))] - interface IDerived : IMiddle { + [KeptInterface (typeof (IBase))] + [KeptInterface (typeof (IMiddle))] + interface IDerived : IMiddle + { [Kept] - static int IBase.Value - { + static int IBase.Value { [Kept] - get=>2; + get => 2; } } interface INotReferenced - {} + { } + + [Kept] + [KeptInterface (typeof (IDerived))] + [KeptInterface (typeof (IMiddle))] + [KeptInterface (typeof (IBase))] + class UsedAsIBase : IDerived, INotReferenced + { + } [Kept] - [KeptInterface(typeof(IDerived))] - [KeptInterface(typeof(IMiddle))] - [KeptInterface(typeof(IBase))] - struct Instance : IDerived, INotReferenced { + // [KeptInterface(typeof(IDerived))] + // [KeptInterface(typeof(IMiddle))] + // [KeptInterface(typeof(IBase))] + class NotUsedAsIBase : IDerived, INotReferenced + { + [Kept] + public static void Keep () { } + } + + [Kept] + class GenericType where T : IBase + { + [Kept] + public static int M () => T.Value; + } + + [Kept] + [KeptInterface (typeof (IDerived))] + [KeptInterface (typeof (IMiddle))] + [KeptInterface (typeof (IBase))] + class UsedAsIBase2 : IBase + { + [Kept] + public static int Value => 0; } #endif } diff --git a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs index 36031245d5191b..7d26b4c939d24d 100644 --- a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs +++ b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Globalization; using System.Linq; +using System.Linq.Expressions; using System.Text; using Mono.Cecil; using Mono.Cecil.Cil; @@ -37,19 +38,25 @@ public AssemblyChecker (AssemblyDefinition original, AssemblyDefinition linked, attr.Name == nameof (RemovedNameValueAttribute)); } - public void Verify () + static (IEnumerable Missing, IEnumerable Extra) SetDifference (IEnumerable expected, IEnumerable found) { - VerifyExportedTypes (originalAssembly, linkedAssembly); + var missing = expected.Except (found); + var extra = found.Except (expected); + return (missing, extra); + } - VerifyCustomAttributes (originalAssembly, linkedAssembly); - VerifySecurityAttributes (originalAssembly, linkedAssembly); + public void Verify () + { + IEnumerable failures = VerifyExportedTypes (originalAssembly, linkedAssembly); + failures = failures.Concat (VerifyCustomAttributes (originalAssembly, linkedAssembly)); + failures = failures.Concat (VerifySecurityAttributes (originalAssembly, linkedAssembly)); foreach (var originalModule in originalAssembly.Modules) - VerifyModule (originalModule, linkedAssembly.Modules.FirstOrDefault (m => m.Name == originalModule.Name)); + failures = failures.Concat (VerifyModule (originalModule, linkedAssembly.Modules.FirstOrDefault (m => m.Name == originalModule.Name))); - VerifyResources (originalAssembly, linkedAssembly); - VerifyReferences (originalAssembly, linkedAssembly); - VerifyKeptByAttributes (originalAssembly, originalAssembly.FullName); + failures = failures.Concat (VerifyResources (originalAssembly, linkedAssembly)); + failures = failures.Concat (VerifyReferences (originalAssembly, linkedAssembly)); + failures = failures.Concat (VerifyKeptByAttributes (originalAssembly, originalAssembly.FullName)); linkedMembers = new HashSet (linkedAssembly.MainModule.AllMembers ().Select (s => { return s.FullName; @@ -73,7 +80,7 @@ public void Verify () } TypeDefinition linkedType = linkedAssembly.MainModule.GetType (originalMember.FullName); - VerifyTypeDefinition (td, linkedType); + failures = failures.Concat (VerifyTypeDefinition (td, linkedType)); linkedMembers.Remove (td.FullName); continue; @@ -82,34 +89,37 @@ public void Verify () throw new NotImplementedException ($"Don't know how to check member of type {originalMember.GetType ()}"); } - Assert.IsEmpty (linkedMembers, "Linked output includes unexpected member"); + if (linkedMembers.Any ()) + failures = failures.Concat (linkedMembers.Select (m => $"Member `{m}' was not expected to be kept")); + Assert.IsEmpty (failures, string.Join (Environment.NewLine, failures)); } static bool IsBackingField (FieldDefinition field) => field.Name.StartsWith ("<") && field.Name.EndsWith (">k__BackingField"); - protected virtual void VerifyModule (ModuleDefinition original, ModuleDefinition linked) + protected virtual IEnumerable VerifyModule (ModuleDefinition original, ModuleDefinition linked) { // We never link away a module today so let's make sure the linked one isn't null if (linked == null) - Assert.Fail ($"Linked assembly `{original.Assembly.Name.Name}` is missing module `{original.Name}`"); + yield return $"Linked assembly `{original.Assembly.Name.Name}` is missing module `{original.Name}`"; var expected = original.Assembly.MainModule.AllDefinedTypes () .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptModuleReferenceAttribute))) - .ToArray (); + .ToHashSet (); var actual = linked.ModuleReferences .Select (name => name.Name) - .ToArray (); + .ToHashSet (); - Assert.That (actual, Is.EquivalentTo (expected)); + if (!expected.SetEquals (actual)) + yield return $"In module {original.FileName} Expected module references `{string.Join (", ", expected)}` but got `{string.Join (", ", actual)}`"; - VerifyCustomAttributes (original, linked); + foreach (var err in VerifyCustomAttributes (original, linked)) yield return err; } - protected virtual void VerifyTypeDefinition (TypeDefinition original, TypeDefinition linked) + protected virtual IEnumerable VerifyTypeDefinition (TypeDefinition original, TypeDefinition linked) { if (linked != null && verifiedGeneratedTypes.Contains (linked.FullName)) - return; + yield break; ModuleDefinition linkedModule = linked?.Module; @@ -126,7 +136,7 @@ protected virtual void VerifyTypeDefinition (TypeDefinition original, TypeDefini if (!expectedKept) { if (linked == null) - return; + yield break; // Compiler generated members can't be annotated with `Kept` attributes directly // For some of them we have special attributes (backing fields for example), but it's impractical to define @@ -137,13 +147,13 @@ protected virtual void VerifyTypeDefinition (TypeDefinition original, TypeDefini // we do want to validate. There's no specific use case right now, but I can easily imagine one // for more detailed testing of for example custom attributes on local functions, or similar. if (!IsCompilerGeneratedMember (original)) - Assert.Fail ($"Type `{original}' should have been removed"); + yield return $"Type `{original}' should have been removed"; } bool prev = checkNames; checkNames |= original.HasAttribute (nameof (VerifyMetadataNamesAttribute)); - VerifyTypeDefinitionKept (original, linked); + foreach (var err in VerifyTypeDefinitionKept (original, linked)) yield return err; checkNames = prev; @@ -151,7 +161,10 @@ protected virtual void VerifyTypeDefinition (TypeDefinition original, TypeDefini foreach (var attr in original.CustomAttributes.Where (l => l.AttributeType.Name == nameof (CreatedMemberAttribute))) { var newName = original.FullName + "::" + attr.ConstructorArguments[0].Value.ToString (); - Assert.AreEqual (1, linkedMembers.RemoveWhere (l => l.Contains (newName)), $"Newly created member '{newName}' was not found"); + var asdf = linkedMembers.Where (l => l.Contains (newName)).ToList (); + if (1 != linkedMembers.RemoveWhere (l => l.Contains (newName))) { + yield return $"Newly created member '{newName}' was not found"; + } } } } @@ -159,23 +172,21 @@ protected virtual void VerifyTypeDefinition (TypeDefinition original, TypeDefini /// /// Validates that all instances on a member are valid (i.e. ILLink recorded a marked dependency described in the attribute) /// - void VerifyKeptByAttributes (IMemberDefinition src, IMemberDefinition linked) + IEnumerable VerifyKeptByAttributes (IMemberDefinition src, IMemberDefinition linked) { - foreach (var keptByAttribute in src.CustomAttributes.Where (ca => ca.AttributeType.IsTypeOf ())) - VerifyKeptByAttribute (linked.FullName, keptByAttribute); + return src.CustomAttributes.Where (ca => ca.AttributeType.IsTypeOf ()).SelectMany (keptByAttribute => VerifyKeptByAttribute (linked.FullName, keptByAttribute)); } /// /// Validates that all instances on an attribute provider are valid (i.e. ILLink recorded a marked dependency described in the attribute) /// is the attribute provider that may have a , and is the 'FullName' of . /// - void VerifyKeptByAttributes (ICustomAttributeProvider src, string attributeProviderFullName) + IEnumerable VerifyKeptByAttributes (ICustomAttributeProvider src, string attributeProviderFullName) { - foreach (var keptByAttribute in src.CustomAttributes.Where (ca => ca.AttributeType.IsTypeOf ())) - VerifyKeptByAttribute (attributeProviderFullName, keptByAttribute); + return src.CustomAttributes.Where (ca => ca.AttributeType.IsTypeOf ()).SelectMany (attr => VerifyKeptByAttribute (attributeProviderFullName, attr)); } - void VerifyKeptByAttribute (string keptAttributeProviderName, CustomAttribute attribute) + IEnumerable VerifyKeptByAttribute (string keptAttributeProviderName, CustomAttribute attribute) { // public KeptByAttribute (string dependencyProvider, string reason) { } // public KeptByAttribute (Type dependencyProvider, string reason) { } @@ -213,57 +224,58 @@ void VerifyKeptByAttribute (string keptAttributeProviderName, CustomAttribute at foreach (var dep in this.linkedTestCase.Customizations.DependencyRecorder.Dependencies) { if (dep == expectedDependency) { - return; + yield break; } } - string errorMessage = $"{keptAttributeProviderName} was expected to be kept by {expectedDependency.Source} with reason {expectedDependency.DependencyKind.ToString ()}."; - Assert.Fail (errorMessage); + yield return $"{keptAttributeProviderName} was expected to be kept by {expectedDependency.Source} with reason {expectedDependency.DependencyKind.ToString ()}."; } - protected virtual void VerifyTypeDefinitionKept (TypeDefinition original, TypeDefinition linked) + protected virtual IEnumerable VerifyTypeDefinitionKept (TypeDefinition original, TypeDefinition linked) { - if (linked == null) - Assert.Fail ($"Type `{original}' should have been kept"); + if (linked == null) { + yield return $"Type `{original}' should have been kept"; + yield break; + } // Skip verification of type metadata for compiler generated types (we don't currently need it yet) if (!IsCompilerGeneratedMember (original)) { - VerifyKeptByAttributes (original, linked); + foreach (var err in VerifyKeptByAttributes (original, linked)) yield return err; if (!original.IsInterface) - VerifyBaseType (original, linked); + foreach (var err in VerifyBaseType (original, linked)) yield return err; - VerifyInterfaces (original, linked); - VerifyPseudoAttributes (original, linked); - VerifyGenericParameters (original, linked, compilerGenerated: false); - VerifyCustomAttributes (original, linked); - VerifySecurityAttributes (original, linked); + foreach (var err in VerifyInterfaces (original, linked)) yield return err; + foreach (var err in VerifyPseudoAttributes (original, linked)) yield return err; + foreach (var err in VerifyGenericParameters (original, linked, compilerGenerated: false)) yield return err; + foreach (var err in VerifyCustomAttributes (original, linked)) yield return err; + foreach (var err in VerifySecurityAttributes (original, linked)) yield return err; - VerifyFixedBufferFields (original, linked); + foreach (var err in VerifyFixedBufferFields (original, linked)) yield return err; } // Need to check delegate cache fields before the normal field check - VerifyDelegateBackingFields (original, linked); - VerifyPrivateImplementationDetails (original, linked); + foreach (var err in VerifyDelegateBackingFields (original, linked)) yield return err; + foreach (var err in VerifyPrivateImplementationDetails (original, linked)) yield return err; foreach (var td in original.NestedTypes) { - VerifyTypeDefinition (td, linked?.NestedTypes.FirstOrDefault (l => td.FullName == l.FullName)); + foreach (var err in VerifyTypeDefinition (td, linked?.NestedTypes.FirstOrDefault (l => td.FullName == l.FullName))) yield return err; linkedMembers.Remove (td.FullName); } // Need to check properties before fields so that the KeptBackingFieldAttribute is handled correctly foreach (var p in original.Properties) { - VerifyProperty (p, linked?.Properties.FirstOrDefault (l => p.Name == l.Name), linked); + foreach (var err in VerifyProperty (p, linked?.Properties.FirstOrDefault (l => p.Name == l.Name), linked)) yield return err; linkedMembers.Remove (p.FullName); } // Need to check events before fields so that the KeptBackingFieldAttribute is handled correctly foreach (var e in original.Events) { - VerifyEvent (e, linked?.Events.FirstOrDefault (l => e.Name == l.Name), linked); + foreach (var err in VerifyEvent (e, linked?.Events.FirstOrDefault (l => e.Name == l.Name), linked)) yield return err; linkedMembers.Remove (e.FullName); } foreach (var f in original.Fields) { if (verifiedGeneratedFields.Contains (f.FullName)) continue; - VerifyField (f, linked?.Fields.FirstOrDefault (l => f.Name == l.Name)); + foreach (var err in VerifyField (f, linked?.Fields.FirstOrDefault (l => f.Name == l.Name))) yield return err; linkedMembers.Remove (f.FullName); } @@ -271,12 +283,12 @@ protected virtual void VerifyTypeDefinitionKept (TypeDefinition original, TypeDe if (verifiedEventMethods.Contains (m.FullName)) continue; var msign = m.GetSignature (); - VerifyMethod (m, linked?.Methods.FirstOrDefault (l => msign == l.GetSignature ())); + foreach (var err in VerifyMethod (m, linked?.Methods.FirstOrDefault (l => msign == l.GetSignature ()))) yield return err; linkedMembers.Remove (m.FullName); } } - void VerifyBaseType (TypeDefinition src, TypeDefinition linked) + IEnumerable VerifyBaseType (TypeDefinition src, TypeDefinition linked) { string expectedBaseName; var expectedBaseGenericAttr = src.CustomAttributes.FirstOrDefault (w => w.AttributeType.Name == nameof (KeptBaseTypeAttribute) && w.ConstructorArguments.Count > 1); @@ -286,24 +298,31 @@ void VerifyBaseType (TypeDefinition src, TypeDefinition linked) var defaultBaseType = src.IsEnum ? "System.Enum" : src.IsValueType ? "System.ValueType" : "System.Object"; expectedBaseName = GetCustomAttributeCtorValues (src, nameof (KeptBaseTypeAttribute)).FirstOrDefault ()?.ToString () ?? defaultBaseType; } - Assert.AreEqual (expectedBaseName, linked.BaseType?.FullName, $"Incorrect base type on : {linked.Name}"); + if (expectedBaseName != linked.BaseType?.FullName) + yield return $"Incorrect base type on : {linked.Name}"; } - void VerifyInterfaces (TypeDefinition src, TypeDefinition linked) + IEnumerable VerifyInterfaces (TypeDefinition src, TypeDefinition linked) { var expectedInterfaces = new HashSet (src.CustomAttributes .Where (w => w.AttributeType.Name == nameof (KeptInterfaceAttribute)) .Select (FormatBaseOrInterfaceAttributeValue)); if (expectedInterfaces.Count == 0) { - Assert.IsFalse (linked.HasInterfaces, $"Type `{src}' has unexpected interfaces"); + if (linked.HasInterfaces) { + yield return $"Type `{src}' has unexpected interfaces"; + } } else { foreach (var iface in linked.Interfaces) { if (!expectedInterfaces.Remove (iface.InterfaceType.FullName)) { - Assert.IsTrue (expectedInterfaces.Remove (iface.InterfaceType.Resolve ().FullName), $"Type `{src}' interface `{iface.InterfaceType.Resolve ().FullName}' should have been removed"); + if (!expectedInterfaces.Remove (iface.InterfaceType.Resolve ().FullName)) { + yield return $"Type `{src}' interface `{iface.InterfaceType.FullName}' should have been removed"; + } } } - Assert.IsEmpty (expectedInterfaces, $"Expected interfaces were not found on {src}"); + if (expectedInterfaces.Any ()) { + yield return $"Expected interfaces were not found on {src}"; + } } } @@ -377,7 +396,7 @@ static string FormatBaseOrInterfaceAttributeValue (CustomAttribute attr) return builder.ToString (); } - void VerifyField (FieldDefinition src, FieldDefinition linked) + IEnumerable VerifyField (FieldDefinition src, FieldDefinition linked) { bool compilerGenerated = IsCompilerGeneratedMember (src); bool expectedKept = ShouldBeKept (src) || @@ -385,115 +404,123 @@ void VerifyField (FieldDefinition src, FieldDefinition linked) if (!expectedKept) { if (linked != null) - Assert.Fail ($"Field `{src}' should have been removed"); + yield return $"Field `{src}' should have been removed"; - return; + yield break; } - VerifyFieldKept (src, linked, compilerGenerated); + foreach (var err in VerifyFieldKept (src, linked, compilerGenerated)) yield return err; } - void VerifyFieldKept (FieldDefinition src, FieldDefinition linked, bool compilerGenerated) + IEnumerable VerifyFieldKept (FieldDefinition src, FieldDefinition linked, bool compilerGenerated) { - if (linked == null) - Assert.Fail ($"Field `{src}' should have been kept"); + if (linked == null) { + yield return $"Field `{src}' should have been kept"; + yield break; ; + } - Assert.AreEqual (src?.Constant, linked?.Constant, $"Field `{src}' value"); + if (src?.Constant != linked?.Constant) + yield return $"Field `{src}' value"; - VerifyKeptByAttributes (src, linked); + foreach (var err in VerifyKeptByAttributes (src, linked)) yield return err; VerifyPseudoAttributes (src, linked); if (!compilerGenerated) - VerifyCustomAttributes (src, linked); + foreach (var err in VerifyCustomAttributes (src, linked)) yield return err; } - void VerifyProperty (PropertyDefinition src, PropertyDefinition linked, TypeDefinition linkedType) + IEnumerable VerifyProperty (PropertyDefinition src, PropertyDefinition linked, TypeDefinition linkedType) { - VerifyMemberBackingField (src, linkedType); + foreach (var err in VerifyMemberBackingField (src, linkedType)) yield return err; bool compilerGenerated = IsCompilerGeneratedMember (src); bool expectedKept = ShouldBeKept (src) || compilerGenerated; if (!expectedKept) { if (linked != null) - Assert.Fail ($"Property `{src}' should have been removed"); + yield return $"Property `{src}' should have been removed"; - return; + yield break; } - if (linked == null) - Assert.Fail ($"Property `{src}' should have been kept"); + if (linked == null) { + yield return $"Property `{src}' should have been kept"; + yield break; + } - Assert.AreEqual (src?.Constant, linked?.Constant, $"Property `{src}' value"); + if (src?.Constant != linked?.Constant) + yield return $"Property `{src}' value"; - VerifyKeptByAttributes (src, linked); - VerifyPseudoAttributes (src, linked); + foreach (var err in VerifyKeptByAttributes (src, linked)) yield return err; + foreach (var err in VerifyPseudoAttributes (src, linked)) yield return err; if (!compilerGenerated) - VerifyCustomAttributes (src, linked); + foreach (var err in VerifyCustomAttributes (src, linked)) yield return err; } - void VerifyEvent (EventDefinition src, EventDefinition linked, TypeDefinition linkedType) + IEnumerable VerifyEvent (EventDefinition src, EventDefinition linked, TypeDefinition linkedType) { - VerifyMemberBackingField (src, linkedType); + foreach (var err in VerifyMemberBackingField (src, linkedType)) yield return err; bool compilerGenerated = IsCompilerGeneratedMember (src); bool expectedKept = ShouldBeKept (src) || compilerGenerated; if (!expectedKept) { if (linked != null) - Assert.Fail ($"Event `{src}' should have been removed"); + yield return $"Event `{src}' should have been removed"; - return; + yield break; } - if (linked == null) - Assert.Fail ($"Event `{src}' should have been kept"); + if (linked == null) { + yield return $"Event `{src}' should have been kept"; + yield break; + } if (src.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (KeptEventAddMethodAttribute))) { - VerifyMethodInternal (src.AddMethod, linked.AddMethod, true, compilerGenerated); + foreach (var err in VerifyMethodInternal (src.AddMethod, linked.AddMethod, true, compilerGenerated)) yield return err; verifiedEventMethods.Add (src.AddMethod.FullName); linkedMembers.Remove (src.AddMethod.FullName); } if (src.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (KeptEventRemoveMethodAttribute))) { - VerifyMethodInternal (src.RemoveMethod, linked.RemoveMethod, true, compilerGenerated); + foreach (var err in VerifyMethodInternal (src.RemoveMethod, linked.RemoveMethod, true, compilerGenerated)) yield return err; verifiedEventMethods.Add (src.RemoveMethod.FullName); linkedMembers.Remove (src.RemoveMethod.FullName); } - VerifyKeptByAttributes (src, linked); - VerifyPseudoAttributes (src, linked); + foreach (var err in VerifyKeptByAttributes (src, linked)) yield return err; + foreach (var err in VerifyPseudoAttributes (src, linked)) yield return err; if (!compilerGenerated) - VerifyCustomAttributes (src, linked); + foreach (var err in VerifyCustomAttributes (src, linked)) yield return err; } - void VerifyMethod (MethodDefinition src, MethodDefinition linked) + IEnumerable VerifyMethod (MethodDefinition src, MethodDefinition linked) { bool compilerGenerated = IsCompilerGeneratedMember (src); bool expectedKept = ShouldMethodBeKept (src); - VerifyMethodInternal (src, linked, expectedKept, compilerGenerated); + foreach (var err in VerifyMethodInternal (src, linked, expectedKept, compilerGenerated)) yield return err; } - void VerifyMethodInternal (MethodDefinition src, MethodDefinition linked, bool expectedKept, bool compilerGenerated) + IEnumerable VerifyMethodInternal (MethodDefinition src, MethodDefinition linked, bool expectedKept, bool compilerGenerated) { if (!expectedKept) { if (linked == null) - return; + yield break; // Similar to comment on types, compiler-generated methods can't be annotated with Kept attribute directly // so we're not going to validate kept/remove on them. Note that we're still going to go validate "into" them // to check for other properties (like parameter name presence/removal for example) if (!compilerGenerated) - Assert.Fail ($"Method `{src.FullName}' should have been removed"); + yield return $"Method `{src.FullName}' should have been removed"; } - VerifyMethodKept (src, linked, compilerGenerated); + foreach (var err in VerifyMethodKept (src, linked, compilerGenerated)) yield return err; } - void VerifyMemberBackingField (IMemberDefinition src, TypeDefinition linkedType) + IEnumerable VerifyMemberBackingField (IMemberDefinition src, TypeDefinition linkedType) { var keptBackingFieldAttribute = src.CustomAttributes.FirstOrDefault (attr => attr.AttributeType.Name == nameof (KeptBackingFieldAttribute)); if (keptBackingFieldAttribute == null) - return; + yield break; var backingFieldName = src.MetadataToken.TokenType == TokenType.Property ? $"<{src.Name}>k__BackingField" : src.Name; @@ -508,51 +535,56 @@ void VerifyMemberBackingField (IMemberDefinition src, TypeDefinition linkedType) srcField = src.DeclaringType.Fields.FirstOrDefault (f => f.Name == backingFieldName); } - if (srcField == null) - Assert.Fail ($"{src.MetadataToken.TokenType} `{src}', could not locate the expected backing field {backingFieldName}"); + if (srcField == null) { + yield return $"{src.MetadataToken.TokenType} `{src}', could not locate the expected backing field {backingFieldName}"; + yield break; + } - VerifyFieldKept (srcField, linkedType?.Fields.FirstOrDefault (l => srcField.Name == l.Name), compilerGenerated: true); + foreach (var err in VerifyFieldKept (srcField, linkedType?.Fields.FirstOrDefault (l => srcField.Name == l.Name), compilerGenerated: true)) yield return err; verifiedGeneratedFields.Add (srcField.FullName); linkedMembers.Remove (srcField.FullName); } - protected virtual void VerifyMethodKept (MethodDefinition src, MethodDefinition linked, bool compilerGenerated) + protected virtual IEnumerable VerifyMethodKept (MethodDefinition src, MethodDefinition linked, bool compilerGenerated) { - if (linked == null) - Assert.Fail ($"Method `{src.FullName}' should have been kept"); + if (linked == null) { + yield return $"Method `{src.FullName}' should have been kept"; + yield break; + } - VerifyPseudoAttributes (src, linked); - VerifyGenericParameters (src, linked, compilerGenerated); + foreach (var err in VerifyPseudoAttributes (src, linked)) yield return err; + foreach (var err in VerifyGenericParameters (src, linked, compilerGenerated)) yield return err; if (!compilerGenerated) { - VerifyCustomAttributes (src, linked); - VerifyCustomAttributes (src.MethodReturnType, linked.MethodReturnType); + foreach (var err in VerifyCustomAttributes (src, linked)) yield return err; + foreach (var err in VerifyCustomAttributes (src.MethodReturnType, linked.MethodReturnType)) yield return err; + } - VerifyParameters (src, linked, compilerGenerated); - VerifySecurityAttributes (src, linked); - VerifyArrayInitializers (src, linked); - VerifyMethodBody (src, linked); - VerifyKeptByAttributes (src, linked); + foreach (var err in VerifyParameters (src, linked, compilerGenerated)) yield return err; + foreach (var err in VerifySecurityAttributes (src, linked)) yield return err; + foreach (var err in VerifyArrayInitializers (src, linked)) yield return err; + foreach (var err in VerifyMethodBody (src, linked)) yield return err; + foreach (var err in VerifyKeptByAttributes (src, linked)) yield return err; } - protected virtual void VerifyMethodBody (MethodDefinition src, MethodDefinition linked) + protected virtual IEnumerable VerifyMethodBody (MethodDefinition src, MethodDefinition linked) { if (!src.HasBody) - return; + yield break; - VerifyInstructions (src, linked); - VerifyLocals (src, linked); + foreach (var err in VerifyInstructions (src, linked)) yield return err; + foreach (var err in VerifyLocals (src, linked)) yield return err; } - protected static void VerifyInstructions (MethodDefinition src, MethodDefinition linked) + protected static IEnumerable VerifyInstructions (MethodDefinition src, MethodDefinition linked) { - VerifyBodyProperties ( + foreach (var err in VerifyBodyProperties ( src, linked, nameof (ExpectedInstructionSequenceAttribute), nameof (ExpectBodyModifiedAttribute), "instructions", m => FormatMethodBody (m.Body), - attr => GetStringArrayAttributeValue (attr).ToArray ()); + attr => GetStringArrayAttributeValue (attr).ToArray ())) yield return err; } public static string[] FormatMethodBody (MethodBody body) @@ -669,19 +701,19 @@ static string FormatInstruction (Instruction instr) } } - static void VerifyLocals (MethodDefinition src, MethodDefinition linked) + static IEnumerable VerifyLocals (MethodDefinition src, MethodDefinition linked) { - VerifyBodyProperties ( + foreach (var err in VerifyBodyProperties ( src, linked, nameof (ExpectedLocalsSequenceAttribute), nameof (ExpectLocalsModifiedAttribute), "locals", m => m.Body.Variables.Select (v => v.VariableType.ToString ()).ToArray (), - attr => GetStringOrTypeArrayAttributeValue (attr).ToArray ()); + attr => GetStringOrTypeArrayAttributeValue (attr).ToArray ())) yield return err; } - public static void VerifyBodyProperties (MethodDefinition src, MethodDefinition linked, string sequenceAttributeName, string expectModifiedAttributeName, + public static IEnumerable VerifyBodyProperties (MethodDefinition src, MethodDefinition linked, string sequenceAttributeName, string expectModifiedAttributeName, string propertyDescription, Func valueCollector, Func getExpectFromSequenceAttribute) @@ -691,30 +723,27 @@ public static void VerifyBodyProperties (MethodDefinition src, MethodDefinition var srcValues = valueCollector (src); if (src.CustomAttributes.Any (attr => attr.AttributeType.Name == expectModifiedAttributeName)) { - Assert.That ( - linkedValues, - Is.Not.EqualTo (srcValues), - $"Expected method `{src} to have {propertyDescription} modified, however, the {propertyDescription} were the same as the original\n{FormattingUtils.FormatSequenceCompareFailureMessage (linkedValues, srcValues)}"); + if (linkedValues.ToHashSet ().SetEquals (srcValues.ToHashSet ())) { + yield return $"Expected method `{src} to have it's {propertyDescription} modified, however, the {propertyDescription} were the same as the original\n{FormattingUtils.FormatSequenceCompareFailureMessage (linkedValues, srcValues)}"; + } } else if (expectedSequenceAttribute != null) { var expected = getExpectFromSequenceAttribute (expectedSequenceAttribute).ToArray (); - Assert.That ( - linkedValues, - Is.EqualTo (expected), - $"Expected method `{src} to have it's {propertyDescription} modified, however, the sequence of {propertyDescription} does not match the expected value\n{FormattingUtils.FormatSequenceCompareFailureMessage2 (linkedValues, expected, srcValues)}"); + if (!linkedValues.ToHashSet ().SetEquals (expected.ToHashSet ())) { + yield return $"Expected method `{src} to have it's {propertyDescription} modified, however, the sequence of {propertyDescription} does not match the expected value\n{FormattingUtils.FormatSequenceCompareFailureMessage2 (linkedValues, expected, srcValues)}"; + } } else { - Assert.That ( - linkedValues, - Is.EqualTo (srcValues), - $"Expected method `{src} to have it's {propertyDescription} unchanged, however, the {propertyDescription} differ from the original\n{FormattingUtils.FormatSequenceCompareFailureMessage (linkedValues, srcValues)}"); + if (!linkedValues.ToHashSet ().SetEquals (srcValues.ToHashSet ())) { + yield return $"Expected method `{src} to have it's {propertyDescription} unchanged, however, the {propertyDescription} differ from the original\n{FormattingUtils.FormatSequenceCompareFailureMessage (linkedValues, srcValues)}"; + } } } - void VerifyReferences (AssemblyDefinition original, AssemblyDefinition linked) + IEnumerable VerifyReferences (AssemblyDefinition original, AssemblyDefinition linked) { var expected = original.MainModule.AllDefinedTypes () .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptReferenceAttribute))) .Select (ReduceAssemblyFileNameOrNameToNameOnly) - .ToArray (); + .ToHashSet (); /* - The test case will always need to have at least 1 reference. @@ -725,14 +754,15 @@ void VerifyReferences (AssemblyDefinition original, AssemblyDefinition linked) Once 1 kept reference attribute is used, the test will need to define all of of it's expected references */ - if (expected.Length == 0) - return; + if (expected.Count == 0) + yield break; var actual = linked.MainModule.AssemblyReferences .Select (name => name.Name) - .ToArray (); + .ToHashSet (); - Assert.That (actual, Is.EquivalentTo (expected)); + if (!expected.SetEquals (actual)) + yield return $"Expected references `{string.Join (", ", expected)}` do not match actual references `{string.Join (", ", actual)}`"; } string ReduceAssemblyFileNameOrNameToNameOnly (string fileNameOrAssemblyName) @@ -744,7 +774,7 @@ string ReduceAssemblyFileNameOrNameToNameOnly (string fileNameOrAssemblyName) return fileNameOrAssemblyName; } - void VerifyResources (AssemblyDefinition original, AssemblyDefinition linked) + IEnumerable VerifyResources (AssemblyDefinition original, AssemblyDefinition linked) { var expectedResourceNames = original.MainModule.AllDefinedTypes () .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptResourceAttribute))) @@ -752,111 +782,129 @@ void VerifyResources (AssemblyDefinition original, AssemblyDefinition linked) foreach (var resource in linked.MainModule.Resources) { if (!expectedResourceNames.Remove (resource.Name)) - Assert.Fail ($"Resource '{resource.Name}' should be removed."); + yield return $"Resource '{resource.Name}' should be removed."; EmbeddedResource embeddedResource = (EmbeddedResource) resource; var expectedResource = (EmbeddedResource) original.MainModule.Resources.First (r => r.Name == resource.Name); - Assert.That (embeddedResource.GetResourceData (), Is.EquivalentTo (expectedResource.GetResourceData ()), $"Resource '{resource.Name}' data doesn't match."); + if (!embeddedResource.GetResourceData ().SequenceEqual(expectedResource.GetResourceData ())) + yield return $"Resource '{resource.Name}' data doesn't match."; } - Assert.IsEmpty (expectedResourceNames, $"Resource '{expectedResourceNames.FirstOrDefault ()}' should be kept."); + if (expectedResourceNames.Any ()) yield return $"Resource '{expectedResourceNames.FirstOrDefault ()}' should be kept."; } - void VerifyExportedTypes (AssemblyDefinition original, AssemblyDefinition linked) + IEnumerable VerifyExportedTypes (AssemblyDefinition original, AssemblyDefinition linked) { var expectedTypes = original.MainModule.AllDefinedTypes () - .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptExportedTypeAttribute)).Select (l => l.FullName)).ToArray (); + .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptExportedTypeAttribute)).Select (l => l.FullName)); - Assert.That (linked.MainModule.ExportedTypes.Select (l => l.FullName), Is.EquivalentTo (expectedTypes)); + if (!linked.MainModule.ExportedTypes.Select (l => l.FullName).ToHashSet ().SetEquals (expectedTypes.ToHashSet ())) + yield return $"Exported types do not match expected."; } - protected virtual void VerifyPseudoAttributes (MethodDefinition src, MethodDefinition linked) + protected virtual IEnumerable VerifyPseudoAttributes (MethodDefinition src, MethodDefinition linked) { var expected = (MethodAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - Assert.AreEqual (expected, linked.Attributes, $"Method `{src}' pseudo attributes did not match expected"); + if (expected != linked.Attributes) + yield return $"Method `{src}' pseudo attributes did not match expected"; } - protected virtual void VerifyPseudoAttributes (TypeDefinition src, TypeDefinition linked) + protected virtual IEnumerable VerifyPseudoAttributes (TypeDefinition src, TypeDefinition linked) { var expected = (TypeAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - Assert.AreEqual (expected, linked.Attributes, $"Type `{src}' pseudo attributes did not match expected"); + if (expected == linked.Attributes) + yield break; + + yield return $"Type `{src}' pseudo attributes did not match expected"; } - protected virtual void VerifyPseudoAttributes (FieldDefinition src, FieldDefinition linked) + protected virtual IEnumerable VerifyPseudoAttributes (FieldDefinition src, FieldDefinition linked) { var expected = (FieldAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - Assert.AreEqual (expected, linked.Attributes, $"Field `{src}' pseudo attributes did not match expected"); + if (expected != linked.Attributes) yield return $"Field `{src}' pseudo attributes did not match expected"; } - protected virtual void VerifyPseudoAttributes (PropertyDefinition src, PropertyDefinition linked) + protected virtual IEnumerable VerifyPseudoAttributes (PropertyDefinition src, PropertyDefinition linked) { var expected = (PropertyAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - Assert.AreEqual (expected, linked.Attributes, $"Property `{src}' pseudo attributes did not match expected"); + if (expected != linked.Attributes) yield return $"Property `{src}' pseudo attributes did not match expected"; } - protected virtual void VerifyPseudoAttributes (EventDefinition src, EventDefinition linked) + protected virtual IEnumerable VerifyPseudoAttributes (EventDefinition src, EventDefinition linked) { var expected = (EventAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - Assert.AreEqual (expected, linked.Attributes, $"Event `{src}' pseudo attributes did not match expected"); + if (expected != linked.Attributes) yield return $"Event `{src}' pseudo attributes did not match expected"; } - protected virtual void VerifyCustomAttributes (ICustomAttributeProvider src, ICustomAttributeProvider linked) + protected virtual IEnumerable VerifyCustomAttributes (ICustomAttributeProvider src, ICustomAttributeProvider linked) { - var expectedAttrs = GetExpectedAttributes (src).ToList (); - var linkedAttrs = FilterLinkedAttributes (linked).ToList (); + var expectedAttrs = GetExpectedAttributes (src).ToHashSet (); + var linkedAttrs = FilterLinkedAttributes (linked).ToHashSet (); + if (!linkedAttrs.SetEquals (expectedAttrs)) { + var missing = $"Missing: {string.Join (", ", expectedAttrs.Except (linkedAttrs))}"; + var extra = $"Extra: {string.Join (", ", linkedAttrs.Except (expectedAttrs))}"; - Assert.That (linkedAttrs, Is.EquivalentTo (expectedAttrs), $"Custom attributes on `{src}' are not matching"); + yield return string.Join (Environment.NewLine, $"Custom attributes on `{src}' are not matching:", missing, extra); + } } - protected virtual void VerifySecurityAttributes (ICustomAttributeProvider src, ISecurityDeclarationProvider linked) + protected virtual IEnumerable VerifySecurityAttributes (ICustomAttributeProvider src, ISecurityDeclarationProvider linked) { var expectedAttrs = GetCustomAttributeCtorValues (src, nameof (KeptSecurityAttribute)) .Select (attr => attr.ToString ()) - .ToList (); + .ToHashSet (); - var linkedAttrs = FilterLinkedSecurityAttributes (linked).ToList (); + var linkedAttrs = FilterLinkedSecurityAttributes (linked).ToHashSet (); - Assert.That (linkedAttrs, Is.EquivalentTo (expectedAttrs), $"Security attributes on `{src}' are not matching"); + if (!linkedAttrs.SetEquals (expectedAttrs)) { + var missing = $"Missing: {string.Join (", ", expectedAttrs.Except (linkedAttrs))}"; + var extra = $"Extra: {string.Join (", ", linkedAttrs.Except (expectedAttrs))}"; + yield return string.Join ($"Security attributes on `{src}' are not matching:", missing, extra); + } } - void VerifyPrivateImplementationDetails (TypeDefinition original, TypeDefinition linked) + IEnumerable VerifyPrivateImplementationDetails (TypeDefinition original, TypeDefinition linked) { var expectedImplementationDetailsMethods = GetCustomAttributeCtorValues (original, nameof (KeptPrivateImplementationDetailsAttribute)) .Select (attr => attr.ToString ()) .ToList (); if (expectedImplementationDetailsMethods.Count == 0) - return; + yield break; - VerifyPrivateImplementationDetailsType (original.Module, linked.Module, out TypeDefinition srcImplementationDetails, out TypeDefinition linkedImplementationDetails); + TypeDefinition srcImplementationDetails; + TypeDefinition linkedImplementationDetails; + foreach (var err in VerifyPrivateImplementationDetailsType (original.Module, linked.Module, out srcImplementationDetails, out linkedImplementationDetails)) yield return err; foreach (var methodName in expectedImplementationDetailsMethods) { var originalMethod = srcImplementationDetails.Methods.FirstOrDefault (m => m.Name == methodName); if (originalMethod == null) - Assert.Fail ($"Could not locate original private implementation details method {methodName}"); + yield return $"Could not locate original private implementation details method {methodName}"; var linkedMethod = linkedImplementationDetails.Methods.FirstOrDefault (m => m.Name == methodName); - VerifyMethodKept (originalMethod, linkedMethod, compilerGenerated: true); + foreach (var erro in VerifyMethodKept (originalMethod, linkedMethod, compilerGenerated: true)) yield return erro; linkedMembers.Remove (linkedMethod.FullName); } verifiedGeneratedTypes.Add (srcImplementationDetails.FullName); } - static void VerifyPrivateImplementationDetailsType (ModuleDefinition src, ModuleDefinition linked, out TypeDefinition srcImplementationDetails, out TypeDefinition linkedImplementationDetails) + static IEnumerable VerifyPrivateImplementationDetailsType (ModuleDefinition src, ModuleDefinition linked, out TypeDefinition srcImplementationDetails, out TypeDefinition linkedImplementationDetails) { srcImplementationDetails = src.Types.FirstOrDefault (t => IsPrivateImplementationDetailsType (t)); - if (srcImplementationDetails == null) - Assert.Fail ("Could not locate in the original assembly. Does your test use initializers?"); - linkedImplementationDetails = linked.Types.FirstOrDefault (t => IsPrivateImplementationDetailsType (t)); - - if (linkedImplementationDetails == null) - Assert.Fail ("Could not locate in the linked assembly"); + const string srcMissingMessage = "Could not locate in the original assembly. Does your test use initializers?"; + const string linkedMissingMessage = "Could not locate in the linked assembly"; + return (srcImplementationDetails, linkedImplementationDetails) switch { + (null, null) => [srcMissingMessage, linkedMissingMessage], + (null, _) => [srcMissingMessage], + (_, null) => [linkedMissingMessage], + _ => Enumerable.Empty () + }; } - protected virtual void VerifyArrayInitializers (MethodDefinition src, MethodDefinition linked) + protected virtual IEnumerable VerifyArrayInitializers (MethodDefinition src, MethodDefinition linked) { var expectedIndices = GetCustomAttributeCtorValues (src, nameof (KeptInitializerData)) .Cast () @@ -865,12 +913,13 @@ protected virtual void VerifyArrayInitializers (MethodDefinition src, MethodDefi var expectKeptAll = src.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (KeptInitializerData) && !attr.HasConstructorArguments); if (expectedIndices.Length == 0 && !expectKeptAll) - return; + yield break; if (!src.HasBody) - Assert.Fail ($"`{nameof (KeptInitializerData)}` cannot be used on methods that don't have bodies"); - - VerifyPrivateImplementationDetailsType (src.Module, linked.Module, out TypeDefinition srcImplementationDetails, out TypeDefinition linkedImplementationDetails); + yield return $"`{nameof (KeptInitializerData)}` cannot be used on methods that don't have bodies"; + TypeDefinition srcImplementationDetails; + TypeDefinition linkedImplementationDetails; + foreach (var err in VerifyPrivateImplementationDetailsType (src.Module, linked.Module, out srcImplementationDetails, out linkedImplementationDetails)) yield return err; var possibleInitializerFields = src.Body.Instructions .Where (ins => IsLdtokenOnPrivateImplementationDetails (srcImplementationDetails, ins)) @@ -878,32 +927,32 @@ protected virtual void VerifyArrayInitializers (MethodDefinition src, MethodDefi .ToArray (); if (possibleInitializerFields.Length == 0) - Assert.Fail ($"`{src}` does not make use of any initializers"); + yield return $"`{src}` does not make use of any initializers"; if (expectKeptAll) { foreach (var srcField in possibleInitializerFields) { var linkedField = linkedImplementationDetails.Fields.FirstOrDefault (f => f.InitialValue.SequenceEqual (srcField.InitialValue)); - VerifyInitializerField (srcField, linkedField); + foreach (var err in VerifyInitializerField (srcField, linkedField)) yield return err; } } else { foreach (var index in expectedIndices) { if (index < 0 || index > possibleInitializerFields.Length) - Assert.Fail ($"Invalid expected index `{index}` in {src}. Value must be between 0 and {expectedIndices.Length}"); + yield return $"Invalid expected index `{index}` in {src}. Value must be between 0 and {expectedIndices.Length}"; var srcField = possibleInitializerFields[index]; var linkedField = linkedImplementationDetails.Fields.FirstOrDefault (f => f.InitialValue.SequenceEqual (srcField.InitialValue)); - VerifyInitializerField (srcField, linkedField); + foreach (var err in VerifyInitializerField (srcField, linkedField)) yield return err; } } } - void VerifyInitializerField (FieldDefinition src, FieldDefinition linked) + IEnumerable VerifyInitializerField (FieldDefinition src, FieldDefinition linked) { - VerifyFieldKept (src, linked, compilerGenerated: true); + foreach (var err in VerifyFieldKept (src, linked, compilerGenerated: true)) yield return err; verifiedGeneratedFields.Add (linked.FullName); linkedMembers.Remove (linked.FullName); - VerifyTypeDefinitionKept (src.FieldType.Resolve (), linked.FieldType.Resolve ()); + foreach (var err in VerifyTypeDefinitionKept (src.FieldType.Resolve (), linked.FieldType.Resolve ())) yield return err; linkedMembers.Remove (linked.FieldType.FullName); linkedMembers.Remove (linked.DeclaringType.FullName); verifiedGeneratedTypes.Add (linked.DeclaringType.FullName); @@ -977,7 +1026,7 @@ protected virtual IEnumerable FilterLinkedSecurityAttributes (ISecurityD .Select (attr => attr.AttributeType.ToString ()); } - void VerifyFixedBufferFields (TypeDefinition src, TypeDefinition linked) + IEnumerable VerifyFixedBufferFields (TypeDefinition src, TypeDefinition linked) { var fields = src.Fields.Where (f => f.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (KeptFixedBufferAttribute))); @@ -987,28 +1036,28 @@ void VerifyFixedBufferFields (TypeDefinition src, TypeDefinition linked) // while mcs and other versions of csc name it `__FixedBuffer0` var originalCompilerGeneratedBufferType = src.NestedTypes.FirstOrDefault (t => t.FullName.Contains ($"<{field.Name}>") && t.FullName.Contains ("__FixedBuffer")); if (originalCompilerGeneratedBufferType == null) - Assert.Fail ($"Could not locate original compiler generated fixed buffer type for field {field}"); + yield return $"Could not locate original compiler generated fixed buffer type for field {field}"; var linkedCompilerGeneratedBufferType = linked.NestedTypes.FirstOrDefault (t => t.Name == originalCompilerGeneratedBufferType.Name); if (linkedCompilerGeneratedBufferType == null) - Assert.Fail ($"Missing expected type {originalCompilerGeneratedBufferType}"); + yield return $"Missing expected type {originalCompilerGeneratedBufferType}"; // Have to verify the field before the type var originalElementField = originalCompilerGeneratedBufferType.Fields.FirstOrDefault (); if (originalElementField == null) - Assert.Fail ($"Could not locate original compiler generated FixedElementField on {originalCompilerGeneratedBufferType}"); + yield return $"Could not locate original compiler generated FixedElementField on {originalCompilerGeneratedBufferType}"; var linkedField = linkedCompilerGeneratedBufferType?.Fields.FirstOrDefault (); - VerifyFieldKept (originalElementField, linkedField, compilerGenerated: true); + foreach (var err in VerifyFieldKept (originalElementField, linkedField, compilerGenerated: true)) yield return err; verifiedGeneratedFields.Add (originalElementField.FullName); linkedMembers.Remove (linkedField.FullName); - VerifyTypeDefinitionKept (originalCompilerGeneratedBufferType, linkedCompilerGeneratedBufferType); + foreach (var err in VerifyTypeDefinitionKept (originalCompilerGeneratedBufferType, linkedCompilerGeneratedBufferType)) yield return err; verifiedGeneratedTypes.Add (originalCompilerGeneratedBufferType.FullName); } } - void VerifyDelegateBackingFields (TypeDefinition src, TypeDefinition linked) + IEnumerable VerifyDelegateBackingFields (TypeDefinition src, TypeDefinition linked) { var expectedFieldNames = src.CustomAttributes .Where (a => a.AttributeType.Name == nameof (KeptDelegateCacheFieldAttribute)) @@ -1017,7 +1066,7 @@ void VerifyDelegateBackingFields (TypeDefinition src, TypeDefinition linked) .ToList (); if (expectedFieldNames.Count == 0) - return; + yield break; foreach (var nestedType in src.NestedTypes) { if (!IsDelegateBackingFieldsType (nestedType)) @@ -1027,20 +1076,20 @@ void VerifyDelegateBackingFields (TypeDefinition src, TypeDefinition linked) foreach (var expectedFieldName in expectedFieldNames) { var originalField = nestedType.Fields.FirstOrDefault (f => f.Name == expectedFieldName); if (originalField is null) - Assert.Fail ($"Invalid expected delegate backing field {expectedFieldName} in {src}. This member was not in the unlinked assembly"); + yield return $"Invalid expected delegate backing field {expectedFieldName} in {src}. This member was not in the unlinked assembly"; var linkedField = linkedNestedType?.Fields.FirstOrDefault (f => f.Name == expectedFieldName); - VerifyFieldKept (originalField, linkedField, compilerGenerated: true); + foreach (var err in VerifyFieldKept (originalField, linkedField, compilerGenerated: true)) yield return err; verifiedGeneratedFields.Add (linkedField.FullName); linkedMembers.Remove (linkedField.FullName); } - VerifyTypeDefinitionKept (nestedType, linkedNestedType); + foreach (var err in VerifyTypeDefinitionKept (nestedType, linkedNestedType)) yield return err; verifiedGeneratedTypes.Add (linkedNestedType.FullName); } } - void VerifyGenericParameters (IGenericParameterProvider src, IGenericParameterProvider linked, bool compilerGenerated) + IEnumerable VerifyGenericParameters (IGenericParameterProvider src, IGenericParameterProvider linked, bool compilerGenerated) { Assert.AreEqual (src.HasGenericParameters, linked.HasGenericParameters); if (src.HasGenericParameters) { @@ -1050,38 +1099,47 @@ void VerifyGenericParameters (IGenericParameterProvider src, IGenericParameterPr var lnkp = linked.GenericParameters[i]; if (!compilerGenerated) { - VerifyCustomAttributes (srcp, lnkp); + foreach (var err in VerifyCustomAttributes (srcp, lnkp)) yield return err; } if (checkNames) { if (srcp.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (RemovedNameValueAttribute))) { string name = (src.GenericParameterType == GenericParameterType.Method ? "!!" : "!") + srcp.Position; - Assert.AreEqual (name, lnkp.Name, "Expected empty generic parameter name"); + if (name != lnkp.Name) { + yield return $"Expected empty generic parameter name. Parameter {i} of {(src.ToString ())}"; + } } else { - Assert.AreEqual (srcp.Name, lnkp.Name, "Mismatch in generic parameter name"); + if (srcp.Name != lnkp.Name) { + yield return $"Mismatch in generic parameter name. Parameter {i} of {(src.ToString ())}"; + } } } } } } - void VerifyParameters (IMethodSignature src, IMethodSignature linked, bool compilerGenerated) + IEnumerable VerifyParameters (IMethodSignature src, IMethodSignature linked, bool compilerGenerated) { - Assert.AreEqual (src.HasParameters, linked.HasParameters); + if (src.HasParameters != linked.HasParameters) + yield return $"Mismatch in parameters. {src} has parameters: {src.HasParameters}, {linked} has parameters: {linked.HasParameters}"; if (src.HasParameters) { for (int i = 0; i < src.Parameters.Count; ++i) { var srcp = src.Parameters[i]; var lnkp = linked.Parameters[i]; if (!compilerGenerated) { - VerifyCustomAttributes (srcp, lnkp); + foreach (var err in VerifyCustomAttributes (srcp, lnkp)) yield return err; } if (checkNames) { if (srcp.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (RemovedNameValueAttribute))) - Assert.IsEmpty (lnkp.Name, $"Expected empty parameter name. Parameter {i} of {(src as MethodDefinition)}"); + { + if (lnkp.Name != string.Empty) yield return $"Expected empty parameter name. Parameter {i} of {(src as MethodDefinition)}"; + } else - Assert.AreEqual (srcp.Name, lnkp.Name, $"Mismatch in parameter name. Parameter {i} of {(src as MethodDefinition)}"); + { + if (srcp.Name != lnkp.Name) yield return $"Mismatch in parameter name. Parameter {i} of {(src as MethodDefinition)}"; + } } } } @@ -1113,7 +1171,7 @@ protected virtual bool ShouldBeKept (T member, string signature = null) where private static IEnumerable GetActiveKeptAttributes (ICustomAttributeProvider provider, string attributeName) { - return provider.CustomAttributes.Where(ca => { + return provider.CustomAttributes.Where (ca => { if (ca.AttributeType.Name != attributeName) { return false; } @@ -1121,7 +1179,7 @@ private static IEnumerable GetActiveKeptAttributes (ICustomAttr object keptBy = ca.GetPropertyValue (nameof (KeptAttribute.By)); return keptBy is null ? true : ((Tool) keptBy).HasFlag (Tool.Trimmer); }); - } + } private static bool HasActiveKeptAttribute (ICustomAttributeProvider provider) { From 565837a47808c86e2d492e00c123c664788ffad0 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Tue, 30 Jan 2024 15:29:09 -0800 Subject: [PATCH 03/19] Mark static DIM if it provides an implementation for a type that is relevant to variant casting --- .../src/linker/Linker.Steps/MarkStep.cs | 61 ++-------- .../illink/src/linker/Linker/Annotations.cs | 4 +- .../MostSpecificDefaultImplementationKept.cs | 15 +-- .../TestCasesRunner/AssemblyChecker.cs | 109 +++++++++--------- 4 files changed, 66 insertions(+), 123 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index 5215eb5608eabc..8c7767741f1bc8 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -723,19 +723,16 @@ void ProcessVirtualMethod (MethodDefinition method) bool ShouldMarkOverrideForBase (OverrideInformation overrideInformation) { Debug.Assert (Annotations.IsMarked (overrideInformation.Base) || IgnoreScope (overrideInformation.Base.DeclaringType.Scope)); - if (!Annotations.IsMarked (overrideInformation.Override.DeclaringType)) - return false; if (overrideInformation.IsOverrideOfInterfaceMember) { _interfaceOverrides.Add ((overrideInformation, ScopeStack.CurrentScope)); return false; } + if (!Annotations.IsMarked (overrideInformation.Override.DeclaringType)) + return false; if (!Context.IsOptimizationEnabled (CodeOptimizations.OverrideRemoval, overrideInformation.Override)) return true; - if(Annotations.GetDefaultInterfaceImplementations(overrideInformation.Base).Where(dim => dim.DefaultInterfaceMethods == overrideInformation.Override).ToList() is [var dim]) - return Annotations.IsRelevantToVariantCasting(dim.InstanceType); - // In this context, an override needs to be kept if // a) it's an override on an instantiated type (of a marked base) or // b) it's an override of an abstract base (required for valid IL) @@ -2553,14 +2550,16 @@ bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformat { var @base = overrideInformation.Base; var method = overrideInformation.Override; + Debug.Assert(@base.DeclaringType.IsInterface); if (@base is null || method is null || @base.DeclaringType is null) return false; if (Annotations.IsMarked (method)) return false; - if (!@base.DeclaringType.IsInterface) - return false; + // If the override is a DIM that provides an implementation for a type that requires the interface, mark the method + if (Annotations.GetDefaultInterfaceImplementations (@base).Where (dim => dim.DefaultInterfaceMethod == overrideInformation.Override).Any (dim => Annotations.IsRelevantToVariantCasting (dim.InstanceType))) + return true; // If the interface implementation is not marked, do not mark the implementation method // A type that doesn't implement the interface isn't required to have methods that implement the interface. @@ -2818,51 +2817,7 @@ void MarkGenericArguments (IGenericInstance instance) if (parameter.HasDefaultConstructorConstraint) MarkDefaultConstructor (argumentTypeDef, new DependencyInfo (DependencyKind.DefaultCtorForNewConstrainedGenericArgument, instance)); - - // var interfaceConstraints = GetConstraintInterfaces (parameter); - // foreach (var constrainedInterfacemethod in interfaceConstraints.SelectMany (c => c.Methods).Where (m => m.IsStatic)) { - // foreach (var dim in Annotations.GetDefaultInterfaceImplementations (constrainedInterfacemethod)) { - // var asdf = new MessageOrigin(parameter); - // MarkMethod (dim.DefaultInterfaceMethods, new DependencyInfo (DependencyKind.DefaultImplementationForImplementingType, constrainedInterfacemethod), in asdf); - // } - // } - } - // IEnumerable GetConstraintInterfaces (GenericParameter gp) - // { - // foreach (var constraint in gp.Constraints) { - // switch (constraint.ConstraintType) { - // case GenericParameter gp2: - // foreach (var td1 in GetConstraintInterfaces (gp2)) { - // yield return td1; - // } - // break; - // case TypeSpecification when constraint.ConstraintType is not GenericInstanceType: - // // Not a resolvable type - // continue; - // default: - // if (Context.Resolve (constraint.ConstraintType) is TypeDefinition { IsInterface: true } td2) { - // yield return td2; - // } - // break; - // } - // } - // } - // MethodDefinition? def = Context.Resolve (method); - // if (def is not null) { - // for (int i = 0; i < gim.GenericArguments.Count; i++) { - // TypeReference arg = gim.GenericArguments[i]; - // GenericParameter param = def.GenericParameters[i]; - // var interfaceConstraints = GetConstraintInterfaces (param); - // foreach (var constrainedInterfacemethod in interfaceConstraints.SelectMany (c => c.Methods).Where (m => m.IsStatic)) { - // foreach (var dim in Annotations.GetDefaultInterfaceImplementations (constrainedInterfacemethod)) { - // var thisMethodOrigin = new MessageOrigin (def); - // using (ScopeStack.PushScope (thisMethodOrigin)) { - // MarkMethod (dim.DefaultInterfaceMethods, new DependencyInfo (DependencyKind.DefaultImplementationForImplementingType, method), in thisMethodOrigin); - // } - // } - // } - // } - // } + } } IGenericParameterProvider? GetGenericProviderFromInstance (IGenericInstance instance) @@ -3876,8 +3831,6 @@ protected internal virtual void MarkInterfaceImplementation (InterfaceImplementa { if (Annotations.IsMarked (iface)) return; - if (iface.InterfaceType.Name == "IBase") - _ = 0; Annotations.MarkProcessed (iface, reason ?? new DependencyInfo (DependencyKind.InterfaceImplementationOnType, ScopeStack.CurrentScope.Origin.Provider)); using var localScope = origin.HasValue ? ScopeStack.PushScope (origin.Value) : null; diff --git a/src/tools/illink/src/linker/Linker/Annotations.cs b/src/tools/illink/src/linker/Linker/Annotations.cs index 0389a3b3e8c939..c4d7bf0a27391b 100644 --- a/src/tools/illink/src/linker/Linker/Annotations.cs +++ b/src/tools/illink/src/linker/Linker/Annotations.cs @@ -236,8 +236,6 @@ public bool IsReflectionUsed (IMemberDefinition method) public void MarkInstantiated (TypeDefinition type) { - if (type.Name == "NotUsedAsIBase") - _ = 0; marked_instantiated.Add (type); } @@ -470,7 +468,7 @@ public bool IsPublic (IMetadataTokenProvider provider) return TypeMapInfo.GetOverrides (method); } - public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface, MethodDefinition DefaultInterfaceMethods)> GetDefaultInterfaceImplementations (MethodDefinition method) + public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface, MethodDefinition DefaultInterfaceMethod)> GetDefaultInterfaceImplementations (MethodDefinition method) { return TypeMapInfo.GetDefaultInterfaceImplementations (method) ?? []; } diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs index 7b92c0948991b1..ff91de8800310d 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs @@ -38,13 +38,13 @@ static virtual int Value2 { } } - // [Kept] - // [KeptInterface (typeof (IBase))] + [Kept] + [KeptInterface (typeof (IBase))] interface IMiddle : IBase { - // [Kept] // Should be removable -- Add link to bug before merge + [Kept] // Should be removable -- Add link to bug before merge static int IBase.Value { - // [Kept] // Should be removable -- Add link to bug before merge + [Kept] // Should be removable -- Add link to bug before merge get => 1; } } @@ -73,9 +73,6 @@ class UsedAsIBase : IDerived, INotReferenced } [Kept] - // [KeptInterface(typeof(IDerived))] - // [KeptInterface(typeof(IMiddle))] - // [KeptInterface(typeof(IBase))] class NotUsedAsIBase : IDerived, INotReferenced { [Kept] @@ -93,10 +90,8 @@ class GenericType where T : IBase [KeptInterface (typeof (IDerived))] [KeptInterface (typeof (IMiddle))] [KeptInterface (typeof (IBase))] - class UsedAsIBase2 : IBase + class UsedAsIBase2 : IDerived { - [Kept] - public static int Value => 0; } #endif } diff --git a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs index 7d26b4c939d24d..8d034019e69f76 100644 --- a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs +++ b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs @@ -38,60 +38,59 @@ public AssemblyChecker (AssemblyDefinition original, AssemblyDefinition linked, attr.Name == nameof (RemovedNameValueAttribute)); } - static (IEnumerable Missing, IEnumerable Extra) SetDifference (IEnumerable expected, IEnumerable found) - { - var missing = expected.Except (found); - var extra = found.Except (expected); - return (missing, extra); - } - public void Verify () { - IEnumerable failures = VerifyExportedTypes (originalAssembly, linkedAssembly); - failures = failures.Concat (VerifyCustomAttributes (originalAssembly, linkedAssembly)); - failures = failures.Concat (VerifySecurityAttributes (originalAssembly, linkedAssembly)); - - foreach (var originalModule in originalAssembly.Modules) - failures = failures.Concat (VerifyModule (originalModule, linkedAssembly.Modules.FirstOrDefault (m => m.Name == originalModule.Name))); - - failures = failures.Concat (VerifyResources (originalAssembly, linkedAssembly)); - failures = failures.Concat (VerifyReferences (originalAssembly, linkedAssembly)); - failures = failures.Concat (VerifyKeptByAttributes (originalAssembly, originalAssembly.FullName)); - - linkedMembers = new HashSet (linkedAssembly.MainModule.AllMembers ().Select (s => { - return s.FullName; - }), StringComparer.Ordinal); - - // Workaround for compiler injected attribute to describe the language version - linkedMembers.Remove ("System.Void Microsoft.CodeAnalysis.EmbeddedAttribute::.ctor()"); - linkedMembers.Remove ("System.Int32 System.Runtime.CompilerServices.RefSafetyRulesAttribute::Version"); - linkedMembers.Remove ("System.Void System.Runtime.CompilerServices.RefSafetyRulesAttribute::.ctor(System.Int32)"); - - // Workaround for compiler injected attribute to describe the language version - verifiedGeneratedTypes.Add ("Microsoft.CodeAnalysis.EmbeddedAttribute"); - verifiedGeneratedTypes.Add ("System.Runtime.CompilerServices.RefSafetyRulesAttribute"); - - var membersToAssert = originalAssembly.MainModule.Types; - foreach (var originalMember in membersToAssert) { - if (originalMember is TypeDefinition td) { - if (td.Name == "") { - linkedMembers.Remove (td.Name); + var failures = GetFailures ().ToList (); + if (failures.Count > 0) + Assert.Fail (string.Join (Environment.NewLine, failures)); + + IEnumerable GetFailures () + { + foreach (var err in VerifyExportedTypes (originalAssembly, linkedAssembly)) yield return err; + foreach (var err in VerifyCustomAttributes (originalAssembly, linkedAssembly)) yield return err; + foreach (var err in VerifySecurityAttributes (originalAssembly, linkedAssembly)) yield return err; + + foreach (var originalModule in originalAssembly.Modules) + foreach (var err in VerifyModule (originalModule, linkedAssembly.Modules.FirstOrDefault (m => m.Name == originalModule.Name))) yield return err; + + foreach (var err in VerifyResources (originalAssembly, linkedAssembly)) yield return err; + foreach (var err in VerifyReferences (originalAssembly, linkedAssembly)) yield return err; + foreach (var err in VerifyKeptByAttributes (originalAssembly, originalAssembly.FullName)) yield return err; + + linkedMembers = new HashSet (linkedAssembly.MainModule.AllMembers ().Select (s => { + return s.FullName; + }), StringComparer.Ordinal); + + // Workaround for compiler injected attribute to describe the language version + linkedMembers.Remove ("System.Void Microsoft.CodeAnalysis.EmbeddedAttribute::.ctor()"); + linkedMembers.Remove ("System.Int32 System.Runtime.CompilerServices.RefSafetyRulesAttribute::Version"); + linkedMembers.Remove ("System.Void System.Runtime.CompilerServices.RefSafetyRulesAttribute::.ctor(System.Int32)"); + + // Workaround for compiler injected attribute to describe the language version + verifiedGeneratedTypes.Add ("Microsoft.CodeAnalysis.EmbeddedAttribute"); + verifiedGeneratedTypes.Add ("System.Runtime.CompilerServices.RefSafetyRulesAttribute"); + + var membersToAssert = originalAssembly.MainModule.Types; + foreach (var originalMember in membersToAssert) { + if (originalMember is TypeDefinition td) { + if (td.Name == "") { + linkedMembers.Remove (td.Name); + continue; + } + + TypeDefinition linkedType = linkedAssembly.MainModule.GetType (originalMember.FullName); + foreach (var err in VerifyTypeDefinition (td, linkedType)) yield return err; + linkedMembers.Remove (td.FullName); + continue; } - TypeDefinition linkedType = linkedAssembly.MainModule.GetType (originalMember.FullName); - failures = failures.Concat (VerifyTypeDefinition (td, linkedType)); - linkedMembers.Remove (td.FullName); - - continue; + yield return $"Don't know how to check member of type {originalMember.GetType ()}"; } - throw new NotImplementedException ($"Don't know how to check member of type {originalMember.GetType ()}"); + if (linkedMembers.Any ()) + foreach (var err in linkedMembers.Select (m => $"Member `{m}' was not expected to be kept")) yield return err; } - - if (linkedMembers.Any ()) - failures = failures.Concat (linkedMembers.Select (m => $"Member `{m}' was not expected to be kept")); - Assert.IsEmpty (failures, string.Join (Environment.NewLine, failures)); } static bool IsBackingField (FieldDefinition field) => field.Name.StartsWith ("<") && field.Name.EndsWith (">k__BackingField"); @@ -161,6 +160,7 @@ protected virtual IEnumerable VerifyTypeDefinition (TypeDefinition origi foreach (var attr in original.CustomAttributes.Where (l => l.AttributeType.Name == nameof (CreatedMemberAttribute))) { var newName = original.FullName + "::" + attr.ConstructorArguments[0].Value.ToString (); + // Assert.AreEqual (1, linkedMembers.RemoveWhere (l => l.Contains (newName)), $"Newly created member '{newName}' was not found"); var asdf = linkedMembers.Where (l => l.Contains (newName)).ToList (); if (1 != linkedMembers.RemoveWhere (l => l.Contains (newName))) { yield return $"Newly created member '{newName}' was not found"; @@ -321,7 +321,7 @@ IEnumerable VerifyInterfaces (TypeDefinition src, TypeDefinition linked) } if (expectedInterfaces.Any ()) { - yield return $"Expected interfaces were not found on {src}"; + yield return $"Expected interfaces were not found on {src}: {string.Join (", ", expectedInterfaces.Select(i => i.Split('.', '/').Last()))}"; } } } @@ -416,11 +416,11 @@ IEnumerable VerifyFieldKept (FieldDefinition src, FieldDefinition linked { if (linked == null) { yield return $"Field `{src}' should have been kept"; - yield break; ; + yield break; } - if (src?.Constant != linked?.Constant) - yield return $"Field `{src}' value"; + if (!src?.Constant?.Equals (linked?.Constant) == true) + yield return $"Field `{src}' value was expected to be {src?.Constant} but was {linked?.Constant}"; foreach (var err in VerifyKeptByAttributes (src, linked)) yield return err; VerifyPseudoAttributes (src, linked); @@ -788,7 +788,7 @@ IEnumerable VerifyResources (AssemblyDefinition original, AssemblyDefini var expectedResource = (EmbeddedResource) original.MainModule.Resources.First (r => r.Name == resource.Name); - if (!embeddedResource.GetResourceData ().SequenceEqual(expectedResource.GetResourceData ())) + if (!embeddedResource.GetResourceData ().SequenceEqual (expectedResource.GetResourceData ())) yield return $"Resource '{resource.Name}' data doesn't match."; } @@ -1132,12 +1132,9 @@ IEnumerable VerifyParameters (IMethodSignature src, IMethodSignature lin } if (checkNames) { - if (srcp.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (RemovedNameValueAttribute))) - { + if (srcp.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (RemovedNameValueAttribute))) { if (lnkp.Name != string.Empty) yield return $"Expected empty parameter name. Parameter {i} of {(src as MethodDefinition)}"; - } - else - { + } else { if (srcp.Name != lnkp.Name) yield return $"Mismatch in parameter name. Parameter {i} of {(src as MethodDefinition)}"; } } From c3b0c63dc258e8fe56855a410df038991093fd3d Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Wed, 31 Jan 2024 16:00:01 -0800 Subject: [PATCH 04/19] Remove unused code and optimize default interface method marking --- .../src/linker/Linker.Steps/MarkStep.cs | 49 ++++++++++++++++--- .../illink/src/linker/Linker/Annotations.cs | 18 +------ .../src/linker/Linker/DependencyInfo.cs | 2 - .../MostSpecificDefaultImplementationKept.cs | 2 - 4 files changed, 45 insertions(+), 26 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index 8c7767741f1bc8..a387b65335ad5f 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -215,7 +215,6 @@ internal DynamicallyAccessedMembersTypeHierarchy DynamicallyAccessedMembersTypeH DependencyKind.ReturnTypeMarshalSpec, DependencyKind.XmlDescriptor, DependencyKind.UnsafeAccessorTarget, - DependencyKind.DefaultImplementationForImplementingType, }; #endif @@ -2550,16 +2549,54 @@ bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformat { var @base = overrideInformation.Base; var method = overrideInformation.Override; - Debug.Assert(@base.DeclaringType.IsInterface); + Debug.Assert (@base.DeclaringType.IsInterface); if (@base is null || method is null || @base.DeclaringType is null) return false; if (Annotations.IsMarked (method)) return false; - // If the override is a DIM that provides an implementation for a type that requires the interface, mark the method - if (Annotations.GetDefaultInterfaceImplementations (@base).Where (dim => dim.DefaultInterfaceMethod == overrideInformation.Override).Any (dim => Annotations.IsRelevantToVariantCasting (dim.InstanceType))) - return true; + // If the override is a DIM that provides an implementation for a type that requires the interface, we may need to mark the DIM + var dims = Annotations.GetDefaultInterfaceImplementations (@base).Where(dim => Annotations.IsRelevantToVariantCasting (dim.InstanceType)); + if (dims.Any (dim => dim.DefaultInterfaceMethod == method && Annotations.IsRelevantToVariantCasting (dim.InstanceType))) + { + // We need to find the most derived DIM for each type that has DIMs -- if there is a most derived DIM, we only mark that DIM + var dimsByInstanceType = dims.GroupBy (dim => dim.InstanceType).Where(group => group.Any(dim => dim.DefaultInterfaceMethod == method)); + foreach (var group in dimsByInstanceType) { + var allDims = group.ToArray (); + int mostDerivedDimIndex = -1; + for (int i = 0; i < allDims.Length; i++) { + var derivesFromAllOtherDimProviders = true; + // Check if DIM i is the most specific DIM for the type by checking if it implements all the other DIM providers for the type + for (int j = 0; j < allDims.Length && derivesFromAllOtherDimProviders; j++) { + if (j == i) + continue; + // If the DIM provider i implements DIM provider j, then i is a more specific implementation that j. Otherwise, it is not and we check the next DIM provider + if (!allDims[i].DefaultInterfaceMethod.DeclaringType.Interfaces + .Any(iface => Context.Resolve(iface.InterfaceType) == allDims[j].DefaultInterfaceMethod.DeclaringType)) + { + derivesFromAllOtherDimProviders = false; + break; + } + } + if (derivesFromAllOtherDimProviders) { + mostDerivedDimIndex = i; + break; + } + } + // If there is a most derived DIM, we only need to mark that DIM -- return true if the override is the most derived DIM + if (mostDerivedDimIndex != -1) { + if (allDims[mostDerivedDimIndex].DefaultInterfaceMethod == method) + return true; + else + continue; + } else { + // If there is no most derived DIM, all DIMs should be marked. + // We already checked that the override is a DIM that provides an implementation for a type that requires the interface, so we can return true + return true; + } + } + } // If the interface implementation is not marked, do not mark the implementation method // A type that doesn't implement the interface isn't required to have methods that implement the interface. @@ -2581,7 +2618,7 @@ bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformat // If the method is static and the implementing type is relevant to variant casting, mark the implementation method. // A static method may only be called through a constrained call if the type is relevant to variant casting. if (@base.IsStatic) - return Annotations.IsRelevantToVariantCasting (method.DeclaringType) || method.DeclaringType.IsInterface + return Annotations.IsRelevantToVariantCasting (method.DeclaringType) || IgnoreScope (@base.DeclaringType.Scope); // If the implementing type is marked as instantiated, mark the implementation method. diff --git a/src/tools/illink/src/linker/Linker/Annotations.cs b/src/tools/illink/src/linker/Linker/Annotations.cs index c4d7bf0a27391b..8c256cb199b222 100644 --- a/src/tools/illink/src/linker/Linker/Annotations.cs +++ b/src/tools/illink/src/linker/Linker/Annotations.cs @@ -246,22 +246,8 @@ public bool IsInstantiated (TypeDefinition type) public void MarkRelevantToVariantCasting (TypeDefinition type) { - if (type == null) - return; - - if (!types_relevant_to_variant_casting.Add (type)) - return; - - foreach (var baseType in type.Interfaces) { - var resolvedBaseType = context.Resolve (baseType.InterfaceType); - if (resolvedBaseType is null) - continue; - // Don't need to recurse for interfaces - types implement all interfaces in the interface/base type hierarchy - types_relevant_to_variant_casting.Add (resolvedBaseType); - } - if (type.BaseType is not null && context.Resolve(type.BaseType) is {} baseTypeDef) { - MarkRelevantToVariantCasting (baseTypeDef); - } + if (type != null) + types_relevant_to_variant_casting.Add (type); } public bool IsRelevantToVariantCasting (TypeDefinition type) diff --git a/src/tools/illink/src/linker/Linker/DependencyInfo.cs b/src/tools/illink/src/linker/Linker/DependencyInfo.cs index 7476baa99757ff..7eb0681c5a9347 100644 --- a/src/tools/illink/src/linker/Linker/DependencyInfo.cs +++ b/src/tools/illink/src/linker/Linker/DependencyInfo.cs @@ -145,8 +145,6 @@ public enum DependencyKind DynamicallyAccessedMemberOnType = 88, // type with DynamicallyAccessedMembers annotations (including those inherited from base types and interfaces) UnsafeAccessorTarget = 89, // the member is referenced via UnsafeAccessor attribute - - DefaultImplementationForImplementingType = 90 // The member is a default implementation of an interface method for a type that (may) need the method } public readonly struct DependencyInfo : IEquatable diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs index ff91de8800310d..4707bf68ff283a 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs @@ -42,9 +42,7 @@ static virtual int Value2 { [KeptInterface (typeof (IBase))] interface IMiddle : IBase { - [Kept] // Should be removable -- Add link to bug before merge static int IBase.Value { - [Kept] // Should be removable -- Add link to bug before merge get => 1; } } From e6d3788c1a43935c2f4647c2bd03a1e3b4a64497 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Wed, 31 Jan 2024 16:09:48 -0800 Subject: [PATCH 05/19] Revert AssemblyChecker changes --- .../TestCasesRunner/AssemblyChecker.cs | 531 ++++++++---------- 1 file changed, 238 insertions(+), 293 deletions(-) diff --git a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs index 8d034019e69f76..36031245d5191b 100644 --- a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs +++ b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Globalization; using System.Linq; -using System.Linq.Expressions; using System.Text; using Mono.Cecil; using Mono.Cecil.Cil; @@ -40,85 +39,77 @@ public AssemblyChecker (AssemblyDefinition original, AssemblyDefinition linked, public void Verify () { - var failures = GetFailures ().ToList (); - if (failures.Count > 0) - Assert.Fail (string.Join (Environment.NewLine, failures)); - - IEnumerable GetFailures () - { - foreach (var err in VerifyExportedTypes (originalAssembly, linkedAssembly)) yield return err; - foreach (var err in VerifyCustomAttributes (originalAssembly, linkedAssembly)) yield return err; - foreach (var err in VerifySecurityAttributes (originalAssembly, linkedAssembly)) yield return err; - - foreach (var originalModule in originalAssembly.Modules) - foreach (var err in VerifyModule (originalModule, linkedAssembly.Modules.FirstOrDefault (m => m.Name == originalModule.Name))) yield return err; - - foreach (var err in VerifyResources (originalAssembly, linkedAssembly)) yield return err; - foreach (var err in VerifyReferences (originalAssembly, linkedAssembly)) yield return err; - foreach (var err in VerifyKeptByAttributes (originalAssembly, originalAssembly.FullName)) yield return err; - - linkedMembers = new HashSet (linkedAssembly.MainModule.AllMembers ().Select (s => { - return s.FullName; - }), StringComparer.Ordinal); - - // Workaround for compiler injected attribute to describe the language version - linkedMembers.Remove ("System.Void Microsoft.CodeAnalysis.EmbeddedAttribute::.ctor()"); - linkedMembers.Remove ("System.Int32 System.Runtime.CompilerServices.RefSafetyRulesAttribute::Version"); - linkedMembers.Remove ("System.Void System.Runtime.CompilerServices.RefSafetyRulesAttribute::.ctor(System.Int32)"); - - // Workaround for compiler injected attribute to describe the language version - verifiedGeneratedTypes.Add ("Microsoft.CodeAnalysis.EmbeddedAttribute"); - verifiedGeneratedTypes.Add ("System.Runtime.CompilerServices.RefSafetyRulesAttribute"); - - var membersToAssert = originalAssembly.MainModule.Types; - foreach (var originalMember in membersToAssert) { - if (originalMember is TypeDefinition td) { - if (td.Name == "") { - linkedMembers.Remove (td.Name); - continue; - } + VerifyExportedTypes (originalAssembly, linkedAssembly); + + VerifyCustomAttributes (originalAssembly, linkedAssembly); + VerifySecurityAttributes (originalAssembly, linkedAssembly); + + foreach (var originalModule in originalAssembly.Modules) + VerifyModule (originalModule, linkedAssembly.Modules.FirstOrDefault (m => m.Name == originalModule.Name)); + + VerifyResources (originalAssembly, linkedAssembly); + VerifyReferences (originalAssembly, linkedAssembly); + VerifyKeptByAttributes (originalAssembly, originalAssembly.FullName); - TypeDefinition linkedType = linkedAssembly.MainModule.GetType (originalMember.FullName); - foreach (var err in VerifyTypeDefinition (td, linkedType)) yield return err; - linkedMembers.Remove (td.FullName); + linkedMembers = new HashSet (linkedAssembly.MainModule.AllMembers ().Select (s => { + return s.FullName; + }), StringComparer.Ordinal); + // Workaround for compiler injected attribute to describe the language version + linkedMembers.Remove ("System.Void Microsoft.CodeAnalysis.EmbeddedAttribute::.ctor()"); + linkedMembers.Remove ("System.Int32 System.Runtime.CompilerServices.RefSafetyRulesAttribute::Version"); + linkedMembers.Remove ("System.Void System.Runtime.CompilerServices.RefSafetyRulesAttribute::.ctor(System.Int32)"); + + // Workaround for compiler injected attribute to describe the language version + verifiedGeneratedTypes.Add ("Microsoft.CodeAnalysis.EmbeddedAttribute"); + verifiedGeneratedTypes.Add ("System.Runtime.CompilerServices.RefSafetyRulesAttribute"); + + var membersToAssert = originalAssembly.MainModule.Types; + foreach (var originalMember in membersToAssert) { + if (originalMember is TypeDefinition td) { + if (td.Name == "") { + linkedMembers.Remove (td.Name); continue; } - yield return $"Don't know how to check member of type {originalMember.GetType ()}"; + TypeDefinition linkedType = linkedAssembly.MainModule.GetType (originalMember.FullName); + VerifyTypeDefinition (td, linkedType); + linkedMembers.Remove (td.FullName); + + continue; } - if (linkedMembers.Any ()) - foreach (var err in linkedMembers.Select (m => $"Member `{m}' was not expected to be kept")) yield return err; + throw new NotImplementedException ($"Don't know how to check member of type {originalMember.GetType ()}"); } + + Assert.IsEmpty (linkedMembers, "Linked output includes unexpected member"); } static bool IsBackingField (FieldDefinition field) => field.Name.StartsWith ("<") && field.Name.EndsWith (">k__BackingField"); - protected virtual IEnumerable VerifyModule (ModuleDefinition original, ModuleDefinition linked) + protected virtual void VerifyModule (ModuleDefinition original, ModuleDefinition linked) { // We never link away a module today so let's make sure the linked one isn't null if (linked == null) - yield return $"Linked assembly `{original.Assembly.Name.Name}` is missing module `{original.Name}`"; + Assert.Fail ($"Linked assembly `{original.Assembly.Name.Name}` is missing module `{original.Name}`"); var expected = original.Assembly.MainModule.AllDefinedTypes () .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptModuleReferenceAttribute))) - .ToHashSet (); + .ToArray (); var actual = linked.ModuleReferences .Select (name => name.Name) - .ToHashSet (); + .ToArray (); - if (!expected.SetEquals (actual)) - yield return $"In module {original.FileName} Expected module references `{string.Join (", ", expected)}` but got `{string.Join (", ", actual)}`"; + Assert.That (actual, Is.EquivalentTo (expected)); - foreach (var err in VerifyCustomAttributes (original, linked)) yield return err; + VerifyCustomAttributes (original, linked); } - protected virtual IEnumerable VerifyTypeDefinition (TypeDefinition original, TypeDefinition linked) + protected virtual void VerifyTypeDefinition (TypeDefinition original, TypeDefinition linked) { if (linked != null && verifiedGeneratedTypes.Contains (linked.FullName)) - yield break; + return; ModuleDefinition linkedModule = linked?.Module; @@ -135,7 +126,7 @@ protected virtual IEnumerable VerifyTypeDefinition (TypeDefinition origi if (!expectedKept) { if (linked == null) - yield break; + return; // Compiler generated members can't be annotated with `Kept` attributes directly // For some of them we have special attributes (backing fields for example), but it's impractical to define @@ -146,13 +137,13 @@ protected virtual IEnumerable VerifyTypeDefinition (TypeDefinition origi // we do want to validate. There's no specific use case right now, but I can easily imagine one // for more detailed testing of for example custom attributes on local functions, or similar. if (!IsCompilerGeneratedMember (original)) - yield return $"Type `{original}' should have been removed"; + Assert.Fail ($"Type `{original}' should have been removed"); } bool prev = checkNames; checkNames |= original.HasAttribute (nameof (VerifyMetadataNamesAttribute)); - foreach (var err in VerifyTypeDefinitionKept (original, linked)) yield return err; + VerifyTypeDefinitionKept (original, linked); checkNames = prev; @@ -160,11 +151,7 @@ protected virtual IEnumerable VerifyTypeDefinition (TypeDefinition origi foreach (var attr in original.CustomAttributes.Where (l => l.AttributeType.Name == nameof (CreatedMemberAttribute))) { var newName = original.FullName + "::" + attr.ConstructorArguments[0].Value.ToString (); - // Assert.AreEqual (1, linkedMembers.RemoveWhere (l => l.Contains (newName)), $"Newly created member '{newName}' was not found"); - var asdf = linkedMembers.Where (l => l.Contains (newName)).ToList (); - if (1 != linkedMembers.RemoveWhere (l => l.Contains (newName))) { - yield return $"Newly created member '{newName}' was not found"; - } + Assert.AreEqual (1, linkedMembers.RemoveWhere (l => l.Contains (newName)), $"Newly created member '{newName}' was not found"); } } } @@ -172,21 +159,23 @@ protected virtual IEnumerable VerifyTypeDefinition (TypeDefinition origi /// /// Validates that all instances on a member are valid (i.e. ILLink recorded a marked dependency described in the attribute) /// - IEnumerable VerifyKeptByAttributes (IMemberDefinition src, IMemberDefinition linked) + void VerifyKeptByAttributes (IMemberDefinition src, IMemberDefinition linked) { - return src.CustomAttributes.Where (ca => ca.AttributeType.IsTypeOf ()).SelectMany (keptByAttribute => VerifyKeptByAttribute (linked.FullName, keptByAttribute)); + foreach (var keptByAttribute in src.CustomAttributes.Where (ca => ca.AttributeType.IsTypeOf ())) + VerifyKeptByAttribute (linked.FullName, keptByAttribute); } /// /// Validates that all instances on an attribute provider are valid (i.e. ILLink recorded a marked dependency described in the attribute) /// is the attribute provider that may have a , and is the 'FullName' of . /// - IEnumerable VerifyKeptByAttributes (ICustomAttributeProvider src, string attributeProviderFullName) + void VerifyKeptByAttributes (ICustomAttributeProvider src, string attributeProviderFullName) { - return src.CustomAttributes.Where (ca => ca.AttributeType.IsTypeOf ()).SelectMany (attr => VerifyKeptByAttribute (attributeProviderFullName, attr)); + foreach (var keptByAttribute in src.CustomAttributes.Where (ca => ca.AttributeType.IsTypeOf ())) + VerifyKeptByAttribute (attributeProviderFullName, keptByAttribute); } - IEnumerable VerifyKeptByAttribute (string keptAttributeProviderName, CustomAttribute attribute) + void VerifyKeptByAttribute (string keptAttributeProviderName, CustomAttribute attribute) { // public KeptByAttribute (string dependencyProvider, string reason) { } // public KeptByAttribute (Type dependencyProvider, string reason) { } @@ -224,58 +213,57 @@ IEnumerable VerifyKeptByAttribute (string keptAttributeProviderName, Cus foreach (var dep in this.linkedTestCase.Customizations.DependencyRecorder.Dependencies) { if (dep == expectedDependency) { - yield break; + return; } } - yield return $"{keptAttributeProviderName} was expected to be kept by {expectedDependency.Source} with reason {expectedDependency.DependencyKind.ToString ()}."; + string errorMessage = $"{keptAttributeProviderName} was expected to be kept by {expectedDependency.Source} with reason {expectedDependency.DependencyKind.ToString ()}."; + Assert.Fail (errorMessage); } - protected virtual IEnumerable VerifyTypeDefinitionKept (TypeDefinition original, TypeDefinition linked) + protected virtual void VerifyTypeDefinitionKept (TypeDefinition original, TypeDefinition linked) { - if (linked == null) { - yield return $"Type `{original}' should have been kept"; - yield break; - } + if (linked == null) + Assert.Fail ($"Type `{original}' should have been kept"); // Skip verification of type metadata for compiler generated types (we don't currently need it yet) if (!IsCompilerGeneratedMember (original)) { - foreach (var err in VerifyKeptByAttributes (original, linked)) yield return err; + VerifyKeptByAttributes (original, linked); if (!original.IsInterface) - foreach (var err in VerifyBaseType (original, linked)) yield return err; + VerifyBaseType (original, linked); - foreach (var err in VerifyInterfaces (original, linked)) yield return err; - foreach (var err in VerifyPseudoAttributes (original, linked)) yield return err; - foreach (var err in VerifyGenericParameters (original, linked, compilerGenerated: false)) yield return err; - foreach (var err in VerifyCustomAttributes (original, linked)) yield return err; - foreach (var err in VerifySecurityAttributes (original, linked)) yield return err; + VerifyInterfaces (original, linked); + VerifyPseudoAttributes (original, linked); + VerifyGenericParameters (original, linked, compilerGenerated: false); + VerifyCustomAttributes (original, linked); + VerifySecurityAttributes (original, linked); - foreach (var err in VerifyFixedBufferFields (original, linked)) yield return err; + VerifyFixedBufferFields (original, linked); } // Need to check delegate cache fields before the normal field check - foreach (var err in VerifyDelegateBackingFields (original, linked)) yield return err; - foreach (var err in VerifyPrivateImplementationDetails (original, linked)) yield return err; + VerifyDelegateBackingFields (original, linked); + VerifyPrivateImplementationDetails (original, linked); foreach (var td in original.NestedTypes) { - foreach (var err in VerifyTypeDefinition (td, linked?.NestedTypes.FirstOrDefault (l => td.FullName == l.FullName))) yield return err; + VerifyTypeDefinition (td, linked?.NestedTypes.FirstOrDefault (l => td.FullName == l.FullName)); linkedMembers.Remove (td.FullName); } // Need to check properties before fields so that the KeptBackingFieldAttribute is handled correctly foreach (var p in original.Properties) { - foreach (var err in VerifyProperty (p, linked?.Properties.FirstOrDefault (l => p.Name == l.Name), linked)) yield return err; + VerifyProperty (p, linked?.Properties.FirstOrDefault (l => p.Name == l.Name), linked); linkedMembers.Remove (p.FullName); } // Need to check events before fields so that the KeptBackingFieldAttribute is handled correctly foreach (var e in original.Events) { - foreach (var err in VerifyEvent (e, linked?.Events.FirstOrDefault (l => e.Name == l.Name), linked)) yield return err; + VerifyEvent (e, linked?.Events.FirstOrDefault (l => e.Name == l.Name), linked); linkedMembers.Remove (e.FullName); } foreach (var f in original.Fields) { if (verifiedGeneratedFields.Contains (f.FullName)) continue; - foreach (var err in VerifyField (f, linked?.Fields.FirstOrDefault (l => f.Name == l.Name))) yield return err; + VerifyField (f, linked?.Fields.FirstOrDefault (l => f.Name == l.Name)); linkedMembers.Remove (f.FullName); } @@ -283,12 +271,12 @@ protected virtual IEnumerable VerifyTypeDefinitionKept (TypeDefinition o if (verifiedEventMethods.Contains (m.FullName)) continue; var msign = m.GetSignature (); - foreach (var err in VerifyMethod (m, linked?.Methods.FirstOrDefault (l => msign == l.GetSignature ()))) yield return err; + VerifyMethod (m, linked?.Methods.FirstOrDefault (l => msign == l.GetSignature ())); linkedMembers.Remove (m.FullName); } } - IEnumerable VerifyBaseType (TypeDefinition src, TypeDefinition linked) + void VerifyBaseType (TypeDefinition src, TypeDefinition linked) { string expectedBaseName; var expectedBaseGenericAttr = src.CustomAttributes.FirstOrDefault (w => w.AttributeType.Name == nameof (KeptBaseTypeAttribute) && w.ConstructorArguments.Count > 1); @@ -298,31 +286,24 @@ IEnumerable VerifyBaseType (TypeDefinition src, TypeDefinition linked) var defaultBaseType = src.IsEnum ? "System.Enum" : src.IsValueType ? "System.ValueType" : "System.Object"; expectedBaseName = GetCustomAttributeCtorValues (src, nameof (KeptBaseTypeAttribute)).FirstOrDefault ()?.ToString () ?? defaultBaseType; } - if (expectedBaseName != linked.BaseType?.FullName) - yield return $"Incorrect base type on : {linked.Name}"; + Assert.AreEqual (expectedBaseName, linked.BaseType?.FullName, $"Incorrect base type on : {linked.Name}"); } - IEnumerable VerifyInterfaces (TypeDefinition src, TypeDefinition linked) + void VerifyInterfaces (TypeDefinition src, TypeDefinition linked) { var expectedInterfaces = new HashSet (src.CustomAttributes .Where (w => w.AttributeType.Name == nameof (KeptInterfaceAttribute)) .Select (FormatBaseOrInterfaceAttributeValue)); if (expectedInterfaces.Count == 0) { - if (linked.HasInterfaces) { - yield return $"Type `{src}' has unexpected interfaces"; - } + Assert.IsFalse (linked.HasInterfaces, $"Type `{src}' has unexpected interfaces"); } else { foreach (var iface in linked.Interfaces) { if (!expectedInterfaces.Remove (iface.InterfaceType.FullName)) { - if (!expectedInterfaces.Remove (iface.InterfaceType.Resolve ().FullName)) { - yield return $"Type `{src}' interface `{iface.InterfaceType.FullName}' should have been removed"; - } + Assert.IsTrue (expectedInterfaces.Remove (iface.InterfaceType.Resolve ().FullName), $"Type `{src}' interface `{iface.InterfaceType.Resolve ().FullName}' should have been removed"); } } - if (expectedInterfaces.Any ()) { - yield return $"Expected interfaces were not found on {src}: {string.Join (", ", expectedInterfaces.Select(i => i.Split('.', '/').Last()))}"; - } + Assert.IsEmpty (expectedInterfaces, $"Expected interfaces were not found on {src}"); } } @@ -396,7 +377,7 @@ static string FormatBaseOrInterfaceAttributeValue (CustomAttribute attr) return builder.ToString (); } - IEnumerable VerifyField (FieldDefinition src, FieldDefinition linked) + void VerifyField (FieldDefinition src, FieldDefinition linked) { bool compilerGenerated = IsCompilerGeneratedMember (src); bool expectedKept = ShouldBeKept (src) || @@ -404,123 +385,115 @@ IEnumerable VerifyField (FieldDefinition src, FieldDefinition linked) if (!expectedKept) { if (linked != null) - yield return $"Field `{src}' should have been removed"; + Assert.Fail ($"Field `{src}' should have been removed"); - yield break; + return; } - foreach (var err in VerifyFieldKept (src, linked, compilerGenerated)) yield return err; + VerifyFieldKept (src, linked, compilerGenerated); } - IEnumerable VerifyFieldKept (FieldDefinition src, FieldDefinition linked, bool compilerGenerated) + void VerifyFieldKept (FieldDefinition src, FieldDefinition linked, bool compilerGenerated) { - if (linked == null) { - yield return $"Field `{src}' should have been kept"; - yield break; - } + if (linked == null) + Assert.Fail ($"Field `{src}' should have been kept"); - if (!src?.Constant?.Equals (linked?.Constant) == true) - yield return $"Field `{src}' value was expected to be {src?.Constant} but was {linked?.Constant}"; + Assert.AreEqual (src?.Constant, linked?.Constant, $"Field `{src}' value"); - foreach (var err in VerifyKeptByAttributes (src, linked)) yield return err; + VerifyKeptByAttributes (src, linked); VerifyPseudoAttributes (src, linked); if (!compilerGenerated) - foreach (var err in VerifyCustomAttributes (src, linked)) yield return err; + VerifyCustomAttributes (src, linked); } - IEnumerable VerifyProperty (PropertyDefinition src, PropertyDefinition linked, TypeDefinition linkedType) + void VerifyProperty (PropertyDefinition src, PropertyDefinition linked, TypeDefinition linkedType) { - foreach (var err in VerifyMemberBackingField (src, linkedType)) yield return err; + VerifyMemberBackingField (src, linkedType); bool compilerGenerated = IsCompilerGeneratedMember (src); bool expectedKept = ShouldBeKept (src) || compilerGenerated; if (!expectedKept) { if (linked != null) - yield return $"Property `{src}' should have been removed"; + Assert.Fail ($"Property `{src}' should have been removed"); - yield break; + return; } - if (linked == null) { - yield return $"Property `{src}' should have been kept"; - yield break; - } + if (linked == null) + Assert.Fail ($"Property `{src}' should have been kept"); - if (src?.Constant != linked?.Constant) - yield return $"Property `{src}' value"; + Assert.AreEqual (src?.Constant, linked?.Constant, $"Property `{src}' value"); - foreach (var err in VerifyKeptByAttributes (src, linked)) yield return err; - foreach (var err in VerifyPseudoAttributes (src, linked)) yield return err; + VerifyKeptByAttributes (src, linked); + VerifyPseudoAttributes (src, linked); if (!compilerGenerated) - foreach (var err in VerifyCustomAttributes (src, linked)) yield return err; + VerifyCustomAttributes (src, linked); } - IEnumerable VerifyEvent (EventDefinition src, EventDefinition linked, TypeDefinition linkedType) + void VerifyEvent (EventDefinition src, EventDefinition linked, TypeDefinition linkedType) { - foreach (var err in VerifyMemberBackingField (src, linkedType)) yield return err; + VerifyMemberBackingField (src, linkedType); bool compilerGenerated = IsCompilerGeneratedMember (src); bool expectedKept = ShouldBeKept (src) || compilerGenerated; if (!expectedKept) { if (linked != null) - yield return $"Event `{src}' should have been removed"; + Assert.Fail ($"Event `{src}' should have been removed"); - yield break; + return; } - if (linked == null) { - yield return $"Event `{src}' should have been kept"; - yield break; - } + if (linked == null) + Assert.Fail ($"Event `{src}' should have been kept"); if (src.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (KeptEventAddMethodAttribute))) { - foreach (var err in VerifyMethodInternal (src.AddMethod, linked.AddMethod, true, compilerGenerated)) yield return err; + VerifyMethodInternal (src.AddMethod, linked.AddMethod, true, compilerGenerated); verifiedEventMethods.Add (src.AddMethod.FullName); linkedMembers.Remove (src.AddMethod.FullName); } if (src.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (KeptEventRemoveMethodAttribute))) { - foreach (var err in VerifyMethodInternal (src.RemoveMethod, linked.RemoveMethod, true, compilerGenerated)) yield return err; + VerifyMethodInternal (src.RemoveMethod, linked.RemoveMethod, true, compilerGenerated); verifiedEventMethods.Add (src.RemoveMethod.FullName); linkedMembers.Remove (src.RemoveMethod.FullName); } - foreach (var err in VerifyKeptByAttributes (src, linked)) yield return err; - foreach (var err in VerifyPseudoAttributes (src, linked)) yield return err; + VerifyKeptByAttributes (src, linked); + VerifyPseudoAttributes (src, linked); if (!compilerGenerated) - foreach (var err in VerifyCustomAttributes (src, linked)) yield return err; + VerifyCustomAttributes (src, linked); } - IEnumerable VerifyMethod (MethodDefinition src, MethodDefinition linked) + void VerifyMethod (MethodDefinition src, MethodDefinition linked) { bool compilerGenerated = IsCompilerGeneratedMember (src); bool expectedKept = ShouldMethodBeKept (src); - foreach (var err in VerifyMethodInternal (src, linked, expectedKept, compilerGenerated)) yield return err; + VerifyMethodInternal (src, linked, expectedKept, compilerGenerated); } - IEnumerable VerifyMethodInternal (MethodDefinition src, MethodDefinition linked, bool expectedKept, bool compilerGenerated) + void VerifyMethodInternal (MethodDefinition src, MethodDefinition linked, bool expectedKept, bool compilerGenerated) { if (!expectedKept) { if (linked == null) - yield break; + return; // Similar to comment on types, compiler-generated methods can't be annotated with Kept attribute directly // so we're not going to validate kept/remove on them. Note that we're still going to go validate "into" them // to check for other properties (like parameter name presence/removal for example) if (!compilerGenerated) - yield return $"Method `{src.FullName}' should have been removed"; + Assert.Fail ($"Method `{src.FullName}' should have been removed"); } - foreach (var err in VerifyMethodKept (src, linked, compilerGenerated)) yield return err; + VerifyMethodKept (src, linked, compilerGenerated); } - IEnumerable VerifyMemberBackingField (IMemberDefinition src, TypeDefinition linkedType) + void VerifyMemberBackingField (IMemberDefinition src, TypeDefinition linkedType) { var keptBackingFieldAttribute = src.CustomAttributes.FirstOrDefault (attr => attr.AttributeType.Name == nameof (KeptBackingFieldAttribute)); if (keptBackingFieldAttribute == null) - yield break; + return; var backingFieldName = src.MetadataToken.TokenType == TokenType.Property ? $"<{src.Name}>k__BackingField" : src.Name; @@ -535,56 +508,51 @@ IEnumerable VerifyMemberBackingField (IMemberDefinition src, TypeDefinit srcField = src.DeclaringType.Fields.FirstOrDefault (f => f.Name == backingFieldName); } - if (srcField == null) { - yield return $"{src.MetadataToken.TokenType} `{src}', could not locate the expected backing field {backingFieldName}"; - yield break; - } + if (srcField == null) + Assert.Fail ($"{src.MetadataToken.TokenType} `{src}', could not locate the expected backing field {backingFieldName}"); - foreach (var err in VerifyFieldKept (srcField, linkedType?.Fields.FirstOrDefault (l => srcField.Name == l.Name), compilerGenerated: true)) yield return err; + VerifyFieldKept (srcField, linkedType?.Fields.FirstOrDefault (l => srcField.Name == l.Name), compilerGenerated: true); verifiedGeneratedFields.Add (srcField.FullName); linkedMembers.Remove (srcField.FullName); } - protected virtual IEnumerable VerifyMethodKept (MethodDefinition src, MethodDefinition linked, bool compilerGenerated) + protected virtual void VerifyMethodKept (MethodDefinition src, MethodDefinition linked, bool compilerGenerated) { - if (linked == null) { - yield return $"Method `{src.FullName}' should have been kept"; - yield break; - } + if (linked == null) + Assert.Fail ($"Method `{src.FullName}' should have been kept"); - foreach (var err in VerifyPseudoAttributes (src, linked)) yield return err; - foreach (var err in VerifyGenericParameters (src, linked, compilerGenerated)) yield return err; + VerifyPseudoAttributes (src, linked); + VerifyGenericParameters (src, linked, compilerGenerated); if (!compilerGenerated) { - foreach (var err in VerifyCustomAttributes (src, linked)) yield return err; - foreach (var err in VerifyCustomAttributes (src.MethodReturnType, linked.MethodReturnType)) yield return err; - + VerifyCustomAttributes (src, linked); + VerifyCustomAttributes (src.MethodReturnType, linked.MethodReturnType); } - foreach (var err in VerifyParameters (src, linked, compilerGenerated)) yield return err; - foreach (var err in VerifySecurityAttributes (src, linked)) yield return err; - foreach (var err in VerifyArrayInitializers (src, linked)) yield return err; - foreach (var err in VerifyMethodBody (src, linked)) yield return err; - foreach (var err in VerifyKeptByAttributes (src, linked)) yield return err; + VerifyParameters (src, linked, compilerGenerated); + VerifySecurityAttributes (src, linked); + VerifyArrayInitializers (src, linked); + VerifyMethodBody (src, linked); + VerifyKeptByAttributes (src, linked); } - protected virtual IEnumerable VerifyMethodBody (MethodDefinition src, MethodDefinition linked) + protected virtual void VerifyMethodBody (MethodDefinition src, MethodDefinition linked) { if (!src.HasBody) - yield break; + return; - foreach (var err in VerifyInstructions (src, linked)) yield return err; - foreach (var err in VerifyLocals (src, linked)) yield return err; + VerifyInstructions (src, linked); + VerifyLocals (src, linked); } - protected static IEnumerable VerifyInstructions (MethodDefinition src, MethodDefinition linked) + protected static void VerifyInstructions (MethodDefinition src, MethodDefinition linked) { - foreach (var err in VerifyBodyProperties ( + VerifyBodyProperties ( src, linked, nameof (ExpectedInstructionSequenceAttribute), nameof (ExpectBodyModifiedAttribute), "instructions", m => FormatMethodBody (m.Body), - attr => GetStringArrayAttributeValue (attr).ToArray ())) yield return err; + attr => GetStringArrayAttributeValue (attr).ToArray ()); } public static string[] FormatMethodBody (MethodBody body) @@ -701,19 +669,19 @@ static string FormatInstruction (Instruction instr) } } - static IEnumerable VerifyLocals (MethodDefinition src, MethodDefinition linked) + static void VerifyLocals (MethodDefinition src, MethodDefinition linked) { - foreach (var err in VerifyBodyProperties ( + VerifyBodyProperties ( src, linked, nameof (ExpectedLocalsSequenceAttribute), nameof (ExpectLocalsModifiedAttribute), "locals", m => m.Body.Variables.Select (v => v.VariableType.ToString ()).ToArray (), - attr => GetStringOrTypeArrayAttributeValue (attr).ToArray ())) yield return err; + attr => GetStringOrTypeArrayAttributeValue (attr).ToArray ()); } - public static IEnumerable VerifyBodyProperties (MethodDefinition src, MethodDefinition linked, string sequenceAttributeName, string expectModifiedAttributeName, + public static void VerifyBodyProperties (MethodDefinition src, MethodDefinition linked, string sequenceAttributeName, string expectModifiedAttributeName, string propertyDescription, Func valueCollector, Func getExpectFromSequenceAttribute) @@ -723,27 +691,30 @@ public static IEnumerable VerifyBodyProperties (MethodDefinition src, Me var srcValues = valueCollector (src); if (src.CustomAttributes.Any (attr => attr.AttributeType.Name == expectModifiedAttributeName)) { - if (linkedValues.ToHashSet ().SetEquals (srcValues.ToHashSet ())) { - yield return $"Expected method `{src} to have it's {propertyDescription} modified, however, the {propertyDescription} were the same as the original\n{FormattingUtils.FormatSequenceCompareFailureMessage (linkedValues, srcValues)}"; - } + Assert.That ( + linkedValues, + Is.Not.EqualTo (srcValues), + $"Expected method `{src} to have {propertyDescription} modified, however, the {propertyDescription} were the same as the original\n{FormattingUtils.FormatSequenceCompareFailureMessage (linkedValues, srcValues)}"); } else if (expectedSequenceAttribute != null) { var expected = getExpectFromSequenceAttribute (expectedSequenceAttribute).ToArray (); - if (!linkedValues.ToHashSet ().SetEquals (expected.ToHashSet ())) { - yield return $"Expected method `{src} to have it's {propertyDescription} modified, however, the sequence of {propertyDescription} does not match the expected value\n{FormattingUtils.FormatSequenceCompareFailureMessage2 (linkedValues, expected, srcValues)}"; - } + Assert.That ( + linkedValues, + Is.EqualTo (expected), + $"Expected method `{src} to have it's {propertyDescription} modified, however, the sequence of {propertyDescription} does not match the expected value\n{FormattingUtils.FormatSequenceCompareFailureMessage2 (linkedValues, expected, srcValues)}"); } else { - if (!linkedValues.ToHashSet ().SetEquals (srcValues.ToHashSet ())) { - yield return $"Expected method `{src} to have it's {propertyDescription} unchanged, however, the {propertyDescription} differ from the original\n{FormattingUtils.FormatSequenceCompareFailureMessage (linkedValues, srcValues)}"; - } + Assert.That ( + linkedValues, + Is.EqualTo (srcValues), + $"Expected method `{src} to have it's {propertyDescription} unchanged, however, the {propertyDescription} differ from the original\n{FormattingUtils.FormatSequenceCompareFailureMessage (linkedValues, srcValues)}"); } } - IEnumerable VerifyReferences (AssemblyDefinition original, AssemblyDefinition linked) + void VerifyReferences (AssemblyDefinition original, AssemblyDefinition linked) { var expected = original.MainModule.AllDefinedTypes () .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptReferenceAttribute))) .Select (ReduceAssemblyFileNameOrNameToNameOnly) - .ToHashSet (); + .ToArray (); /* - The test case will always need to have at least 1 reference. @@ -754,15 +725,14 @@ IEnumerable VerifyReferences (AssemblyDefinition original, AssemblyDefin Once 1 kept reference attribute is used, the test will need to define all of of it's expected references */ - if (expected.Count == 0) - yield break; + if (expected.Length == 0) + return; var actual = linked.MainModule.AssemblyReferences .Select (name => name.Name) - .ToHashSet (); + .ToArray (); - if (!expected.SetEquals (actual)) - yield return $"Expected references `{string.Join (", ", expected)}` do not match actual references `{string.Join (", ", actual)}`"; + Assert.That (actual, Is.EquivalentTo (expected)); } string ReduceAssemblyFileNameOrNameToNameOnly (string fileNameOrAssemblyName) @@ -774,7 +744,7 @@ string ReduceAssemblyFileNameOrNameToNameOnly (string fileNameOrAssemblyName) return fileNameOrAssemblyName; } - IEnumerable VerifyResources (AssemblyDefinition original, AssemblyDefinition linked) + void VerifyResources (AssemblyDefinition original, AssemblyDefinition linked) { var expectedResourceNames = original.MainModule.AllDefinedTypes () .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptResourceAttribute))) @@ -782,129 +752,111 @@ IEnumerable VerifyResources (AssemblyDefinition original, AssemblyDefini foreach (var resource in linked.MainModule.Resources) { if (!expectedResourceNames.Remove (resource.Name)) - yield return $"Resource '{resource.Name}' should be removed."; + Assert.Fail ($"Resource '{resource.Name}' should be removed."); EmbeddedResource embeddedResource = (EmbeddedResource) resource; var expectedResource = (EmbeddedResource) original.MainModule.Resources.First (r => r.Name == resource.Name); - if (!embeddedResource.GetResourceData ().SequenceEqual (expectedResource.GetResourceData ())) - yield return $"Resource '{resource.Name}' data doesn't match."; + Assert.That (embeddedResource.GetResourceData (), Is.EquivalentTo (expectedResource.GetResourceData ()), $"Resource '{resource.Name}' data doesn't match."); } - if (expectedResourceNames.Any ()) yield return $"Resource '{expectedResourceNames.FirstOrDefault ()}' should be kept."; + Assert.IsEmpty (expectedResourceNames, $"Resource '{expectedResourceNames.FirstOrDefault ()}' should be kept."); } - IEnumerable VerifyExportedTypes (AssemblyDefinition original, AssemblyDefinition linked) + void VerifyExportedTypes (AssemblyDefinition original, AssemblyDefinition linked) { var expectedTypes = original.MainModule.AllDefinedTypes () - .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptExportedTypeAttribute)).Select (l => l.FullName)); + .SelectMany (t => GetCustomAttributeCtorValues (t, nameof (KeptExportedTypeAttribute)).Select (l => l.FullName)).ToArray (); - if (!linked.MainModule.ExportedTypes.Select (l => l.FullName).ToHashSet ().SetEquals (expectedTypes.ToHashSet ())) - yield return $"Exported types do not match expected."; + Assert.That (linked.MainModule.ExportedTypes.Select (l => l.FullName), Is.EquivalentTo (expectedTypes)); } - protected virtual IEnumerable VerifyPseudoAttributes (MethodDefinition src, MethodDefinition linked) + protected virtual void VerifyPseudoAttributes (MethodDefinition src, MethodDefinition linked) { var expected = (MethodAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - if (expected != linked.Attributes) - yield return $"Method `{src}' pseudo attributes did not match expected"; + Assert.AreEqual (expected, linked.Attributes, $"Method `{src}' pseudo attributes did not match expected"); } - protected virtual IEnumerable VerifyPseudoAttributes (TypeDefinition src, TypeDefinition linked) + protected virtual void VerifyPseudoAttributes (TypeDefinition src, TypeDefinition linked) { var expected = (TypeAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - if (expected == linked.Attributes) - yield break; - - yield return $"Type `{src}' pseudo attributes did not match expected"; + Assert.AreEqual (expected, linked.Attributes, $"Type `{src}' pseudo attributes did not match expected"); } - protected virtual IEnumerable VerifyPseudoAttributes (FieldDefinition src, FieldDefinition linked) + protected virtual void VerifyPseudoAttributes (FieldDefinition src, FieldDefinition linked) { var expected = (FieldAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - if (expected != linked.Attributes) yield return $"Field `{src}' pseudo attributes did not match expected"; + Assert.AreEqual (expected, linked.Attributes, $"Field `{src}' pseudo attributes did not match expected"); } - protected virtual IEnumerable VerifyPseudoAttributes (PropertyDefinition src, PropertyDefinition linked) + protected virtual void VerifyPseudoAttributes (PropertyDefinition src, PropertyDefinition linked) { var expected = (PropertyAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - if (expected != linked.Attributes) yield return $"Property `{src}' pseudo attributes did not match expected"; + Assert.AreEqual (expected, linked.Attributes, $"Property `{src}' pseudo attributes did not match expected"); } - protected virtual IEnumerable VerifyPseudoAttributes (EventDefinition src, EventDefinition linked) + protected virtual void VerifyPseudoAttributes (EventDefinition src, EventDefinition linked) { var expected = (EventAttributes) GetExpectedPseudoAttributeValue (src, (uint) src.Attributes); - if (expected != linked.Attributes) yield return $"Event `{src}' pseudo attributes did not match expected"; + Assert.AreEqual (expected, linked.Attributes, $"Event `{src}' pseudo attributes did not match expected"); } - protected virtual IEnumerable VerifyCustomAttributes (ICustomAttributeProvider src, ICustomAttributeProvider linked) + protected virtual void VerifyCustomAttributes (ICustomAttributeProvider src, ICustomAttributeProvider linked) { - var expectedAttrs = GetExpectedAttributes (src).ToHashSet (); - var linkedAttrs = FilterLinkedAttributes (linked).ToHashSet (); - if (!linkedAttrs.SetEquals (expectedAttrs)) { - var missing = $"Missing: {string.Join (", ", expectedAttrs.Except (linkedAttrs))}"; - var extra = $"Extra: {string.Join (", ", linkedAttrs.Except (expectedAttrs))}"; + var expectedAttrs = GetExpectedAttributes (src).ToList (); + var linkedAttrs = FilterLinkedAttributes (linked).ToList (); - yield return string.Join (Environment.NewLine, $"Custom attributes on `{src}' are not matching:", missing, extra); - } + Assert.That (linkedAttrs, Is.EquivalentTo (expectedAttrs), $"Custom attributes on `{src}' are not matching"); } - protected virtual IEnumerable VerifySecurityAttributes (ICustomAttributeProvider src, ISecurityDeclarationProvider linked) + protected virtual void VerifySecurityAttributes (ICustomAttributeProvider src, ISecurityDeclarationProvider linked) { var expectedAttrs = GetCustomAttributeCtorValues (src, nameof (KeptSecurityAttribute)) .Select (attr => attr.ToString ()) - .ToHashSet (); + .ToList (); - var linkedAttrs = FilterLinkedSecurityAttributes (linked).ToHashSet (); + var linkedAttrs = FilterLinkedSecurityAttributes (linked).ToList (); - if (!linkedAttrs.SetEquals (expectedAttrs)) { - var missing = $"Missing: {string.Join (", ", expectedAttrs.Except (linkedAttrs))}"; - var extra = $"Extra: {string.Join (", ", linkedAttrs.Except (expectedAttrs))}"; - yield return string.Join ($"Security attributes on `{src}' are not matching:", missing, extra); - } + Assert.That (linkedAttrs, Is.EquivalentTo (expectedAttrs), $"Security attributes on `{src}' are not matching"); } - IEnumerable VerifyPrivateImplementationDetails (TypeDefinition original, TypeDefinition linked) + void VerifyPrivateImplementationDetails (TypeDefinition original, TypeDefinition linked) { var expectedImplementationDetailsMethods = GetCustomAttributeCtorValues (original, nameof (KeptPrivateImplementationDetailsAttribute)) .Select (attr => attr.ToString ()) .ToList (); if (expectedImplementationDetailsMethods.Count == 0) - yield break; + return; - TypeDefinition srcImplementationDetails; - TypeDefinition linkedImplementationDetails; - foreach (var err in VerifyPrivateImplementationDetailsType (original.Module, linked.Module, out srcImplementationDetails, out linkedImplementationDetails)) yield return err; + VerifyPrivateImplementationDetailsType (original.Module, linked.Module, out TypeDefinition srcImplementationDetails, out TypeDefinition linkedImplementationDetails); foreach (var methodName in expectedImplementationDetailsMethods) { var originalMethod = srcImplementationDetails.Methods.FirstOrDefault (m => m.Name == methodName); if (originalMethod == null) - yield return $"Could not locate original private implementation details method {methodName}"; + Assert.Fail ($"Could not locate original private implementation details method {methodName}"); var linkedMethod = linkedImplementationDetails.Methods.FirstOrDefault (m => m.Name == methodName); - foreach (var erro in VerifyMethodKept (originalMethod, linkedMethod, compilerGenerated: true)) yield return erro; + VerifyMethodKept (originalMethod, linkedMethod, compilerGenerated: true); linkedMembers.Remove (linkedMethod.FullName); } verifiedGeneratedTypes.Add (srcImplementationDetails.FullName); } - static IEnumerable VerifyPrivateImplementationDetailsType (ModuleDefinition src, ModuleDefinition linked, out TypeDefinition srcImplementationDetails, out TypeDefinition linkedImplementationDetails) + static void VerifyPrivateImplementationDetailsType (ModuleDefinition src, ModuleDefinition linked, out TypeDefinition srcImplementationDetails, out TypeDefinition linkedImplementationDetails) { srcImplementationDetails = src.Types.FirstOrDefault (t => IsPrivateImplementationDetailsType (t)); + if (srcImplementationDetails == null) + Assert.Fail ("Could not locate in the original assembly. Does your test use initializers?"); + linkedImplementationDetails = linked.Types.FirstOrDefault (t => IsPrivateImplementationDetailsType (t)); - const string srcMissingMessage = "Could not locate in the original assembly. Does your test use initializers?"; - const string linkedMissingMessage = "Could not locate in the linked assembly"; - return (srcImplementationDetails, linkedImplementationDetails) switch { - (null, null) => [srcMissingMessage, linkedMissingMessage], - (null, _) => [srcMissingMessage], - (_, null) => [linkedMissingMessage], - _ => Enumerable.Empty () - }; + + if (linkedImplementationDetails == null) + Assert.Fail ("Could not locate in the linked assembly"); } - protected virtual IEnumerable VerifyArrayInitializers (MethodDefinition src, MethodDefinition linked) + protected virtual void VerifyArrayInitializers (MethodDefinition src, MethodDefinition linked) { var expectedIndices = GetCustomAttributeCtorValues (src, nameof (KeptInitializerData)) .Cast () @@ -913,13 +865,12 @@ protected virtual IEnumerable VerifyArrayInitializers (MethodDefinition var expectKeptAll = src.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (KeptInitializerData) && !attr.HasConstructorArguments); if (expectedIndices.Length == 0 && !expectKeptAll) - yield break; + return; if (!src.HasBody) - yield return $"`{nameof (KeptInitializerData)}` cannot be used on methods that don't have bodies"; - TypeDefinition srcImplementationDetails; - TypeDefinition linkedImplementationDetails; - foreach (var err in VerifyPrivateImplementationDetailsType (src.Module, linked.Module, out srcImplementationDetails, out linkedImplementationDetails)) yield return err; + Assert.Fail ($"`{nameof (KeptInitializerData)}` cannot be used on methods that don't have bodies"); + + VerifyPrivateImplementationDetailsType (src.Module, linked.Module, out TypeDefinition srcImplementationDetails, out TypeDefinition linkedImplementationDetails); var possibleInitializerFields = src.Body.Instructions .Where (ins => IsLdtokenOnPrivateImplementationDetails (srcImplementationDetails, ins)) @@ -927,32 +878,32 @@ protected virtual IEnumerable VerifyArrayInitializers (MethodDefinition .ToArray (); if (possibleInitializerFields.Length == 0) - yield return $"`{src}` does not make use of any initializers"; + Assert.Fail ($"`{src}` does not make use of any initializers"); if (expectKeptAll) { foreach (var srcField in possibleInitializerFields) { var linkedField = linkedImplementationDetails.Fields.FirstOrDefault (f => f.InitialValue.SequenceEqual (srcField.InitialValue)); - foreach (var err in VerifyInitializerField (srcField, linkedField)) yield return err; + VerifyInitializerField (srcField, linkedField); } } else { foreach (var index in expectedIndices) { if (index < 0 || index > possibleInitializerFields.Length) - yield return $"Invalid expected index `{index}` in {src}. Value must be between 0 and {expectedIndices.Length}"; + Assert.Fail ($"Invalid expected index `{index}` in {src}. Value must be between 0 and {expectedIndices.Length}"); var srcField = possibleInitializerFields[index]; var linkedField = linkedImplementationDetails.Fields.FirstOrDefault (f => f.InitialValue.SequenceEqual (srcField.InitialValue)); - foreach (var err in VerifyInitializerField (srcField, linkedField)) yield return err; + VerifyInitializerField (srcField, linkedField); } } } - IEnumerable VerifyInitializerField (FieldDefinition src, FieldDefinition linked) + void VerifyInitializerField (FieldDefinition src, FieldDefinition linked) { - foreach (var err in VerifyFieldKept (src, linked, compilerGenerated: true)) yield return err; + VerifyFieldKept (src, linked, compilerGenerated: true); verifiedGeneratedFields.Add (linked.FullName); linkedMembers.Remove (linked.FullName); - foreach (var err in VerifyTypeDefinitionKept (src.FieldType.Resolve (), linked.FieldType.Resolve ())) yield return err; + VerifyTypeDefinitionKept (src.FieldType.Resolve (), linked.FieldType.Resolve ()); linkedMembers.Remove (linked.FieldType.FullName); linkedMembers.Remove (linked.DeclaringType.FullName); verifiedGeneratedTypes.Add (linked.DeclaringType.FullName); @@ -1026,7 +977,7 @@ protected virtual IEnumerable FilterLinkedSecurityAttributes (ISecurityD .Select (attr => attr.AttributeType.ToString ()); } - IEnumerable VerifyFixedBufferFields (TypeDefinition src, TypeDefinition linked) + void VerifyFixedBufferFields (TypeDefinition src, TypeDefinition linked) { var fields = src.Fields.Where (f => f.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (KeptFixedBufferAttribute))); @@ -1036,28 +987,28 @@ IEnumerable VerifyFixedBufferFields (TypeDefinition src, TypeDefinition // while mcs and other versions of csc name it `__FixedBuffer0` var originalCompilerGeneratedBufferType = src.NestedTypes.FirstOrDefault (t => t.FullName.Contains ($"<{field.Name}>") && t.FullName.Contains ("__FixedBuffer")); if (originalCompilerGeneratedBufferType == null) - yield return $"Could not locate original compiler generated fixed buffer type for field {field}"; + Assert.Fail ($"Could not locate original compiler generated fixed buffer type for field {field}"); var linkedCompilerGeneratedBufferType = linked.NestedTypes.FirstOrDefault (t => t.Name == originalCompilerGeneratedBufferType.Name); if (linkedCompilerGeneratedBufferType == null) - yield return $"Missing expected type {originalCompilerGeneratedBufferType}"; + Assert.Fail ($"Missing expected type {originalCompilerGeneratedBufferType}"); // Have to verify the field before the type var originalElementField = originalCompilerGeneratedBufferType.Fields.FirstOrDefault (); if (originalElementField == null) - yield return $"Could not locate original compiler generated FixedElementField on {originalCompilerGeneratedBufferType}"; + Assert.Fail ($"Could not locate original compiler generated FixedElementField on {originalCompilerGeneratedBufferType}"); var linkedField = linkedCompilerGeneratedBufferType?.Fields.FirstOrDefault (); - foreach (var err in VerifyFieldKept (originalElementField, linkedField, compilerGenerated: true)) yield return err; + VerifyFieldKept (originalElementField, linkedField, compilerGenerated: true); verifiedGeneratedFields.Add (originalElementField.FullName); linkedMembers.Remove (linkedField.FullName); - foreach (var err in VerifyTypeDefinitionKept (originalCompilerGeneratedBufferType, linkedCompilerGeneratedBufferType)) yield return err; + VerifyTypeDefinitionKept (originalCompilerGeneratedBufferType, linkedCompilerGeneratedBufferType); verifiedGeneratedTypes.Add (originalCompilerGeneratedBufferType.FullName); } } - IEnumerable VerifyDelegateBackingFields (TypeDefinition src, TypeDefinition linked) + void VerifyDelegateBackingFields (TypeDefinition src, TypeDefinition linked) { var expectedFieldNames = src.CustomAttributes .Where (a => a.AttributeType.Name == nameof (KeptDelegateCacheFieldAttribute)) @@ -1066,7 +1017,7 @@ IEnumerable VerifyDelegateBackingFields (TypeDefinition src, TypeDefinit .ToList (); if (expectedFieldNames.Count == 0) - yield break; + return; foreach (var nestedType in src.NestedTypes) { if (!IsDelegateBackingFieldsType (nestedType)) @@ -1076,20 +1027,20 @@ IEnumerable VerifyDelegateBackingFields (TypeDefinition src, TypeDefinit foreach (var expectedFieldName in expectedFieldNames) { var originalField = nestedType.Fields.FirstOrDefault (f => f.Name == expectedFieldName); if (originalField is null) - yield return $"Invalid expected delegate backing field {expectedFieldName} in {src}. This member was not in the unlinked assembly"; + Assert.Fail ($"Invalid expected delegate backing field {expectedFieldName} in {src}. This member was not in the unlinked assembly"); var linkedField = linkedNestedType?.Fields.FirstOrDefault (f => f.Name == expectedFieldName); - foreach (var err in VerifyFieldKept (originalField, linkedField, compilerGenerated: true)) yield return err; + VerifyFieldKept (originalField, linkedField, compilerGenerated: true); verifiedGeneratedFields.Add (linkedField.FullName); linkedMembers.Remove (linkedField.FullName); } - foreach (var err in VerifyTypeDefinitionKept (nestedType, linkedNestedType)) yield return err; + VerifyTypeDefinitionKept (nestedType, linkedNestedType); verifiedGeneratedTypes.Add (linkedNestedType.FullName); } } - IEnumerable VerifyGenericParameters (IGenericParameterProvider src, IGenericParameterProvider linked, bool compilerGenerated) + void VerifyGenericParameters (IGenericParameterProvider src, IGenericParameterProvider linked, bool compilerGenerated) { Assert.AreEqual (src.HasGenericParameters, linked.HasGenericParameters); if (src.HasGenericParameters) { @@ -1099,44 +1050,38 @@ IEnumerable VerifyGenericParameters (IGenericParameterProvider src, IGen var lnkp = linked.GenericParameters[i]; if (!compilerGenerated) { - foreach (var err in VerifyCustomAttributes (srcp, lnkp)) yield return err; + VerifyCustomAttributes (srcp, lnkp); } if (checkNames) { if (srcp.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (RemovedNameValueAttribute))) { string name = (src.GenericParameterType == GenericParameterType.Method ? "!!" : "!") + srcp.Position; - if (name != lnkp.Name) { - yield return $"Expected empty generic parameter name. Parameter {i} of {(src.ToString ())}"; - } + Assert.AreEqual (name, lnkp.Name, "Expected empty generic parameter name"); } else { - if (srcp.Name != lnkp.Name) { - yield return $"Mismatch in generic parameter name. Parameter {i} of {(src.ToString ())}"; - } + Assert.AreEqual (srcp.Name, lnkp.Name, "Mismatch in generic parameter name"); } } } } } - IEnumerable VerifyParameters (IMethodSignature src, IMethodSignature linked, bool compilerGenerated) + void VerifyParameters (IMethodSignature src, IMethodSignature linked, bool compilerGenerated) { - if (src.HasParameters != linked.HasParameters) - yield return $"Mismatch in parameters. {src} has parameters: {src.HasParameters}, {linked} has parameters: {linked.HasParameters}"; + Assert.AreEqual (src.HasParameters, linked.HasParameters); if (src.HasParameters) { for (int i = 0; i < src.Parameters.Count; ++i) { var srcp = src.Parameters[i]; var lnkp = linked.Parameters[i]; if (!compilerGenerated) { - foreach (var err in VerifyCustomAttributes (srcp, lnkp)) yield return err; + VerifyCustomAttributes (srcp, lnkp); } if (checkNames) { - if (srcp.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (RemovedNameValueAttribute))) { - if (lnkp.Name != string.Empty) yield return $"Expected empty parameter name. Parameter {i} of {(src as MethodDefinition)}"; - } else { - if (srcp.Name != lnkp.Name) yield return $"Mismatch in parameter name. Parameter {i} of {(src as MethodDefinition)}"; - } + if (srcp.CustomAttributes.Any (attr => attr.AttributeType.Name == nameof (RemovedNameValueAttribute))) + Assert.IsEmpty (lnkp.Name, $"Expected empty parameter name. Parameter {i} of {(src as MethodDefinition)}"); + else + Assert.AreEqual (srcp.Name, lnkp.Name, $"Mismatch in parameter name. Parameter {i} of {(src as MethodDefinition)}"); } } } @@ -1168,7 +1113,7 @@ protected virtual bool ShouldBeKept (T member, string signature = null) where private static IEnumerable GetActiveKeptAttributes (ICustomAttributeProvider provider, string attributeName) { - return provider.CustomAttributes.Where (ca => { + return provider.CustomAttributes.Where(ca => { if (ca.AttributeType.Name != attributeName) { return false; } @@ -1176,7 +1121,7 @@ private static IEnumerable GetActiveKeptAttributes (ICustomAttr object keptBy = ca.GetPropertyValue (nameof (KeptAttribute.By)); return keptBy is null ? true : ((Tool) keptBy).HasFlag (Tool.Trimmer); }); - } + } private static bool HasActiveKeptAttribute (ICustomAttributeProvider provider) { From bb67ff15b1b59154f585a1d4915fb592ff39aa28 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Wed, 31 Jan 2024 16:26:11 -0800 Subject: [PATCH 06/19] Add example of overmarking in test --- .../MostSpecificDefaultImplementationKept.cs | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs index 4707bf68ff283a..4cf90ec809d7b5 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs @@ -10,9 +10,9 @@ public static void Main () { #if SUPPORTS_DEFAULT_INTERFACE_METHODS M (); - NotUsedAsIBase.Keep (); + NotUsedInGeneric.Keep (); GenericType.M (); - + GenericType2.Keep (); #endif } @@ -59,6 +59,20 @@ static int IBase.Value { } } + [Kept] + [KeptInterface (typeof (IBase))] + [KeptInterface (typeof (IMiddle))] + interface IDerived2 : IMiddle + { + // https://github.com/dotnet/runtime/issues/97798 + // This shouldn't need to be kept. Implementor UsedInUnconstrainedGeneric is not passed as a constrained generic + [Kept] + static int IBase.Value { + [Kept] + get => 2; + } + } + interface INotReferenced { } @@ -71,12 +85,21 @@ class UsedAsIBase : IDerived, INotReferenced } [Kept] - class NotUsedAsIBase : IDerived, INotReferenced + class NotUsedInGeneric : IDerived, INotReferenced { [Kept] public static void Keep () { } } + [Kept] + [KeptInterface (typeof (IBase))] + [KeptInterface (typeof (IMiddle))] + [KeptInterface (typeof (IDerived2))] + class UsedInUnconstrainedGeneric : IDerived2, INotReferenced + { + } + + [Kept] class GenericType where T : IBase { @@ -84,6 +107,13 @@ class GenericType where T : IBase public static int M () => T.Value; } + [Kept] + class GenericType2 + { + [Kept] + public static void Keep() { } + } + [Kept] [KeptInterface (typeof (IDerived))] [KeptInterface (typeof (IMiddle))] From 8851aaedad21a2ece66f4e62641b3cd0e244d6d0 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Fri, 2 Feb 2024 16:21:11 -0800 Subject: [PATCH 07/19] wip --- .../src/linker/Linker.Steps/MarkStep.cs | 49 +++---------------- .../illink/src/linker/Linker/Annotations.cs | 2 +- .../illink/src/linker/Linker/TypeMapInfo.cs | 2 +- 3 files changed, 8 insertions(+), 45 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index a387b65335ad5f..4722b14a4dd888 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -708,7 +708,7 @@ void ProcessVirtualMethod (MethodDefinition method) var defaultImplementations = Annotations.GetDefaultInterfaceImplementations (method); if (defaultImplementations != null) { foreach (var defaultImplementationInfo in defaultImplementations) { - ProcessDefaultImplementation (defaultImplementationInfo.InstanceType, defaultImplementationInfo.ProvidingInterface); + ProcessDefaultImplementation (defaultImplementationInfo.ImplementingType, defaultImplementationInfo.InterfaceImpl); } } } @@ -2556,47 +2556,10 @@ bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformat if (Annotations.IsMarked (method)) return false; - // If the override is a DIM that provides an implementation for a type that requires the interface, we may need to mark the DIM - var dims = Annotations.GetDefaultInterfaceImplementations (@base).Where(dim => Annotations.IsRelevantToVariantCasting (dim.InstanceType)); - if (dims.Any (dim => dim.DefaultInterfaceMethod == method && Annotations.IsRelevantToVariantCasting (dim.InstanceType))) - { - // We need to find the most derived DIM for each type that has DIMs -- if there is a most derived DIM, we only mark that DIM - var dimsByInstanceType = dims.GroupBy (dim => dim.InstanceType).Where(group => group.Any(dim => dim.DefaultInterfaceMethod == method)); - foreach (var group in dimsByInstanceType) { - var allDims = group.ToArray (); - int mostDerivedDimIndex = -1; - for (int i = 0; i < allDims.Length; i++) { - var derivesFromAllOtherDimProviders = true; - // Check if DIM i is the most specific DIM for the type by checking if it implements all the other DIM providers for the type - for (int j = 0; j < allDims.Length && derivesFromAllOtherDimProviders; j++) { - if (j == i) - continue; - // If the DIM provider i implements DIM provider j, then i is a more specific implementation that j. Otherwise, it is not and we check the next DIM provider - if (!allDims[i].DefaultInterfaceMethod.DeclaringType.Interfaces - .Any(iface => Context.Resolve(iface.InterfaceType) == allDims[j].DefaultInterfaceMethod.DeclaringType)) - { - derivesFromAllOtherDimProviders = false; - break; - } - } - if (derivesFromAllOtherDimProviders) { - mostDerivedDimIndex = i; - break; - } - } - // If there is a most derived DIM, we only need to mark that DIM -- return true if the override is the most derived DIM - if (mostDerivedDimIndex != -1) { - if (allDims[mostDerivedDimIndex].DefaultInterfaceMethod == method) - return true; - else - continue; - } else { - // If there is no most derived DIM, all DIMs should be marked. - // We already checked that the override is a DIM that provides an implementation for a type that requires the interface, so we can return true - return true; - } - } - } + // If the override is a DIM that provides an implementation for a type that requires the interface, we may need the DIM + // Technically we only need it if it's the most derived DIM, but we can overmark slightly here + if (Annotations.GetDefaultInterfaceImplementations (@base).Any (dim => dim.DefaultInterfaceMethod == method && Annotations.IsRelevantToVariantCasting (dim.ImplementingType))) + return true; // If the interface implementation is not marked, do not mark the implementation method // A type that doesn't implement the interface isn't required to have methods that implement the interface. @@ -2612,7 +2575,7 @@ bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformat // If the interface method is abstract, mark the implementation method // The method is needed for valid IL. - if (@base.IsAbstract) + if (@base.IsAbstract && !@method.DeclaringType.IsInterface) return true; // If the method is static and the implementing type is relevant to variant casting, mark the implementation method. diff --git a/src/tools/illink/src/linker/Linker/Annotations.cs b/src/tools/illink/src/linker/Linker/Annotations.cs index 8c256cb199b222..2f12b3e6cb185c 100644 --- a/src/tools/illink/src/linker/Linker/Annotations.cs +++ b/src/tools/illink/src/linker/Linker/Annotations.cs @@ -454,7 +454,7 @@ public bool IsPublic (IMetadataTokenProvider provider) return TypeMapInfo.GetOverrides (method); } - public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface, MethodDefinition DefaultInterfaceMethod)> GetDefaultInterfaceImplementations (MethodDefinition method) + public IEnumerable<(TypeDefinition ImplementingType, InterfaceImplementation InterfaceImpl, MethodDefinition DefaultInterfaceMethod)> GetDefaultInterfaceImplementations (MethodDefinition method) { return TypeMapInfo.GetDefaultInterfaceImplementations (method) ?? []; } diff --git a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs index 46d790d4d634d8..cd3a7a06728282 100644 --- a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs +++ b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs @@ -84,7 +84,7 @@ public void EnsureProcessed (AssemblyDefinition assembly) return bases; } - public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface, MethodDefinition DefaultImplementationMethod)>? GetDefaultInterfaceImplementations (MethodDefinition baseMethod) + public IEnumerable<(TypeDefinition ImplementingType, InterfaceImplementation InterfaceImpl, MethodDefinition DefaultImplementationMethod)>? GetDefaultInterfaceImplementations (MethodDefinition baseMethod) { default_interface_implementations.TryGetValue (baseMethod, out var ret); return ret; From 4beb771589b7fad0effa449ccdab4807a6c15a9d Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:57:29 -0800 Subject: [PATCH 08/19] Keep all DIMs that provide an implementation for a kept interface method --- ...ecificDefaultImplementationKeptInstance.cs | 75 +++++++++++++++++++ ...pecificDefaultImplementationKeptStatic.cs} | 12 ++- 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs rename src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/{MostSpecificDefaultImplementationKept.cs => MostSpecificDefaultImplementationKeptStatic.cs} (93%) diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs new file mode 100644 index 00000000000000..f84730f76799d8 --- /dev/null +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs @@ -0,0 +1,75 @@ +using Mono.Linker.Tests.Cases.Expectations.Assertions; + +namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.DefaultInterfaceMethods +{ + [TestCaseRequirements (TestRunCharacteristics.SupportsDefaultInterfaceMethods, "Requires support for default interface methods")] + class MostSpecificDefaultImplementationKeptInstance + { + [Kept] + public static void Main () + { + M (new UsedAsIBase()); + } + + + [Kept] + static int M (IBase ibase) + { + return ibase.Value; + } + + [Kept] + interface IBase + { + [Kept] + int Value { + [Kept] + get => 0; + } + + int Value2 { + get => 0; + } + } + + [Kept] + [KeptInterface (typeof (IBase))] + interface IMiddle : IBase + { + int IBase.Value { + get => 1; + } + + int Value2 { + get => 0; + } + } + + [Kept] + [KeptInterface (typeof (IBase))] + [KeptInterface (typeof (IMiddle))] + interface IDerived : IMiddle + { + [Kept] + int IBase.Value { + [Kept] + get => 2; + } + + int Value2 { + get => 0; + } + } + + interface INotReferenced + { } + + [Kept] + [KeptInterface (typeof (IDerived))] + [KeptInterface (typeof (IMiddle))] + [KeptInterface (typeof (IBase))] + class UsedAsIBase : IDerived, INotReferenced + { + } + } +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptStatic.cs similarity index 93% rename from src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs rename to src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptStatic.cs index 4cf90ec809d7b5..9b9e1d25b55cbe 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKept.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptStatic.cs @@ -3,7 +3,7 @@ namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.DefaultInterfaceMethods { [TestCaseRequirements (TestRunCharacteristics.SupportsDefaultInterfaceMethods, "Requires support for default interface methods")] - class MostSpecificDefaultImplementationKept + class MostSpecificDefaultImplementationKeptStatic { [Kept] public static void Main () @@ -42,7 +42,13 @@ static virtual int Value2 { [KeptInterface (typeof (IBase))] interface IMiddle : IBase { + [Kept] static int IBase.Value { + [Kept] + get => 1; + } + + static int IBase.Value2 { get => 1; } } @@ -57,6 +63,10 @@ static int IBase.Value { [Kept] get => 2; } + + static int IBase.Value2 { + get => 1; + } } [Kept] From a46a642014808a3a4b71aeeee00f23baadaba653 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:58:02 -0800 Subject: [PATCH 09/19] Add generated tests --- ...terfaces.DefaultInterfaceMethodsTests.g.cs | 24 +++++++++++++++++++ .../LibrariesTests.g.cs | 12 ++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/Inheritance.Interfaces.DefaultInterfaceMethodsTests.g.cs b/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/Inheritance.Interfaces.DefaultInterfaceMethodsTests.g.cs index 98b66ac5299824..0d38c40a400a57 100644 --- a/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/Inheritance.Interfaces.DefaultInterfaceMethodsTests.g.cs +++ b/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/Inheritance.Interfaces.DefaultInterfaceMethodsTests.g.cs @@ -27,12 +27,36 @@ public Task InterfaceWithAttributeOnImplementation () return RunTest (allowMissingWarnings: true); } + [Fact] + public Task MostSpecificDefaultImplementationKeptInstance () + { + return RunTest (allowMissingWarnings: true); + } + + [Fact] + public Task MostSpecificDefaultImplementationKeptStatic () + { + return RunTest (allowMissingWarnings: true); + } + [Fact] public Task SimpleDefaultInterfaceMethod () { return RunTest (allowMissingWarnings: true); } + [Fact] + public Task StaticDefaultInterfaceMethodOnDerivedInterface () + { + return RunTest (allowMissingWarnings: true); + } + + [Fact] + public Task StaticDefaultInterfaceMethodOnStruct () + { + return RunTest (allowMissingWarnings: true); + } + [Fact] public Task UnusedDefaultInterfaceImplementation () { diff --git a/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/LibrariesTests.g.cs b/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/LibrariesTests.g.cs index e5f8e41b03e867..03c76d0f1ee9de 100644 --- a/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/LibrariesTests.g.cs +++ b/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/LibrariesTests.g.cs @@ -39,6 +39,18 @@ public Task LibraryWithUnresolvedInterfaces () return RunTest (allowMissingWarnings: true); } + [Fact] + public Task RootAllLibraryBehavior () + { + return RunTest (allowMissingWarnings: true); + } + + [Fact] + public Task RootAllLibraryCopyBehavior () + { + return RunTest (allowMissingWarnings: true); + } + [Fact] public Task RootLibrary () { From 601658a61ef3332daa4c3e5940b80bd90a61e5f5 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Mon, 5 Feb 2024 13:47:58 -0800 Subject: [PATCH 10/19] Fix test expectations --- .../MostSpecificDefaultImplementationKeptInstance.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs index f84730f76799d8..c4b9f231a7065b 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs @@ -36,7 +36,9 @@ int Value2 { [KeptInterface (typeof (IBase))] interface IMiddle : IBase { + [Kept] int IBase.Value { + [Kept] get => 1; } @@ -68,6 +70,7 @@ interface INotReferenced [KeptInterface (typeof (IDerived))] [KeptInterface (typeof (IMiddle))] [KeptInterface (typeof (IBase))] + [KeptMember(".ctor()")] class UsedAsIBase : IDerived, INotReferenced { } From 5520d8548e93ab64f70596d4bd4a421b6cbfdf13 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:14:35 -0800 Subject: [PATCH 11/19] Use ProcessDefaultImplementations for static iface methods --- .../src/linker/Linker.Steps/MarkStep.cs | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index 4722b14a4dd888..bd9b132dc9807a 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -708,7 +708,7 @@ void ProcessVirtualMethod (MethodDefinition method) var defaultImplementations = Annotations.GetDefaultInterfaceImplementations (method); if (defaultImplementations != null) { foreach (var defaultImplementationInfo in defaultImplementations) { - ProcessDefaultImplementation (defaultImplementationInfo.ImplementingType, defaultImplementationInfo.InterfaceImpl); + ProcessDefaultImplementation (defaultImplementationInfo.ImplementingType, defaultImplementationInfo.InterfaceImpl, defaultImplementationInfo.DefaultInterfaceMethod); } } } @@ -722,13 +722,13 @@ void ProcessVirtualMethod (MethodDefinition method) bool ShouldMarkOverrideForBase (OverrideInformation overrideInformation) { Debug.Assert (Annotations.IsMarked (overrideInformation.Base) || IgnoreScope (overrideInformation.Base.DeclaringType.Scope)); + if (!Annotations.IsMarked (overrideInformation.Override.DeclaringType)) + return false; + if (overrideInformation.IsOverrideOfInterfaceMember) { _interfaceOverrides.Add ((overrideInformation, ScopeStack.CurrentScope)); return false; } - if (!Annotations.IsMarked (overrideInformation.Override.DeclaringType)) - return false; - if (!Context.IsOptimizationEnabled (CodeOptimizations.OverrideRemoval, overrideInformation.Override)) return true; @@ -816,11 +816,15 @@ bool RequiresInterfaceRecursively (TypeDefinition typeToExamine, TypeDefinition return false; } - void ProcessDefaultImplementation (TypeDefinition typeWithDefaultImplementedInterfaceMethod, InterfaceImplementation implementation) + void ProcessDefaultImplementation (TypeDefinition typeWithDefaultImplementedInterfaceMethod, InterfaceImplementation implementation, MethodDefinition implementationMethod) { - if (!Annotations.IsInstantiated (typeWithDefaultImplementedInterfaceMethod)) + if ((!implementationMethod.IsStatic && !Annotations.IsInstantiated (typeWithDefaultImplementedInterfaceMethod)) + || implementationMethod.IsStatic && !Annotations.IsRelevantToVariantCasting(typeWithDefaultImplementedInterfaceMethod)) return; + var origin = ScopeStack.CurrentScope.Origin; + MarkMethod(implementationMethod, new DependencyInfo(DependencyKind.Unspecified, implementation), in origin); + MarkInterfaceImplementation (implementation); } @@ -2556,11 +2560,6 @@ bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformat if (Annotations.IsMarked (method)) return false; - // If the override is a DIM that provides an implementation for a type that requires the interface, we may need the DIM - // Technically we only need it if it's the most derived DIM, but we can overmark slightly here - if (Annotations.GetDefaultInterfaceImplementations (@base).Any (dim => dim.DefaultInterfaceMethod == method && Annotations.IsRelevantToVariantCasting (dim.ImplementingType))) - return true; - // If the interface implementation is not marked, do not mark the implementation method // A type that doesn't implement the interface isn't required to have methods that implement the interface. InterfaceImplementation? iface = overrideInformation.MatchingInterfaceImplementation; From 64e9f866f638a16e5c0e5f3d97ff67e3066d30fc Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:17:06 -0800 Subject: [PATCH 12/19] Undo unrelated changes --- src/tools/illink/src/linker/Linker.Steps/MarkStep.cs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index bd9b132dc9807a..0c1ee0f7a3cbf3 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -3009,7 +3009,7 @@ void MarkMethodCollection (IList methods, in DependencyInfo re protected virtual MethodDefinition? MarkMethod (MethodReference reference, DependencyInfo reason, in MessageOrigin origin) { DependencyKind originalReasonKind = reason.Kind; - (reference, reason) = GetOriginalMethod (reference, reason, in origin); + (reference, reason) = GetOriginalMethod (reference, reason); if (reference.DeclaringType is ArrayType arrayType) { MarkType (reference.DeclaringType, new DependencyInfo (DependencyKind.DeclaringType, reference)); @@ -3180,15 +3180,14 @@ internal static void ReportRequiresUnreferencedCode (string displayName, Require diagnosticContext.AddDiagnostic (DiagnosticId.RequiresUnreferencedCode, displayName, arg1, arg2); } - protected (MethodReference, DependencyInfo) GetOriginalMethod (MethodReference method, DependencyInfo reason, in MessageOrigin origin) + protected (MethodReference, DependencyInfo) GetOriginalMethod (MethodReference method, DependencyInfo reason) { while (method is MethodSpecification specification) { // Blame the method reference (which isn't marked) on the original reason. Tracer.AddDirectDependency (specification, reason, marked: false); // Blame the outgoing element method on the specification. - if (method is GenericInstanceMethod gim) { + if (method is GenericInstanceMethod gim) MarkGenericArguments (gim); - } (method, reason) = (specification.ElementMethod, new DependencyInfo (DependencyKind.ElementMethod, specification)); Debug.Assert (!(method is MethodSpecification)); From 57ca7d19956fb06ac0c4bf3298ba796264597e0a Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:15:04 -0800 Subject: [PATCH 13/19] Get rid of _interfaceOverrides, use Annotations.GetOverrides and Annotations.GetDIMs --- .../src/linker/Linker.Steps/MarkStep.cs | 65 +++++++++++++------ 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index 0c1ee0f7a3cbf3..887d0d6e973d2b 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -65,7 +65,6 @@ protected LinkContext Context { readonly List _ivt_attributes; protected Queue<(AttributeProviderPair, DependencyInfo, MarkScopeStack.Scope)> _lateMarkedAttributes; protected List<(TypeDefinition, MarkScopeStack.Scope)> _typesWithInterfaces; - protected HashSet<(OverrideInformation, MarkScopeStack.Scope)> _interfaceOverrides; protected HashSet _dynamicInterfaceCastableImplementationTypesDiscovered; protected List _dynamicInterfaceCastableImplementationTypes; protected List<(MethodBody, MarkScopeStack.Scope)> _unreachableBodies; @@ -226,7 +225,6 @@ public MarkStep () _ivt_attributes = new List (); _lateMarkedAttributes = new Queue<(AttributeProviderPair, DependencyInfo, MarkScopeStack.Scope)> (); _typesWithInterfaces = new List<(TypeDefinition, MarkScopeStack.Scope)> (); - _interfaceOverrides = new HashSet<(OverrideInformation, MarkScopeStack.Scope)> (); _dynamicInterfaceCastableImplementationTypesDiscovered = new HashSet (); _dynamicInterfaceCastableImplementationTypes = new List (); _unreachableBodies = new List<(MethodBody, MarkScopeStack.Scope)> (); @@ -573,9 +571,17 @@ protected virtual void EnqueueMethod (MethodDefinition method, in DependencyInfo void ProcessVirtualMethods () { - foreach ((MethodDefinition method, MarkScopeStack.Scope scope) in _virtual_methods) { + var vms = _virtual_methods.ToArray (); + foreach((var method, var scope) in vms) { using (ScopeStack.PushScope (scope)) + { ProcessVirtualMethod (method); + if (method.DeclaringType.IsInterface) + { + + } + } + } } @@ -610,19 +616,21 @@ void ProcessMarkedTypesWithInterfaces () continue; foreach (var ov in baseOverrideInformations) { if (ov.Base.DeclaringType is not null && ov.Base.DeclaringType.IsInterface && IgnoreScope (ov.Base.DeclaringType.Scope)) - _interfaceOverrides.Add ((ov, ScopeStack.CurrentScope)); + { + _virtual_methods.Add((ov.Base, ScopeStack.CurrentScope)); + } } } } } - var interfaceOverrides = _interfaceOverrides.ToArray (); - foreach ((var overrideInformation, var scope) in interfaceOverrides) { - using (ScopeStack.PushScope (scope)) { - if (IsInterfaceImplementationMethodNeededByTypeDueToInterface (overrideInformation)) - MarkMethod (overrideInformation.Override, new DependencyInfo (DependencyKind.Override, overrideInformation.Base), scope.Origin); - } - } + // var interfaceOverrides = _interfaceOverrides.ToArray (); + // foreach ((var overrideInformation, var scope) in interfaceOverrides) { + // using (ScopeStack.PushScope (scope)) { + // if (IsInterfaceImplementationMethodNeededByTypeDueToInterface (overrideInformation)) + // MarkMethod (overrideInformation.Override, new DependencyInfo (DependencyKind.Override, overrideInformation.Base), scope.Origin); + // } + // } } void DiscoverDynamicCastableImplementationInterfaces () @@ -709,6 +717,22 @@ void ProcessVirtualMethod (MethodDefinition method) if (defaultImplementations != null) { foreach (var defaultImplementationInfo in defaultImplementations) { ProcessDefaultImplementation (defaultImplementationInfo.ImplementingType, defaultImplementationInfo.InterfaceImpl, defaultImplementationInfo.DefaultInterfaceMethod); + + } + } + if (method.DeclaringType.IsInterface) + { + var overridingMethods = Annotations.GetOverrides(method); + foreach(var ov in overridingMethods ?? []) + { + if(IsInterfaceImplementationMethodNeededByTypeDueToInterface(ov, ov.Override.DeclaringType)) + MarkMethod(ov.Override, new DependencyInfo(DependencyKind.Override, ov.Base), ScopeStack.CurrentScope.Origin); + } + foreach(var diminfo in defaultImplementations ?? []) + { + var ov = new OverrideInformation(method, diminfo.DefaultInterfaceMethod, Context); + if (IsInterfaceImplementationMethodNeededByTypeDueToInterface(ov, diminfo.ImplementingType)) + MarkMethod(ov.Override, new DependencyInfo(DependencyKind.Override, ov.Base), ScopeStack.CurrentScope.Origin); } } } @@ -722,13 +746,12 @@ void ProcessVirtualMethod (MethodDefinition method) bool ShouldMarkOverrideForBase (OverrideInformation overrideInformation) { Debug.Assert (Annotations.IsMarked (overrideInformation.Base) || IgnoreScope (overrideInformation.Base.DeclaringType.Scope)); - if (!Annotations.IsMarked (overrideInformation.Override.DeclaringType)) + if (overrideInformation.IsOverrideOfInterfaceMember) return false; - if (overrideInformation.IsOverrideOfInterfaceMember) { - _interfaceOverrides.Add ((overrideInformation, ScopeStack.CurrentScope)); + if (!Annotations.IsMarked (overrideInformation.Override.DeclaringType)) return false; - } + if (!Context.IsOptimizationEnabled (CodeOptimizations.OverrideRemoval, overrideInformation.Override)) return true; @@ -822,8 +845,8 @@ void ProcessDefaultImplementation (TypeDefinition typeWithDefaultImplementedInte || implementationMethod.IsStatic && !Annotations.IsRelevantToVariantCasting(typeWithDefaultImplementedInterfaceMethod)) return; - var origin = ScopeStack.CurrentScope.Origin; - MarkMethod(implementationMethod, new DependencyInfo(DependencyKind.Unspecified, implementation), in origin); + // var origin = ScopeStack.CurrentScope.Origin; + // MarkMethod(implementationMethod, new DependencyInfo(DependencyKind.Unspecified, implementation), in origin); MarkInterfaceImplementation (implementation); } @@ -2549,7 +2572,7 @@ bool IsMethodNeededByTypeDueToPreservedScope (MethodDefinition method) /// /// Returns true if the override method is required due to the interface that the base method is declared on. See doc at for explanation of logic. /// - bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformation overrideInformation) + bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformation overrideInformation, TypeDefinition typeThatImplsInterface) { var @base = overrideInformation.Base; var method = overrideInformation.Override; @@ -2574,18 +2597,18 @@ bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformat // If the interface method is abstract, mark the implementation method // The method is needed for valid IL. - if (@base.IsAbstract && !@method.DeclaringType.IsInterface) + if (@base.IsAbstract) return true; // If the method is static and the implementing type is relevant to variant casting, mark the implementation method. // A static method may only be called through a constrained call if the type is relevant to variant casting. if (@base.IsStatic) - return Annotations.IsRelevantToVariantCasting (method.DeclaringType) + return Annotations.IsRelevantToVariantCasting (typeThatImplsInterface) || IgnoreScope (@base.DeclaringType.Scope); // If the implementing type is marked as instantiated, mark the implementation method. // If the type is not instantiated, do not mark the implementation method - return Annotations.IsInstantiated (method.DeclaringType); + return Annotations.IsInstantiated (typeThatImplsInterface); } static bool IsSpecialSerializationConstructor (MethodDefinition method) From c47c7d31906be4d52c294d1b85356de8c4f2707c Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:52:30 -0800 Subject: [PATCH 14/19] Clean up changes --- .../src/linker/Linker.Steps/MarkStep.cs | 69 +++++++------------ .../illink/src/linker/Linker/Annotations.cs | 9 +-- .../illink/src/linker/Linker/TypeMapInfo.cs | 2 + 3 files changed, 31 insertions(+), 49 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index 887d0d6e973d2b..6d7601797b017d 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -572,16 +572,10 @@ protected virtual void EnqueueMethod (MethodDefinition method, in DependencyInfo void ProcessVirtualMethods () { var vms = _virtual_methods.ToArray (); - foreach((var method, var scope) in vms) { - using (ScopeStack.PushScope (scope)) - { + foreach ((var method, var scope) in vms) { + using (ScopeStack.PushScope (scope)) { ProcessVirtualMethod (method); - if (method.DeclaringType.IsInterface) - { - - } } - } } @@ -609,28 +603,19 @@ void ProcessMarkedTypesWithInterfaces () !unusedInterfacesOptimizationEnabled) { MarkInterfaceImplementations (type); } - // OverrideInformation for interfaces in PreservedScope aren't added yet + // Interfaces in PreservedScope should have their methods added to _virtual_methods so that they are properly processed foreach (var method in type.Methods) { - var baseOverrideInformations = Annotations.GetBaseMethods (method); - if (baseOverrideInformations is null) + var baseMethods = Annotations.GetBaseMethods (method); + if (baseMethods is null) continue; - foreach (var ov in baseOverrideInformations) { - if (ov.Base.DeclaringType is not null && ov.Base.DeclaringType.IsInterface && IgnoreScope (ov.Base.DeclaringType.Scope)) - { - _virtual_methods.Add((ov.Base, ScopeStack.CurrentScope)); + foreach (var ov in baseMethods) { + if (ov.Base.DeclaringType is not null && ov.Base.DeclaringType.IsInterface && IgnoreScope (ov.Base.DeclaringType.Scope)) { + _virtual_methods.Add ((ov.Base, ScopeStack.CurrentScope)); } } } } } - - // var interfaceOverrides = _interfaceOverrides.ToArray (); - // foreach ((var overrideInformation, var scope) in interfaceOverrides) { - // using (ScopeStack.PushScope (scope)) { - // if (IsInterfaceImplementationMethodNeededByTypeDueToInterface (overrideInformation)) - // MarkMethod (overrideInformation.Override, new DependencyInfo (DependencyKind.Override, overrideInformation.Base), scope.Origin); - // } - // } } void DiscoverDynamicCastableImplementationInterfaces () @@ -713,26 +698,23 @@ void ProcessVirtualMethod (MethodDefinition method) { Annotations.EnqueueVirtualMethod (method); - var defaultImplementations = Annotations.GetDefaultInterfaceImplementations (method); - if (defaultImplementations != null) { - foreach (var defaultImplementationInfo in defaultImplementations) { - ProcessDefaultImplementation (defaultImplementationInfo.ImplementingType, defaultImplementationInfo.InterfaceImpl, defaultImplementationInfo.DefaultInterfaceMethod); + if (method.DeclaringType.IsInterface) { + var defaultImplementations = Annotations.GetDefaultInterfaceImplementations (method); + if (defaultImplementations is not null) { + foreach (var dimInfo in defaultImplementations) { + ProcessDefaultImplementation (dimInfo.ImplementingType, dimInfo.InterfaceImpl, dimInfo.DefaultInterfaceMethod); + var ov = new OverrideInformation (method, dimInfo.DefaultInterfaceMethod, Context); + if (IsInterfaceImplementationMethodNeededByTypeDueToInterface (ov, dimInfo.ImplementingType)) + MarkMethod (ov.Override, new DependencyInfo (DependencyKind.Override, ov.Base), ScopeStack.CurrentScope.Origin); + } } - } - if (method.DeclaringType.IsInterface) - { - var overridingMethods = Annotations.GetOverrides(method); - foreach(var ov in overridingMethods ?? []) - { - if(IsInterfaceImplementationMethodNeededByTypeDueToInterface(ov, ov.Override.DeclaringType)) - MarkMethod(ov.Override, new DependencyInfo(DependencyKind.Override, ov.Base), ScopeStack.CurrentScope.Origin); - } - foreach(var diminfo in defaultImplementations ?? []) - { - var ov = new OverrideInformation(method, diminfo.DefaultInterfaceMethod, Context); - if (IsInterfaceImplementationMethodNeededByTypeDueToInterface(ov, diminfo.ImplementingType)) - MarkMethod(ov.Override, new DependencyInfo(DependencyKind.Override, ov.Base), ScopeStack.CurrentScope.Origin); + var overridingMethods = Annotations.GetOverrides (method); + if (overridingMethods is not null) { + foreach (var ov in overridingMethods ?? []) { + if (IsInterfaceImplementationMethodNeededByTypeDueToInterface (ov, ov.Override.DeclaringType)) + MarkMethod (ov.Override, new DependencyInfo (DependencyKind.Override, ov.Base), ScopeStack.CurrentScope.Origin); + } } } } @@ -842,12 +824,9 @@ bool RequiresInterfaceRecursively (TypeDefinition typeToExamine, TypeDefinition void ProcessDefaultImplementation (TypeDefinition typeWithDefaultImplementedInterfaceMethod, InterfaceImplementation implementation, MethodDefinition implementationMethod) { if ((!implementationMethod.IsStatic && !Annotations.IsInstantiated (typeWithDefaultImplementedInterfaceMethod)) - || implementationMethod.IsStatic && !Annotations.IsRelevantToVariantCasting(typeWithDefaultImplementedInterfaceMethod)) + || implementationMethod.IsStatic && !Annotations.IsRelevantToVariantCasting (typeWithDefaultImplementedInterfaceMethod)) return; - // var origin = ScopeStack.CurrentScope.Origin; - // MarkMethod(implementationMethod, new DependencyInfo(DependencyKind.Unspecified, implementation), in origin); - MarkInterfaceImplementation (implementation); } diff --git a/src/tools/illink/src/linker/Linker/Annotations.cs b/src/tools/illink/src/linker/Linker/Annotations.cs index 2f12b3e6cb185c..1236a274d0684b 100644 --- a/src/tools/illink/src/linker/Linker/Annotations.cs +++ b/src/tools/illink/src/linker/Linker/Annotations.cs @@ -447,22 +447,23 @@ public bool IsPublic (IMetadataTokenProvider provider) } /// - /// Returns a list of all known methods that override . The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet + /// Returns a list of all known methods that override . + /// The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet /// public IEnumerable? GetOverrides (MethodDefinition method) { return TypeMapInfo.GetOverrides (method); } - public IEnumerable<(TypeDefinition ImplementingType, InterfaceImplementation InterfaceImpl, MethodDefinition DefaultInterfaceMethod)> GetDefaultInterfaceImplementations (MethodDefinition method) + public IEnumerable<(TypeDefinition ImplementingType, InterfaceImplementation InterfaceImpl, MethodDefinition DefaultInterfaceMethod)>? GetDefaultInterfaceImplementations (MethodDefinition method) { - return TypeMapInfo.GetDefaultInterfaceImplementations (method) ?? []; + return TypeMapInfo.GetDefaultInterfaceImplementations (method); } /// /// Returns all base methods that overrides. /// This includes methods on 's declaring type's base type (but not methods higher up in the type hierarchy), - /// methods on an interface that 's delcaring type implements, + /// methods on an interface that 's declaring type implements, /// and methods an interface implemented by a derived type of 's declaring type if the derived type uses as the implementing method. /// The list may be incomplete if there are derived types in assemblies that havent been processed yet that use to implement an interface. /// diff --git a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs index cd3a7a06728282..8f58ce13d9ae39 100644 --- a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs +++ b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs @@ -30,6 +30,7 @@ // using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using Mono.Cecil; @@ -112,6 +113,7 @@ public void AddOverride (MethodDefinition @base, MethodDefinition @override, Int public void AddDefaultInterfaceImplementation (MethodDefinition @base, TypeDefinition implementingType, (InterfaceImplementation, MethodDefinition) matchingInterfaceImplementation) { + Debug.Assert(@base.DeclaringType.IsInterface); if (!default_interface_implementations.TryGetValue (@base, out var implementations)) { implementations = new List<(TypeDefinition, InterfaceImplementation, MethodDefinition)> (); default_interface_implementations.Add (@base, implementations); From 694443a322b2e804e6346a8f9fc24d7a0b979327 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Mon, 5 Feb 2024 20:18:46 -0800 Subject: [PATCH 15/19] Undo moving lines, update doc comments --- src/tools/illink/src/linker/Linker.Steps/MarkStep.cs | 8 ++++---- src/tools/illink/src/linker/Linker/Annotations.cs | 7 +++++++ src/tools/illink/src/linker/Linker/TypeMapInfo.cs | 7 +++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index 6d7601797b017d..2a764390be09b9 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -728,11 +728,10 @@ void ProcessVirtualMethod (MethodDefinition method) bool ShouldMarkOverrideForBase (OverrideInformation overrideInformation) { Debug.Assert (Annotations.IsMarked (overrideInformation.Base) || IgnoreScope (overrideInformation.Base.DeclaringType.Scope)); - if (overrideInformation.IsOverrideOfInterfaceMember) - return false; - if (!Annotations.IsMarked (overrideInformation.Override.DeclaringType)) return false; + if (overrideInformation.IsOverrideOfInterfaceMember) + return false; if (!Context.IsOptimizationEnabled (CodeOptimizations.OverrideRemoval, overrideInformation.Override)) return true; @@ -3766,7 +3765,8 @@ protected virtual void MarkInstruction (Instruction instruction, MethodDefinitio ScopeStack.UpdateCurrentScopeInstructionOffset (instruction.Offset); if (markForReflectionAccess) { MarkMethodVisibleToReflection (methodReference, new DependencyInfo (dependencyKind, method), ScopeStack.CurrentScope.Origin); - } else { + } + else { MarkMethod (methodReference, new DependencyInfo (dependencyKind, method), ScopeStack.CurrentScope.Origin); } break; diff --git a/src/tools/illink/src/linker/Linker/Annotations.cs b/src/tools/illink/src/linker/Linker/Annotations.cs index 1236a274d0684b..8f7747cba3543a 100644 --- a/src/tools/illink/src/linker/Linker/Annotations.cs +++ b/src/tools/illink/src/linker/Linker/Annotations.cs @@ -455,6 +455,13 @@ public bool IsPublic (IMetadataTokenProvider provider) return TypeMapInfo.GetOverrides (method); } + /// + /// Returns a list of all default interface methods that implement for a type. + /// ImplementingType is the type that implements the interface, + /// InterfaceImpl is the for the interface is declared on, and + /// DefaultInterfaceMethod is the method that implements . + /// + /// The interface method to find default implementations for public IEnumerable<(TypeDefinition ImplementingType, InterfaceImplementation InterfaceImpl, MethodDefinition DefaultInterfaceMethod)>? GetDefaultInterfaceImplementations (MethodDefinition method) { return TypeMapInfo.GetDefaultInterfaceImplementations (method); diff --git a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs index 8f58ce13d9ae39..8f9b16d13082a0 100644 --- a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs +++ b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs @@ -85,6 +85,13 @@ public void EnsureProcessed (AssemblyDefinition assembly) return bases; } + /// + /// Returns a list of all default interface methods that implement for a type. + /// ImplementingType is the type that implements the interface, + /// InterfaceImpl is the for the interface is declared on, and + /// DefaultInterfaceMethod is the method that implements . + /// + /// The interface method to find default implementations for public IEnumerable<(TypeDefinition ImplementingType, InterfaceImplementation InterfaceImpl, MethodDefinition DefaultImplementationMethod)>? GetDefaultInterfaceImplementations (MethodDefinition baseMethod) { default_interface_implementations.TryGetValue (baseMethod, out var ret); From 42d87649193acb166449b629f709f304f11f0019 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Wed, 7 Feb 2024 11:01:39 -0800 Subject: [PATCH 16/19] Remove redundant test --- ...terfaces.DefaultInterfaceMethodsTests.g.cs | 6 -- ...ecificDefaultImplementationKeptInstance.cs | 1 - ...efaultInterfaceMethodOnDerivedInterface.cs | 68 ------------------- 3 files changed, 75 deletions(-) delete mode 100644 src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/StaticDefaultInterfaceMethodOnDerivedInterface.cs diff --git a/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/Inheritance.Interfaces.DefaultInterfaceMethodsTests.g.cs b/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/Inheritance.Interfaces.DefaultInterfaceMethodsTests.g.cs index 0d38c40a400a57..4b3f387a390151 100644 --- a/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/Inheritance.Interfaces.DefaultInterfaceMethodsTests.g.cs +++ b/src/tools/illink/test/ILLink.RoslynAnalyzer.Tests/generated/ILLink.RoslynAnalyzer.Tests.Generator/ILLink.RoslynAnalyzer.Tests.TestCaseGenerator/Inheritance.Interfaces.DefaultInterfaceMethodsTests.g.cs @@ -45,12 +45,6 @@ public Task SimpleDefaultInterfaceMethod () return RunTest (allowMissingWarnings: true); } - [Fact] - public Task StaticDefaultInterfaceMethodOnDerivedInterface () - { - return RunTest (allowMissingWarnings: true); - } - [Fact] public Task StaticDefaultInterfaceMethodOnStruct () { diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs index c4b9f231a7065b..6f5c02b5fad53d 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptInstance.cs @@ -11,7 +11,6 @@ public static void Main () M (new UsedAsIBase()); } - [Kept] static int M (IBase ibase) { diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/StaticDefaultInterfaceMethodOnDerivedInterface.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/StaticDefaultInterfaceMethodOnDerivedInterface.cs deleted file mode 100644 index eeb626cd5b95d7..00000000000000 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/StaticDefaultInterfaceMethodOnDerivedInterface.cs +++ /dev/null @@ -1,68 +0,0 @@ - -using Mono.Linker.Tests.Cases.Expectations.Assertions; - -namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.DefaultInterfaceMethods -{ - [TestCaseRequirements (TestRunCharacteristics.SupportsDefaultInterfaceMethods, "Requires support for default interface methods")] - class StaticDefaultInterfaceMethodOnDerivedInterface - { - [Kept] - public static void Main () - { -#if SUPPORTS_DEFAULT_INTERFACE_METHODS - M(); -#endif - } - -#if SUPPORTS_DEFAULT_INTERFACE_METHODS - - [Kept] - static int M() where T : IBase { - return T.Value; - } - - [Kept] - interface IBase { - [Kept] - static abstract int Value - { - [Kept] - get; - } - } - - [Kept] - [KeptInterface(typeof(IBase))] - interface IMiddle : IBase { - [Kept] // Should be removable -- Add link to bug before merge - static int IBase.Value - { - [Kept] // Should be removable -- Add link to bug before merge - get=>1; - } - } - - [Kept] - [KeptInterface(typeof(IBase))] - [KeptInterface(typeof(IMiddle))] - interface IDerived : IMiddle { - [Kept] - static int IBase.Value - { - [Kept] - get=>2; - } - } - - interface INotReferenced - {} - - [Kept] - [KeptInterface(typeof(IDerived))] - [KeptInterface(typeof(IMiddle))] - [KeptInterface(typeof(IBase))] - struct Instance : IDerived, INotReferenced { - } -#endif - } -} From 386f6448a76d0e89d1ff7d6dd0b1cd29769f30b8 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:11:31 -0800 Subject: [PATCH 17/19] Add test types to make sure all DIMs aren't kept --- ...SpecificDefaultImplementationKeptStatic.cs | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptStatic.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptStatic.cs index 9b9e1d25b55cbe..39e4ae308f65a1 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptStatic.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Inheritance.Interfaces/DefaultInterfaceMethods/MostSpecificDefaultImplementationKeptStatic.cs @@ -101,11 +101,32 @@ class NotUsedInGeneric : IDerived, INotReferenced public static void Keep () { } } + public interface IBaseUnused + { + public static virtual int Value { + get => 0; + } + } + + public interface IMiddleUnused : IBaseUnused + { + static int IBaseUnused.Value { + get => 1; + } + } + + public interface IDerivedUnused : IMiddleUnused + { + static int IBaseUnused.Value { + get => 2; + } + } + [Kept] [KeptInterface (typeof (IBase))] [KeptInterface (typeof (IMiddle))] [KeptInterface (typeof (IDerived2))] - class UsedInUnconstrainedGeneric : IDerived2, INotReferenced + class UsedInUnconstrainedGeneric : IDerived2, INotReferenced, IDerivedUnused { } From cab167f62f0e10926b08cecc9000037597687f17 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Thu, 8 Feb 2024 15:43:35 -0800 Subject: [PATCH 18/19] PR feedback --- src/tools/illink/src/linker/Linker.Steps/MarkStep.cs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs index 2a764390be09b9..bb5e95e0a38dd0 100644 --- a/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs +++ b/src/tools/illink/src/linker/Linker.Steps/MarkStep.cs @@ -571,8 +571,7 @@ protected virtual void EnqueueMethod (MethodDefinition method, in DependencyInfo void ProcessVirtualMethods () { - var vms = _virtual_methods.ToArray (); - foreach ((var method, var scope) in vms) { + foreach ((var method, var scope) in _virtual_methods) { using (ScopeStack.PushScope (scope)) { ProcessVirtualMethod (method); } @@ -711,7 +710,7 @@ void ProcessVirtualMethod (MethodDefinition method) } var overridingMethods = Annotations.GetOverrides (method); if (overridingMethods is not null) { - foreach (var ov in overridingMethods ?? []) { + foreach (var ov in overridingMethods) { if (IsInterfaceImplementationMethodNeededByTypeDueToInterface (ov, ov.Override.DeclaringType)) MarkMethod (ov.Override, new DependencyInfo (DependencyKind.Override, ov.Base), ScopeStack.CurrentScope.Origin); } From 51be916a2b6c369f1107842077c0fab5d26fccca Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Fri, 9 Feb 2024 17:34:10 -0800 Subject: [PATCH 19/19] Break early if the DIM is the interface method --- src/tools/illink/src/linker/Linker/TypeMapInfo.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs index 77f5e89b941ee0..a2f118adf9fb78 100644 --- a/src/tools/illink/src/linker/Linker/TypeMapInfo.cs +++ b/src/tools/illink/src/linker/Linker/TypeMapInfo.cs @@ -299,6 +299,8 @@ void FindAndAddDefaultInterfaceImplementations (TypeDefinition type, MethodDefin if (potentialImplMethod == interfaceMethod && !potentialImplMethod.IsAbstract) { AddDefaultInterfaceImplementation (interfaceMethod, type, (interfaceImpl, potentialImplMethod)); + foundImpl = true; + break; } if (!potentialImplMethod.HasOverrides)