Skip to content

Commit

Permalink
Query: Merge query optimizing expression visitors
Browse files Browse the repository at this point in the history
Part of #18923

Resolves #20155
Resolves #20369
We convert Queryable.Contains to Queryable.Any after navigation expansion has run so only true queraybles would have Queryable.Contains. Array properties would have Enumerable.Contains hence does not get rewritten.

Resolves #19433
  • Loading branch information
smitpatel committed Apr 1, 2020
1 parent bb88898 commit 8b6bba7
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 197 deletions.

This file was deleted.

31 changes: 12 additions & 19 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ protected override Expression VisitMember(MemberExpression memberExpression)
if (innerQueryable.Type.TryGetElementType(typeof(IQueryable<>)) != null)
{
return Visit(
Expression.Call(
QueryableMethods.CountWithoutPredicate.MakeGenericMethod(innerQueryable.Type.TryGetSequenceType()),
innerQueryable));
Expression.Call(
QueryableMethods.CountWithoutPredicate.MakeGenericMethod(innerQueryable.Type.TryGetSequenceType()),
innerQueryable));
}
}

Expand Down Expand Up @@ -528,13 +528,8 @@ when QueryableMethods.IsSumWithSelector(method):
&& (method.GetGenericMethodDefinition() == EnumerableMethods.ToList
|| method.GetGenericMethodDefinition() == EnumerableMethods.ToArray))
{
var argument = Visit(methodCallExpression.Arguments[0]);
if (argument is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression)
{
argument = materializeCollectionNavigationExpression.Subquery;
}

return methodCallExpression.Update(null, new[] { argument });
return methodCallExpression.Update(
null, new[] { UnwrapCollectionMaterialization(Visit(methodCallExpression.Arguments[0])) });
}

return ProcessUnknownMethod(methodCallExpression);
Expand Down Expand Up @@ -1584,16 +1579,14 @@ private LambdaExpression GenerateLambda(Expression body, ParameterExpression cur

private Expression UnwrapCollectionMaterialization(Expression expression)
{
if (expression is MethodCallExpression innerMethodCall
&& innerMethodCall.Method.IsGenericMethod)
while (expression is MethodCallExpression innerMethodCall
&& innerMethodCall.Method.IsGenericMethod
&& innerMethodCall.Method.GetGenericMethodDefinition() is MethodInfo innerMethod
&& (innerMethod == EnumerableMethods.AsEnumerable
|| innerMethod == EnumerableMethods.ToList
|| innerMethod == EnumerableMethods.ToArray))
{
var innerGenericMethod = innerMethodCall.Method.GetGenericMethodDefinition();
if (innerGenericMethod == EnumerableMethods.AsEnumerable
|| innerGenericMethod == EnumerableMethods.ToList
|| innerGenericMethod == EnumerableMethods.ToArray)
{
expression = innerMethodCall.Arguments[0];
}
expression = innerMethodCall.Arguments[0];
}

if (expression is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,51 @@

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
public class AllAnyContainsRewritingExpressionVisitor : ExpressionVisitor
public class QueryOptimizingExpressionVisitor : ExpressionVisitor
{
private static bool IsExpressionOfFunc(Type type, int funcGenericArgs = 2)
=> type.IsGenericType
&& type.GetGenericArguments().Length == funcGenericArgs;
private static readonly MethodInfo _stringCompareWithComparisonMethod =
typeof(string).GetRuntimeMethod(nameof(string.Compare), new[] { typeof(string), typeof(string), typeof(StringComparison) });
private static readonly MethodInfo _stringCompareWithoutComparisonMethod =
typeof(string).GetRuntimeMethod(nameof(string.Compare), new[] { typeof(string), typeof(string) });
private static readonly MethodInfo _startsWithMethodInfo =
typeof(string).GetRuntimeMethod(nameof(string.StartsWith), new[] { typeof(string) });
private static readonly MethodInfo _endsWithMethodInfo =
typeof(string).GetRuntimeMethod(nameof(string.EndsWith), new[] { typeof(string) });

private static readonly Expression _constantNullString = Expression.Constant(null, typeof(string));

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
Check.NotNull(methodCallExpression, nameof(methodCallExpression));

if (_startsWithMethodInfo.Equals(methodCallExpression.Method)
|| _endsWithMethodInfo.Equals(methodCallExpression.Method))
{
if (methodCallExpression.Arguments[0] is ConstantExpression constantArgument
&& (string)constantArgument.Value == string.Empty)
{
// every string starts/ends with empty string.
return Expression.Constant(true);
}

var newObject = Visit(methodCallExpression.Object);
var newArgument = Visit(methodCallExpression.Arguments[0]);

var result = Expression.AndAlso(
Expression.NotEqual(newObject, _constantNullString),
Expression.AndAlso(
Expression.NotEqual(newArgument, _constantNullString),
methodCallExpression.Update(newObject, new[] { newArgument })));

return newArgument is ConstantExpression
? result
: Expression.OrElse(
Expression.Equal(
newArgument,
Expression.Constant(string.Empty)),
result);
}

if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo methodInfo
&& (methodInfo.Equals(EnumerableMethods.AnyWithPredicate) || methodInfo.Equals(EnumerableMethods.All))
Expand Down Expand Up @@ -46,9 +81,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo containsMethodInfo
&& containsMethodInfo.Equals(QueryableMethods.Contains)
// special case Queryable.Contains(byte_array, byte) - we don't want those to be rewritten
&& methodCallExpression.Arguments[1].Type != typeof(byte))
&& containsMethodInfo.Equals(QueryableMethods.Contains))
{
var typeArgument = methodCallExpression.Method.GetGenericArguments()[0];
var anyMethod = QueryableMethods.AnyWithPredicate.MakeGenericMethod(typeArgument);
Expand All @@ -63,7 +96,67 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return Expression.Call(null, anyMethod, new[] { methodCallExpression.Arguments[0], anyLambda });
}

return base.VisitMethodCall(methodCallExpression);
var visited = (MethodCallExpression)base.VisitMethodCall(methodCallExpression);

// In VB.NET, comparison operators between strings (equality, greater-than, less-than) yield
// calls to a VB-specific CompareString method. Normalize that to string.Compare.
if (visited.Method.Name == "CompareString"
&& visited.Method.DeclaringType?.Name == "Operators"
&& visited.Method.DeclaringType?.Namespace == "Microsoft.VisualBasic.CompilerServices"
&& visited.Object == null
&& visited.Arguments.Count == 3
&& visited.Arguments[2] is ConstantExpression textCompareConstantExpression)
{
return (bool)textCompareConstantExpression.Value
? Expression.Call(
_stringCompareWithComparisonMethod,
visited.Arguments[0],
visited.Arguments[1],
Expression.Constant(StringComparison.OrdinalIgnoreCase))
: Expression.Call(
_stringCompareWithoutComparisonMethod,
visited.Arguments[0],
visited.Arguments[1]);
}

return visited;
}

protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
Check.NotNull(unaryExpression, nameof(unaryExpression));

if (unaryExpression.NodeType == ExpressionType.Not
&& unaryExpression.Operand is MethodCallExpression innerMethodCall
&& (_startsWithMethodInfo.Equals(innerMethodCall.Method)
|| _endsWithMethodInfo.Equals(innerMethodCall.Method)))
{
if (innerMethodCall.Arguments[0] is ConstantExpression constantArgument
&& (string)constantArgument.Value == string.Empty)
{
// every string starts/ends with empty string.
return Expression.Constant(false);
}

var newObject = Visit(innerMethodCall.Object);
var newArgument = Visit(innerMethodCall.Arguments[0]);

var result = Expression.AndAlso(
Expression.NotEqual(newObject, _constantNullString),
Expression.AndAlso(
Expression.NotEqual(newArgument, _constantNullString),
Expression.Not(innerMethodCall.Update(newObject, new[] { newArgument }))));

return newArgument is ConstantExpression
? result
: Expression.AndAlso(
Expression.NotEqual(
newArgument,
Expression.Constant(string.Empty)),
result);
}

return base.VisitUnary(unaryExpression);
}

private bool TryExtractEqualityOperands(Expression expression, out Expression left, out Expression right, out bool negated)
Expand Down
54 changes: 0 additions & 54 deletions src/EFCore/Query/Internal/VBToCSharpConvertingExpressionVisitor.cs

This file was deleted.

6 changes: 1 addition & 5 deletions src/EFCore/Query/QueryTranslationPreprocessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,12 @@ public virtual Expression Process([NotNull] Expression query)
Check.NotNull(query, nameof(query));

query = new InvocationExpressionRemovingExpressionVisitor().Visit(query);

query = NormalizeQueryableMethodCall(query);

query = new VBToCSharpConvertingExpressionVisitor().Visit(query);
query = new AllAnyContainsRewritingExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new SubqueryMemberPushdownExpressionVisitor(QueryCompilationContext.Model).Visit(query);
query = new NavigationExpandingExpressionVisitor(this, QueryCompilationContext, Dependencies.EvaluatableExpressionFilter)
.Expand(query);
query = new FunctionPreprocessingExpressionVisitor().Visit(query);
query = new QueryOptimizingExpressionVisitor().Visit(query);

return query;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2176,7 +2176,7 @@ public virtual Task Where_collection_navigation_ToArray_Count(bool async)
elementAsserter: (e, a) => AssertCollection(e, a));
}

[ConditionalTheory(Skip = "Issue#19433")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_collection_navigation_ToArray_Contains(bool async)
{
Expand All @@ -2185,7 +2185,7 @@ public virtual Task Where_collection_navigation_ToArray_Contains(bool async)
return AssertQuery(
async,
ss => ss.Set<Customer>()
.Select(c => c.Orders.ToArray())
.Select(c => c.Orders.AsEnumerable().ToArray())
.Where(e => e.Contains(order)),
entryCount: 5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2190,9 +2190,13 @@ public override async Task Contains_with_subquery_optional_navigation_and_consta
FROM [LevelOne] AS [l]
LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[Level1_Optional_Id]
WHERE EXISTS (
SELECT DISTINCT 1
FROM [LevelThree] AS [l1]
WHERE ([l0].[Id] IS NOT NULL AND ([l0].[Id] = [l1].[OneToMany_Optional_Inverse3Id])) AND ([l1].[Id] = 1))");
SELECT 1
FROM (
SELECT DISTINCT [l1].[Id], [l1].[Level2_Optional_Id], [l1].[Level2_Required_Id], [l1].[Name], [l1].[OneToMany_Optional_Inverse3Id], [l1].[OneToMany_Optional_Self_Inverse3Id], [l1].[OneToMany_Required_Inverse3Id], [l1].[OneToMany_Required_Self_Inverse3Id], [l1].[OneToOne_Optional_PK_Inverse3Id], [l1].[OneToOne_Optional_Self3Id]
FROM [LevelThree] AS [l1]
WHERE [l0].[Id] IS NOT NULL AND ([l0].[Id] = [l1].[OneToMany_Optional_Inverse3Id])
) AS [t]
WHERE [t].[Id] = 1)");
}

public override async Task Contains_with_subquery_optional_navigation_scalar_distinct_and_constant_item(bool async)
Expand Down
Loading

0 comments on commit 8b6bba7

Please sign in to comment.