Skip to content

Commit

Permalink
Support member access on parameter and constant introduced in pipeline
Browse files Browse the repository at this point in the history
Entity equality introduces member access expressions on what may be a
parameter or a constant. Identify these cases and transform them into
a new parameter (for access of a parameter) or evaluate the constant.

Fixes #15855
  • Loading branch information
roji committed Jul 5, 2019
1 parent f535d39 commit 19b0a36
Show file tree
Hide file tree
Showing 17 changed files with 256 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline
public interface IRelationalSqlTranslatingExpressionVisitorFactory
{
RelationalSqlTranslatingExpressionVisitor Create(
IModel model,
QueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ public class RelationalQueryableMethodTranslatingExpressionVisitor : QueryableMe
private readonly ISqlExpressionFactory _sqlExpressionFactory;

public RelationalQueryableMethodTranslatingExpressionVisitor(
IModel model,
QueryCompilationContext queryCompilationContext,
IRelationalSqlTranslatingExpressionVisitorFactory relationalSqlTranslatingExpressionVisitorFactory,
ISqlExpressionFactory sqlExpressionFactory)
: base(subquery: false)
{
_sqlTranslator = relationalSqlTranslatingExpressionVisitorFactory.Create(model, this);
_sqlTranslator = relationalSqlTranslatingExpressionVisitorFactory.Create(queryCompilationContext, this);
_projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator);
_model = model;
_model = queryCompilationContext.Model;
_sqlExpressionFactory = sqlExpressionFactory;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ public RelationalQueryableMethodTranslatingExpressionVisitorFactory(
_relationalSqlTranslatingExpressionVisitorFactory = relationalSqlTranslatingExpressionVisitorFactory;
}

public QueryableMethodTranslatingExpressionVisitor Create(IModel model)
public QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
{
return new RelationalQueryableMethodTranslatingExpressionVisitor(
model,
queryCompilationContext,
_relationalSqlTranslatingExpressionVisitorFactory,
_sqlExpressionFactory);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion;
Expand All @@ -19,21 +21,21 @@ namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline
{
public class RelationalSqlTranslatingExpressionVisitor : ExpressionVisitor
{
private readonly IModel _model;
private readonly QueryCompilationContext _queryCompilationContext;
private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IMemberTranslatorProvider _memberTranslatorProvider;
private readonly IMethodCallTranslatorProvider _methodCallTranslatorProvider;
private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlVerifyingExpressionVisitor;

public RelationalSqlTranslatingExpressionVisitor(
IModel model,
QueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor,
ISqlExpressionFactory sqlExpressionFactory,
IMemberTranslatorProvider memberTranslatorProvider,
IMethodCallTranslatorProvider methodCallTranslatorProvider)
{
_model = model;
_queryCompilationContext = queryCompilationContext;
_queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor;
_sqlExpressionFactory = sqlExpressionFactory;
_memberTranslatorProvider = memberTranslatorProvider;
Expand Down Expand Up @@ -212,6 +214,67 @@ private SqlExpression BindProperty(Expression source, string propertyName)
return BindProperty(entityProjection, entityType.FindProperty(propertyName));
}

if (source is SqlConstantExpression constantExpression)
{
var constantValue = constantExpression.Value;

if (constantValue.GetType().GetProperty(propertyName) is PropertyInfo propertyInfo)
{
return _sqlExpressionFactory.Constant(
propertyInfo.GetValue(constantValue),
_sqlExpressionFactory.FindMapping(propertyInfo.PropertyType));
}

if (constantValue.GetType().GetField(propertyName) is FieldInfo fieldInfo)
{
return _sqlExpressionFactory.Constant(
fieldInfo.GetValue(constantValue),
_sqlExpressionFactory.FindMapping(fieldInfo.FieldType));
}
}

if (source is SqlParameterExpression baseParameterExpression)
{
// Generate an expression to get the base parameter from the query context's parameter list
var baseParameterValueVariable = Expression.Variable(baseParameterExpression.Type, "baseParam");
var assignBaseParameterValue = Expression.Assign(
baseParameterValueVariable,
Expression.Convert(
Expression.Property(
Expression.Property(QueryCompilationContext.QueryContextParameter, nameof(QueryContext.ParameterValues)),
"Item",
Expression.Constant(baseParameterExpression.Name, typeof(string))),
baseParameterExpression.Type));

// Generate an expression to access the new parameter's value from the base parameter.
// Since the base parameter may be null, we want coalescing behavior.
var getNewParameterValue = (Expression)Expression.PropertyOrField(baseParameterValueVariable, propertyName);
if (getNewParameterValue.Type.IsValueType)
{
getNewParameterValue = Expression.Convert(getNewParameterValue, getNewParameterValue.Type.MakeNullable());
}

var newParameterType = getNewParameterValue.Type;

var lambda = Expression.Lambda(
Expression.Block(
new[] { baseParameterValueVariable },
assignBaseParameterValue,
Expression.Condition(
Expression.ReferenceEqual(baseParameterValueVariable, Expression.Constant(null)),
Expression.Constant(null, newParameterType),
getNewParameterValue )
), QueryCompilationContext.QueryContextParameter);

var newParameterName = $"{baseParameterExpression.Name}_{propertyName}";

// Note: if the same property is accessed twice on the same parameter, we just overwrite the previous lambda
// which is identical.
_queryCompilationContext.AddCompileTimeParameter(newParameterName, lambda);
return new SqlParameterExpression(Expression.Parameter(newParameterType, newParameterName),
_sqlExpressionFactory.FindMapping(newParameterType));
}

throw new InvalidOperationException();
}

Expand Down Expand Up @@ -342,7 +405,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
arguments[i] = (SqlExpression)argument;
}

return _methodCallTranslatorProvider.Translate(_model, (SqlExpression)@object, methodCallExpression.Method, arguments);
return _methodCallTranslatorProvider.Translate(_queryCompilationContext.Model, (SqlExpression)@object, methodCallExpression.Method, arguments);
}

private static Expression TryRemoveImplicitConvert(Expression expression)
Expand Down Expand Up @@ -375,7 +438,6 @@ private static Expression TryRemoveImplicitConvert(Expression expression)
return expression;
}


private Expression ConvertAnonymousObjectEqualityComparison(BinaryExpression binaryExpression)
{
Expression removeObjectConvert(Expression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ public RelationalSqlTranslatingExpressionVisitorFactory(
}

public virtual RelationalSqlTranslatingExpressionVisitor Create(
IModel model,
QueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
{
return new RelationalSqlTranslatingExpressionVisitor(
model,
queryCompilationContext,
queryableMethodTranslatingExpressionVisitor,
_sqlExpressionFactory,
_memberTranslatorProvider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ private static readonly HashSet<ExpressionType> _arithmeticOperatorTypes
private readonly ISqlExpressionFactory _sqlExpressionFactory;

public SqlServerSqlTranslatingExpressionVisitor(
IModel model,
QueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor,
ISqlExpressionFactory sqlExpressionFactory,
IMemberTranslatorProvider memberTranslatorProvider,
IMethodCallTranslatorProvider methodCallTranslatorProvider)
: base(model, queryableMethodTranslatingExpressionVisitor, sqlExpressionFactory, memberTranslatorProvider, methodCallTranslatorProvider)
: base(queryCompilationContext, queryableMethodTranslatingExpressionVisitor, sqlExpressionFactory, memberTranslatorProvider, methodCallTranslatorProvider)
{
_sqlExpressionFactory = sqlExpressionFactory;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ public SqlServerSqlTranslatingExpressionVisitorFactory(
}

public override RelationalSqlTranslatingExpressionVisitor Create(
IModel model,
QueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
{
return new SqlServerSqlTranslatingExpressionVisitor(
model,
queryCompilationContext,
queryableMethodTranslatingExpressionVisitor,
_sqlExpressionFactory,
_memberTranslatorProvider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ namespace Microsoft.EntityFrameworkCore.Query.Pipeline
{
public interface IQueryableMethodTranslatingExpressionVisitorFactory
{
QueryableMethodTranslatingExpressionVisitor Create(IModel model);
QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext);
}

}
45 changes: 43 additions & 2 deletions src/EFCore/Query/Pipeline/QueryCompilationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
Expand All @@ -20,6 +22,13 @@ public class QueryCompilationContext
private readonly IShapedQueryOptimizerFactory _shapedQueryOptimizerFactory;
private readonly IShapedQueryCompilingExpressionVisitorFactory _shapedQueryCompilingExpressionVisitorFactory;

/// <summary>
/// A dictionary mapping SQL parameter names to lambdas that, given a QueryContext, can extract that parameter's value.
/// This is needed for cases where we need to introduce a parameter during the compilation phase (e.g. entity equality rewrites
/// a parameter to an ID property on that parameter).
/// </summary>
private Dictionary<string, LambdaExpression> _additionalCompileTimeParameters;

public QueryCompilationContext(
IModel model,
IQueryOptimizerFactory queryOptimizerFactory,
Expand All @@ -42,7 +51,6 @@ public QueryCompilationContext(
_queryableMethodTranslatingExpressionVisitorFactory = queryableMethodTranslatingExpressionVisitorFactory;
_shapedQueryOptimizerFactory = shapedQueryOptimizerFactory;
_shapedQueryCompilingExpressionVisitorFactory = shapedQueryCompilingExpressionVisitorFactory;

}

public bool Async { get; }
Expand All @@ -62,13 +70,17 @@ public virtual Func<QueryContext, TResult> CreateQueryExecutor<TResult>(Expressi
{
query = _queryOptimizerFactory.Create(this).Visit(query);
// Convert EntityQueryable to ShapedQueryExpression
query = _queryableMethodTranslatingExpressionVisitorFactory.Create(Model).Visit(query);
query = _queryableMethodTranslatingExpressionVisitorFactory.Create(this).Visit(query);
query = _shapedQueryOptimizerFactory.Create(this).Visit(query);

// Inject actual entity materializer
// Inject tracking
query = _shapedQueryCompilingExpressionVisitorFactory.Create(this).Visit(query);

// If any additional parameters were added during the compilation phase (e.g. entity equality ID expression),
// wrap the query with code adding those parameters to the query context
query = WrapWithAdditionalCompileTimeParameters(query);

var queryExecutorExpression = Expression.Lambda<Func<QueryContext, TResult>>(
query,
QueryContextParameter);
Expand All @@ -82,5 +94,34 @@ public virtual Func<QueryContext, TResult> CreateQueryExecutor<TResult>(Expressi
Logger.QueryExecutionPlanned(new ExpressionPrinter(), queryExecutorExpression);
}
}

public void AddCompileTimeParameter(string name, LambdaExpression valueExtractor)
{
if (_additionalCompileTimeParameters == null)
{
_additionalCompileTimeParameters = new Dictionary<string, LambdaExpression>();
}

_additionalCompileTimeParameters[name] = valueExtractor;
}

private Expression WrapWithAdditionalCompileTimeParameters(Expression query)
=> _additionalCompileTimeParameters == null
? query
: Expression.Block(_additionalCompileTimeParameters
.Select(kv =>
Expression.Call(
QueryContextParameter,
_queryContextAddParameterMethodInfo,
Expression.Constant(kv.Key),
Expression.Convert(
Expression.Invoke(kv.Value, QueryContextParameter),
typeof(object))))
.Append(query));

private static readonly MethodInfo _queryContextAddParameterMethodInfo
= typeof(QueryContext)
.GetTypeInfo()
.GetDeclaredMethod(nameof(QueryContext.AddParameter));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ public virtual void FromSqlRaw_does_not_parameterize_interpolated_string()
}
}

[ConditionalFact(Skip = "#15855")]
[ConditionalFact(Skip = "Join bug?")]
public virtual void Entity_equality_through_fromsql()
{
using (var context = CreateContext())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ protected ComplexNavigationsQueryTestBase(TFixture fixture)
{
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Entity_equality_empty(bool isAsync)
{
Expand Down Expand Up @@ -146,7 +146,7 @@ public virtual Task Key_equality_using_property_method_and_member_expression3(bo
(e, a) => Assert.Equal(e.Id, a.Id));
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Key_equality_navigation_converted_to_FK(bool isAsync)
{
Expand All @@ -163,7 +163,7 @@ public virtual Task Key_equality_navigation_converted_to_FK(bool isAsync)
(e, a) => Assert.Equal(e.Id, a.Id));
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Key_equality_two_conditions_on_same_navigation(bool isAsync)
{
Expand All @@ -185,7 +185,7 @@ public virtual Task Key_equality_two_conditions_on_same_navigation(bool isAsync)
(e, a) => Assert.Equal(e.Id, a.Id));
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Key_equality_two_conditions_on_same_navigation2(bool isAsync)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ public virtual Task OrderBy_Skip_Last_gives_correct_result(bool isAsync)
entryCount: 1);
}

[ConditionalFact(Skip = "#15855")]
[ConditionalFact]
public virtual void Contains_over_entityType_should_rewrite_to_identity_equality()
{
using (var context = CreateContext())
Expand All @@ -1510,7 +1510,7 @@ var query
}
}

[ConditionalFact(Skip = "#15855")]
[ConditionalFact]
public virtual void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
using (var context = CreateContext())
Expand All @@ -1519,7 +1519,7 @@ var query
= context.Orders.Where(o => o.CustomerID == "VINET")
.Contains(null);

Assert.True(query);
Assert.False(query);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ await AssertQuery<Customer>(
entryCount: 1);
}

[ConditionalTheory(Skip = "#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Where_poco_closure(bool isAsync)
{
Expand Down Expand Up @@ -1970,7 +1970,7 @@ public virtual Task Where_subquery_FirstOrDefault_is_null(bool isAsync)
entryCount: 2);
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_FirstOrDefault_compared_to_entity(bool isAsync)
{
Expand Down
Loading

0 comments on commit 19b0a36

Please sign in to comment.