From 57f691afbc6d092395dbd64e72a54a0ae55884bb Mon Sep 17 00:00:00 2001 From: David Wrighton <davidwr@microsoft.com> Date: Tue, 11 Jul 2023 16:41:24 -0700 Subject: [PATCH] Fix virtual static dispatch for variant interfaces when using default implementations (#88639) - Move the handling for variance so that it is in the same place as it is for instance methods - Validate the behavior of such dispatch is equivalent to the dispatch behavior involving instance variant default dispatch - Rework the fix for #80350 to avoid violating an invariant of the type system (which is that all MethodDef's of a TypeDef are loaded as MethodDescs when loading the open type). Fixes #78621 --- src/coreclr/vm/methodtable.cpp | 89 ++++++++---- src/coreclr/vm/methodtable.h | 2 +- src/coreclr/vm/methodtablebuilder.cpp | 7 - src/coreclr/vm/runtimehandles.cpp | 2 +- .../ComplexHierarchyPositive.cs | 133 ++++++++++++++++++ .../ComplexHierarchyPositive.csproj | 9 ++ src/tests/issues.targets | 9 ++ 7 files changed, 213 insertions(+), 38 deletions(-) create mode 100644 src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.cs create mode 100644 src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.csproj diff --git a/src/coreclr/vm/methodtable.cpp b/src/coreclr/vm/methodtable.cpp index 0a137ed547ef1..2709be3f7ef62 100644 --- a/src/coreclr/vm/methodtable.cpp +++ b/src/coreclr/vm/methodtable.cpp @@ -6349,7 +6349,7 @@ namespace candidateMaybe = interfaceMD; } } - else + else if (!interfaceMD->IsStatic()) { // // A more specific interface - search for an methodimpl for explicit override @@ -6413,17 +6413,18 @@ namespace } } } - else if (pMD->IsStatic() && pMD->HasMethodImplSlot()) - { - // Static virtual methods don't record MethodImpl slots so they need special handling - candidateMaybe = pMT->TryResolveVirtualStaticMethodOnThisType( - interfaceMT, - interfaceMD, - /* verifyImplemented */ FALSE, - /* level */ level); - } } } + else + { + // Static virtual methods don't record MethodImpl slots so they need special handling + candidateMaybe = pMT->TryResolveVirtualStaticMethodOnThisType( + interfaceMT, + interfaceMD, + /* verifyImplemented */ FALSE, + /* allowVariance */ allowVariance, + /* level */ level); + } } if (candidateMaybe == NULL) @@ -8911,7 +8912,7 @@ MethodTable::ResolveVirtualStaticMethod( // Search for match on a per-level in the type hierarchy for (MethodTable* pMT = this; pMT != nullptr; pMT = pMT->GetParentMethodTable()) { - MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented, level); + MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level); if (pMD != nullptr) { return pMD; @@ -8920,7 +8921,7 @@ MethodTable::ResolveVirtualStaticMethod( if (pInterfaceType->HasVariance() || pInterfaceType->HasTypeEquivalence()) { // Variant interface dispatch - MethodTable::InterfaceMapIterator it = IterateInterfaceMap(); + MethodTable::InterfaceMapIterator it = pMT->IterateInterfaceMap(); while (it.Next()) { if (it.CurrentInterfaceMatches(this, pInterfaceType)) @@ -8955,7 +8956,7 @@ MethodTable::ResolveVirtualStaticMethod( { // Variant or equivalent matching interface found // Attempt to resolve on variance matched interface - pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, verifyImplemented, level); + pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level); if (pMD != nullptr) { return pMD; @@ -8966,22 +8967,38 @@ MethodTable::ResolveVirtualStaticMethod( } MethodDesc *pMDDefaultImpl = nullptr; - BOOL haveUniqueDefaultImplementation = FindDefaultInterfaceImplementation( - pInterfaceMD, - pInterfaceType, - &pMDDefaultImpl, - /* allowVariance */ allowVariantMatches, - /* throwOnConflict */ uniqueResolution == nullptr, - level); - if (haveUniqueDefaultImplementation || (pMDDefaultImpl != nullptr && (verifyImplemented || uniqueResolution != nullptr))) + BOOL allowVariantMatchInDefaultImplementationLookup = FALSE; + do { - // We tolerate conflicts upon verification of implemented SVMs so that they only blow up when actually called at execution time. - if (uniqueResolution != nullptr) + BOOL haveUniqueDefaultImplementation = FindDefaultInterfaceImplementation( + pInterfaceMD, + pInterfaceType, + &pMDDefaultImpl, + /* allowVariance */ allowVariantMatchInDefaultImplementationLookup, + /* throwOnConflict */ uniqueResolution == nullptr, + level); + if (haveUniqueDefaultImplementation || (pMDDefaultImpl != nullptr && (verifyImplemented || uniqueResolution != nullptr))) { - *uniqueResolution = haveUniqueDefaultImplementation; + // We tolerate conflicts upon verification of implemented SVMs so that they only blow up when actually called at execution time. + if (uniqueResolution != nullptr) + { + // Always report a unique resolution when reporting results of a variant match + if (allowVariantMatchInDefaultImplementationLookup) + *uniqueResolution = TRUE; + else + *uniqueResolution = haveUniqueDefaultImplementation; + } + return pMDDefaultImpl; } - return pMDDefaultImpl; - } + + // We only loop at most twice here + if (allowVariantMatchInDefaultImplementationLookup) + { + break; + } + + allowVariantMatchInDefaultImplementationLookup = allowVariantMatches; + } while (allowVariantMatchInDefaultImplementationLookup); } // Default implementation logic, which only kicks in for default implementations when looking up on an exact interface target @@ -9001,7 +9018,7 @@ MethodTable::ResolveVirtualStaticMethod( // Try to locate the appropriate MethodImpl matching a given interface static virtual method. // Returns nullptr on failure. MethodDesc* -MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, ClassLoadLevel level) +MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level) { HRESULT hr = S_OK; IMDInternalImport* pMDInternalImport = GetMDImport(); @@ -9049,9 +9066,22 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType ClassLoader::LoadTypes, CLASS_LOAD_EXACTPARENTS) .GetMethodTable(); - if (pInterfaceMT != pInterfaceType) + + if (allowVariance) { - continue; + // Allow variant, but not equivalent interface match + if (!pInterfaceType->HasSameTypeDefAs(pInterfaceMT) || + !pInterfaceMT->CanCastTo(pInterfaceType, NULL)) + { + continue; + } + } + else + { + if (pInterfaceMT != pInterfaceType) + { + continue; + } } MethodDesc *pMethodDecl; @@ -9167,6 +9197,7 @@ MethodTable::VerifyThatAllVirtualStaticMethodsAreImplemented() BOOL uniqueResolution; if (pMD->IsVirtual() && pMD->IsStatic() && + !pMD->HasMethodImplSlot() && // Re-abstractions are virtual static abstract with a MethodImpl (pMD->IsAbstract() && !ResolveVirtualStaticMethod( pInterfaceMT, diff --git a/src/coreclr/vm/methodtable.h b/src/coreclr/vm/methodtable.h index e506299d9cd7f..9b347bf5ccc85 100644 --- a/src/coreclr/vm/methodtable.h +++ b/src/coreclr/vm/methodtable.h @@ -2219,7 +2219,7 @@ class MethodTable // Try to resolve a given static virtual method override on this type. Return nullptr // when not found. - MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, ClassLoadLevel level); + MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level); public: static MethodDesc *MapMethodDeclToMethodImpl(MethodDesc *pMDDecl); diff --git a/src/coreclr/vm/methodtablebuilder.cpp b/src/coreclr/vm/methodtablebuilder.cpp index 2a8d775ba57de..1c6eadbe5bdc4 100644 --- a/src/coreclr/vm/methodtablebuilder.cpp +++ b/src/coreclr/vm/methodtablebuilder.cpp @@ -3293,13 +3293,6 @@ MethodTableBuilder::EnumerateClassMethods() } } - if (implType == METHOD_IMPL && isStaticVirtual && IsMdAbstract(dwMemberAttrs)) - { - // Don't record reabstracted static virtual methods as they don't constitute - // actual method slots, they're just markers used by the static virtual lookup. - continue; - } - // For delegates we don't allow any non-runtime implemented bodies // for any of the four special methods if (IsDelegate() && !IsMiRuntime(dwImplFlags)) diff --git a/src/coreclr/vm/runtimehandles.cpp b/src/coreclr/vm/runtimehandles.cpp index e0eba7e172613..ef5a2dcaaf93a 100644 --- a/src/coreclr/vm/runtimehandles.cpp +++ b/src/coreclr/vm/runtimehandles.cpp @@ -1104,7 +1104,7 @@ extern "C" MethodDesc* QCALLTYPE RuntimeTypeHandle_GetInterfaceMethodImplementat pMD, /* allowNullResult */ TRUE, /* verifyImplemented*/ FALSE, - /*allowVariantMatches */ TRUE); + /* allowVariantMatches */ TRUE); } else { diff --git a/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.cs b/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.cs new file mode 100644 index 0000000000000..188686795ff9a --- /dev/null +++ b/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.cs @@ -0,0 +1,133 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime; +using Xunit; + +// This regression test tracks the issue where variant static interface dispatch crashes the runtime, and behaves incorrectly + +namespace VariantStaticInterfaceDispatchRegressionTest +{ + class Test + { + static int Main() + { + Console.WriteLine("Test cases"); + + Console.WriteLine("---FooBar"); + TestTheFooString<FooBar, Base>("IFoo<Base>"); + TestTheFooString<FooBar, Mid>("IFoo<Base>"); + TestTheFooString<FooBar, Derived>("IFoo<Base>"); + + TestTheBarString<FooBar, Base>("IBar<Derived>"); + TestTheBarString<FooBar, Mid>("IBar<Derived>"); + TestTheBarString<FooBar, Derived>("IBar<Derived>"); + + Console.WriteLine("---FooBar2"); + TestTheFooString<FooBar2, Base>("IFoo<Base>"); + TestTheFooString<FooBar2, Mid>("IFoo<Mid>"); + TestTheFooString<FooBar2, Derived>("IFoo<Base>"); + + TestTheBarString<FooBar2, Base>("IBar<Derived>"); + TestTheBarString<FooBar2, Mid>("IBar<Mid>"); + TestTheBarString<FooBar2, Derived>("IBar<Derived>"); + + Console.WriteLine("---FooBarBaz"); + TestTheFooString<FooBarBaz, Base>("IFoo<Base>"); + TestTheFooString<FooBarBaz, Mid>("IBaz"); + TestTheFooString<FooBarBaz, Derived>("IBaz"); + + TestTheBarString<FooBarBaz, Base>("IBaz"); + TestTheBarString<FooBarBaz, Mid>("IBaz"); + TestTheBarString<FooBarBaz, Derived>("IBar<Derived>"); + + Console.WriteLine("---FooBarBazBoz"); + TestTheFooString<FooBarBazBoz, Base>("IBoz"); + TestTheFooString<FooBarBazBoz, Mid>("IBaz"); + TestTheFooString<FooBarBazBoz, Derived>("IBoz"); + + TestTheBarString<FooBarBazBoz, Base>("IBoz"); + TestTheBarString<FooBarBazBoz, Mid>("IBaz"); + TestTheBarString<FooBarBazBoz, Derived>("IBoz"); + + Console.WriteLine("---FooBarBaz2"); + TestTheFooString<FooBarBaz2, Base>("IFoo<Base>"); + TestTheFooString<FooBarBaz2, Mid>("IBaz"); + TestTheFooString<FooBarBaz2, Derived>("IFoo<Base>"); + + TestTheBarString<FooBarBaz2, Base>("IBar<Derived>"); + TestTheBarString<FooBarBaz2, Mid>("IBaz"); + TestTheBarString<FooBarBaz2, Derived>("IBar<Derived>"); + + Console.WriteLine("---FooBarBazBoz2"); + TestTheFooString<FooBarBazBoz2, Base>("IBoz"); + TestTheFooString<FooBarBazBoz2, Mid>("IBaz"); + TestTheFooString<FooBarBazBoz2, Derived>("IBoz"); + + TestTheBarString<FooBarBazBoz2, Base>("IBoz"); + TestTheBarString<FooBarBazBoz2, Mid>("IBaz"); + TestTheBarString<FooBarBazBoz2, Derived>("IBoz"); + return 100; + } + + static string GetTheFooString<T, U>() where T : IFoo<U> { try { return T.GetString(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } } + static string GetTheBarString<T, U>() where T : IBar<U> { try { return T.GetString(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } } + static string GetTheFooStringInstance<T, U>() where T : IFoo<U>, new() { try { return (new T()).GetStringInstance(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } } + static string GetTheBarStringInstance<T, U>() where T : IBar<U>, new() { try { return (new T()).GetStringInstance(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } } + + static void TestTheFooString<T, U>(string expected) where T : IFoo<U>, new() + { + Console.WriteLine($"TestTheFooString {typeof(T).Name} {typeof(T).Name} {expected}"); + Assert.Equal(expected, GetTheFooString<T, U>()); + Assert.Equal(expected, GetTheFooStringInstance<T, U>()); + } + + static void TestTheBarString<T, U>(string expected) where T : IBar<U>, new() + { + Console.WriteLine($"TestTheBarString {typeof(T).Name} {typeof(T).Name} {expected}"); + Assert.Equal(expected, GetTheBarString<T, U>()); + Assert.Equal(expected, GetTheBarStringInstance<T, U>()); + } + + interface IFoo<in T> + { + static virtual string GetString() => $"IFoo<{typeof(T).Name}>"; + virtual string GetStringInstance() => $"IFoo<{typeof(T).Name}>"; + }; + + interface IBar<out T> + { + static virtual string GetString() => $"IBar<{typeof(T).Name}>"; + virtual string GetStringInstance() => $"IBar<{typeof(T).Name}>"; + }; + + + interface IBaz : IFoo<Mid>, IBar<Mid> + { + static string IFoo<Mid>.GetString() => "IBaz"; + static string IBar<Mid>.GetString() => "IBaz"; + string IFoo<Mid>.GetStringInstance() => "IBaz"; + string IBar<Mid>.GetStringInstance() => "IBaz"; + } + + interface IBoz : IFoo<Base>, IBar<Derived> + { + static string IFoo<Base>.GetString() => "IBoz"; + static string IBar<Derived>.GetString() => "IBoz"; + string IFoo<Base>.GetStringInstance() => "IBoz"; + string IBar<Derived>.GetStringInstance() => "IBoz"; + } + + class FooBar : IFoo<Base>, IBar<Derived> { } + class FooBar2 : IFoo<Base>, IBar<Derived>, IFoo<Mid>, IBar<Mid> { } + class FooBarBaz : FooBar, IBaz { } + class FooBarBaz2 : IFoo<Base>, IBar<Derived>, IBaz { } // Implementation with all interfaces defined on the same type + class FooBarBazBoz : FooBarBaz, IBoz { } + class FooBarBazBoz2 : IFoo<Base>, IBar<Derived>, IBaz, IBoz { } // Implementation with all interfaces defined on the same type + + class Base { } + class Mid : Base { } + class Derived : Mid { } + } +} diff --git a/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.csproj b/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.csproj new file mode 100644 index 0000000000000..7f0c12d57d8e5 --- /dev/null +++ b/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.csproj @@ -0,0 +1,9 @@ +<Project Sdk="Microsoft.NET.Sdk"> + <PropertyGroup> + <AllowUnsafeBlocks>true</AllowUnsafeBlocks> + <OutputType>Exe</OutputType> + </PropertyGroup> + <ItemGroup> + <Compile Include="$(MSBuildProjectName).cs" /> + </ItemGroup> +</Project> diff --git a/src/tests/issues.targets b/src/tests/issues.targets index dcf0616de702d..a1d35dac9357e 100644 --- a/src/tests/issues.targets +++ b/src/tests/issues.targets @@ -694,6 +694,9 @@ <!-- NativeAOT specific --> <ItemGroup Condition="'$(XunitTestBinBase)' != '' and '$(TestBuildMode)' == 'nativeaot' and '$(RuntimeFlavor)' == 'coreclr'"> + <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/InterfaceVariance/**"> + <Issue>https://github.com/dotnet/runtime/issues/88690</Issue> + </ExcludeList> <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/NegativeTestCases/**"> <Issue>https://github.com/dotnet/runtimelab/issues/155: Compatible TypeLoadException for invalid inputs</Issue> </ExcludeList> @@ -703,6 +706,9 @@ <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/DiamondShape/svm_diamondshape_r/*"> <Issue>https://github.com/dotnet/runtime/issues/72589</Issue> </ExcludeList> + <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/NegativeTestCases/**"> + <Issue>https://github.com/dotnet/runtimelab/issues/155: Compatible TypeLoadException for invalid inputs</Issue> + </ExcludeList> <ExcludeList Include="$(XunitTestBinBase)/baseservices/RuntimeConfiguration/TestConfig/*"> <Issue>Test expects being run with corerun</Issue> </ExcludeList> @@ -1248,6 +1254,9 @@ <!-- Known failures for mono runtime on *all* architectures/operating systems in *all* runtime modes --> <ItemGroup Condition="'$(RuntimeFlavor)' == 'mono'" > + <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/InterfaceVariance/**"> + <Issue>https://github.com/dotnet/runtime/issues/88689</Issue> + </ExcludeList> <ExcludeList Include="$(XunitTestBinBase)/Regressions/coreclr/GitHub_85240/test85240/**"> <Issue>https://github.com/dotnet/runtime/issues/71095</Issue> </ExcludeList>