Skip to content

Commit

Permalink
Visit arguments in QueryableMethodNormalizingExpressionVisitor after …
Browse files Browse the repository at this point in the history
…converting List.Contains

Fixes dotnet#32215
Fixes dotnet#32218
  • Loading branch information
roji committed Nov 3, 2023
1 parent 338b76a commit 7fe11ed
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// Server), we need to fall back to the previous IN translation.
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition() == QueryableMethods.Contains
&& methodCallExpression.Arguments[0] is ParameterQueryRootExpression parameterSource
&& UnwrapAsQueryable(methodCallExpression.Arguments[0]) is ParameterQueryRootExpression parameterSource
&& TranslateExpression(methodCallExpression.Arguments[1]) is SqlExpression item
&& _sqlTranslator.Visit(parameterSource.ParameterExpression) is SqlParameterExpression sqlParameterExpression)
{
Expand All @@ -300,6 +300,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
.UpdateResultCardinality(ResultCardinality.Single);
return shapedQueryExpression;
}

static Expression UnwrapAsQueryable(Expression expression)
=> expression is MethodCallExpression { Method: { IsGenericMethod: true } method } methodCall
&& method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable
? methodCall.Arguments[0]
: expression;
}

return translated;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,13 @@ private Expression TryConvertListContainsToQueryableContains(MethodCallExpressio

var sourceType = methodCallExpression.Method.DeclaringType!.GetGenericArguments()[0];

return Expression.Call(
QueryableMethods.Contains.MakeGenericMethod(sourceType),
return VisitMethodCall(
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(sourceType),
methodCallExpression.Object!),
methodCallExpression.Arguments[0]);
QueryableMethods.Contains.MakeGenericMethod(sourceType),
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(sourceType),
methodCallExpression.Object!),
methodCallExpression.Arguments[0]));
}

private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type queryableType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ public virtual Task Project_primitive_collections_element(bool async)
},
assertOrder: true);

[ConditionalTheory] // #32208
[ConditionalTheory] // #32208, #32215
[MemberData(nameof(IsAsyncData))]
public virtual Task Nested_contains_with_Lists_and_no_inferred_type_mapping(bool async)
{
Expand All @@ -821,6 +821,20 @@ public virtual Task Nested_contains_with_Lists_and_no_inferred_type_mapping(bool
ss => ss.Set<PrimitiveCollectionsEntity>().Where(e => strings.Contains(ints.Contains(e.Int) ? "one" : "two")));
}

[ConditionalTheory] // #32208, #32215
[MemberData(nameof(IsAsyncData))]
public virtual Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
{
var ints = new[] { 1, 2, 3 };
var strings = new[] { "one", "two", "three" };

// Note that in this query, the outer Contains really has no type mapping, neither for its source (collection parameter), nor
// for its item (the conditional expression returns constants). The default type mapping must be applied.
return AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(e => strings.Contains(ints.Contains(e.Int) ? "one" : "two")));
}

public abstract class PrimitiveCollectionsQueryFixtureBase : SharedStoreFixtureBase<PrimitiveCollectionsContext>, IQueryFixtureBase
{
private PrimitiveArrayData? _expectedData;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,21 @@ END IN (N'one', N'two', N'three')
""");
}

public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
{
await base.Nested_contains_with_arrays_and_no_inferred_type_mapping(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE CASE
WHEN [p].[Int] IN (1, 2, 3) THEN N'one'
ELSE N'two'
END IN (N'one', N'two', N'three')
""");
}

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1233,12 +1233,40 @@ public override async Task Nested_contains_with_Lists_and_no_inferred_type_mappi

AssertSql(
"""
@__ints_1='[1,2,3]' (Size = 4000)
@__strings_0='["one","two","three"]' (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE CASE
WHEN [p].[Int] IN (1, 2, 3) THEN N'one'
WHEN [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_1) WITH ([value] int '$') AS [i]
) THEN N'one'
ELSE N'two'
END IN (
SELECT [s].[value]
FROM OPENJSON(@__strings_0) WITH ([value] nvarchar(max) '$') AS [s]
)
""");
}

public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
{
await base.Nested_contains_with_arrays_and_no_inferred_type_mapping(async);

AssertSql(
"""
@__ints_1='[1,2,3]' (Size = 4000)
@__strings_0='["one","two","three"]' (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE CASE
WHEN [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_1) WITH ([value] int '$') AS [i]
) THEN N'one'
ELSE N'two'
END IN (
SELECT [s].[value]
Expand Down
6 changes: 5 additions & 1 deletion test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3982,13 +3982,17 @@ public virtual async Task Nested_contains_with_enum()

AssertSql(
"""
@__todoTypes_1='[0]' (Size = 4000)
@__key_2='5f221fb9-66f4-442a-92c9-d97ed5989cc7'
@__keys_0='["0a47bcb7-a1cb-4345-8944-c58f82d6aac7","5f221fb9-66f4-442a-92c9-d97ed5989cc7"]' (Size = 4000)

SELECT [t].[Id], [t].[Type]
FROM [Todos] AS [t]
WHERE CASE
WHEN [t].[Type] = 0 THEN @__key_2
WHEN [t].[Type] IN (
SELECT [t0].[value]
FROM OPENJSON(@__todoTypes_1) WITH ([value] int '$') AS [t0]
) THEN @__key_2
ELSE @__key_2
END IN (
SELECT [k].[value]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1115,12 +1115,16 @@ public override async Task Nested_contains_with_Lists_and_no_inferred_type_mappi

AssertSql(
"""
@__ints_1='[1,2,3]' (Size = 7)
@__strings_0='["one","two","three"]' (Size = 21)

SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE CASE
WHEN "p"."Int" IN (1, 2, 3) THEN 'one'
WHEN "p"."Int" IN (
SELECT "i"."value"
FROM json_each(@__ints_1) AS "i"
) THEN 'one'
ELSE 'two'
END IN (
SELECT "s"."value"
Expand All @@ -1129,6 +1133,30 @@ FROM json_each(@__strings_0) AS "s"
""");
}

public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
{
await base.Nested_contains_with_arrays_and_no_inferred_type_mapping(async);

AssertSql(
"""
@__ints_1='[1,2,3]' (Size = 7)
@__strings_0='["one","two","three"]' (Size = 21)

SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE CASE
WHEN "p"."Int" IN (
SELECT "i"."value"
FROM json_each(@__ints_1) AS "i"
) THEN 'one'
ELSE 'two'
END IN (
SELECT "s"."value"
FROM json_each(@__strings_0) AS "s"
)
""");
}

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down

0 comments on commit 7fe11ed

Please sign in to comment.