Skip to content

Commit

Permalink
ExecuteUpdate: Allow using other tables in the query to generate resu…
Browse files Browse the repository at this point in the history
…lt set

Part of #795
  • Loading branch information
smitpatel committed Aug 15, 2022
1 parent 446be9a commit f1f2086
Show file tree
Hide file tree
Showing 14 changed files with 524 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Collections;
using System.Collections.Concurrent;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.Extensions.Caching.Memory;

namespace Microsoft.EntityFrameworkCore.Query.Internal;
Expand Down
53 changes: 48 additions & 5 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1237,8 +1237,6 @@ protected override Expression VisitUpdate(UpdateExpression updateExpression)
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Tables.Count == 1
&& selectExpression.Tables[0] == updateExpression.Table
&& selectExpression.Projection.Count == 0)
{
_relationalCommandBuilder.Append("UPDATE ");
Expand All @@ -1255,13 +1253,58 @@ protected override Expression VisitUpdate(UpdateExpression updateExpression)
},
joinAction: e => e.AppendLine(","));
_relationalCommandBuilder.AppendLine();
}

if (selectExpression.Predicate != null)
var predicate = selectExpression.Predicate;
var firstTablePrinted = false;
if (selectExpression.Tables.Count > 1)
{
_relationalCommandBuilder.AppendLine().Append("FROM ");
for (var i = 0; i < selectExpression.Tables.Count; i++)
{
var table = selectExpression.Tables[i];
var joinExpression = table as JoinExpressionBase;

if (ReferenceEquals(updateExpression.Table, joinExpression?.Table ?? table))
{
LiftPredicate(table);
continue;
}

if (firstTablePrinted)
{
_relationalCommandBuilder.AppendLine();
}
else
{
firstTablePrinted = true;
LiftPredicate(table);
table = joinExpression?.Table ?? table;
}

Visit(table);

void LiftPredicate(TableExpressionBase joinTable)
{
if (joinTable is PredicateJoinExpressionBase predicateJoinExpression)
{
predicate = predicate == null
? predicateJoinExpression.JoinPredicate
: new SqlBinaryExpression(
ExpressionType.AndAlso,
predicateJoinExpression.JoinPredicate,
predicate,
typeof(bool),
predicate.TypeMapping);
}
}
}
}

if (predicate != null)
{
_relationalCommandBuilder.AppendLine().Append("WHERE ");
Visit(selectExpression.Predicate);
Visit(predicate);
}

return updateExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
foreach (var (propertyExpression, valueExpression) in propertyValueLambdaExpressions)
{
var left = RemapLambdaBody(source, propertyExpression);
left = left.UnwrapTypeConversion(out _);
if (!IsValidPropertyAccess(left, out var ese))
{
AddTranslationErrorDetails(RelationalStrings.InvalidPropertyInSetProperty(propertyExpression.Print()));
Expand All @@ -1148,6 +1149,10 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

var right = RemapLambdaBody(source, valueExpression);
if (right.Type != left.Type)
{
right = Expression.Convert(right, left.Type);
}
// We generate equality between property = value while translating sothat value infer tye type mapping from property correctly.
// Later we decompose it back into left/right components so that the equality is not in the tree which can get affected by
// null semantics or other visitor.
Expand Down Expand Up @@ -1305,7 +1310,7 @@ static bool IsValidPropertyAccess(Expression expression, [NotNullWhen(true)] out
/// <param name="selectExpression">The select expression to validate.</param>
/// <param name="entityShaperExpression">The entity shaper expression on which the delete operation is being applied.</param>
/// <param name="tableExpression">The table expression from which rows are being deleted.</param>
/// <returns> das </returns>
/// <returns>Returns <see langword="true" /> if the current select expression can be used for delete as-is, <see langword="false" /> otherwise.</returns>
protected virtual bool IsValidSelectExpressionForExecuteDelete(
SelectExpression selectExpression,
EntityShaperExpression entityShaperExpression,
Expand All @@ -1330,13 +1335,12 @@ protected virtual bool IsValidSelectExpressionForExecuteDelete(
return false;
}

// TODO: Update this documentation.
/// <summary>
/// Validates if the current select expression can be used for execute update operation or it requires to be pushed into a subquery.
/// Validates if the current select expression can be used for execute update operation or it requires to be joined as a subquery.
/// </summary>
/// <remarks>
/// <para>
/// By default, only single-table select expressions are supported, and optionally with a predicate.
/// By default, only muli-table select expressions are supported, and optionally with a predicate.
/// </para>
/// <para>
/// Providers can override this to allow more select expression features to be supported without pushing down into a subquery.
Expand All @@ -1347,7 +1351,7 @@ protected virtual bool IsValidSelectExpressionForExecuteDelete(
/// <param name="selectExpression">The select expression to validate.</param>
/// <param name="entityShaperExpression">The entity shaper expression on which the update operation is being applied.</param>
/// <param name="tableExpression">The table expression from which rows are being deleted.</param>
/// <returns> das </returns>
/// <returns>Returns <see langword="true" /> if the current select expression can be used for update as-is, <see langword="false" /> otherwise.</returns>
protected virtual bool IsValidSelectExpressionForExecuteUpdate(
SelectExpression selectExpression,
EntityShaperExpression entityShaperExpression,
Expand All @@ -1359,13 +1363,30 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate(
&& (!selectExpression.IsDistinct || entityShaperExpression.EntityType.FindPrimaryKey() != null)
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.Tables.Count == 1
&& selectExpression.Tables[0] is TableExpression expression)
&& selectExpression.Orderings.Count == 0)
{
tableExpression = expression;
TableExpressionBase table;
if (selectExpression.Tables.Count == 1)
{
table = selectExpression.Tables[0];
}
else
{
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetProjection(projectionBindingExpression);
var column = entityProjectionExpression.BindProperty(entityShaperExpression.EntityType.GetProperties().First());
table = column.Table;
if (table is JoinExpressionBase joinExpressionBase)
{
table = joinExpressionBase.Table;
}
}

return true;
if (table is TableExpression te)
{
tableExpression = te;
return true;
}
}

tableExpression = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,14 @@ protected override Expression VisitUpdate(UpdateExpression updateExpression)
var selectExpression = updateExpression.SelectExpression;

if (selectExpression.Offset == null
&& selectExpression.Limit == null
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Tables.Count == 1
&& selectExpression.Tables[0] == updateExpression.Table
&& selectExpression.Projection.Count == 0)
{
Sql.Append("UPDATE ");
GenerateTop(selectExpression);

Sql.AppendLine($"{Dependencies.SqlGenerationHelper.DelimitIdentifier(updateExpression.Table.Alias)}");
using (Sql.Indent())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,52 @@ protected override bool IsValidSelectExpressionForExecuteDelete(
return false;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override bool IsValidSelectExpressionForExecuteUpdate(
SelectExpression selectExpression,
EntityShaperExpression entityShaperExpression,
[NotNullWhen(true)] out TableExpression? tableExpression)
{
if (selectExpression.Offset == null
// If entity type has primary key then Distinct is no-op
&& (!selectExpression.IsDistinct || entityShaperExpression.EntityType.FindPrimaryKey() != null)
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0)
{
TableExpressionBase table;
if (selectExpression.Tables.Count == 1)
{
table = selectExpression.Tables[0];
}
else
{
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetProjection(projectionBindingExpression);
var column = entityProjectionExpression.BindProperty(entityShaperExpression.EntityType.GetProperties().First());
table = column.Table;
if (table is JoinExpressionBase joinExpressionBase)
{
table = joinExpressionBase.Table;
}
}

if (table is TableExpression te)
{
tableExpression = te;
return true;
}
}

tableExpression = null;
return false;
}

private sealed class TemporalAnnotationApplyingExpressionVisitor : ExpressionVisitor
{
private readonly Func<TableExpression, TableExpressionBase> _annotationApplyingFunc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,29 @@ public virtual Task Update_where_constant(bool async)
rowsAffectedCount: 8,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Update_where_parameter_in_predicate(bool async)
{
var customer = "ALFKI";
await AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID == customer),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

customer = null;
await AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID == customer),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 0,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_parameter(bool async)
Expand All @@ -357,6 +380,113 @@ public virtual Task Update_where_parameter(bool async)
(b, a) => a.ForEach(c => Assert.Equal("Abc", c.ContactName)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_take_constant(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F")).Take(4),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 4,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_group_by_aggregate_constant(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>()
.Where(c => c.CustomerID == ss.Set<Order>()
.GroupBy(e => e.CustomerID).Where(g => g.Count() > 11).Select(e => e.Key).FirstOrDefault()),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_group_by_first_constant(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>()
.Where(c => c.CustomerID == ss.Set<Order>()
.GroupBy(e => e.CustomerID).Where(g => g.Count() > 11).Select(e => e.First().CustomerID).FirstOrDefault()),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory(Skip = "Issue#26753")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_group_by_first_constant_2(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>()
.Where(c => c == ss.Set<Order>()
.GroupBy(e => e.CustomerID).Where(g => g.Count() > 11).Select(e => e.First().Customer).FirstOrDefault()),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory(Skip = "Issue#28524")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_group_by_first_constant_3(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>()
.Where(c => ss.Set<Order>()
.GroupBy(e => e.CustomerID).Where(g => g.Count() > 11).Select(e => e.First().Customer).Contains(c)),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_distinct_constant(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F")).Distinct(),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 8,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_using_navigation(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Order>().Where(o => o.Customer.City == "Seattle"),
e => e,
s => s.SetProperty(c => c.OrderDate, c => null),
rowsAffectedCount: 14,
(b, a) => a.ForEach(c => Assert.Null(c.OrderDate)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_using_navigation_2(bool async)
=> AssertUpdate(
async,
ss => ss.Set<OrderDetail>().Where(od => od.Order.Customer.City == "Seattle"),
e => e,
s => s.SetProperty(c => c.Quantity, c => 1),
rowsAffectedCount: 40,
(b, a) => a.ForEach(c => Assert.Equal(1, c.Quantity)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_select_many(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F")).SelectMany(c => c.Orders),
e => e,
s => s.SetProperty(c => c.OrderDate, c => null),
rowsAffectedCount: 63,
(b, a) => a.ForEach(c => Assert.Null(c.OrderDate)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where_using_property_plus_constant(bool async)
Expand Down
Loading

0 comments on commit f1f2086

Please sign in to comment.