Skip to content

Commit

Permalink
Fix to #16078 - Query/Null semantics: when checking if expression is …
Browse files Browse the repository at this point in the history
…null, just check it's constituents rather than entire expression

Problem was that during null semantics rewrite we create IS NULL calls on the operands of the comparison. If the operands themselves are complicated, we were simply comparing the entire complex expression to null. In some cases, we only need to look at constituents, e.g. a + b == null <=> a == null || b == null.

Also added other minor optimizations around null semantics:

- non_nullable_column IS NULL resolves to false,
- try to simplify expression after applying de Morgan transformations

Also fixed a bug exposed by these changes, where column nullability would be incorrect for scenarios with owned types.
  • Loading branch information
maumar committed Oct 24, 2019
1 parent 4e796ce commit 249ae62
Show file tree
Hide file tree
Showing 13 changed files with 408 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ private SqlBinaryExpression VisitSqlBinaryExpression(SqlBinaryExpression sqlBina
newRight = rightUnary.Operand;
}

// TODO: optimize this by looking at subcomponents, e.g. f(a, b) == null <=> a == null || b == null
var leftIsNull = _sqlExpressionFactory.IsNull(newLeft);
var rightIsNull = _sqlExpressionFactory.IsNull(newRight);

// doing a full null semantics rewrite - removing all nulls from truth table
// this will NOT be correct once we introduce simplified null semantics
_isNullable = false;

if (sqlBinaryExpression.OperatorType == ExpressionType.Equal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
Expand Down Expand Up @@ -47,42 +48,62 @@ protected override Expression VisitExtension(Expression extensionExpression)

protected virtual Expression VisitSqlUnaryExpression(SqlUnaryExpression sqlUnaryExpression)
{
if (sqlUnaryExpression.OperatorType == ExpressionType.Not)
switch (sqlUnaryExpression.OperatorType)
{
return VisitNot(sqlUnaryExpression);
}
case ExpressionType.Not:
return VisitNot(sqlUnaryExpression);

// NULL IS NULL -> true
// non_nullable_constant IS NULL -> false
if (sqlUnaryExpression.OperatorType == ExpressionType.Equal
&& sqlUnaryExpression.Operand is SqlConstantExpression innerConstantNull1)
{
return SqlExpressionFactory.Constant(innerConstantNull1.Value == null, sqlUnaryExpression.TypeMapping);
}
case ExpressionType.Equal:
switch (sqlUnaryExpression.Operand)
{
case SqlConstantExpression constantOperand:
return SqlExpressionFactory.Constant(constantOperand.Value == null, sqlUnaryExpression.TypeMapping);

// NULL IS NOT NULL -> false
// non_nullable_constant IS NOT NULL -> true
if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual
&& sqlUnaryExpression.Operand is SqlConstantExpression innerConstantNull2)
{
return SqlExpressionFactory.Constant(innerConstantNull2.Value != null, sqlUnaryExpression.TypeMapping);
}
case ColumnExpression columnOperand
when !columnOperand.IsNullable:
return SqlExpressionFactory.Constant(false, sqlUnaryExpression.TypeMapping);

if (sqlUnaryExpression.Operand is SqlUnaryExpression innerUnary)
{
// (!a) IS NULL <==> a IS NULL
if (sqlUnaryExpression.OperatorType == ExpressionType.Equal
&& innerUnary.OperatorType == ExpressionType.Not)
{
return Visit(SqlExpressionFactory.IsNull(innerUnary.Operand));
}
case SqlUnaryExpression sqlUnaryOperand
when sqlUnaryOperand.OperatorType == ExpressionType.Convert
|| sqlUnaryOperand.OperatorType == ExpressionType.Not
|| sqlUnaryOperand.OperatorType == ExpressionType.Negate:
return (SqlExpression)Visit(SqlExpressionFactory.IsNull(sqlUnaryOperand.Operand));

// (!a) IS NOT NULL <==> a IS NOT NULL
if (sqlUnaryExpression.OperatorType == ExpressionType.NotEqual
&& innerUnary.OperatorType == ExpressionType.Not)
{
return Visit(SqlExpressionFactory.IsNotNull(innerUnary.Operand));
}
case SqlBinaryExpression sqlBinaryOperand:
var newLeft = (SqlExpression)Visit(SqlExpressionFactory.IsNull(sqlBinaryOperand.Left));
var newRight = (SqlExpression)Visit(SqlExpressionFactory.IsNull(sqlBinaryOperand.Right));

return sqlBinaryOperand.OperatorType == ExpressionType.Coalesce
? SimplifyLogicalSqlBinaryExpression(ExpressionType.AndAlso, newLeft, newRight, sqlBinaryOperand.TypeMapping)
: SimplifyLogicalSqlBinaryExpression(ExpressionType.OrElse, newLeft, newRight, sqlBinaryOperand.TypeMapping);
}
break;

case ExpressionType.NotEqual:
switch (sqlUnaryExpression.Operand)
{
case SqlConstantExpression constantOperand:
return SqlExpressionFactory.Constant(constantOperand.Value != null, sqlUnaryExpression.TypeMapping);

case ColumnExpression columnOperand
when !columnOperand.IsNullable:
return SqlExpressionFactory.Constant(true, sqlUnaryExpression.TypeMapping);

case SqlUnaryExpression sqlUnaryOperand
when sqlUnaryOperand.OperatorType == ExpressionType.Convert
|| sqlUnaryOperand.OperatorType == ExpressionType.Not
|| sqlUnaryOperand.OperatorType == ExpressionType.Negate:
return (SqlExpression)Visit(SqlExpressionFactory.IsNotNull(sqlUnaryOperand.Operand));

case SqlBinaryExpression sqlBinaryOperand:
var newLeft = (SqlExpression)Visit(SqlExpressionFactory.IsNotNull(sqlBinaryOperand.Left));
var newRight = (SqlExpression)Visit(SqlExpressionFactory.IsNotNull(sqlBinaryOperand.Right));

return sqlBinaryOperand.OperatorType == ExpressionType.Coalesce
? SimplifyLogicalSqlBinaryExpression(ExpressionType.OrElse, newLeft, newRight, sqlBinaryOperand.TypeMapping)
: SimplifyLogicalSqlBinaryExpression(ExpressionType.AndAlso, newLeft, newRight, sqlBinaryOperand.TypeMapping);
}
break;
}

var newOperand = (SqlExpression)Visit(sqlUnaryExpression.Operand);
Expand Down Expand Up @@ -135,9 +156,13 @@ private Expression VisitNot(SqlUnaryExpression sqlUnaryExpression)
var newLeft = (SqlExpression)Visit(SqlExpressionFactory.Not(innerBinary.Left));
var newRight = (SqlExpression)Visit(SqlExpressionFactory.Not(innerBinary.Right));

return innerBinary.OperatorType == ExpressionType.AndAlso
? SqlExpressionFactory.OrElse(newLeft, newRight)
: SqlExpressionFactory.AndAlso(newLeft, newRight);
return SimplifyLogicalSqlBinaryExpression(
innerBinary.OperatorType == ExpressionType.AndAlso
? ExpressionType.OrElse
: ExpressionType.AndAlso,
newLeft,
newRight,
innerBinary.TypeMapping);
}

// those optimizations are only valid in 2-value logic
Expand Down Expand Up @@ -168,36 +193,11 @@ private Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpress
if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
|| sqlBinaryExpression.OperatorType == ExpressionType.OrElse)
{
// true && a -> a
// true || a -> true
// false && a -> false
// false || a -> a
if (newLeft is SqlConstantExpression newLeftConstant)
{
return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
? (bool)newLeftConstant.Value
? newRight
: newLeftConstant
: (bool)newLeftConstant.Value
? newLeftConstant
: newRight;
}
else if (newRight is SqlConstantExpression newRightConstant)
{
// a && true -> a
// a || true -> true
// a && false -> false
// a || false -> a
return sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
? (bool)newRightConstant.Value
? newLeft
: newRightConstant
: (bool)newRightConstant.Value
? newRightConstant
: newLeft;
}

return sqlBinaryExpression.Update(newLeft, newRight);
return SimplifyLogicalSqlBinaryExpression(
sqlBinaryExpression.OperatorType,
newLeft,
newRight,
sqlBinaryExpression.TypeMapping);
}

// those optimizations are only valid in 2-value logic
Expand Down Expand Up @@ -227,5 +227,43 @@ private Expression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpress

return sqlBinaryExpression.Update(newLeft, newRight);
}

private SqlExpression SimplifyLogicalSqlBinaryExpression(
ExpressionType operatorType,
SqlExpression newLeft,
SqlExpression newRight,
RelationalTypeMapping typeMapping)
{
// true && a -> a
// true || a -> true
// false && a -> false
// false || a -> a
if (newLeft is SqlConstantExpression newLeftConstant)
{
return operatorType == ExpressionType.AndAlso
? (bool)newLeftConstant.Value
? newRight
: newLeftConstant
: (bool)newLeftConstant.Value
? newLeftConstant
: newRight;
}
else if (newRight is SqlConstantExpression newRightConstant)
{
// a && true -> a
// a || true -> true
// a && false -> false
// a || false -> a
return operatorType == ExpressionType.AndAlso
? (bool)newRightConstant.Value
? newLeft
: newRightConstant
: (bool)newRightConstant.Value
? newRightConstant
: newLeft;
}

return SqlExpressionFactory.MakeBinary(operatorType, newLeft, newRight, typeMapping);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,33 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
var newSelectExpression = (SelectExpression)base.VisitExtension(extensionExpression);

return newSelectExpression.Predicate is SqlConstantExpression newSelectPredicateConstant
// if predicate is optimized to true, we can simply remove it
var newPredicate = newSelectExpression.Predicate is SqlConstantExpression newSelectPredicateConstant
&& !(selectExpression.Predicate is SqlConstantExpression)
? (bool)newSelectPredicateConstant.Value
? null
: SqlExpressionFactory.Equal(
newSelectPredicateConstant,
SqlExpressionFactory.Constant(true, newSelectPredicateConstant.TypeMapping))
: newSelectExpression.Predicate;

var newHaving = newSelectExpression.Having is SqlConstantExpression newSelectHavingConstant
&& !(selectExpression.Having is SqlConstantExpression)
? (bool)newSelectHavingConstant.Value
? null
: SqlExpressionFactory.Equal(
newSelectHavingConstant,
SqlExpressionFactory.Constant(true, newSelectHavingConstant.TypeMapping))
: newSelectExpression.Having;

return newPredicate != newSelectExpression.Predicate
|| newHaving != newSelectExpression.Having
? newSelectExpression.Update(
newSelectExpression.Projection.ToList(),
newSelectExpression.Tables.ToList(),
SqlExpressionFactory.Equal(
newSelectPredicateConstant,
SqlExpressionFactory.Constant(true, newSelectPredicateConstant.TypeMapping)),
newPredicate,
newSelectExpression.GroupBy.ToList(),
newSelectExpression.Having,
newHaving,
newSelectExpression.Orderings.ToList(),
newSelectExpression.Limit,
newSelectExpression.Offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class ColumnExpression : SqlExpression
internal ColumnExpression(IProperty property, TableExpressionBase table, bool nullable)
: this(
property.GetColumnName(), table, property.ClrType, property.GetRelationalTypeMapping(),
nullable || property.IsNullable || property.DeclaringEntityType.BaseType != null)
nullable || property.IsColumnNullable())
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,25 @@ public virtual void Null_semantics_with_null_check_complex()
}
}

[ConditionalFact]
public virtual void IsNull_on_complex_expression()
{
using (var ctx = CreateContext())
{
var query1 = ctx.Entities1.Where(e => -e.NullableIntA != null).ToList();
Assert.Equal(18, query1.Count);

var query2 = ctx.Entities1.Where(e => (e.NullableIntA + e.NullableIntB) == null).ToList();
Assert.Equal(15, query2.Count);

var query3 = ctx.Entities1.Where(e => (e.NullableIntA ?? e.NullableIntB) == null).ToList();
Assert.Equal(3, query3.Count);

var query4 = ctx.Entities1.Where(e => (e.NullableIntA ?? e.NullableIntB) != null).ToList();
Assert.Equal(24, query4.Count);
}
}

protected static TResult Maybe<TResult>(object caller, Func<TResult> expression)
where TResult : class
{
Expand Down
101 changes: 101 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8006,6 +8006,107 @@ public virtual Task Group_by_with_aggregate_max_on_entity_type(bool isAsync)
})));
}

[ConditionalTheory(Skip = "issue #18492")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_by_on_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQueryScalar(
isAsync,
ss => ss.Set<Gear>().GroupBy(g => g.FullName.StartsWith(prm)).Select(g => g.Key));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_by_with_having_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQuery(
isAsync,
ss => ss.Set<Gear>().GroupBy(g => g.FullName).Where(g => g.Key.StartsWith(prm)).Select(g => g.Key),
ss => ss.Set<Gear>().GroupBy(g => g.FullName).Where(g => false).Select(g => g.Key));
}

[ConditionalTheory(Skip = "issue #18492")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQueryScalar(
isAsync,
ss => ss.Set<Gear>().Select(g => g.FullName.StartsWith(prm)),
ss => ss.Set<Gear>().Select(g => false));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_null_parameter_is_not_null(bool isAsync)
{
var prm = (string)null;

return AssertQueryScalar(
isAsync,
ss => ss.Set<Gear>().Select(g => prm != null),
ss => ss.Set<Gear>().Select(g => false));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_null_parameter_is_not_null(bool isAsync)
{
var prm = (string)null;

return AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => prm != null),
ss => ss.Set<Gear>().Where(g => false));
}

[ConditionalTheory(Skip = "issue #18492")]
[MemberData(nameof(IsAsyncData))]
public virtual Task OrderBy_StartsWith_with_null_parameter_as_argument(bool isAsync)
{
var prm = (string)null;

return AssertQuery(
isAsync,
ss => ss.Set<Gear>().OrderBy(g => g.FullName.StartsWith(prm)).ThenBy(g => g.Nickname),
ss => ss.Set<Gear>().OrderBy(g => false).ThenBy(g => g.Nickname),
assertOrder: true);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Where_with_enum_flags_parameter(bool isAsync)
{
MilitaryRank? rank = MilitaryRank.Private;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank & rank) == rank));

rank = null;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank & rank) == rank));

rank = MilitaryRank.Corporal;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank | rank) != rank));

rank = null;

await AssertQuery(
isAsync,
ss => ss.Set<Gear>().Where(g => (g.Rank | rank) != rank));
}

protected async Task AssertTranslationFailed(Func<Task> testCode)
{
Assert.Contains(
Expand Down
Loading

0 comments on commit 249ae62

Please sign in to comment.