Skip to content

Commit

Permalink
Support member access on parameter and constant introduced in pipeline
Browse files Browse the repository at this point in the history
Entity equality introduces member access expressions on what may be a
parameter or a constant. Identify these cases and generate a new
parameter (for access of a parameter) or evaluate the constant.

Fixes #15855
Fixes #14645
Fixes #14644
  • Loading branch information
roji committed Jul 9, 2019
1 parent f535d39 commit ae4a9ed
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 145 deletions.
4 changes: 2 additions & 2 deletions src/EFCore/Metadata/Internal/PropertyBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ private void UpdateFieldInfoConfigurationSource(ConfigurationSource configuratio
/// </summary>
public virtual IClrPropertyGetter Getter =>
NonCapturingLazyInitializer.EnsureInitialized(
ref _getter, this,p => new ClrPropertyGetterFactory().Create(p));
ref _getter, this, p => new ClrPropertyGetterFactory().Create(p));

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -300,7 +300,7 @@ private void UpdateFieldInfoConfigurationSource(ConfigurationSource configuratio
/// </summary>
public virtual IClrPropertySetter Setter =>
NonCapturingLazyInitializer.EnsureInitialized(
ref _setter, this,p => new ClrPropertySetterFactory().Create(p));
ref _setter, this, p => new ClrPropertySetterFactory().Create(p));

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors;

namespace Microsoft.EntityFrameworkCore.Query.Pipeline
{
Expand All @@ -24,23 +26,31 @@ namespace Microsoft.EntityFrameworkCore.Query.Pipeline
/// </remarks>
public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor
{
/// <summary>
/// If the entity equality visitors introduces new runtime parameters (because it adds key access over existing parameters),
/// those parameters will have this prefix.
/// </summary>
private const string RuntimeParameterPrefix = CompiledQueryCache.CompiledQueryParameterPrefix + "entity_equality_";

protected QueryCompilationContext QueryCompilationContext { get; }
protected IDiagnosticsLogger<DbLoggerCategory.Query> Logger { get; }
protected IModel Model { get; }

private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

public EntityEqualityRewritingExpressionVisitor(QueryCompilationContext queryCompilationContext)
{
Model = queryCompilationContext.Model;
QueryCompilationContext = queryCompilationContext;
Logger = queryCompilationContext.Logger;
}

public Expression Rewrite(Expression expression) => Unwrap(Visit(expression));

protected override Expression VisitConstant(ConstantExpression constantExpression)
=> constantExpression.IsEntityQueryable()
? new EntityReferenceExpression(constantExpression, Model.FindEntityType(((IQueryable)constantExpression.Value).ElementType))
? new EntityReferenceExpression(
constantExpression,
QueryCompilationContext.Model.FindEntityType(((IQueryable)constantExpression.Value).ElementType))
: (Expression)constantExpression;

protected override Expression VisitNew(NewExpression newExpression)
Expand Down Expand Up @@ -278,15 +288,15 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method

// Wrap the source with a projection to its primary key, and the item with a primary key access expression
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(param.CreateEFPropertyExpression(keyProperty, makeNullable: false), param);
var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
var keyProjection = Expression.Call(
LinqMethodHelpers.QueryableSelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
Unwrap(newSource),
keySelector);

var rewrittenItem = newItem.IsNullConstantExpression()
? Expression.Constant(null)
: Unwrap(newItem).CreateEFPropertyExpression(keyProperty, makeNullable: false);
: CreatePropertyAccessExpression(Unwrap(newItem), keyProperty);

return Expression.Call(
LinqMethodHelpers.QueryableContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType),
Expand Down Expand Up @@ -333,7 +343,7 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method
var rewrittenKeySelector = Expression.Lambda(
ReplacingExpressionVisitor.Replace(
oldParam, param,
body.CreateEFPropertyExpression(keyProperty, makeNullable: false)),
CreatePropertyAccessExpression(body, keyProperty)),
param);

var orderingMethodInfo = GetOrderingMethodInfo(firstOrdering, isAscending);
Expand Down Expand Up @@ -609,7 +619,7 @@ private Expression RewriteNullEquality(
// (this is also why we can do it even over a subquery with a composite key)
return Expression.MakeBinary(
equality ? ExpressionType.Equal : ExpressionType.NotEqual,
nonNullExpression.CreateEFPropertyExpression(keyProperties[0]),
CreatePropertyAccessExpression(nonNullExpression, keyProperties[0], makeNullable: true),
Expression.Constant(null));
}

Expand Down Expand Up @@ -688,11 +698,11 @@ protected virtual Expression VisitNullConditional(NullConditionalExpression expr
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
// TODO: DRY with NavigationExpansionHelpers
protected static Expression CreateKeyAccessExpression(
protected Expression CreateKeyAccessExpression(
[NotNull] Expression target,
[NotNull] IReadOnlyList<IProperty> properties)
=> properties.Count == 1
? target.CreateEFPropertyExpression(properties[0])
? CreatePropertyAccessExpression(target, properties[0])
: Expression.New(
AnonymousObject.AnonymousObjectCtor,
Expression.NewArrayInit(
Expand All @@ -701,11 +711,53 @@ protected static Expression CreateKeyAccessExpression(
.Select(
p =>
Expression.Convert(
target.CreateEFPropertyExpression(p),
CreatePropertyAccessExpression(target, p),
typeof(object)))
.Cast<Expression>()
.ToArray()));

private Expression CreatePropertyAccessExpression(Expression target, IProperty property, bool makeNullable = false)
{
// The target is a constant - evaluate the property immediately and return the result
if (target is ConstantExpression constantExpression)
{
return Expression.Constant(property.GetGetter().GetClrValue(constantExpression.Value), property.ClrType);
}

// If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new
// parameter to be added at runtime, with the value of the property on the base parameter.
if (target is ParameterExpression baseParameterExpression
&& baseParameterExpression.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal))
{
// Generate an expression to get the base parameter from the query context's parameter list
var baseParameterValueVariable = Expression.Variable(baseParameterExpression.Type);
var assignBaseParameterValue =
Expression.Assign(
baseParameterValueVariable,
Expression.Convert(
Expression.Property(
Expression.Property(QueryCompilationContext.QueryContextParameter, nameof(QueryContext.ParameterValues)),
"Item",
Expression.Constant(baseParameterExpression.Name, typeof(string))),
baseParameterExpression.Type));

var lambda = Expression.Lambda(
Expression.Block(
new[] { baseParameterValueVariable },
assignBaseParameterValue,
Expression.Condition( // The target could be null, wrap in a conditional expression to coalesce
Expression.ReferenceEqual(baseParameterValueVariable, Expression.Constant(null)),
Expression.Constant(null),
Expression.Convert(Expression.PropertyOrField(baseParameterValueVariable, property.Name), typeof(object)))),
QueryCompilationContext.QueryContextParameter);

var newParameterName = $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}";
QueryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda);
return Expression.Parameter(property.ClrType, newParameterName);
}

return target.CreateEFPropertyExpression(property, makeNullable);
}

protected static Expression UnwrapLastNavigation(Expression expression)
=> (expression as MemberExpression)?.Expression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ private Expression TryGetConstantValue(Expression expression)
{
if (_evaluatableExpressions.ContainsKey(expression))
{
var value = GetValue(expression, out var _);
var value = GetValue(expression, out _);

if (value is bool)
{
Expand Down
54 changes: 53 additions & 1 deletion src/EFCore/Query/Pipeline/QueryCompilationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
Expand All @@ -20,6 +22,13 @@ public class QueryCompilationContext
private readonly IShapedQueryOptimizerFactory _shapedQueryOptimizerFactory;
private readonly IShapedQueryCompilingExpressionVisitorFactory _shapedQueryCompilingExpressionVisitorFactory;

/// <summary>
/// A dictionary mapping parameter names to lambdas that, given a QueryContext, can extract that parameter's value.
/// This is needed for cases where we need to introduce a parameter during the compilation phase (e.g. entity equality rewrites
/// a parameter to an ID property on that parameter).
/// </summary>
private Dictionary<string, LambdaExpression> _runtimeParameters;

public QueryCompilationContext(
IModel model,
IQueryOptimizerFactory queryOptimizerFactory,
Expand All @@ -42,7 +51,6 @@ public QueryCompilationContext(
_queryableMethodTranslatingExpressionVisitorFactory = queryableMethodTranslatingExpressionVisitorFactory;
_shapedQueryOptimizerFactory = shapedQueryOptimizerFactory;
_shapedQueryCompilingExpressionVisitorFactory = shapedQueryCompilingExpressionVisitorFactory;

}

public bool Async { get; }
Expand All @@ -69,6 +77,10 @@ public virtual Func<QueryContext, TResult> CreateQueryExecutor<TResult>(Expressi
// Inject tracking
query = _shapedQueryCompilingExpressionVisitorFactory.Create(this).Visit(query);

// If any additional parameters were added during the compilation phase (e.g. entity equality ID expression),
// wrap the query with code adding those parameters to the query context
query = InsertRuntimeParameters(query);

var queryExecutorExpression = Expression.Lambda<Func<QueryContext, TResult>>(
query,
QueryContextParameter);
Expand All @@ -82,5 +94,45 @@ public virtual Func<QueryContext, TResult> CreateQueryExecutor<TResult>(Expressi
Logger.QueryExecutionPlanned(new ExpressionPrinter(), queryExecutorExpression);
}
}

/// <summary>
/// Registers a runtime parameter that is being added at some point during the compilation phase.
/// A lambda must be provided, which will extract the parameter's value from the QueryContext every time
/// the query is executed.
/// </summary>
public void RegisterRuntimeParameter(string name, LambdaExpression valueExtractor)
{
if (valueExtractor.Parameters.Count != 1
|| valueExtractor.Parameters[0] != QueryContextParameter
|| valueExtractor.ReturnType != typeof(object))
{
throw new ArgumentException("Runtime parameter extraction lambda must have one QueryContext parameter and return an object",
nameof(valueExtractor));
}

if (_runtimeParameters == null)
{
_runtimeParameters = new Dictionary<string, LambdaExpression>();
}

_runtimeParameters[name] = valueExtractor;
}

private Expression InsertRuntimeParameters(Expression query)
=> _runtimeParameters == null
? query
: Expression.Block(_runtimeParameters
.Select(kv =>
Expression.Call(
QueryContextParameter,
_queryContextAddParameterMethodInfo,
Expression.Constant(kv.Key),
Expression.Invoke(kv.Value, QueryContextParameter)))
.Append(query));

private static readonly MethodInfo _queryContextAddParameterMethodInfo
= typeof(QueryContext)
.GetTypeInfo()
.GetDeclaredMethod(nameof(QueryContext.AddParameter));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override void Contains_over_entityType_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_should_rewrite_to_identity_equality();
Expand All @@ -1206,6 +1207,17 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

public override void Contains_over_entityType_should_materialize_when_composite()
{
base.Contains_over_entityType_should_materialize_when_composite();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public SimpleQueryCosmosTest(
: base(fixture)
{
ClearLog();
//Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

[ConditionalFact(Skip = "See issue#13857")]
Expand Down Expand Up @@ -116,12 +116,26 @@ public override async Task Entity_equality_local(bool isAsync)
{
await base.Entity_equality_local(isAsync);

AssertSql(
@"@__entity_equality_local_0_CustomerID='ANATR'
SELECT c[""CustomerID""]
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = @__entity_equality_local_0_CustomerID))");
}

[ConditionalTheory(Skip = "Issue#14935")]
public override async Task Entity_equality_local_composite_key(bool isAsync)
{
await base.Entity_equality_local_composite_key(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}

[ConditionalTheory(Skip = "Issue#14935")]
public override async Task Join_with_entity_equality_local_on_both_sources(bool isAsync)
{
await base.Join_with_entity_equality_local_on_both_sources(isAsync);
Expand All @@ -136,6 +150,17 @@ public override async Task Entity_equality_local_inline(bool isAsync)
{
await base.Entity_equality_local_inline(isAsync);

AssertSql(
@"SELECT c[""CustomerID""]
FROM root c
WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""CustomerID""] = ""ANATR""))");
}

[ConditionalTheory(Skip = "Issue#14935")]
public override async Task Entity_equality_local_inline_composite_key(bool isAsync)
{
await base.Entity_equality_local_inline_composite_key(isAsync);

AssertSql(
@"SELECT c
FROM root c
Expand Down Expand Up @@ -321,6 +346,7 @@ FROM root c
WHERE (c[""Discriminator""] = ""Employee"")");
}

[ConditionalTheory(Skip = "Issue#14935")]
public override async Task Where_query_composition_entity_equality_one_element_FirstOrDefault(bool isAsync)
{
await base.Where_query_composition_entity_equality_one_element_FirstOrDefault(isAsync);
Expand All @@ -341,6 +367,7 @@ FROM root c
WHERE (c[""Discriminator""] = ""Employee"")");
}

[ConditionalTheory(Skip = "Issue#14935")]
public override async Task Where_query_composition_entity_equality_no_elements_FirstOrDefault(bool isAsync)
{
await base.Where_query_composition_entity_equality_no_elements_FirstOrDefault(isAsync);
Expand All @@ -351,6 +378,7 @@ FROM root c
WHERE (c[""Discriminator""] = ""Employee"")");
}

[ConditionalTheory(Skip = "Issue#14935")]
public override async Task Where_query_composition_entity_equality_multiple_elements_FirstOrDefault(bool isAsync)
{
await base.Where_query_composition_entity_equality_multiple_elements_FirstOrDefault(isAsync);
Expand Down Expand Up @@ -3721,6 +3749,7 @@ FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}

[ConditionalTheory(Skip = "Issue#14935")]
public override async Task Let_entity_equality_to_other_entity(bool isAsync)
{
await base.Let_entity_equality_to_other_entity(isAsync);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ public virtual void FromSqlRaw_does_not_parameterize_interpolated_string()
}
}

[ConditionalFact(Skip = "#15855")]
[ConditionalFact]
public virtual void Entity_equality_through_fromsql()
{
using (var context = CreateContext())
Expand All @@ -1002,7 +1002,7 @@ public virtual void Entity_equality_through_fromsql()
})
.ToArray();

Assert.Equal(1, actual.Length);
Assert.Equal(5, actual.Length);
}
}

Expand Down
Loading

0 comments on commit ae4a9ed

Please sign in to comment.