diff --git a/src/EFCore/Query/EntityShaperExpression.cs b/src/EFCore/Query/EntityShaperExpression.cs index 2b45efec8a9..d0e34bedd61 100644 --- a/src/EFCore/Query/EntityShaperExpression.cs +++ b/src/EFCore/Query/EntityShaperExpression.cs @@ -45,9 +45,7 @@ protected EntityShaperExpression( if (discriminatorCondition == null) { - // Generate condition to discriminator if TPH - discriminatorCondition = GenerateDiscriminatorCondition(entityType); - + discriminatorCondition = GenerateDiscriminatorCondition(entityType, nullable); } else if (discriminatorCondition.Parameters.Count != 1 || discriminatorCondition.Parameters[0].Type != typeof(ValueBuffer) @@ -63,11 +61,11 @@ protected EntityShaperExpression( DiscriminatorCondition = discriminatorCondition; } - private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType) + private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType, bool nullable) { var valueBufferParameter = Parameter(typeof(ValueBuffer)); Expression body; - var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); + var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToArray(); var discriminatorProperty = entityType.GetDiscriminatorProperty(); if (discriminatorProperty != null) { @@ -80,8 +78,8 @@ private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType) discriminatorProperty.ClrType, discriminatorProperty.GetIndex(), discriminatorProperty)) }; - var switchCases = new SwitchCase[concreteEntityTypes.Count]; - for (var i = 0; i < concreteEntityTypes.Count; i++) + var switchCases = new SwitchCase[concreteEntityTypes.Length]; + for (var i = 0; i < concreteEntityTypes.Length; i++) { var discriminatorValue = Constant(concreteEntityTypes[i].GetDiscriminatorValue(), discriminatorProperty.ClrType); switchCases[i] = SwitchCase(Constant(concreteEntityTypes[i], typeof(IEntityType)), discriminatorValue); @@ -97,7 +95,20 @@ private LambdaExpression GenerateDiscriminatorCondition(IEntityType entityType) } else { - body = Constant(concreteEntityTypes.Count == 1 ? concreteEntityTypes[0] : entityType, typeof(IEntityType)); + body = Constant(concreteEntityTypes.Length == 1 ? concreteEntityTypes[0] : entityType, typeof(IEntityType)); + } + + if (entityType.FindPrimaryKey() == null + && nullable) + { + body = Condition( + entityType.GetProperties() + .Select(p => NotEqual( + valueBufferParameter.CreateValueBufferReadValueExpression(typeof(object), p.GetIndex(), p), + Constant(null))) + .Aggregate((a, b) => OrElse(a, b)), + body, + Default(typeof(IEntityType))); } return Lambda(body, valueBufferParameter); @@ -128,7 +139,8 @@ public virtual EntityShaperExpression WithEntityType([NotNull] IEntityType entit public virtual EntityShaperExpression MarkAsNullable() => !IsNullable - ? new EntityShaperExpression(EntityType, ValueBufferExpression, true, DiscriminatorCondition) + // Marking nullable requires recomputation of Discriminator condition + ? new EntityShaperExpression(EntityType, ValueBufferExpression, true) : this; public virtual EntityShaperExpression Update([NotNull] Expression valueBufferExpression) diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs index d9a5eb5e970..2a6e633f5f7 100644 --- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs @@ -414,29 +414,9 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres } else { - if (entityShaperExpression.IsNullable) - { - expressions.Add( - Expression.IfThen( - entityType.GetProperties() - .Select( - p => - Expression.NotEqual( - valueBufferExpression.CreateValueBufferReadValueExpression( - typeof(object), - p.GetIndex(), - p), - Expression.Constant(null))) - .Aggregate((a, b) => Expression.OrElse(a, b)), - MaterializeEntity( - entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null))); - } - else - { - expressions.Add( - MaterializeEntity( - entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)); - } + expressions.Add( + MaterializeEntity( + entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, null)); } } @@ -476,30 +456,27 @@ private Expression MaterializeEntity( valueBufferExpression, entityShaperExpression.DiscriminatorCondition.Body))); - var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); + var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToArray(); var discriminatorProperty = entityType.GetDiscriminatorProperty(); - if (discriminatorProperty != null) + if (discriminatorProperty == null + && concreteEntityTypes.Length > 1) { - var switchCases = new SwitchCase[concreteEntityTypes.Count]; - for (var i = 0; i < concreteEntityTypes.Count; i++) - { - switchCases[i] = Expression.SwitchCase( - CreateFullMaterializeExpression(concreteEntityTypes[i], expressionContext), - Expression.Constant(concreteEntityTypes[i], typeof(IEntityType))); - } - - materializationExpression = Expression.Switch( - concreteEntityTypeVariable, - Expression.Constant(null, returnType), - switchCases); + concreteEntityTypes = new [] { entityType }; } - else + + var switchCases = new SwitchCase[concreteEntityTypes.Length]; + for (var i = 0; i < concreteEntityTypes.Length; i++) { - materializationExpression = CreateFullMaterializeExpression( - concreteEntityTypes.Count == 1 ? concreteEntityTypes[0] : entityType, - expressionContext); + switchCases[i] = Expression.SwitchCase( + CreateFullMaterializeExpression(concreteEntityTypes[i], expressionContext), + Expression.Constant(concreteEntityTypes[i], typeof(IEntityType))); } + materializationExpression = Expression.Switch( + concreteEntityTypeVariable, + Expression.Constant(null, returnType), + switchCases); + expressions.Add(Expression.Assign(instanceVariable, materializationExpression)); if (_trackQueryResults