From 57f691afbc6d092395dbd64e72a54a0ae55884bb Mon Sep 17 00:00:00 2001
From: David Wrighton <davidwr@microsoft.com>
Date: Tue, 11 Jul 2023 16:41:24 -0700
Subject: [PATCH] Fix virtual static dispatch for variant interfaces when using
 default implementations (#88639)

- Move the handling for variance so that it is in the same place as it is for instance methods
- Validate the behavior of such dispatch is equivalent to the dispatch behavior involving instance variant default dispatch
- Rework the fix for #80350 to avoid violating an invariant of the type system (which is that all MethodDef's of a TypeDef are loaded as MethodDescs when loading the open type).

Fixes #78621
---
 src/coreclr/vm/methodtable.cpp                |  89 ++++++++----
 src/coreclr/vm/methodtable.h                  |   2 +-
 src/coreclr/vm/methodtablebuilder.cpp         |   7 -
 src/coreclr/vm/runtimehandles.cpp             |   2 +-
 .../ComplexHierarchyPositive.cs               | 133 ++++++++++++++++++
 .../ComplexHierarchyPositive.csproj           |   9 ++
 src/tests/issues.targets                      |   9 ++
 7 files changed, 213 insertions(+), 38 deletions(-)
 create mode 100644 src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.cs
 create mode 100644 src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.csproj

diff --git a/src/coreclr/vm/methodtable.cpp b/src/coreclr/vm/methodtable.cpp
index 0a137ed547ef1..2709be3f7ef62 100644
--- a/src/coreclr/vm/methodtable.cpp
+++ b/src/coreclr/vm/methodtable.cpp
@@ -6349,7 +6349,7 @@ namespace
                     candidateMaybe = interfaceMD;
                 }
             }
-            else
+            else if (!interfaceMD->IsStatic())
             {
                 //
                 // A more specific interface - search for an methodimpl for explicit override
@@ -6413,17 +6413,18 @@ namespace
                             }
                         }
                     }
-                    else if (pMD->IsStatic() && pMD->HasMethodImplSlot())
-                    {
-                        // Static virtual methods don't record MethodImpl slots so they need special handling
-                        candidateMaybe = pMT->TryResolveVirtualStaticMethodOnThisType(
-                            interfaceMT,
-                            interfaceMD,
-                            /* verifyImplemented */ FALSE,
-                            /* level */ level);
-                    }
                 }
             }
+            else
+            {
+                // Static virtual methods don't record MethodImpl slots so they need special handling
+                candidateMaybe = pMT->TryResolveVirtualStaticMethodOnThisType(
+                    interfaceMT,
+                    interfaceMD,
+                    /* verifyImplemented */ FALSE,
+                    /* allowVariance */ allowVariance,
+                    /* level */ level);
+            }
         }
 
         if (candidateMaybe == NULL)
@@ -8911,7 +8912,7 @@ MethodTable::ResolveVirtualStaticMethod(
             // Search for match on a per-level in the type hierarchy
             for (MethodTable* pMT = this; pMT != nullptr; pMT = pMT->GetParentMethodTable())
             {
-                MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented, level);
+                MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
                 if (pMD != nullptr)
                 {
                     return pMD;
@@ -8920,7 +8921,7 @@ MethodTable::ResolveVirtualStaticMethod(
                 if (pInterfaceType->HasVariance() || pInterfaceType->HasTypeEquivalence())
                 {
                     // Variant interface dispatch
-                    MethodTable::InterfaceMapIterator it = IterateInterfaceMap();
+                    MethodTable::InterfaceMapIterator it = pMT->IterateInterfaceMap();
                     while (it.Next())
                     {
                         if (it.CurrentInterfaceMatches(this, pInterfaceType))
@@ -8955,7 +8956,7 @@ MethodTable::ResolveVirtualStaticMethod(
                         {
                             // Variant or equivalent matching interface found
                             // Attempt to resolve on variance matched interface
-                            pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, verifyImplemented, level);
+                            pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
                             if (pMD != nullptr)
                             {
                                 return pMD;
@@ -8966,22 +8967,38 @@ MethodTable::ResolveVirtualStaticMethod(
             }
 
             MethodDesc *pMDDefaultImpl = nullptr;
-            BOOL haveUniqueDefaultImplementation = FindDefaultInterfaceImplementation(
-                pInterfaceMD,
-                pInterfaceType,
-                &pMDDefaultImpl,
-                /* allowVariance */ allowVariantMatches,
-                /* throwOnConflict */ uniqueResolution == nullptr,
-                level);
-            if (haveUniqueDefaultImplementation || (pMDDefaultImpl != nullptr && (verifyImplemented || uniqueResolution != nullptr)))
+            BOOL allowVariantMatchInDefaultImplementationLookup = FALSE;
+            do
             {
-                // We tolerate conflicts upon verification of implemented SVMs so that they only blow up when actually called at execution time.
-                if (uniqueResolution != nullptr)
+                BOOL haveUniqueDefaultImplementation = FindDefaultInterfaceImplementation(
+                    pInterfaceMD,
+                    pInterfaceType,
+                    &pMDDefaultImpl,
+                    /* allowVariance */ allowVariantMatchInDefaultImplementationLookup,
+                    /* throwOnConflict */ uniqueResolution == nullptr,
+                    level);
+                if (haveUniqueDefaultImplementation || (pMDDefaultImpl != nullptr && (verifyImplemented || uniqueResolution != nullptr)))
                 {
-                    *uniqueResolution = haveUniqueDefaultImplementation;
+                    // We tolerate conflicts upon verification of implemented SVMs so that they only blow up when actually called at execution time.
+                    if (uniqueResolution != nullptr)
+                    {
+                        // Always report a unique resolution when reporting results of a variant match
+                        if (allowVariantMatchInDefaultImplementationLookup)
+                            *uniqueResolution = TRUE;
+                        else
+                            *uniqueResolution = haveUniqueDefaultImplementation;
+                    }
+                    return pMDDefaultImpl;
                 }
-                return pMDDefaultImpl;
-            }
+
+                // We only loop at most twice here
+                if (allowVariantMatchInDefaultImplementationLookup)
+                {
+                    break;
+                }
+
+                allowVariantMatchInDefaultImplementationLookup = allowVariantMatches;
+            } while (allowVariantMatchInDefaultImplementationLookup);
         }
 
         // Default implementation logic, which only kicks in for default implementations when looking up on an exact interface target
@@ -9001,7 +9018,7 @@ MethodTable::ResolveVirtualStaticMethod(
 // Try to locate the appropriate MethodImpl matching a given interface static virtual method.
 // Returns nullptr on failure.
 MethodDesc*
-MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, ClassLoadLevel level)
+MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level)
 {
     HRESULT hr = S_OK;
     IMDInternalImport* pMDInternalImport = GetMDImport();
@@ -9049,9 +9066,22 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType
             ClassLoader::LoadTypes,
             CLASS_LOAD_EXACTPARENTS)
             .GetMethodTable();
-        if (pInterfaceMT != pInterfaceType)
+
+        if (allowVariance)
         {
-            continue;
+            // Allow variant, but not equivalent interface match
+            if (!pInterfaceType->HasSameTypeDefAs(pInterfaceMT) ||
+                !pInterfaceMT->CanCastTo(pInterfaceType, NULL))
+            {
+                continue;
+            }
+        }
+        else
+        {
+            if (pInterfaceMT != pInterfaceType)
+            {
+                continue;
+            }
         }
         MethodDesc *pMethodDecl;
 
@@ -9167,6 +9197,7 @@ MethodTable::VerifyThatAllVirtualStaticMethodsAreImplemented()
                 BOOL uniqueResolution;
                 if (pMD->IsVirtual() &&
                     pMD->IsStatic() &&
+                    !pMD->HasMethodImplSlot() && // Re-abstractions are virtual static abstract with a MethodImpl
                     (pMD->IsAbstract() &&
                     !ResolveVirtualStaticMethod(
                         pInterfaceMT,
diff --git a/src/coreclr/vm/methodtable.h b/src/coreclr/vm/methodtable.h
index e506299d9cd7f..9b347bf5ccc85 100644
--- a/src/coreclr/vm/methodtable.h
+++ b/src/coreclr/vm/methodtable.h
@@ -2219,7 +2219,7 @@ class MethodTable
 
     // Try to resolve a given static virtual method override on this type. Return nullptr
     // when not found.
-    MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, ClassLoadLevel level);
+    MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level);
 
 public:
     static MethodDesc *MapMethodDeclToMethodImpl(MethodDesc *pMDDecl);
diff --git a/src/coreclr/vm/methodtablebuilder.cpp b/src/coreclr/vm/methodtablebuilder.cpp
index 2a8d775ba57de..1c6eadbe5bdc4 100644
--- a/src/coreclr/vm/methodtablebuilder.cpp
+++ b/src/coreclr/vm/methodtablebuilder.cpp
@@ -3293,13 +3293,6 @@ MethodTableBuilder::EnumerateClassMethods()
             }
         }
 
-        if (implType == METHOD_IMPL && isStaticVirtual && IsMdAbstract(dwMemberAttrs))
-        {
-            // Don't record reabstracted static virtual methods as they don't constitute
-            // actual method slots, they're just markers used by the static virtual lookup.
-            continue;
-        }
-
         // For delegates we don't allow any non-runtime implemented bodies
         // for any of the four special methods
         if (IsDelegate() && !IsMiRuntime(dwImplFlags))
diff --git a/src/coreclr/vm/runtimehandles.cpp b/src/coreclr/vm/runtimehandles.cpp
index e0eba7e172613..ef5a2dcaaf93a 100644
--- a/src/coreclr/vm/runtimehandles.cpp
+++ b/src/coreclr/vm/runtimehandles.cpp
@@ -1104,7 +1104,7 @@ extern "C" MethodDesc* QCALLTYPE RuntimeTypeHandle_GetInterfaceMethodImplementat
             pMD,
             /* allowNullResult */ TRUE,
             /* verifyImplemented*/ FALSE,
-            /*allowVariantMatches */ TRUE);
+            /* allowVariantMatches */ TRUE);
     }
     else
     {
diff --git a/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.cs b/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.cs
new file mode 100644
index 0000000000000..188686795ff9a
--- /dev/null
+++ b/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.cs
@@ -0,0 +1,133 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime;
+using Xunit;
+
+// This regression test tracks the issue where variant static interface dispatch crashes the runtime, and behaves incorrectly
+
+namespace VariantStaticInterfaceDispatchRegressionTest
+{
+    class Test
+    {
+        static int Main()
+        {
+            Console.WriteLine("Test cases");
+
+            Console.WriteLine("---FooBar");
+            TestTheFooString<FooBar, Base>("IFoo<Base>");
+            TestTheFooString<FooBar, Mid>("IFoo<Base>");
+            TestTheFooString<FooBar, Derived>("IFoo<Base>");
+
+            TestTheBarString<FooBar, Base>("IBar<Derived>");
+            TestTheBarString<FooBar, Mid>("IBar<Derived>");
+            TestTheBarString<FooBar, Derived>("IBar<Derived>");
+
+            Console.WriteLine("---FooBar2");
+            TestTheFooString<FooBar2, Base>("IFoo<Base>");
+            TestTheFooString<FooBar2, Mid>("IFoo<Mid>");
+            TestTheFooString<FooBar2, Derived>("IFoo<Base>");
+
+            TestTheBarString<FooBar2, Base>("IBar<Derived>");
+            TestTheBarString<FooBar2, Mid>("IBar<Mid>");
+            TestTheBarString<FooBar2, Derived>("IBar<Derived>");
+
+            Console.WriteLine("---FooBarBaz");
+            TestTheFooString<FooBarBaz, Base>("IFoo<Base>");
+            TestTheFooString<FooBarBaz, Mid>("IBaz");
+            TestTheFooString<FooBarBaz, Derived>("IBaz");
+
+            TestTheBarString<FooBarBaz, Base>("IBaz");
+            TestTheBarString<FooBarBaz, Mid>("IBaz");
+            TestTheBarString<FooBarBaz, Derived>("IBar<Derived>");
+
+            Console.WriteLine("---FooBarBazBoz");
+            TestTheFooString<FooBarBazBoz, Base>("IBoz");
+            TestTheFooString<FooBarBazBoz, Mid>("IBaz");
+            TestTheFooString<FooBarBazBoz, Derived>("IBoz");
+
+            TestTheBarString<FooBarBazBoz, Base>("IBoz");
+            TestTheBarString<FooBarBazBoz, Mid>("IBaz");
+            TestTheBarString<FooBarBazBoz, Derived>("IBoz");
+
+            Console.WriteLine("---FooBarBaz2");
+            TestTheFooString<FooBarBaz2, Base>("IFoo<Base>");
+            TestTheFooString<FooBarBaz2, Mid>("IBaz");
+            TestTheFooString<FooBarBaz2, Derived>("IFoo<Base>");
+
+            TestTheBarString<FooBarBaz2, Base>("IBar<Derived>");
+            TestTheBarString<FooBarBaz2, Mid>("IBaz");
+            TestTheBarString<FooBarBaz2, Derived>("IBar<Derived>");
+
+            Console.WriteLine("---FooBarBazBoz2");
+            TestTheFooString<FooBarBazBoz2, Base>("IBoz");
+            TestTheFooString<FooBarBazBoz2, Mid>("IBaz");
+            TestTheFooString<FooBarBazBoz2, Derived>("IBoz");
+
+            TestTheBarString<FooBarBazBoz2, Base>("IBoz");
+            TestTheBarString<FooBarBazBoz2, Mid>("IBaz");
+            TestTheBarString<FooBarBazBoz2, Derived>("IBoz");
+            return 100;
+        }
+
+        static string GetTheFooString<T, U>() where T : IFoo<U> { try { return T.GetString(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } }
+        static string GetTheBarString<T, U>() where T : IBar<U> { try { return T.GetString(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } }
+        static string GetTheFooStringInstance<T, U>() where T : IFoo<U>, new() { try { return (new T()).GetStringInstance(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } }
+        static string GetTheBarStringInstance<T, U>() where T : IBar<U>, new() { try { return (new T()).GetStringInstance(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } }
+
+        static void TestTheFooString<T, U>(string expected) where T : IFoo<U>, new()
+        {
+            Console.WriteLine($"TestTheFooString {typeof(T).Name} {typeof(T).Name} {expected}");
+            Assert.Equal(expected, GetTheFooString<T, U>());
+            Assert.Equal(expected, GetTheFooStringInstance<T, U>());
+        }
+
+        static void TestTheBarString<T, U>(string expected) where T : IBar<U>, new()
+        {
+            Console.WriteLine($"TestTheBarString {typeof(T).Name} {typeof(T).Name} {expected}");
+            Assert.Equal(expected, GetTheBarString<T, U>());
+            Assert.Equal(expected, GetTheBarStringInstance<T, U>());
+        }
+
+        interface IFoo<in T>
+        {
+            static virtual string GetString() => $"IFoo<{typeof(T).Name}>";
+            virtual string GetStringInstance() => $"IFoo<{typeof(T).Name}>";
+        };
+
+        interface IBar<out T>
+        {
+            static virtual string GetString() => $"IBar<{typeof(T).Name}>";
+            virtual string GetStringInstance() => $"IBar<{typeof(T).Name}>";
+        };
+
+
+        interface IBaz : IFoo<Mid>, IBar<Mid>
+        {
+            static string IFoo<Mid>.GetString() => "IBaz";
+            static string IBar<Mid>.GetString() => "IBaz";
+            string IFoo<Mid>.GetStringInstance() => "IBaz";
+            string IBar<Mid>.GetStringInstance() => "IBaz";
+        }
+
+        interface IBoz : IFoo<Base>, IBar<Derived>
+        {
+            static string IFoo<Base>.GetString() => "IBoz";
+            static string IBar<Derived>.GetString() => "IBoz";
+            string IFoo<Base>.GetStringInstance() => "IBoz";
+            string IBar<Derived>.GetStringInstance() => "IBoz";
+        }
+
+        class FooBar : IFoo<Base>, IBar<Derived> { }
+        class FooBar2 : IFoo<Base>, IBar<Derived>, IFoo<Mid>, IBar<Mid> { }
+        class FooBarBaz : FooBar, IBaz { }
+        class FooBarBaz2 : IFoo<Base>, IBar<Derived>, IBaz { } // Implementation with all interfaces defined on the same type
+        class FooBarBazBoz : FooBarBaz, IBoz { }
+        class FooBarBazBoz2 : IFoo<Base>, IBar<Derived>, IBaz, IBoz { } // Implementation with all interfaces defined on the same type
+
+        class Base { }
+        class Mid : Base { }
+        class Derived : Mid { }
+    }
+}
diff --git a/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.csproj b/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.csproj
new file mode 100644
index 0000000000000..7f0c12d57d8e5
--- /dev/null
+++ b/src/tests/Loader/classloader/StaticVirtualMethods/InterfaceVariance/ComplexHierarchyPositive.csproj
@@ -0,0 +1,9 @@
+<Project Sdk="Microsoft.NET.Sdk">
+  <PropertyGroup>
+    <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
+    <OutputType>Exe</OutputType>
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="$(MSBuildProjectName).cs" />
+  </ItemGroup>
+</Project>
diff --git a/src/tests/issues.targets b/src/tests/issues.targets
index dcf0616de702d..a1d35dac9357e 100644
--- a/src/tests/issues.targets
+++ b/src/tests/issues.targets
@@ -694,6 +694,9 @@
 
     <!-- NativeAOT specific -->
     <ItemGroup Condition="'$(XunitTestBinBase)' != '' and '$(TestBuildMode)' == 'nativeaot' and '$(RuntimeFlavor)' == 'coreclr'">
+        <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/InterfaceVariance/**">
+            <Issue>https://github.com/dotnet/runtime/issues/88690</Issue>
+        </ExcludeList>
         <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/NegativeTestCases/**">
             <Issue>https://github.com/dotnet/runtimelab/issues/155: Compatible TypeLoadException for invalid inputs</Issue>
         </ExcludeList>
@@ -703,6 +706,9 @@
         <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/DiamondShape/svm_diamondshape_r/*">
             <Issue>https://github.com/dotnet/runtime/issues/72589</Issue>
         </ExcludeList>
+        <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/NegativeTestCases/**">
+            <Issue>https://github.com/dotnet/runtimelab/issues/155: Compatible TypeLoadException for invalid inputs</Issue>
+        </ExcludeList>
         <ExcludeList Include="$(XunitTestBinBase)/baseservices/RuntimeConfiguration/TestConfig/*">
             <Issue>Test expects being run with corerun</Issue>
         </ExcludeList>
@@ -1248,6 +1254,9 @@
 
     <!-- Known failures for mono runtime on *all* architectures/operating systems in *all* runtime modes -->
     <ItemGroup Condition="'$(RuntimeFlavor)' == 'mono'" >
+        <ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/InterfaceVariance/**">
+            <Issue>https://github.com/dotnet/runtime/issues/88689</Issue>
+        </ExcludeList>
         <ExcludeList Include="$(XunitTestBinBase)/Regressions/coreclr/GitHub_85240/test85240/**">
             <Issue>https://github.com/dotnet/runtime/issues/71095</Issue>
         </ExcludeList>