Skip to content

Commit

Permalink
Query: Generate predicate correctly when expanding owned collections
Browse files Browse the repository at this point in the history
Resolves #23130

Some additional tests changed because it ended up causing client eval in the middle due to another level of OwnsMany
  • Loading branch information
smitpatel committed Oct 29, 2020
1 parent 33570d0 commit 8a17e31
Show file tree
Hide file tree
Showing 6 changed files with 441 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1282,9 +1282,13 @@ internal Expression ExpandWeakEntities(InMemoryQueryExpression queryExpression,

private sealed class WeakEntityExpandingExpressionVisitor : ExpressionVisitor
{
private InMemoryQueryExpression _queryExpression;
private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

private readonly InMemoryExpressionTranslatingExpressionVisitor _expressionTranslator;

private InMemoryQueryExpression _queryExpression;

public WeakEntityExpandingExpressionVisitor(InMemoryExpressionTranslatingExpressionVisitor expressionTranslator)
{
_expressionTranslator = expressionTranslator;
Expand Down Expand Up @@ -1402,15 +1406,23 @@ private Expression TryExpand(Expression source, MemberIdentity member)
: foreignKey.Properties,
makeNullable);

var outerKeyFirstProperty = outerKey is NewExpression newExpression
? ((UnaryExpression)((NewArrayExpression)newExpression.Arguments[0]).Expressions[0]).Operand
: outerKey;
var keyComparison = Expression.Call(_objectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey));

var predicate = outerKeyFirstProperty.Type.IsNullableType()
var predicate = makeNullable
? Expression.AndAlso(
Expression.NotEqual(outerKeyFirstProperty, Expression.Constant(null, outerKeyFirstProperty.Type)),
Expression.Equal(outerKey, innerKey))
: Expression.Equal(outerKey, innerKey);
outerKey is NewArrayExpression newArrayExpression
? newArrayExpression.Expressions
.Select(
e =>
{
var left = (e as UnaryExpression)?.Operand ?? e;
return Expression.NotEqual(left, Expression.Constant(null, left.Type));
})
.Aggregate((l, r) => Expression.AndAlso(l, r))
: Expression.NotEqual(outerKey, Expression.Constant(null, outerKey.Type)),
keyComparison)
: (Expression)keyComparison;

var correlationPredicate = _expressionTranslator.Translate(predicate);
innerQueryExpression.UpdateServerQueryExpression(
Expand Down Expand Up @@ -1460,6 +1472,10 @@ ProjectionBindingExpression projectionBindingExpression

return innerShaper;
}
private static Expression AddConvertToObject(Expression expression)
=> expression.Type.IsValueType
? Expression.Convert(expression, typeof(object))
: expression;
}

private ShapedQueryExpression TranslateScalarAggregate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1220,10 +1220,14 @@ internal Expression ExpandWeakEntities(SelectExpression selectExpression, Expres

private sealed class WeakEntityExpandingExpressionVisitor : ExpressionVisitor
{
private SelectExpression _selectExpression;
private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly ISqlExpressionFactory _sqlExpressionFactory;

private SelectExpression _selectExpression;

public WeakEntityExpandingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitor sqlTranslator,
ISqlExpressionFactory sqlExpressionFactory)
Expand Down Expand Up @@ -1347,15 +1351,23 @@ private Expression TryExpand(Expression source, MemberIdentity member)
: foreignKey.Properties,
makeNullable);

var outerKeyFirstProperty = outerKey is NewExpression newExpression
? ((UnaryExpression)((NewArrayExpression)newExpression.Arguments[0]).Expressions[0]).Operand
: outerKey;
var keyComparison = Expression.Call(_objectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey));

var predicate = outerKeyFirstProperty.Type.IsNullableType()
var predicate = makeNullable
? Expression.AndAlso(
Expression.NotEqual(outerKeyFirstProperty, Expression.Constant(null, outerKeyFirstProperty.Type)),
Expression.Equal(outerKey, innerKey))
: Expression.Equal(outerKey, innerKey);
outerKey is NewArrayExpression newArrayExpression
? newArrayExpression.Expressions
.Select(
e =>
{
var left = (e as UnaryExpression)?.Operand ?? e;
return Expression.NotEqual(left, Expression.Constant(null, left.Type));
})
.Aggregate((l, r) => Expression.AndAlso(l, r))
: Expression.NotEqual(outerKey, Expression.Constant(null, outerKey.Type)),
keyComparison)
: (Expression)keyComparison;

var correlationPredicate = Expression.Lambda(predicate, correlationPredicateParameter);

Expand Down Expand Up @@ -1446,6 +1458,11 @@ private Expression TryExpand(Expression source, MemberIdentity member)
return innerShaper;
}

private static Expression AddConvertToObject(Expression expression)
=> expression.Type.IsValueType
? Expression.Convert(expression, typeof(object))
: expression;

private static IDictionary<IProperty, ColumnExpression> GetPropertyExpressionFromSameTable(
IEntityType entityType,
ITableBase table,
Expand Down
33 changes: 33 additions & 0 deletions test/EFCore.Cosmos.FunctionalTests/Query/OwnedQueryCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,39 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
OrderDate = Convert.ToDateTime("2016-04-25 19:23:56")
}
);
ob.OwnsMany(e => e.Details, odb =>
{
odb.HasData(
new
{
Id = -100,
OrderId = -10,
OrderClientId = 1,
Detail = "Discounted Order"
},
new
{
Id = -101,
OrderId = -10,
OrderClientId = 1,
Detail = "Full Price Order"
},
new
{
Id = -200,
OrderId = -20,
OrderClientId = 2,
Detail = "Internal Order"
},
new
{
Id = -300,
OrderId = -30,
OrderClientId = 3,
Detail = "Bulk Order"
});
});
});
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public virtual Task Unmapped_property_projection_loads_owned_navigations_split(b
return AssertQuery(
async,
ss => ss.Set<OwnedPerson>().Where(e => e.Id == 1).AsTracking().Select(e => new { e.ReadOnlyProperty }).AsSplitQuery(),
entryCount: 5);
entryCount: 7);
}

[ConditionalTheory]
Expand Down
98 changes: 91 additions & 7 deletions test/EFCore.Specification.Tests/Query/OwnedQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public virtual Task Query_for_branch_type_loads_all_owned_navs_tracking(bool asy
return AssertQuery(
async,
ss => ss.Set<Branch>().AsTracking(),
entryCount: 14);
entryCount: 16);
}

[ConditionalTheory]
Expand Down Expand Up @@ -425,7 +425,7 @@ public virtual Task Unmapped_property_projection_loads_owned_navigations(bool as
return AssertQuery(
async,
ss => ss.Set<OwnedPerson>().Where(e => e.Id == 1).AsTracking().Select(e => new { e.ReadOnlyProperty }),
entryCount: 5);
entryCount: 7);
}
// Issue#18140
Expand Down Expand Up @@ -498,7 +498,8 @@ public virtual Task Where_owned_collection_navigation_ToList_Count(bool async)
async,
ss => ss.Set<OwnedPerson>()
.OrderBy(p => p.Id)
.Select(p => p.Orders.ToList())
.SelectMany(p => p.Orders)
.Select(p => p.Details.ToList())
.Where(e => e.Count() == 0),
assertOrder: true,
elementAsserter: (e, a) => AssertCollection(e, a));
Expand All @@ -512,7 +513,8 @@ public virtual Task Where_collection_navigation_ToArray_Count(bool async)
async,
ss => ss.Set<OwnedPerson>()
.OrderBy(p => p.Id)
.Select(p => p.Orders.ToArray())
.SelectMany(p => p.Orders)
.Select(p => p.Details.AsEnumerable().ToArray())
.Where(e => e.Count() == 0),
assertOrder: true,
elementAsserter: (e, a) => AssertCollection(e, a));
Expand All @@ -526,7 +528,8 @@ public virtual Task Where_collection_navigation_AsEnumerable_Count(bool async)
async,
ss => ss.Set<OwnedPerson>()
.OrderBy(p => p.Id)
.Select(p => p.Orders.AsEnumerable())
.SelectMany(p => p.Orders)
.Select(p => p.Details.AsEnumerable())
.Where(e => e.Count() == 0),
assertOrder: true,
elementAsserter: (e, a) => AssertCollection(e, a));
Expand All @@ -540,7 +543,8 @@ public virtual Task Where_collection_navigation_ToList_Count_member(bool async)
async,
ss => ss.Set<OwnedPerson>()
.OrderBy(p => p.Id)
.Select(p => p.Orders.ToList())
.SelectMany(p => p.Orders)
.Select(p => p.Details.ToList())
.Where(e => e.Count == 0),
assertOrder: true,
elementAsserter: (e, a) => AssertCollection(e, a));
Expand All @@ -554,7 +558,8 @@ public virtual Task Where_collection_navigation_ToArray_Length_member(bool async
async,
ss => ss.Set<OwnedPerson>()
.OrderBy(p => p.Id)
.Select(p => p.Orders.ToArray())
.SelectMany(p => p.Orders)
.Select(p => p.Details.AsEnumerable().ToArray())
.Where(e => e.Length == 0),
assertOrder: true,
elementAsserter: (e, a) => AssertCollection(e, a));
Expand Down Expand Up @@ -902,6 +907,18 @@ private static void AssertOrders(ICollection<Order> expectedOrders, ICollection<
Assert.Equal(element.e.Id, element.a.Id);
Assert.Equal(element.e["OrderDate"], element.a["OrderDate"]);
Assert.Equal(element.e.Client.Id, element.a.Client.Id);
AssertOrderDetails(element.e.Details, element.a.Details);
}
}

private static void AssertOrderDetails(IList<OrderDetail> expectedOrderDetails, IList<OrderDetail> actualOrderDetails)
{
Assert.Equal(expectedOrderDetails.Count, actualOrderDetails.Count);
expectedOrderDetails = expectedOrderDetails.OrderBy(e => e.Detail).ToList();
actualOrderDetails = actualOrderDetails.OrderBy(e => e.Detail).ToList();
for (var i = 0; i < expectedOrderDetails.Count; i++)
{
Assert.Equal(expectedOrderDetails[i].Detail, actualOrderDetails[i].Detail);
}
}

Expand All @@ -926,6 +943,7 @@ public IReadOnlyDictionary<Type, object> GetEntitySorters()
// owned entities - still need comparers in case they are projected directly
{ typeof(Order), e => ((Order)e)?.Id },
{ typeof(OrderDetail), e => ((OrderDetail)e)?.Detail },
{ typeof(OwnedAddress), e => ((OwnedAddress)e)?.Country.Name },
{ typeof(OwnedCountry), e => ((OwnedCountry)e)?.Name },
{ typeof(Element), e => ((Element)e)?.Id },
Expand Down Expand Up @@ -1107,6 +1125,16 @@ public IReadOnlyDictionary<Type, object> GetEntityAsserters()
}
}
},
{
typeof(OrderDetail), (e, a) =>
{
Assert.Equal(e == null, a == null);
if (a != null)
{
Assert.Equal(((OrderDetail)e).Detail, ((OrderDetail)a).Detail);
}
}
},
{
typeof(OwnedAddress), (e, a) =>
{
Expand Down Expand Up @@ -1272,6 +1300,39 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
OrderDate = Convert.ToDateTime("2016-04-25 19:23:56")
}
);
ob.OwnsMany(e => e.Details, odb =>
{
odb.HasData(
new
{
Id = -100,
OrderId = -10,
OrderClientId = 1,
Detail = "Discounted Order"
},
new
{
Id = -101,
OrderId = -10,
OrderClientId = 1,
Detail = "Full Price Order"
},
new
{
Id = -200,
OrderId = -20,
OrderClientId = 2,
Detail = "Internal Order"
},
new
{
Id = -300,
OrderId = -30,
OrderClientId = 3,
Detail = "Bulk Order"
});
});
});
});

Expand Down Expand Up @@ -1616,20 +1677,36 @@ private static IReadOnlyList<OwnedPerson> CreateOwnedPeople()

var order1 = new Order { Id = -10, Client = ownedPerson1 };
order1["OrderDate"] = Convert.ToDateTime("2018-07-11 10:01:41");
order1.Details = new List<OrderDetail>
{
new OrderDetail { Detail = "Discounted Order" },
new OrderDetail { Detail = "Full Price Order" }
};

var order2 = new Order { Id = -11, Client = ownedPerson1 };
order2["OrderDate"] = Convert.ToDateTime("2015-03-03 04:37:59");
order2.Details = new List<OrderDetail>();
ownedPerson1.Orders = new List<Order> { order1, order2 };

var order3 = new Order { Id = -20, Client = ownedPerson2 };
order3["OrderDate"] = Convert.ToDateTime("2015-05-25 20:35:48");
order3.Details = new List<OrderDetail>
{
new OrderDetail { Detail = "Internal Order" }
};
ownedPerson2.Orders = new List<Order> { order3 };

var order4 = new Order { Id = -30, Client = ownedPerson3 };
order4["OrderDate"] = Convert.ToDateTime("2014-11-10 04:32:42");
order4.Details = new List<OrderDetail>
{
new OrderDetail { Detail = "Bulk Order" }
};
ownedPerson3.Orders = new List<Order> { order4 };

var order5 = new Order { Id = -40, Client = ownedPerson4 };
order5["OrderDate"] = Convert.ToDateTime("2016-04-25 19:23:56");
order5.Details = new List<OrderDetail>();
ownedPerson4.Orders = new List<Order> { order5 };

return new List<OwnedPerson>
Expand Down Expand Up @@ -1825,6 +1902,13 @@ public object this[string name]
}

public OwnedPerson Client { get; set; }

public List<OrderDetail> Details { get; set; }
}

protected class OrderDetail
{
public string Detail { get; set; }
}

protected class Branch : OwnedPerson
Expand Down
Loading

0 comments on commit 8a17e31

Please sign in to comment.