Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Queryable-to-Enumerable overload mapping logic #65569

Merged
merged 4 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -255,25 +255,83 @@ protected override Expression VisitConstant(ConstantExpression c)
}

private static ILookup<string, MethodInfo>? s_seqMethods;
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2060:MakeGenericMethod",
Justification = "Enumerable methods don't have trim annotations.")]
private static MethodInfo FindEnumerableMethodForQueryable(string name, ReadOnlyCollection<Expression> args, params Type[]? typeArgs)
{
if (s_seqMethods == null)
s_seqMethods ??= GetEnumerableStaticMethods(typeof(Enumerable)).ToLookup(m => m.Name);

MethodInfo[] matchingMethods = s_seqMethods[name]
.Where(m => ArgsMatch(m, args, typeArgs))
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
.Select(ApplyTypeArgs)
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
.ToArray();

Debug.Assert(matchingMethods.Length > 0, "All static methods with arguments on Queryable have equivalents on Enumerable.");

if (matchingMethods.Length > 1)
{
s_seqMethods = GetEnumerableStaticMethods(typeof(Enumerable)).ToLookup(m => m.Name);
return DisambiguateMatches(matchingMethods);
}
MethodInfo? mi = s_seqMethods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
Debug.Assert(mi != null, "All static methods with arguments on Queryable have equivalents on Enumerable.");
if (typeArgs != null)
return mi.MakeGenericMethod(typeArgs);
return mi;

return matchingMethods[0];

[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern",
Justification = "This method is intentionally hiding the Enumerable type from the trimmer so it doesn't preserve all Enumerable's methods. " +
"This is safe because all Queryable methods have a DynamicDependency to the corresponding Enumerable method.")]
static MethodInfo[] GetEnumerableStaticMethods(Type type) =>
type.GetMethods(BindingFlags.Public | BindingFlags.Static);

[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2060:MakeGenericMethod",
Justification = "Enumerable methods don't have trim annotations.")]
MethodInfo ApplyTypeArgs(MethodInfo methodInfo) => typeArgs == null ? methodInfo : methodInfo.MakeGenericMethod(typeArgs);

// In certain cases, there might be ambiguities when resolving matching overloads, for example between
// 1. FirstOrDefault<object>(IEnumerable<object> source, Func<object, bool> predicate) and
// 2. FirstOrDefault<object>(IEnumerable<object> source, object defaultvalue).
// In such cases we disambiguate by picking a method with the most derived signature.
static MethodInfo DisambiguateMatches(MethodInfo[] matchingMethods)
{
Debug.Assert(matchingMethods.Length > 1);
ParameterInfo[][] parameters = matchingMethods.Select(m => m.GetParameters()).ToArray();
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved

// `IsLessDerivedThan` defines a partial order on method signatures; pick a maximal element using that order.
// It is assumed that `matchingMethods` is a small array, so a naive quadratic search is probably better than
// doing some variant of topological sorting.

for (int i = 0; i < matchingMethods.Length; i++)
{
bool foundDerivedMethodSignature = false;
for (int j = 0; j < matchingMethods.Length; j++)
{
if (i != j && IsLessDerivedThan(parameters[i], parameters[j]))
{
foundDerivedMethodSignature = true;
break;
}
}

if (!foundDerivedMethodSignature)
{
// Found a maximal element
return matchingMethods[i];
}
}

Debug.Fail("Search should have found a maximal element");
throw new Exception();

static bool IsLessDerivedThan(ParameterInfo[] params1, ParameterInfo[] params2)
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
{
Debug.Assert(params1.Length == params2.Length);
for (int i = 0; i < params1.Length; i++)
{
if (!params1[i].ParameterType.IsAssignableFrom(params2[i].ParameterType))
{
return false;
}
}

return true;
}
}
}

[RequiresUnreferencedCode(Queryable.InMemoryQueryableExtensionMethodsRequiresUnreferencedCode)]
Expand Down
11 changes: 11 additions & 0 deletions src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,16 @@ public void FirstOrDefault2()
var val = (new int[] { 0, 1, 2 }).AsQueryable().FirstOrDefault(n => n > 1);
Assert.Equal(2, val);
}

[Fact]
public void FirstOrDefault_OverloadResolution_Regression()
{
// Regression test for https://github.com/dotnet/runtime/issues/65419
object? result = new object[] { 1, "" }.AsQueryable().FirstOrDefault(x => x is string);
Assert.IsType<string>(result);

result = Array.Empty<object>().AsQueryable().FirstOrDefault(1);
Assert.IsType<int>(result);
}
}
}
11 changes: 11 additions & 0 deletions src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,16 @@ public void LastOrDefault2()
var val = (new int[] { 0, 1, 2 }).AsQueryable().LastOrDefault(n => n > 1);
Assert.Equal(2, val);
}

[Fact]
public void LastOrDefault_OverloadResolution_Regression()
{
// Regression test for https://github.com/dotnet/runtime/issues/65419
object? result = new object[] { 1, "" }.AsQueryable().LastOrDefault(x => x is int);
Assert.IsType<int>(result);

result = Array.Empty<object>().AsQueryable().LastOrDefault(1);
Assert.IsType<int>(result);
}
}
}
11 changes: 11 additions & 0 deletions src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,16 @@ public void SingleOrDefault2()
var val = (new int[] { 2 }).AsQueryable().SingleOrDefault(n => n > 1);
Assert.Equal(2, val);
}

[Fact]
public void SingleOrDefault_OverloadResolution_Regression()
{
// Regression test for https://github.com/dotnet/runtime/issues/65419
object? result = new object[] { 1, "" }.AsQueryable().SingleOrDefault(x => x is string);
Assert.IsType<string>(result);

result = Array.Empty<object>().AsQueryable().SingleOrDefault(1);
Assert.IsType<int>(result);
}
}
}