Skip to content

Commit

Permalink
Query: API review changes (#21346)
Browse files Browse the repository at this point in the history
Part of #20409
  • Loading branch information
smitpatel authored Jun 20, 2020
1 parent 973af84 commit 8c4bc0e
Show file tree
Hide file tree
Showing 15 changed files with 92 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ private SqlExpression TranslateExpression(Expression expression)
var translation = _sqlTranslator.Translate(expression);
if (translation == null && _sqlTranslator.TranslationErrorDetails != null)
{
ProvideTranslationErrorDetails(_sqlTranslator.TranslationErrorDetails);
AddTranslationErrorDetails(_sqlTranslator.TranslationErrorDetails);
}

return translation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public CosmosSqlTranslatingExpressionVisitor(
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected virtual void ProvideTranslationErrorDetails([NotNull] string details)
protected virtual void AddTranslationErrorDetails([NotNull] string details)
{
Check.NotNull(details, nameof(details));

Expand Down Expand Up @@ -497,11 +497,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
if (methodCallExpression.Method == _stringEqualsWithStringComparison
|| methodCallExpression.Method == _stringEqualsWithStringComparisonStatic)
{
ProvideTranslationErrorDetails(CoreStrings.QueryUnableToTranslateStringEqualsWithStringComparison);
AddTranslationErrorDetails(CoreStrings.QueryUnableToTranslateStringEqualsWithStringComparison);
}
else
{
ProvideTranslationErrorDetails(CoreStrings.QueryUnableToTranslateMethod(
AddTranslationErrorDetails(CoreStrings.QueryUnableToTranslateMethod(
methodCallExpression.Method.Name,
methodCallExpression.Method.DeclaringType?.DisplayName()));
}
Expand Down Expand Up @@ -611,7 +611,7 @@ private Expression TryBindMember(Expression source, MemberIdentity member)

if (result == null)
{
ProvideTranslationErrorDetails(
AddTranslationErrorDetails(
CoreStrings.QueryUnableToTranslateMember(
member.Name,
entityReferenceExpression.EntityType.DisplayName()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public InMemoryExpressionTranslatingExpressionVisitor(
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected virtual void ProvideTranslationErrorDetails([NotNull] string details)
protected virtual void AddTranslationErrorDetails([NotNull] string details)
{
Check.NotNull(details, nameof(details));

Expand Down Expand Up @@ -872,7 +872,7 @@ private Expression TryBindMember(Expression source, MemberIdentity member, Type

}

ProvideTranslationErrorDetails(
AddTranslationErrorDetails(
CoreStrings.QueryUnableToTranslateMember(
member.Name,
entityReferenceExpression.EntityType.DisplayName()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ private Expression TranslateExpression(Expression expression, bool preserveType
var translation = _expressionTranslator.Translate(expression);
if (translation == null && _expressionTranslator.TranslationErrorDetails != null)
{
ProvideTranslationErrorDetails(_expressionTranslator.TranslationErrorDetails);
AddTranslationErrorDetails(_expressionTranslator.TranslationErrorDetails);
}

if (expression != null
Expand Down
10 changes: 0 additions & 10 deletions src/EFCore.Relational/Query/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,6 @@ namespace Microsoft.EntityFrameworkCore.Query
/// </summary>
public static class ExpressionExtensions
{
/// <summary>
/// Checks if the given sql unary expression represents a logical NOT operation.
/// </summary>
/// <param name="sqlUnaryExpression"> A sql unary expression to check. </param>
/// <returns> A bool value indicating if the given expression represents a logical NOT operation. </returns>
public static bool IsLogicalNot([NotNull] this SqlUnaryExpression sqlUnaryExpression)
=> sqlUnaryExpression.OperatorType == ExpressionType.Not
&& (sqlUnaryExpression.Type == typeof(bool)
|| sqlUnaryExpression.Type == typeof(bool?));

/// <summary>
/// Infers type mapping from given <see cref="SqlExpression"/>s.
/// </summary>
Expand Down
11 changes: 11 additions & 0 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,18 @@ SqlUnaryExpression Convert(
/// <param name="operand"> An expression to compare with <see cref="CaseWhenClause.Test"/> in <paramref name="whenClauses"/>. </param>
/// <param name="whenClauses"> A list of <see cref="CaseWhenClause"/> to compare and get result from. </param>
/// <returns> An expression representing a CASE statement in a SQL tree. </returns>
[Obsolete("Use overload which takes IReadOnlyList instead of params")]
CaseExpression Case([NotNull] SqlExpression operand, [NotNull] params CaseWhenClause[] whenClauses);

/// <summary>
/// Creates a new <see cref="CaseExpression"/> which represent a CASE statement in a SQL tree.
/// </summary>
/// <param name="operand"> An expression to compare with <see cref="CaseWhenClause.Test"/> in <paramref name="whenClauses"/>. </param>
/// <param name="whenClauses"> A list of <see cref="CaseWhenClause"/> to compare and get result from. </param>
/// <param name="elseResult"> A value to return if no <paramref name="whenClauses"/> matches, if any. </param>
/// <returns> An expression representing a CASE statement in a SQL tree. </returns>
CaseExpression Case([NotNull] SqlExpression operand, [NotNull] IReadOnlyList<CaseWhenClause> whenClauses, [CanBeNull] SqlExpression elseResult);

/// <summary>
/// Creates a new <see cref="CaseExpression"/> which represent a CASE statement in a SQL tree.
/// </summary>
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpressio
}

case ExpressionType.Not
when sqlUnaryExpression.IsLogicalNot():
when sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool):
{
_relationalCommandBuilder.Append("NOT (");
Visit(sqlUnaryExpression.Operand);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ private SqlExpression TranslateExpression(Expression expression)
var translation = _sqlTranslator.Translate(expression);
if (translation == null && _sqlTranslator.TranslationErrorDetails != null)
{
ProvideTranslationErrorDetails(_sqlTranslator.TranslationErrorDetails);
AddTranslationErrorDetails(_sqlTranslator.TranslationErrorDetails);
}

return translation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ public RelationalSqlTranslatingExpressionVisitor(
public virtual string TranslationErrorDetails { get; private set; }

/// <summary>
/// Provides detailed information about error encountered during translation.
/// Adds detailed information about error encountered during translation.
/// </summary>
/// <param name="details">Detailed information about error encountered during translation.</param>
protected virtual void ProvideTranslationErrorDetails([NotNull] string details)
/// <param name="details"> Detailed information about error encountered during translation. </param>
protected virtual void AddTranslationErrorDetails([NotNull] string details)
{
Check.NotNull(details, nameof(details));

Expand Down Expand Up @@ -702,11 +702,11 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method)
if (methodCallExpression.Method == _stringEqualsWithStringComparison
|| methodCallExpression.Method == _stringEqualsWithStringComparisonStatic)
{
ProvideTranslationErrorDetails(CoreStrings.QueryUnableToTranslateStringEqualsWithStringComparison);
AddTranslationErrorDetails(CoreStrings.QueryUnableToTranslateStringEqualsWithStringComparison);
}
else
{
ProvideTranslationErrorDetails(CoreStrings.QueryUnableToTranslateMethod(
AddTranslationErrorDetails(CoreStrings.QueryUnableToTranslateMethod(
methodCallExpression.Method.Name,
methodCallExpression.Method.DeclaringType?.DisplayName()));
}
Expand Down Expand Up @@ -842,7 +842,7 @@ private Expression TryBindMember(Expression source, MemberIdentity member)
return BindProperty(entityReferenceExpression, property);
}

ProvideTranslationErrorDetails(
AddTranslationErrorDetails(
CoreStrings.QueryUnableToTranslateMember(
member.Name,
entityReferenceExpression.EntityType.DisplayName()));
Expand Down
27 changes: 18 additions & 9 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private SqlExpression ApplyTypeMappingOnSqlUnary(
case ExpressionType.Equal:
case ExpressionType.NotEqual:
case ExpressionType.Not
when sqlUnaryExpression.IsLogicalNot():
when sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool):
{
resultTypeMapping = _boolTypeMapping;
resultType = typeof(bool);
Expand Down Expand Up @@ -417,12 +417,9 @@ public virtual SqlUnaryExpression MakeUnary(
Check.NotNull(operand, nameof(operand));
Check.NotNull(type, nameof(type));

if (!SqlUnaryExpression.IsValidOperator(operatorType))
{
return null;
}

return (SqlUnaryExpression)ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping);
return !SqlUnaryExpression.IsValidOperator(operatorType)
? null
: (SqlUnaryExpression)ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping);
}

/// <inheritdoc />
Expand Down Expand Up @@ -467,21 +464,33 @@ public virtual SqlUnaryExpression Negate(SqlExpression operand)
}

/// <inheritdoc />
[Obsolete("Use overload which takes IReadOnlyList instead of params")]
public virtual CaseExpression Case(SqlExpression operand, params CaseWhenClause[] whenClauses)
{
Check.NotNull(operand, nameof(operand));
Check.NotNull(whenClauses, nameof(whenClauses));

return Case(operand, whenClauses, null);
}

/// <inheritdoc />
public virtual CaseExpression Case(SqlExpression operand, IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression elseResult)
{
Check.NotNull(operand, nameof(operand));
Check.NotNull(whenClauses, nameof(whenClauses));

var operandTypeMapping = operand.TypeMapping
?? whenClauses.Select(wc => wc.Test.TypeMapping).FirstOrDefault(t => t != null)
// Since we never look at type of Operand/Test after this place,
// we need to find actual typeMapping based on non-object type.
?? new[] { operand.Type }.Concat(whenClauses.Select(wc => wc.Test.Type))
.Where(t => t != typeof(object)).Select(t => _typeMappingSource.FindMapping(t)).FirstOrDefault();

var resultTypeMapping = whenClauses.Select(wc => wc.Result.TypeMapping).FirstOrDefault(t => t != null);
var resultTypeMapping = elseResult?.TypeMapping
?? whenClauses.Select(wc => wc.Result.TypeMapping).FirstOrDefault(t => t != null);

operand = ApplyTypeMapping(operand, operandTypeMapping);
elseResult = ApplyTypeMapping(elseResult, resultTypeMapping);

var typeMappedWhenClauses = new List<CaseWhenClause>();
foreach (var caseWhenClause in whenClauses)
Expand All @@ -492,7 +501,7 @@ public virtual CaseExpression Case(SqlExpression operand, params CaseWhenClause[
ApplyTypeMapping(caseWhenClause.Result, resultTypeMapping)));
}

return new CaseExpression(operand, typeMappedWhenClauses);
return new CaseExpression(operand, typeMappedWhenClauses, elseResult);
}

/// <inheritdoc />
Expand Down
13 changes: 9 additions & 4 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1115,8 +1115,8 @@ private SqlExpression OptimizeComparison(
var leftUnary = left as SqlUnaryExpression;
var rightUnary = right as SqlUnaryExpression;

var leftNegated = leftUnary?.IsLogicalNot() == true;
var rightNegated = rightUnary?.IsLogicalNot() == true;
var leftNegated = IsLogicalNot(leftUnary);
var rightNegated = IsLogicalNot(rightUnary);

if (leftNegated)
{
Expand Down Expand Up @@ -1152,8 +1152,8 @@ private SqlExpression RewriteNullSemantics(
var leftUnary = left as SqlUnaryExpression;
var rightUnary = right as SqlUnaryExpression;

var leftNegated = leftUnary?.IsLogicalNot() == true;
var rightNegated = rightUnary?.IsLogicalNot() == true;
var leftNegated = IsLogicalNot(leftUnary);
var rightNegated = IsLogicalNot(rightUnary);

if (leftNegated)
{
Expand Down Expand Up @@ -1608,6 +1608,11 @@ private SqlExpression ProcessNullNotNull(SqlUnaryExpression sqlUnaryExpression,
return sqlUnaryExpression;
}

private static bool IsLogicalNot(SqlUnaryExpression sqlUnaryExpression)
=> sqlUnaryExpression != null
&& sqlUnaryExpression.OperatorType == ExpressionType.Not
&& sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool);

// ?a == ?b -> [(a == b) && (a != null && b != null)] || (a == null && b == null))
//
// a | b | F1 = a == b | F2 = (a != null && b != null) | F3 = F1 && F2 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ public virtual SqlExpression Translate(SqlExpression instance, MemberInfo member
instancePropagatesNullability: true,
argumentsPropagateNullability: Array.Empty<bool>(),
typeof(string)),
whenClauses.ToArray());
whenClauses,
null);
}

if (Equals(member, _srid))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ private SqlExpression BuildCompareToExpression(SqlExpression sqlExpression)
// !(a != b) -> (a == b)
private SqlExpression SimplifyNegatedBinary(SqlExpression sqlExpression)
=> sqlExpression is SqlUnaryExpression sqlUnaryExpression
&& sqlUnaryExpression.IsLogicalNot()
&& sqlUnaryExpression.OperatorType == ExpressionType.Not
&& sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool)
&& sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand
&& (sqlBinaryOperand.OperatorType == ExpressionType.Equal || sqlBinaryOperand.OperatorType == ExpressionType.NotEqual)
? _sqlExpressionFactory.MakeBinary(
Expand Down Expand Up @@ -350,7 +351,7 @@ protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpressio
switch (sqlUnaryExpression.OperatorType)
{
case ExpressionType.Not
when sqlUnaryExpression.IsLogicalNot():
when sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool):
{
_isSearchCondition = true;
resultCondition = true;
Expand Down
Loading

0 comments on commit 8c4bc0e

Please sign in to comment.