From e72297c8f77488096a9a31cc860edc63170c4a99 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Tue, 17 Mar 2020 17:28:15 -0700 Subject: [PATCH] Query: Create Discriminator Condition on EntityShaper and use it in shaper Part of #18923 Part of #10140 --- src/EFCore/Query/EntityShaperExpression.cs | 84 ++++++++++++++++++- .../ShapedQueryCompilingExpressionVisitor.cs | 80 +++++++----------- 2 files changed, 114 insertions(+), 50 deletions(-) diff --git a/src/EFCore/Query/EntityShaperExpression.cs b/src/EFCore/Query/EntityShaperExpression.cs index c68aa5b45b6..2b45efec8a9 100644 --- a/src/EFCore/Query/EntityShaperExpression.cs +++ b/src/EFCore/Query/EntityShaperExpression.cs @@ -2,31 +2,111 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; +using System.Reflection; using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Query { public class EntityShaperExpression : Expression, IPrintableExpression { + private static readonly MethodInfo _createUnableToDiscriminateException + = typeof(EntityShaperExpression).GetTypeInfo() + .GetDeclaredMethod(nameof(CreateUnableToDiscriminateException)); + + [UsedImplicitly] + private static Exception CreateUnableToDiscriminateException(IEntityType entityType, object discriminator) + => new InvalidOperationException(CoreStrings.UnableToDiscriminate(entityType.DisplayName(), discriminator?.ToString())); + public EntityShaperExpression( [NotNull] IEntityType entityType, [NotNull] Expression valueBufferExpression, bool nullable) + : this(entityType, valueBufferExpression, nullable, null) + { + } + + protected EntityShaperExpression( + [NotNull] IEntityType entityType, + [NotNull] Expression valueBufferExpression, + bool nullable, + [CanBeNull] LambdaExpression discriminatorCondition) { Check.NotNull(entityType, nameof(entityType)); Check.NotNull(valueBufferExpression, nameof(valueBufferExpression)); + if (discriminatorCondition == null) + { + // Generate condition to discriminator if TPH + discriminatorCondition = GenerateDiscriminatorCondition(entityType); + + } + else if (discriminatorCondition.Parameters.Count != 1 + || discriminatorCondition.Parameters[0].Type != typeof(ValueBuffer) + || discriminatorCondition.ReturnType != typeof(IEntityType)) + { + throw new InvalidOperationException( + "Discriminator condition must be lambda expression of type Func."); + } + EntityType = entityType; ValueBufferExpression = valueBufferExpression; IsNullable = nullable; + DiscriminatorCondition = discriminatorCondition; + } + + private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType) + { + var valueBufferParameter = Parameter(typeof(ValueBuffer)); + Expression body; + var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); + var discriminatorProperty = entityType.GetDiscriminatorProperty(); + if (discriminatorProperty != null) + { + var discriminatorValueVariable = Variable(discriminatorProperty.ClrType, "discriminator"); + var expressions = new List + { + Assign( + discriminatorValueVariable, + valueBufferParameter.CreateValueBufferReadValueExpression( + discriminatorProperty.ClrType, discriminatorProperty.GetIndex(), discriminatorProperty)) + }; + + var switchCases = new SwitchCase[concreteEntityTypes.Count]; + for (var i = 0; i < concreteEntityTypes.Count; i++) + { + var discriminatorValue = Constant(concreteEntityTypes[i].GetDiscriminatorValue(), discriminatorProperty.ClrType); + switchCases[i] = SwitchCase(Constant(concreteEntityTypes[i], typeof(IEntityType)), discriminatorValue); + } + + var exception = Block( + Throw(Call( + _createUnableToDiscriminateException, Constant(entityType), Convert(discriminatorValueVariable, typeof(object)))), + Constant(null, typeof(IEntityType))); + + expressions.Add(Switch(discriminatorValueVariable, exception, switchCases)); + body = Block(new[] { discriminatorValueVariable }, expressions); + } + else + { + body = Constant(concreteEntityTypes.Count == 1 ? concreteEntityTypes[0] : entityType, typeof(IEntityType)); + } + + return Lambda(body, valueBufferParameter); } public virtual IEntityType EntityType { get; } public virtual Expression ValueBufferExpression { get; } public virtual bool IsNullable { get; } + public virtual LambdaExpression DiscriminatorCondition { get; } protected override Expression VisitChildren(ExpressionVisitor visitor) { @@ -48,7 +128,7 @@ public virtual EntityShaperExpression WithEntityType([NotNull] IEntityType entit public virtual EntityShaperExpression MarkAsNullable() => !IsNullable - ? new EntityShaperExpression(EntityType, ValueBufferExpression, true) + ? new EntityShaperExpression(EntityType, ValueBufferExpression, true, DiscriminatorCondition) : this; public virtual EntityShaperExpression Update([NotNull] Expression valueBufferExpression) @@ -56,7 +136,7 @@ public virtual EntityShaperExpression Update([NotNull] Expression valueBufferExp Check.NotNull(valueBufferExpression, nameof(valueBufferExpression)); return valueBufferExpression != ValueBufferExpression - ? new EntityShaperExpression(EntityType, valueBufferExpression, IsNullable) + ? new EntityShaperExpression(EntityType, valueBufferExpression, IsNullable, DiscriminatorCondition) : this; } diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs index ce00d7250c8..d9a5eb5e970 100644 --- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs @@ -396,14 +396,13 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres Expression.MakeMemberAccess(entryVariable, _entityMemberInfo), entityType.ClrType))), MaterializeEntity( - entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, + entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, entryVariable)))); } else { if (primaryKey != null) { - expressions.Add(Expression.IfThen( primaryKey.Properties.Select( p => Expression.NotEqual( @@ -411,7 +410,7 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres Expression.Constant(null))) .Aggregate((a, b) => Expression.AndAlso(a, b)), MaterializeEntity( - entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null))); + entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null))); } else { @@ -430,13 +429,13 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres Expression.Constant(null))) .Aggregate((a, b) => Expression.OrElse(a, b)), MaterializeEntity( - entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null))); + entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null))); } else { expressions.Add( MaterializeEntity( - entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)); + entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)); } } } @@ -446,12 +445,14 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres } private Expression MaterializeEntity( - IEntityType entityType, + EntityShaperExpression entityShaperExpression, ParameterExpression materializationContextVariable, ParameterExpression concreteEntityTypeVariable, ParameterExpression instanceVariable, ParameterExpression entryVariable) { + var entityType = entityShaperExpression.EntityType; + var expressions = new List(); var variables = new List(); @@ -468,47 +469,35 @@ private Expression MaterializeEntity( Expression materializationExpression; var valueBufferExpression = Expression.Call(materializationContextVariable, MaterializationContext.GetValueBufferMethod); var expressionContext = (returnType, materializationContextVariable, concreteEntityTypeVariable, shadowValuesVariable); + expressions.Add( + Expression.Assign(concreteEntityTypeVariable, + ReplacingExpressionVisitor.Replace( + entityShaperExpression.DiscriminatorCondition.Parameters[0], + valueBufferExpression, + entityShaperExpression.DiscriminatorCondition.Body))); + var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); - var firstEntityType = concreteEntityTypes[0]; - if (concreteEntityTypes.Count == 1) + var discriminatorProperty = entityType.GetDiscriminatorProperty(); + if (discriminatorProperty != null) { - materializationExpression = CreateFullMaterializeExpression(firstEntityType, expressionContext); + var switchCases = new SwitchCase[concreteEntityTypes.Count]; + for (var i = 0; i < concreteEntityTypes.Count; i++) + { + switchCases[i] = Expression.SwitchCase( + CreateFullMaterializeExpression(concreteEntityTypes[i], expressionContext), + Expression.Constant(concreteEntityTypes[i], typeof(IEntityType))); + } + + materializationExpression = Expression.Switch( + concreteEntityTypeVariable, + Expression.Constant(null, returnType), + switchCases); } else { - var discriminatorProperty = firstEntityType.GetDiscriminatorProperty(); - var discriminatorValueVariable = Expression.Variable( - discriminatorProperty.ClrType, "discriminator" + _currentEntityIndex); - variables.Add(discriminatorValueVariable); - - expressions.Add( - Expression.Assign( - discriminatorValueVariable, - valueBufferExpression.CreateValueBufferReadValueExpression( - discriminatorProperty.ClrType, - discriminatorProperty.GetIndex(), - discriminatorProperty))); - - materializationExpression = Expression.Block( - Expression.Throw( - Expression.Call( - _createUnableToDiscriminateException, - Expression.Constant(entityType), - Expression.Convert(discriminatorValueVariable, typeof(object)))), - Expression.Constant(null, returnType)); - - foreach (var concreteEntityType in concreteEntityTypes) - { - var discriminatorValue - = Expression.Constant( - concreteEntityType.GetDiscriminatorValue(), - discriminatorProperty.ClrType); - - materializationExpression = Expression.Condition( - Expression.Equal(discriminatorValueVariable, discriminatorValue), - CreateFullMaterializeExpression(concreteEntityType, expressionContext), - materializationExpression); - } + materializationExpression = CreateFullMaterializeExpression( + concreteEntityTypes.Count == 1 ? concreteEntityTypes[0] : entityType, + expressionContext); } expressions.Add(Expression.Assign(instanceVariable, materializationExpression)); @@ -551,12 +540,7 @@ private BlockExpression CreateFullMaterializeExpression( concreteEntityTypeVariable, shadowValuesVariable) = materializeExpressionContext; - var blockExpressions = new List(3) - { - Expression.Assign( - concreteEntityTypeVariable, - Expression.Constant(concreteEntityType)) - }; + var blockExpressions = new List(2); var materializer = _entityMaterializerSource .CreateMaterializeExpression(concreteEntityType, "instance", materializationContextVariable);