diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs index fe91fac010e..4ff2b0733d1 100644 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -373,7 +373,7 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method { // The source list is a constant, evaluate and replace with a list of the keys var listValue = (IEnumerable)listConstant.Value; - var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType); + var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType.MakeNullable()); var keyList = (IList)Activator.CreateInstance(keyListType); var getter = keyProperty.GetGetter(); foreach (var listItem in listValue) @@ -386,7 +386,6 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method && listParam.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal)) { // The source list is a parameter. Add a runtime parameter that will contain a list of the extracted keys for each execution. - var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType); var lambda = Expression.Lambda( Expression.Call( _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType.MakeNullable()), @@ -397,7 +396,10 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method ); var newParameterName = $"{RuntimeParameterPrefix}{listParam.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{keyProperty.Name}"; - rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda, keyListType); + rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter( + newParameterName, + lambda, + typeof(List<>).MakeGenericType(keyProperty.ClrType.MakeNullable())); } else { @@ -911,8 +913,6 @@ private static readonly MethodInfo _parameterValueExtractor /// private static List ParameterListValueExtractor(QueryContext context, string baseParameterName, IProperty property) { - Debug.Assert(property.ClrType == typeof(TProperty)); - var baseListParameter = context.ParameterValues[baseParameterName] as IEnumerable; if (baseListParameter == null) { diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs index 65f3acd907f..7754276ff66 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs @@ -1278,6 +1278,28 @@ FROM root c WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))"); } + [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")] + public override async Task Contains_with_parameter_list_value_type_id(bool isAsync) + { + await base.Contains_with_parameter_list_value_type_id(isAsync); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))"); + } + + [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")] + public override async Task Contains_with_constant_list_value_type_id(bool isAsync) + { + await base.Contains_with_constant_list_value_type_id(isAsync); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))"); + } + [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")] public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality() { diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs index e7b49846ce0..4b71410b480 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs @@ -1614,6 +1614,32 @@ public virtual Task List_Contains_with_parameter_list(bool isAsync) entryCount: 2); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Contains_with_parameter_list_value_type_id(bool isAsync) + { + var orders = new List + { + new Order { OrderID = 10248 }, + new Order { OrderID = 10249 } + }; + + return AssertQuery(isAsync, od => od.Where(o => orders.Contains(o)), + entryCount: 2); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Contains_with_constant_list_value_type_id(bool isAsync) + { + return AssertQuery(isAsync, od => od.Where(o => new List + { + new Order { OrderID = 10248 }, + new Order { OrderID = 10249 } + }.Contains(o)), + entryCount: 2); + } + [ConditionalFact] public virtual void Contains_over_keyless_entity_throws() { diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs index 3840e074a8c..f8dee1defb3 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs @@ -1225,6 +1225,26 @@ FROM [Customers] AS [c] WHERE [c].[CustomerID] IN (N'ALFKI', N'ANATR')"); } + public override async Task Contains_with_parameter_list_value_type_id(bool isAsync) + { + await base.Contains_with_parameter_list_value_type_id(isAsync); + + AssertSql( + @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Orders] AS [o] +WHERE [o].[OrderID] IN (10248, 10249)"); + } + + public override async Task Contains_with_constant_list_value_type_id(bool isAsync) + { + await base.Contains_with_constant_list_value_type_id(isAsync); + + AssertSql( + @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Orders] AS [o] +WHERE [o].[OrderID] IN (10248, 10249)"); + } + public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality() { base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();