Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jun 4, 2019
1 parent b769b59 commit 2385534
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ namespace Microsoft.EntityFrameworkCore.Query.Pipeline
/// <remarks>
/// For example, an expression such as cs.Where(c => c == something) would be rewritten to cs.Where(c => c.Id == something.Id).
/// </remarks>
public class EntityEqualityRewritingExpressionVisitor2 : ExpressionVisitor
public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor
{
protected RewritingVisitor Rewriter { get; }
protected ReducingVisitor Reducer { get; }

protected IDiagnosticsLogger<DbLoggerCategory.Query> Logger { get; }
protected IModel Model { get; }

public EntityEqualityRewritingExpressionVisitor2(QueryCompilationContext2 queryCompilationContext)
public EntityEqualityRewritingExpressionVisitor(QueryCompilationContext2 queryCompilationContext)
{
Rewriter = new RewritingVisitor(queryCompilationContext);
Reducer = new ReducingVisitor();
Expand Down Expand Up @@ -60,15 +60,15 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio

protected override Expression VisitNew(NewExpression newExpression)
{
var newNew = (NewExpression)base.VisitNew(newExpression);
var visitedExpression = (NewExpression)base.VisitNew(newExpression);

return (newExpression.Members?.Count ?? 0) == 0
? (Expression)newNew
: new EntityReferenceExpression(newNew, newNew.Members
? (Expression)visitedExpression
: new EntityReferenceExpression(visitedExpression, visitedExpression.Members
.Select((m, i) => (Member: m, Index: i))
.ToDictionary(
mi => mi.Member.Name,
mi => Visit(newExpression.Arguments[mi.Index])));
mi => visitedExpression.Arguments[mi.Index]));
}

protected override Expression VisitMember(MemberExpression memberExpression)
Expand Down Expand Up @@ -152,14 +152,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return methodCallExpression.Update(Visit(methodCallExpression.Object), Array.Empty<Expression>());
}

// Methods with one lambda argument that has one parameter, and with a typed first argument (source) are rewritten automatically
// (e.g. Where(), FromSql(), Average()
// Methods with a typed first argument (source), and with no lambda arguments or a single lambda
// argument that has one parameter are rewritten automatically (e.g. Where(), FromSql(), Average()
var newArguments = new Expression[arguments.Count];
var lambdaParamCount = arguments.Count(a => GetLambdaOrNull(a) != null);
var singleLambda = arguments.Select(GetLambdaOrNull).SingleOrDefault(l => l != null);
var lambdaArgs = arguments.Select(GetLambdaOrNull).Where(l => l != null).ToArray();
newSource = newArguments[0] = Visit(arguments[0]);
if ((lambdaParamCount == 0 || singleLambda.Parameters.Count == 1)
&& newSource is EntityReferenceExpression newSourceWrapper)
if (newSource is EntityReferenceExpression newSourceWrapper
&& (lambdaArgs.Length == 0
|| lambdaArgs.Length == 1 && lambdaArgs[0].Parameters.Count == 1))
{
for (var i = 1; i < arguments.Count; i++)
{
Expand All @@ -170,11 +170,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

var sourceParamType = methodCallExpression.Method.GetParameters()[0].ParameterType;
if (IsQueryableOrEnumerable(sourceParamType, out var sourceElementType))
if (sourceParamType.TryGetSequenceType() is Type sourceElementType)
{
// If the method returns the element same type as the source, flow the type information
// (e.g. Where, OrderBy)
if (IsQueryableOrEnumerable(methodCallExpression.Method.ReturnType, out var returnElementType)
if (methodCallExpression.Method.ReturnType.TryGetSequenceType() is Type returnElementType
&& returnElementType == sourceElementType)
{
return newSourceWrapper.WithUnderlying(
Expand Down Expand Up @@ -208,21 +208,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return methodCallExpression.Update(Visit(methodCallExpression.Object), newArguments);
}

protected static bool IsQueryableOrEnumerable(Type type, out Type elementType)
{
var queryableOrEnumerable =
type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(IQueryable<>)
|| type.GetGenericTypeDefinition() == typeof(IEnumerable<>))
? type
: type.GetInterfaces().FirstOrDefault(
i => i.IsGenericType && (
i.GetGenericTypeDefinition() == typeof(IQueryable<>)
|| i.GetGenericTypeDefinition() == typeof(IEnumerable<>)));

elementType = queryableOrEnumerable?.GetGenericArguments()[0];
return queryableOrEnumerable != null;
}

protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
Expand Down Expand Up @@ -278,13 +263,11 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall
return base.VisitMethodCall(methodCallExpression);
}

var (newOuter, newInner, outerKeySelector, innerKeySelector, resultSelector) = (
Visit(arguments[0]),
Visit(arguments[1]),
arguments[2].UnwrapQuote(),
arguments[3].UnwrapQuote(),
arguments[4].UnwrapQuote()
);
var newOuter = Visit(arguments[0]);
var newInner = Visit(arguments[1]);
var outerKeySelector = arguments[2].UnwrapQuote();
var innerKeySelector = arguments[3].UnwrapQuote();
var resultSelector = arguments[4].UnwrapQuote();

if (!(newOuter is EntityReferenceExpression outerWrapper && newInner is EntityReferenceExpression innerWrapper))
{
Expand Down Expand Up @@ -348,6 +331,7 @@ protected LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda,
/// <returns>The rewritten entity equality expression, or null if rewriting could not occur for some reason.</returns>
protected virtual Expression RewriteEquality(bool isEqual, Expression left, Expression right)
{

// TODO: Consider throwing if a child has no flowed entity type, but has a Type that corresponds to an entity type on the model.
// TODO: This would indicate an issue in our flowing logic, and would help the user (and us) understand what's going on.

Expand Down Expand Up @@ -469,6 +453,7 @@ protected override Expression VisitExtension(Expression expression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
// TODO: DRY with NavigationExpansionHelpers
protected static Expression CreateKeyAccessExpression(
[NotNull] Expression target,
[NotNull] IReadOnlyList<IProperty> properties)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public Expression Visit(Expression query)
query = new AllAnyToContainsRewritingExpressionVisitor().Visit(query);
query = new GroupJoinFlatteningExpressionVisitor().Visit(query);
query = new NullCheckRemovingExpressionVisitor().Visit(query);
query = new EntityEqualityRewritingExpressionVisitor2(_queryCompilationContext).Visit(query);
query = new EntityEqualityRewritingExpressionVisitor(_queryCompilationContext).Visit(query);
query = new NavigationExpander(_queryCompilationContext.Model).ExpandNavigations(query);
query = new EnumerableToQueryableReMappingExpressionVisitor().Visit(query);
query = new QueryMetadataExtractingExpressionVisitor(_queryCompilationContext).Visit(query);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,7 @@ public virtual Task OrderBy_Skip_Last_gives_correct_result(bool isAsync)
entryCount: 1);
}

[ConditionalFact(Skip = "Null TypeMapping in Sql Tree")]
[ConditionalFact(Skip = "#15939")]
public virtual void Contains_over_entityType_should_rewrite_to_identity_equality()
{
using (var context = CreateContext())
Expand Down
8 changes: 5 additions & 3 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ from od in odt
OrderID = 10248,
ProductID = 11
}
select (object)od.ProductID);
select od,
entryCount: 1);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand All @@ -451,7 +452,7 @@ public virtual Task Entity_equality_null_composite_key(bool isAsync)
odt =>
from od in odt
where od == null
select (object)od.ProductID);
select od);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand All @@ -473,7 +474,8 @@ public virtual Task Entity_equality_not_null_composite_key(bool isAsync)
odt =>
from od in odt
where od != null
select (object)od.ProductID);
select od,
entryCount: 2155);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down

0 comments on commit 2385534

Please sign in to comment.