Skip to content

Commit

Permalink
Query: Rewrite Entity Equality during translation phase
Browse files Browse the repository at this point in the history
Resolves #15080
Implemented behavior:
- If any part of composite key is null then key is null.
- If comparing entity with null then check if "any" key value is null.
- If comparing entity with non-null then check if "all" key values are non null.

Resolves #20344
Resolves #19431
Resolves #13568
Resolves #13655
Since we already convert property access to nullable, if entity from client is null, make key value as null.

Resolves #19676
Clr type mismatch between proxy type and entity type is ignored.

Resolves #20164
Rewrites entity equality during translation

Part of #18923
  • Loading branch information
smitpatel committed Mar 28, 2020
1 parent 379354c commit 904a105
Show file tree
Hide file tree
Showing 32 changed files with 1,756 additions and 1,517 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public CosmosQueryableMethodTranslatingExpressionVisitor(
_model = queryCompilationContext.Model;
_sqlExpressionFactory = sqlExpressionFactory;
_sqlTranslator = new CosmosSqlTranslatingExpressionVisitor(
_model,
queryCompilationContext,
sqlExpressionFactory,
memberTranslatorProvider,
methodCallTranslatorProvider);
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,18 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.InMemory.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class InMemoryQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor
{
private static readonly MethodInfo _efPropertyMethod = typeof(EF).GetTypeInfo().GetDeclaredMethod(nameof(EF.Property));

private readonly InMemoryExpressionTranslatingExpressionVisitor _expressionTranslator;
private readonly WeakEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor;
private readonly InMemoryProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
Expand All @@ -32,7 +29,7 @@ public InMemoryQueryableMethodTranslatingExpressionVisitor(
[NotNull] QueryCompilationContext queryCompilationContext)
: base(dependencies, subquery: false)
{
_expressionTranslator = new InMemoryExpressionTranslatingExpressionVisitor(this, queryCompilationContext.Model);
_expressionTranslator = new InMemoryExpressionTranslatingExpressionVisitor(queryCompilationContext, this);
_weakEntityExpandingExpressionVisitor = new WeakEntityExpandingExpressionVisitor(_expressionTranslator);
_projectionBindingExpressionVisitor = new InMemoryProjectionBindingExpressionVisitor(this, _expressionTranslator);
_model = queryCompilationContext.Model;
Expand Down Expand Up @@ -402,16 +399,14 @@ protected override ShapedQueryExpression TranslateJoin(
Check.NotNull(inner, nameof(inner));
Check.NotNull(resultSelector, nameof(resultSelector));

outerKeySelector = TranslateLambdaExpression(outer, outerKeySelector);
innerKeySelector = TranslateLambdaExpression(inner, innerKeySelector);
(outerKeySelector, innerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector);

if (outerKeySelector == null
|| innerKeySelector == null)
{
return null;
}

(outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector);

var transparentIdentifierType = TransparentIdentifierFactory.Create(
resultSelector.Parameters[0].Type,
resultSelector.Parameters[1].Type);
Expand All @@ -429,6 +424,71 @@ protected override ShapedQueryExpression TranslateJoin(
transparentIdentifierType);
}

private (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector) ProcessJoinKeySelector(
ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector)
{
var left = RemapLambdaBody(outer, outerKeySelector);
var right = RemapLambdaBody(inner, innerKeySelector);

var joinCondition = TranslateExpression(Expression.Equal(left, right));

var (outerKeyBody, innerKeyBody) = DecomposeJoinCondition(joinCondition);

if (outerKeyBody == null
|| innerKeyBody == null)
{
return (null, null);
}

outerKeySelector = Expression.Lambda(outerKeyBody, ((InMemoryQueryExpression)outer.QueryExpression).CurrentParameter);
innerKeySelector = Expression.Lambda(innerKeyBody, ((InMemoryQueryExpression)inner.QueryExpression).CurrentParameter);

return AlignKeySelectorTypes(outerKeySelector, innerKeySelector);
}

private static (Expression, Expression) DecomposeJoinCondition(Expression joinCondition)
{
var leftExpressions = new List<Expression>();
var rightExpressions = new List<Expression>();

return ProcessJoinCondition(joinCondition, leftExpressions, rightExpressions)
? leftExpressions.Count == 1
? (leftExpressions[0], rightExpressions[0])
: (CreateAnonymousObject(leftExpressions), CreateAnonymousObject(rightExpressions))
: (null, null);

static Expression CreateAnonymousObject(List<Expression> expressions)
=> Expression.New(
AnonymousObject.AnonymousObjectCtor,
Expression.NewArrayInit(
typeof(object),
expressions.Select(e => Expression.Convert(e, typeof(object)))));
}


private static bool ProcessJoinCondition(
Expression joinCondition, List<Expression> leftExpressions, List<Expression> rightExpressions)
{
if (joinCondition is BinaryExpression binaryExpression)
{
if (binaryExpression.NodeType == ExpressionType.Equal)
{
leftExpressions.Add(binaryExpression.Left);
rightExpressions.Add(binaryExpression.Right);

return true;
}

if (binaryExpression.NodeType == ExpressionType.AndAlso)
{
return ProcessJoinCondition(binaryExpression.Left, leftExpressions, rightExpressions)
&& ProcessJoinCondition(binaryExpression.Right, leftExpressions, rightExpressions);
}
}

return false;
}

private static (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector)
AlignKeySelectorTypes(LambdaExpression outerKeySelector, LambdaExpression innerKeySelector)
{
Expand Down Expand Up @@ -477,15 +537,14 @@ protected override ShapedQueryExpression TranslateLeftJoin(
Check.NotNull(inner, nameof(inner));
Check.NotNull(resultSelector, nameof(resultSelector));

outerKeySelector = TranslateLambdaExpression(outer, outerKeySelector);
innerKeySelector = TranslateLambdaExpression(inner, innerKeySelector);
(outerKeySelector, innerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector);

if (outerKeySelector == null
|| innerKeySelector == null)
{
return null;
}

(outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector);

var transparentIdentifierType = TransparentIdentifierFactory.Create(
resultSelector.Parameters[0].Type,
Expand Down Expand Up @@ -579,22 +638,16 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s
var discriminatorProperty = entityType.GetDiscriminatorProperty();
var parameter = Expression.Parameter(entityType.ClrType);

var callEFProperty = Expression.Call(
_efPropertyMethod.MakeGenericMethod(
discriminatorProperty.ClrType),
parameter,
Expression.Constant(discriminatorProperty.Name));

var equals = Expression.Equal(
callEFProperty,
parameter.CreateEFPropertyExpression(discriminatorProperty),
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
{
equals = Expression.OrElse(
equals,
Expression.Equal(
callEFProperty,
parameter.CreateEFPropertyExpression(discriminatorProperty),
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class RelationalQueryableMethodTranslatingExpressionVisitor : QueryableMe
private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly WeakEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor;
private readonly RelationalProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly bool _subquery;
Expand Down Expand Up @@ -54,7 +55,7 @@ protected RelationalQueryableMethodTranslatingExpressionVisitor(
: base(parentVisitor.Dependencies, subquery: true)
{
RelationalDependencies = parentVisitor.RelationalDependencies;
_model = parentVisitor._model;
_queryCompilationContext = parentVisitor._queryCompilationContext;
_sqlTranslator = parentVisitor._sqlTranslator;
_weakEntityExpandingExpressionVisitor = parentVisitor._weakEntityExpandingExpressionVisitor;
_projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator);
Expand Down Expand Up @@ -116,7 +117,7 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(Type elemen
{
Check.NotNull(elementType, nameof(elementType));

var entityType = _model.FindEntityType(elementType);
var entityType = _queryCompilationContext.Model.FindEntityType(elementType);
var queryExpression = _sqlExpressionFactory.Select(entityType);

return CreateShapedQueryExpression(entityType, queryExpression);
Expand Down
Loading

0 comments on commit 904a105

Please sign in to comment.