diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/EnumerableRewriter.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/EnumerableRewriter.cs index ad489f768ec11..faf127b5e3eb2 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/EnumerableRewriter.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/EnumerableRewriter.cs @@ -255,25 +255,86 @@ protected override Expression VisitConstant(ConstantExpression c) } private static ILookup? s_seqMethods; - [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2060:MakeGenericMethod", - Justification = "Enumerable methods don't have trim annotations.")] private static MethodInfo FindEnumerableMethodForQueryable(string name, ReadOnlyCollection args, params Type[]? typeArgs) { - if (s_seqMethods == null) + s_seqMethods ??= GetEnumerableStaticMethods(typeof(Enumerable)).ToLookup(m => m.Name); + + MethodInfo[] matchingMethods = s_seqMethods[name] + .Where(m => ArgsMatch(m, args, typeArgs)) + .Select(ApplyTypeArgs) + .ToArray(); + + Debug.Assert(matchingMethods.Length > 0, "All static methods with arguments on Queryable have equivalents on Enumerable."); + + if (matchingMethods.Length > 1) { - s_seqMethods = GetEnumerableStaticMethods(typeof(Enumerable)).ToLookup(m => m.Name); + return DisambiguateMatches(matchingMethods); } - MethodInfo? mi = s_seqMethods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs)); - Debug.Assert(mi != null, "All static methods with arguments on Queryable have equivalents on Enumerable."); - if (typeArgs != null) - return mi.MakeGenericMethod(typeArgs); - return mi; + + return matchingMethods[0]; [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern", Justification = "This method is intentionally hiding the Enumerable type from the trimmer so it doesn't preserve all Enumerable's methods. " + "This is safe because all Queryable methods have a DynamicDependency to the corresponding Enumerable method.")] static MethodInfo[] GetEnumerableStaticMethods(Type type) => type.GetMethods(BindingFlags.Public | BindingFlags.Static); + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2060:MakeGenericMethod", + Justification = "Enumerable methods don't have trim annotations.")] + MethodInfo ApplyTypeArgs(MethodInfo methodInfo) => typeArgs == null ? methodInfo : methodInfo.MakeGenericMethod(typeArgs); + + // In certain cases, there might be ambiguities when resolving matching overloads, for example between + // 1. FirstOrDefault(IEnumerable source, Func predicate) and + // 2. FirstOrDefault(IEnumerable source, object defaultvalue). + // In such cases we disambiguate by picking a method with the most derived signature. + static MethodInfo DisambiguateMatches(MethodInfo[] matchingMethods) + { + Debug.Assert(matchingMethods.Length > 1); + ParameterInfo[][] parameters = matchingMethods.Select(m => m.GetParameters()).ToArray(); + + // `AreAssignableFrom[Strict]` defines a partial order on method signatures; pick a maximal element using that order. + // It is assumed that `matchingMethods` is a small array, so a naive quadratic search is probably better than + // doing some variant of topological sorting. + + for (int i = 0; i < matchingMethods.Length; i++) + { + bool isMaximal = true; + for (int j = 0; j < matchingMethods.Length; j++) + { + if (i != j && AreAssignableFromStrict(parameters[i], parameters[j])) + { + // Found a matching method that contains strictly more specific parameter types. + isMaximal = false; + break; + } + } + + if (isMaximal) + { + return matchingMethods[i]; + } + } + + Debug.Fail("Search should have found a maximal element"); + throw new Exception(); + + static bool AreAssignableFromStrict(ParameterInfo[] left, ParameterInfo[] right) + { + Debug.Assert(left.Length == right.Length); + + bool areEqual = true; + bool areAssignableFrom = true; + for (int i = 0; i < left.Length; i++) + { + Type leftParam = left[i].ParameterType; + Type rightParam = right[i].ParameterType; + areEqual = areEqual && leftParam == rightParam; + areAssignableFrom = areAssignableFrom && leftParam.IsAssignableFrom(rightParam); + } + + return !areEqual && areAssignableFrom; + } + } } [RequiresUnreferencedCode(Queryable.InMemoryQueryableExtensionMethodsRequiresUnreferencedCode)] diff --git a/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs b/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs index ac74ca5bd6836..9d655177ef92d 100644 --- a/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs @@ -129,5 +129,16 @@ public void FirstOrDefault2() var val = (new int[] { 0, 1, 2 }).AsQueryable().FirstOrDefault(n => n > 1); Assert.Equal(2, val); } + + [Fact] + public void FirstOrDefault_OverloadResolution_Regression() + { + // Regression test for https://github.com/dotnet/runtime/issues/65419 + object? result = new object[] { 1, "" }.AsQueryable().FirstOrDefault(x => x is string); + Assert.IsType(result); + + result = Array.Empty().AsQueryable().FirstOrDefault(1); + Assert.IsType(result); + } } } diff --git a/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs b/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs index c66c42fce4a4f..d65a9eb27d1e7 100644 --- a/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs @@ -96,5 +96,16 @@ public void LastOrDefault2() var val = (new int[] { 0, 1, 2 }).AsQueryable().LastOrDefault(n => n > 1); Assert.Equal(2, val); } + + [Fact] + public void LastOrDefault_OverloadResolution_Regression() + { + // Regression test for https://github.com/dotnet/runtime/issues/65419 + object? result = new object[] { 1, "" }.AsQueryable().LastOrDefault(x => x is int); + Assert.IsType(result); + + result = Array.Empty().AsQueryable().LastOrDefault(1); + Assert.IsType(result); + } } } diff --git a/src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs b/src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs index f34edae8b9f84..991ac5106ba87 100644 --- a/src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs @@ -79,5 +79,16 @@ public void SingleOrDefault2() var val = (new int[] { 2 }).AsQueryable().SingleOrDefault(n => n > 1); Assert.Equal(2, val); } + + [Fact] + public void SingleOrDefault_OverloadResolution_Regression() + { + // Regression test for https://github.com/dotnet/runtime/issues/65419 + object? result = new object[] { 1, "" }.AsQueryable().SingleOrDefault(x => x is string); + Assert.IsType(result); + + result = Array.Empty().AsQueryable().SingleOrDefault(1); + Assert.IsType(result); + } } }