Skip to content

Commit

Permalink
Fix virtual static dispatch for variant interfaces when using default…
Browse files Browse the repository at this point in the history
… implementations (#88639)

- Move the handling for variance so that it is in the same place as it is for instance methods
- Validate the behavior of such dispatch is equivalent to the dispatch behavior involving instance variant default dispatch
- Rework the fix for #80350 to avoid violating an invariant of the type system (which is that all MethodDef's of a TypeDef are loaded as MethodDescs when loading the open type).

Fixes #78621
  • Loading branch information
davidwrighton authored Jul 11, 2023
1 parent 83bf4b6 commit 57f691a
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 38 deletions.
89 changes: 60 additions & 29 deletions src/coreclr/vm/methodtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6349,7 +6349,7 @@ namespace
candidateMaybe = interfaceMD;
}
}
else
else if (!interfaceMD->IsStatic())
{
//
// A more specific interface - search for an methodimpl for explicit override
Expand Down Expand Up @@ -6413,17 +6413,18 @@ namespace
}
}
}
else if (pMD->IsStatic() && pMD->HasMethodImplSlot())
{
// Static virtual methods don't record MethodImpl slots so they need special handling
candidateMaybe = pMT->TryResolveVirtualStaticMethodOnThisType(
interfaceMT,
interfaceMD,
/* verifyImplemented */ FALSE,
/* level */ level);
}
}
}
else
{
// Static virtual methods don't record MethodImpl slots so they need special handling
candidateMaybe = pMT->TryResolveVirtualStaticMethodOnThisType(
interfaceMT,
interfaceMD,
/* verifyImplemented */ FALSE,
/* allowVariance */ allowVariance,
/* level */ level);
}
}

if (candidateMaybe == NULL)
Expand Down Expand Up @@ -8911,7 +8912,7 @@ MethodTable::ResolveVirtualStaticMethod(
// Search for match on a per-level in the type hierarchy
for (MethodTable* pMT = this; pMT != nullptr; pMT = pMT->GetParentMethodTable())
{
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented, level);
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
if (pMD != nullptr)
{
return pMD;
Expand All @@ -8920,7 +8921,7 @@ MethodTable::ResolveVirtualStaticMethod(
if (pInterfaceType->HasVariance() || pInterfaceType->HasTypeEquivalence())
{
// Variant interface dispatch
MethodTable::InterfaceMapIterator it = IterateInterfaceMap();
MethodTable::InterfaceMapIterator it = pMT->IterateInterfaceMap();
while (it.Next())
{
if (it.CurrentInterfaceMatches(this, pInterfaceType))
Expand Down Expand Up @@ -8955,7 +8956,7 @@ MethodTable::ResolveVirtualStaticMethod(
{
// Variant or equivalent matching interface found
// Attempt to resolve on variance matched interface
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, verifyImplemented, level);
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
if (pMD != nullptr)
{
return pMD;
Expand All @@ -8966,22 +8967,38 @@ MethodTable::ResolveVirtualStaticMethod(
}

MethodDesc *pMDDefaultImpl = nullptr;
BOOL haveUniqueDefaultImplementation = FindDefaultInterfaceImplementation(
pInterfaceMD,
pInterfaceType,
&pMDDefaultImpl,
/* allowVariance */ allowVariantMatches,
/* throwOnConflict */ uniqueResolution == nullptr,
level);
if (haveUniqueDefaultImplementation || (pMDDefaultImpl != nullptr && (verifyImplemented || uniqueResolution != nullptr)))
BOOL allowVariantMatchInDefaultImplementationLookup = FALSE;
do
{
// We tolerate conflicts upon verification of implemented SVMs so that they only blow up when actually called at execution time.
if (uniqueResolution != nullptr)
BOOL haveUniqueDefaultImplementation = FindDefaultInterfaceImplementation(
pInterfaceMD,
pInterfaceType,
&pMDDefaultImpl,
/* allowVariance */ allowVariantMatchInDefaultImplementationLookup,
/* throwOnConflict */ uniqueResolution == nullptr,
level);
if (haveUniqueDefaultImplementation || (pMDDefaultImpl != nullptr && (verifyImplemented || uniqueResolution != nullptr)))
{
*uniqueResolution = haveUniqueDefaultImplementation;
// We tolerate conflicts upon verification of implemented SVMs so that they only blow up when actually called at execution time.
if (uniqueResolution != nullptr)
{
// Always report a unique resolution when reporting results of a variant match
if (allowVariantMatchInDefaultImplementationLookup)
*uniqueResolution = TRUE;
else
*uniqueResolution = haveUniqueDefaultImplementation;
}
return pMDDefaultImpl;
}
return pMDDefaultImpl;
}

// We only loop at most twice here
if (allowVariantMatchInDefaultImplementationLookup)
{
break;
}

allowVariantMatchInDefaultImplementationLookup = allowVariantMatches;
} while (allowVariantMatchInDefaultImplementationLookup);
}

// Default implementation logic, which only kicks in for default implementations when looking up on an exact interface target
Expand All @@ -9001,7 +9018,7 @@ MethodTable::ResolveVirtualStaticMethod(
// Try to locate the appropriate MethodImpl matching a given interface static virtual method.
// Returns nullptr on failure.
MethodDesc*
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, ClassLoadLevel level)
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level)
{
HRESULT hr = S_OK;
IMDInternalImport* pMDInternalImport = GetMDImport();
Expand Down Expand Up @@ -9049,9 +9066,22 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType
ClassLoader::LoadTypes,
CLASS_LOAD_EXACTPARENTS)
.GetMethodTable();
if (pInterfaceMT != pInterfaceType)

if (allowVariance)
{
continue;
// Allow variant, but not equivalent interface match
if (!pInterfaceType->HasSameTypeDefAs(pInterfaceMT) ||
!pInterfaceMT->CanCastTo(pInterfaceType, NULL))
{
continue;
}
}
else
{
if (pInterfaceMT != pInterfaceType)
{
continue;
}
}
MethodDesc *pMethodDecl;

Expand Down Expand Up @@ -9167,6 +9197,7 @@ MethodTable::VerifyThatAllVirtualStaticMethodsAreImplemented()
BOOL uniqueResolution;
if (pMD->IsVirtual() &&
pMD->IsStatic() &&
!pMD->HasMethodImplSlot() && // Re-abstractions are virtual static abstract with a MethodImpl
(pMD->IsAbstract() &&
!ResolveVirtualStaticMethod(
pInterfaceMT,
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/vm/methodtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,7 @@ class MethodTable

// Try to resolve a given static virtual method override on this type. Return nullptr
// when not found.
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, ClassLoadLevel level);
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level);

public:
static MethodDesc *MapMethodDeclToMethodImpl(MethodDesc *pMDDecl);
Expand Down
7 changes: 0 additions & 7 deletions src/coreclr/vm/methodtablebuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3293,13 +3293,6 @@ MethodTableBuilder::EnumerateClassMethods()
}
}

if (implType == METHOD_IMPL && isStaticVirtual && IsMdAbstract(dwMemberAttrs))
{
// Don't record reabstracted static virtual methods as they don't constitute
// actual method slots, they're just markers used by the static virtual lookup.
continue;
}

// For delegates we don't allow any non-runtime implemented bodies
// for any of the four special methods
if (IsDelegate() && !IsMiRuntime(dwImplFlags))
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/vm/runtimehandles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ extern "C" MethodDesc* QCALLTYPE RuntimeTypeHandle_GetInterfaceMethodImplementat
pMD,
/* allowNullResult */ TRUE,
/* verifyImplemented*/ FALSE,
/*allowVariantMatches */ TRUE);
/* allowVariantMatches */ TRUE);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime;
using Xunit;

// This regression test tracks the issue where variant static interface dispatch crashes the runtime, and behaves incorrectly

namespace VariantStaticInterfaceDispatchRegressionTest
{
class Test
{
static int Main()
{
Console.WriteLine("Test cases");

Console.WriteLine("---FooBar");
TestTheFooString<FooBar, Base>("IFoo<Base>");
TestTheFooString<FooBar, Mid>("IFoo<Base>");
TestTheFooString<FooBar, Derived>("IFoo<Base>");

TestTheBarString<FooBar, Base>("IBar<Derived>");
TestTheBarString<FooBar, Mid>("IBar<Derived>");
TestTheBarString<FooBar, Derived>("IBar<Derived>");

Console.WriteLine("---FooBar2");
TestTheFooString<FooBar2, Base>("IFoo<Base>");
TestTheFooString<FooBar2, Mid>("IFoo<Mid>");
TestTheFooString<FooBar2, Derived>("IFoo<Base>");

TestTheBarString<FooBar2, Base>("IBar<Derived>");
TestTheBarString<FooBar2, Mid>("IBar<Mid>");
TestTheBarString<FooBar2, Derived>("IBar<Derived>");

Console.WriteLine("---FooBarBaz");
TestTheFooString<FooBarBaz, Base>("IFoo<Base>");
TestTheFooString<FooBarBaz, Mid>("IBaz");
TestTheFooString<FooBarBaz, Derived>("IBaz");

TestTheBarString<FooBarBaz, Base>("IBaz");
TestTheBarString<FooBarBaz, Mid>("IBaz");
TestTheBarString<FooBarBaz, Derived>("IBar<Derived>");

Console.WriteLine("---FooBarBazBoz");
TestTheFooString<FooBarBazBoz, Base>("IBoz");
TestTheFooString<FooBarBazBoz, Mid>("IBaz");
TestTheFooString<FooBarBazBoz, Derived>("IBoz");

TestTheBarString<FooBarBazBoz, Base>("IBoz");
TestTheBarString<FooBarBazBoz, Mid>("IBaz");
TestTheBarString<FooBarBazBoz, Derived>("IBoz");

Console.WriteLine("---FooBarBaz2");
TestTheFooString<FooBarBaz2, Base>("IFoo<Base>");
TestTheFooString<FooBarBaz2, Mid>("IBaz");
TestTheFooString<FooBarBaz2, Derived>("IFoo<Base>");

TestTheBarString<FooBarBaz2, Base>("IBar<Derived>");
TestTheBarString<FooBarBaz2, Mid>("IBaz");
TestTheBarString<FooBarBaz2, Derived>("IBar<Derived>");

Console.WriteLine("---FooBarBazBoz2");
TestTheFooString<FooBarBazBoz2, Base>("IBoz");
TestTheFooString<FooBarBazBoz2, Mid>("IBaz");
TestTheFooString<FooBarBazBoz2, Derived>("IBoz");

TestTheBarString<FooBarBazBoz2, Base>("IBoz");
TestTheBarString<FooBarBazBoz2, Mid>("IBaz");
TestTheBarString<FooBarBazBoz2, Derived>("IBoz");
return 100;
}

static string GetTheFooString<T, U>() where T : IFoo<U> { try { return T.GetString(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } }
static string GetTheBarString<T, U>() where T : IBar<U> { try { return T.GetString(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } }
static string GetTheFooStringInstance<T, U>() where T : IFoo<U>, new() { try { return (new T()).GetStringInstance(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } }
static string GetTheBarStringInstance<T, U>() where T : IBar<U>, new() { try { return (new T()).GetStringInstance(); } catch (AmbiguousImplementationException) { return "AmbiguousImplementationException"; } }

static void TestTheFooString<T, U>(string expected) where T : IFoo<U>, new()
{
Console.WriteLine($"TestTheFooString {typeof(T).Name} {typeof(T).Name} {expected}");
Assert.Equal(expected, GetTheFooString<T, U>());
Assert.Equal(expected, GetTheFooStringInstance<T, U>());
}

static void TestTheBarString<T, U>(string expected) where T : IBar<U>, new()
{
Console.WriteLine($"TestTheBarString {typeof(T).Name} {typeof(T).Name} {expected}");
Assert.Equal(expected, GetTheBarString<T, U>());
Assert.Equal(expected, GetTheBarStringInstance<T, U>());
}

interface IFoo<in T>
{
static virtual string GetString() => $"IFoo<{typeof(T).Name}>";
virtual string GetStringInstance() => $"IFoo<{typeof(T).Name}>";
};

interface IBar<out T>
{
static virtual string GetString() => $"IBar<{typeof(T).Name}>";
virtual string GetStringInstance() => $"IBar<{typeof(T).Name}>";
};


interface IBaz : IFoo<Mid>, IBar<Mid>
{
static string IFoo<Mid>.GetString() => "IBaz";
static string IBar<Mid>.GetString() => "IBaz";
string IFoo<Mid>.GetStringInstance() => "IBaz";
string IBar<Mid>.GetStringInstance() => "IBaz";
}

interface IBoz : IFoo<Base>, IBar<Derived>
{
static string IFoo<Base>.GetString() => "IBoz";
static string IBar<Derived>.GetString() => "IBoz";
string IFoo<Base>.GetStringInstance() => "IBoz";
string IBar<Derived>.GetStringInstance() => "IBoz";
}

class FooBar : IFoo<Base>, IBar<Derived> { }
class FooBar2 : IFoo<Base>, IBar<Derived>, IFoo<Mid>, IBar<Mid> { }
class FooBarBaz : FooBar, IBaz { }
class FooBarBaz2 : IFoo<Base>, IBar<Derived>, IBaz { } // Implementation with all interfaces defined on the same type
class FooBarBazBoz : FooBarBaz, IBoz { }
class FooBarBazBoz2 : IFoo<Base>, IBar<Derived>, IBaz, IBoz { } // Implementation with all interfaces defined on the same type

class Base { }
class Mid : Base { }
class Derived : Mid { }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<OutputType>Exe</OutputType>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(MSBuildProjectName).cs" />
</ItemGroup>
</Project>
9 changes: 9 additions & 0 deletions src/tests/issues.targets
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,9 @@

<!-- NativeAOT specific -->
<ItemGroup Condition="'$(XunitTestBinBase)' != '' and '$(TestBuildMode)' == 'nativeaot' and '$(RuntimeFlavor)' == 'coreclr'">
<ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/InterfaceVariance/**">
<Issue>https://github.com/dotnet/runtime/issues/88690</Issue>
</ExcludeList>
<ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/NegativeTestCases/**">
<Issue>https://github.com/dotnet/runtimelab/issues/155: Compatible TypeLoadException for invalid inputs</Issue>
</ExcludeList>
Expand All @@ -703,6 +706,9 @@
<ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/DiamondShape/svm_diamondshape_r/*">
<Issue>https://github.com/dotnet/runtime/issues/72589</Issue>
</ExcludeList>
<ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/NegativeTestCases/**">
<Issue>https://github.com/dotnet/runtimelab/issues/155: Compatible TypeLoadException for invalid inputs</Issue>
</ExcludeList>
<ExcludeList Include="$(XunitTestBinBase)/baseservices/RuntimeConfiguration/TestConfig/*">
<Issue>Test expects being run with corerun</Issue>
</ExcludeList>
Expand Down Expand Up @@ -1248,6 +1254,9 @@

<!-- Known failures for mono runtime on *all* architectures/operating systems in *all* runtime modes -->
<ItemGroup Condition="'$(RuntimeFlavor)' == 'mono'" >
<ExcludeList Include="$(XunitTestBinBase)/Loader/classloader/StaticVirtualMethods/InterfaceVariance/**">
<Issue>https://github.com/dotnet/runtime/issues/88689</Issue>
</ExcludeList>
<ExcludeList Include="$(XunitTestBinBase)/Regressions/coreclr/GitHub_85240/test85240/**">
<Issue>https://github.com/dotnet/runtime/issues/71095</Issue>
</ExcludeList>
Expand Down

0 comments on commit 57f691a

Please sign in to comment.