Skip to content

Commit

Permalink
Query: Create Discriminator Condition on EntityShaper and use it in s…
Browse files Browse the repository at this point in the history
…haper

Part of #18923
Part of #10140
  • Loading branch information
smitpatel committed Mar 18, 2020
1 parent d97332e commit e72297c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 50 deletions.
84 changes: 82 additions & 2 deletions src/EFCore/Query/EntityShaperExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,111 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query
{
public class EntityShaperExpression : Expression, IPrintableExpression
{
private static readonly MethodInfo _createUnableToDiscriminateException
= typeof(EntityShaperExpression).GetTypeInfo()
.GetDeclaredMethod(nameof(CreateUnableToDiscriminateException));

[UsedImplicitly]
private static Exception CreateUnableToDiscriminateException(IEntityType entityType, object discriminator)
=> new InvalidOperationException(CoreStrings.UnableToDiscriminate(entityType.DisplayName(), discriminator?.ToString()));

public EntityShaperExpression(
[NotNull] IEntityType entityType,
[NotNull] Expression valueBufferExpression,
bool nullable)
: this(entityType, valueBufferExpression, nullable, null)
{
}

protected EntityShaperExpression(
[NotNull] IEntityType entityType,
[NotNull] Expression valueBufferExpression,
bool nullable,
[CanBeNull] LambdaExpression discriminatorCondition)
{
Check.NotNull(entityType, nameof(entityType));
Check.NotNull(valueBufferExpression, nameof(valueBufferExpression));

if (discriminatorCondition == null)
{
// Generate condition to discriminator if TPH
discriminatorCondition = GenerateDiscriminatorCondition(entityType);

}
else if (discriminatorCondition.Parameters.Count != 1
|| discriminatorCondition.Parameters[0].Type != typeof(ValueBuffer)
|| discriminatorCondition.ReturnType != typeof(IEntityType))
{
throw new InvalidOperationException(
"Discriminator condition must be lambda expression of type Func<ValueBuffer, IEntityType>.");
}

EntityType = entityType;
ValueBufferExpression = valueBufferExpression;
IsNullable = nullable;
DiscriminatorCondition = discriminatorCondition;
}

private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType)
{
var valueBufferParameter = Parameter(typeof(ValueBuffer));
Expression body;
var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList();
var discriminatorProperty = entityType.GetDiscriminatorProperty();
if (discriminatorProperty != null)
{
var discriminatorValueVariable = Variable(discriminatorProperty.ClrType, "discriminator");
var expressions = new List<Expression>
{
Assign(
discriminatorValueVariable,
valueBufferParameter.CreateValueBufferReadValueExpression(
discriminatorProperty.ClrType, discriminatorProperty.GetIndex(), discriminatorProperty))
};

var switchCases = new SwitchCase[concreteEntityTypes.Count];
for (var i = 0; i < concreteEntityTypes.Count; i++)
{
var discriminatorValue = Constant(concreteEntityTypes[i].GetDiscriminatorValue(), discriminatorProperty.ClrType);
switchCases[i] = SwitchCase(Constant(concreteEntityTypes[i], typeof(IEntityType)), discriminatorValue);
}

var exception = Block(
Throw(Call(
_createUnableToDiscriminateException, Constant(entityType), Convert(discriminatorValueVariable, typeof(object)))),
Constant(null, typeof(IEntityType)));

expressions.Add(Switch(discriminatorValueVariable, exception, switchCases));
body = Block(new[] { discriminatorValueVariable }, expressions);
}
else
{
body = Constant(concreteEntityTypes.Count == 1 ? concreteEntityTypes[0] : entityType, typeof(IEntityType));
}

return Lambda(body, valueBufferParameter);
}

public virtual IEntityType EntityType { get; }
public virtual Expression ValueBufferExpression { get; }
public virtual bool IsNullable { get; }
public virtual LambdaExpression DiscriminatorCondition { get; }

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
Expand All @@ -48,15 +128,15 @@ public virtual EntityShaperExpression WithEntityType([NotNull] IEntityType entit

public virtual EntityShaperExpression MarkAsNullable()
=> !IsNullable
? new EntityShaperExpression(EntityType, ValueBufferExpression, true)
? new EntityShaperExpression(EntityType, ValueBufferExpression, true, DiscriminatorCondition)
: this;

public virtual EntityShaperExpression Update([NotNull] Expression valueBufferExpression)
{
Check.NotNull(valueBufferExpression, nameof(valueBufferExpression));

return valueBufferExpression != ValueBufferExpression
? new EntityShaperExpression(EntityType, valueBufferExpression, IsNullable)
? new EntityShaperExpression(EntityType, valueBufferExpression, IsNullable, DiscriminatorCondition)
: this;
}

Expand Down
80 changes: 32 additions & 48 deletions src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -396,22 +396,21 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres
Expression.MakeMemberAccess(entryVariable, _entityMemberInfo),
entityType.ClrType))),
MaterializeEntity(
entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable,
entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable,
entryVariable))));
}
else
{
if (primaryKey != null)
{

expressions.Add(Expression.IfThen(
primaryKey.Properties.Select(
p => Expression.NotEqual(
valueBufferExpression.CreateValueBufferReadValueExpression(typeof(object), p.GetIndex(), p),
Expression.Constant(null)))
.Aggregate((a, b) => Expression.AndAlso(a, b)),
MaterializeEntity(
entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)));
entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)));
}
else
{
Expand All @@ -430,13 +429,13 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres
Expression.Constant(null)))
.Aggregate((a, b) => Expression.OrElse(a, b)),
MaterializeEntity(
entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)));
entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)));
}
else
{
expressions.Add(
MaterializeEntity(
entityType, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null));
entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null));
}
}
}
Expand All @@ -446,12 +445,14 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres
}

private Expression MaterializeEntity(
IEntityType entityType,
EntityShaperExpression entityShaperExpression,
ParameterExpression materializationContextVariable,
ParameterExpression concreteEntityTypeVariable,
ParameterExpression instanceVariable,
ParameterExpression entryVariable)
{
var entityType = entityShaperExpression.EntityType;

var expressions = new List<Expression>();
var variables = new List<ParameterExpression>();

Expand All @@ -468,47 +469,35 @@ private Expression MaterializeEntity(
Expression materializationExpression;
var valueBufferExpression = Expression.Call(materializationContextVariable, MaterializationContext.GetValueBufferMethod);
var expressionContext = (returnType, materializationContextVariable, concreteEntityTypeVariable, shadowValuesVariable);
expressions.Add(
Expression.Assign(concreteEntityTypeVariable,
ReplacingExpressionVisitor.Replace(
entityShaperExpression.DiscriminatorCondition.Parameters[0],
valueBufferExpression,
entityShaperExpression.DiscriminatorCondition.Body)));

var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList();
var firstEntityType = concreteEntityTypes[0];
if (concreteEntityTypes.Count == 1)
var discriminatorProperty = entityType.GetDiscriminatorProperty();
if (discriminatorProperty != null)
{
materializationExpression = CreateFullMaterializeExpression(firstEntityType, expressionContext);
var switchCases = new SwitchCase[concreteEntityTypes.Count];
for (var i = 0; i < concreteEntityTypes.Count; i++)
{
switchCases[i] = Expression.SwitchCase(
CreateFullMaterializeExpression(concreteEntityTypes[i], expressionContext),
Expression.Constant(concreteEntityTypes[i], typeof(IEntityType)));
}

materializationExpression = Expression.Switch(
concreteEntityTypeVariable,
Expression.Constant(null, returnType),
switchCases);
}
else
{
var discriminatorProperty = firstEntityType.GetDiscriminatorProperty();
var discriminatorValueVariable = Expression.Variable(
discriminatorProperty.ClrType, "discriminator" + _currentEntityIndex);
variables.Add(discriminatorValueVariable);

expressions.Add(
Expression.Assign(
discriminatorValueVariable,
valueBufferExpression.CreateValueBufferReadValueExpression(
discriminatorProperty.ClrType,
discriminatorProperty.GetIndex(),
discriminatorProperty)));

materializationExpression = Expression.Block(
Expression.Throw(
Expression.Call(
_createUnableToDiscriminateException,
Expression.Constant(entityType),
Expression.Convert(discriminatorValueVariable, typeof(object)))),
Expression.Constant(null, returnType));

foreach (var concreteEntityType in concreteEntityTypes)
{
var discriminatorValue
= Expression.Constant(
concreteEntityType.GetDiscriminatorValue(),
discriminatorProperty.ClrType);

materializationExpression = Expression.Condition(
Expression.Equal(discriminatorValueVariable, discriminatorValue),
CreateFullMaterializeExpression(concreteEntityType, expressionContext),
materializationExpression);
}
materializationExpression = CreateFullMaterializeExpression(
concreteEntityTypes.Count == 1 ? concreteEntityTypes[0] : entityType,
expressionContext);
}

expressions.Add(Expression.Assign(instanceVariable, materializationExpression));
Expand Down Expand Up @@ -551,12 +540,7 @@ private BlockExpression CreateFullMaterializeExpression(
concreteEntityTypeVariable,
shadowValuesVariable) = materializeExpressionContext;

var blockExpressions = new List<Expression>(3)
{
Expression.Assign(
concreteEntityTypeVariable,
Expression.Constant(concreteEntityType))
};
var blockExpressions = new List<Expression>(2);

var materializer = _entityMaterializerSource
.CreateMaterializeExpression(concreteEntityType, "instance", materializationContextVariable);
Expand Down

0 comments on commit e72297c

Please sign in to comment.