Skip to content

Commit

Permalink
Fix to #8652 - InvalidCastException when casting from one value type …
Browse files Browse the repository at this point in the history
…to another in a simple select statement

Problem was for queries with value types being projected in a select expression when also using convert. The problem was that getValue is typed as object which then was converted to an expected type. However if the value returned by SQL was not exactly the same type, exception would get thrown due to boxing/unboxing.

Fix is to detect when we apply convert on a top level projection, and in this case use explicit cast, so that type returned by SQL was exactly the same as the type that was expected after unboxing.

Also added support for translating Negate expression, which was previously evaluated on the client.
  • Loading branch information
maumar committed Jun 15, 2017
1 parent a15be13 commit 081c18b
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ private static readonly Dictionary<ExpressionType, ExpressionType> _inverseOpera
private readonly bool _inProjection;
private readonly NullCheckRemovalTestingVisitor _nullCheckRemovalTestingVisitor;

private bool _isTopLevelProjection;

/// <summary>
/// Creates a new instance of <see cref="SqlTranslatingExpressionVisitor" />.
/// </summary>
Expand Down Expand Up @@ -79,6 +81,7 @@ public SqlTranslatingExpressionVisitor(
_topLevelPredicate = topLevelPredicate;
_inProjection = inProjection;
_nullCheckRemovalTestingVisitor = new NullCheckRemovalTestingVisitor(_queryModelVisitor);
_isTopLevelProjection = inProjection;
}

/// <summary>
Expand Down Expand Up @@ -107,7 +110,22 @@ public override Expression Visit(Expression expression)
return Visit(translatedExpression);
}

return base.Visit(expression);
if ((expression is UnaryExpression || expression is NewExpression))
{
return base.Visit(expression);
}

var isTopLevelProjection = _isTopLevelProjection;
_isTopLevelProjection = false;

try
{
return base.Visit(expression);
}
finally
{
_isTopLevelProjection = isTopLevelProjection;
}
}

/// <summary>
Expand All @@ -129,9 +147,9 @@ protected override Expression VisitBinary(BinaryExpression expression)
var left = Visit(expression.Left);
var right = Visit(expression.Right);

return left != null
&& right != null
&& left.Type != typeof(Expression[])
return left != null
&& right != null
&& left.Type != typeof(Expression[])
&& right.Type != typeof(Expression[])
? expression.Update(left, expression.Conversion, right)
: null;
Expand Down Expand Up @@ -202,7 +220,7 @@ var structuralComparisonExpression
var rightExpression = Visit(expression.Right);

return leftExpression != null
&& rightExpression != null
&& rightExpression != null
? Expression.MakeBinary(
expression.NodeType,
leftExpression,
Expand Down Expand Up @@ -303,7 +321,7 @@ public NullCheckRemovalTestingVisitor(RelationalQueryModelVisitor queryModelVisi
}

public bool CanRemoveNullCheck(
Expression testExpression,
Expression testExpression,
Expression resultExpression)
{
AnalyzeTestExpression(testExpression);
Expand Down Expand Up @@ -449,7 +467,7 @@ protected override Expression VisitExtension(Expression extensionExpression)

return extensionExpression;
}
}
}

private static Expression UnfoldStructuralComparison(ExpressionType expressionType, Expression expression)
{
Expand Down Expand Up @@ -639,7 +657,7 @@ var projectionIndex

private bool IsNonTranslatableSubquery(Expression expression)
=> expression is SubQueryExpression subQueryExpression
&& !(subQueryExpression.QueryModel.GetOutputDataInfo() is StreamedScalarValueInfo
&& !(subQueryExpression.QueryModel.GetOutputDataInfo() is StreamedScalarValueInfo
|| subQueryExpression.QueryModel.GetOutputDataInfo() is StreamedSingleValueInfo streamedSingleValueInfo
&& IsStreamedSingleValueSupportedType(streamedSingleValueInfo));

Expand Down Expand Up @@ -754,6 +772,16 @@ protected override Expression VisitUnary(UnaryExpression expression)

switch (expression.NodeType)
{
case ExpressionType.Negate:
{
var operand = Visit(expression.Operand);
if (operand != null)
{
return Expression.Negate(operand);
}

break;
}
case ExpressionType.Not:
{
var operand = Visit(expression.Operand);
Expand All @@ -766,10 +794,20 @@ protected override Expression VisitUnary(UnaryExpression expression)
}
case ExpressionType.Convert:
{
var isTopLevelProjection = _isTopLevelProjection;
_isTopLevelProjection = false;
var operand = Visit(expression.Operand);
_isTopLevelProjection = isTopLevelProjection;

if (operand != null)
{
return Expression.Convert(operand, expression.Type);
return _isTopLevelProjection
&& operand.Type.IsValueType
&& expression.Type.IsValueType
&& expression.Type.UnwrapNullableType() != operand.Type.UnwrapNullableType()
&& expression.Type.UnwrapEnumType() != operand.Type.UnwrapEnumType()
? (Expression)new ExplicitCastExpression(operand, expression.Type)
: Expression.Convert(operand, expression.Type);
}

break;
Expand Down Expand Up @@ -985,7 +1023,7 @@ protected override Expression VisitExtension(Expression expression)
}

return newLeft != stringCompare.Left
|| newRight != stringCompare.Right
|| newRight != stringCompare.Right
? new StringCompareExpression(stringCompare.Operator, newLeft, newRight)
: expression;
}
Expand Down Expand Up @@ -1129,4 +1167,4 @@ protected override TResult VisitUnhandledItem<TItem, TResult>(
protected override Exception CreateUnhandledItemException<T>(T unhandledItem, string visitMethod)
=> null; // Never called
}
}
}
111 changes: 111 additions & 0 deletions src/EFCore.Specification.Tests/Query/QueryTestBase.Select.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
Expand Down Expand Up @@ -477,5 +478,115 @@ var efObjects
Assert.Equal(l2oObjects, efObjects);
});
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_int_to_long_introduces_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => (long)o.OrderID),
assertOrder: true);
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_nullable_int_to_long_introduces_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => (long)o.EmployeeID),
assertOrder: true);
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_nullable_int_to_int_doesnt_introduces_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => (int)o.EmployeeID),
assertOrder: true);
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_int_to_nullable_int_doesnt_introduce_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => (int?)o.OrderID),
assertOrder: true);
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_from_binary_expression_introduces_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => (long)(o.OrderID + o.OrderID)),
assertOrder: true);
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_from_binary_expression_nested_introduces_top_level_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => (short)((long)o.OrderID + (long)o.OrderID)),
assertOrder: true);
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_from_unary_expression_introduces_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => (long)-o.OrderID),
assertOrder: true);
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_from_length_introduces_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => (long)o.CustomerID.Length),
assertOrder: true);
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_from_method_call_introduces_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => (long)Math.Abs(o.OrderID)),
assertOrder: true);
}

[ConditionalFact]
public virtual void Select_non_matching_value_types_from_anonymous_type_introduces_explicit_cast()
{
AssertQuery<Order>(
os => os
.Where(o => o.CustomerID == "ALFKI")
.OrderBy(o => o.OrderID)
.Select(o => new { LongOrder = (long)o.OrderID, ShortOrder = (short)o.OrderID, Order = o.OrderID }),
assertOrder: true);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ public override void Select_count_plus_sum()

AssertSql(
@"SELECT (
SELECT SUM([od].[Quantity])
SELECT SUM(CAST([od].[Quantity] AS int))
FROM [Order Details] AS [od]
WHERE [o].[OrderID] = [od].[OrderID]
) + (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,5 +428,115 @@ FROM [Orders] AS [o]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] LIKE N'A' + N'%' AND (LEFT([c].[CustomerID], LEN(N'A')) = N'A')");
}

public override void Select_non_matching_value_types_int_to_long_introduces_explicit_cast()
{
base.Select_non_matching_value_types_int_to_long_introduces_explicit_cast();

AssertSql(
@"SELECT CAST([o].[OrderID] AS bigint)
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [o].[OrderID]");
}

public override void Select_non_matching_value_types_nullable_int_to_long_introduces_explicit_cast()
{
base.Select_non_matching_value_types_nullable_int_to_long_introduces_explicit_cast();

AssertSql(
@"SELECT CAST([o].[EmployeeID] AS bigint)
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [o].[OrderID]");
}

public override void Select_non_matching_value_types_nullable_int_to_int_doesnt_introduces_explicit_cast()
{
base.Select_non_matching_value_types_nullable_int_to_int_doesnt_introduces_explicit_cast();

AssertSql(
@"SELECT [o].[EmployeeID]
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [o].[OrderID]");
}

public override void Select_non_matching_value_types_int_to_nullable_int_doesnt_introduce_explicit_cast()
{
base.Select_non_matching_value_types_int_to_nullable_int_doesnt_introduce_explicit_cast();

AssertSql(
@"SELECT [o].[OrderID]
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [o].[OrderID]");
}

public override void Select_non_matching_value_types_from_binary_expression_introduces_explicit_cast()
{
base.Select_non_matching_value_types_from_binary_expression_introduces_explicit_cast();

AssertSql(
@"SELECT CAST([o].[OrderID] + [o].[OrderID] AS bigint)
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [o].[OrderID]");
}

public override void Select_non_matching_value_types_from_binary_expression_nested_introduces_top_level_explicit_cast()
{
base.Select_non_matching_value_types_from_binary_expression_nested_introduces_top_level_explicit_cast();

AssertSql(
@"SELECT CAST([o].[OrderID] + [o].[OrderID] AS smallint)
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [o].[OrderID]");
}

public override void Select_non_matching_value_types_from_unary_expression_introduces_explicit_cast()
{
base.Select_non_matching_value_types_from_unary_expression_introduces_explicit_cast();

AssertSql(
@"SELECT CAST(-[o].[OrderID] AS bigint)
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [o].[OrderID]");
}

public override void Select_non_matching_value_types_from_length_introduces_explicit_cast()
{
base.Select_non_matching_value_types_from_length_introduces_explicit_cast();

AssertSql(
@"SELECT CAST(CAST(LEN([o].[CustomerID]) AS int) AS bigint)
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [o].[OrderID]");
}

public override void Select_non_matching_value_types_from_method_call_introduces_explicit_cast()
{
base.Select_non_matching_value_types_from_method_call_introduces_explicit_cast();

AssertSql(
@"SELECT CAST(ABS([o].[OrderID]) AS bigint)
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [o].[OrderID]");
}

public override void Select_non_matching_value_types_from_anonymous_type_introduces_explicit_cast()
{
base.Select_non_matching_value_types_from_anonymous_type_introduces_explicit_cast();

AssertSql(
@"SELECT CAST([o].[OrderID] AS bigint) AS [LongOrder], CAST([o].[OrderID] AS smallint) AS [ShortOrder], [o].[OrderID] AS [Order]
FROM [Orders] AS [o]
WHERE [o].[CustomerID] = N'ALFKI'
ORDER BY [Order]");
}
}
}

0 comments on commit 081c18b

Please sign in to comment.