From 828a86c6117e99ecff41afa6a965802e846fcaf0 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Wed, 22 Jul 2020 15:32:47 -0700 Subject: [PATCH] Query: Key comparison should use object.Equals internally in query And associated changes to support translation. Resolves #19407 --- .../Query/Internal/EqualsTranslator.cs | 6 +- ...yExpressionTranslatingExpressionVisitor.cs | 15 +- .../Query/Internal/EqualsTranslator.cs | 11 +- ...lationalSqlTranslatingExpressionVisitor.cs | 15 +- src/EFCore/Internal/EntityFinder.cs | 47 ++- ...ingExpressionVisitor.ExpressionVisitors.cs | 10 +- .../KeysWithConvertersTestBase.cs | 278 +++++++----------- .../Query/GearsOfWarQueryTestBase.cs | 32 +- 8 files changed, 203 insertions(+), 211 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs b/src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs index 5b02718cd19..b04cafc5a75 100644 --- a/src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs +++ b/src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs @@ -64,8 +64,10 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method && right != null) { return left.Type.UnwrapNullableType() == right.Type.UnwrapNullableType() - ? (SqlExpression)_sqlExpressionFactory.Equal(left, right) - : _sqlExpressionFactory.Constant(false); + || (right.Type == typeof(object) && right is SqlParameterExpression) + || (left.Type == typeof(object) && left is SqlParameterExpression) + ? _sqlExpressionFactory.Equal(left, right) + : (SqlExpression)_sqlExpressionFactory.Constant(false); } return null; diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 266f5aea8e3..4b6c58fffd1 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -145,7 +145,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) && binaryExpression.Left is NewArrayExpression && binaryExpression.NodeType == ExpressionType.Equal) { - return Visit(ConvertObjectArrayEqualityComparison(binaryExpression)); + return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right)); } var newLeft = Visit(binaryExpression.Left); @@ -557,6 +557,13 @@ MethodInfo GetMethod() && methodCallExpression.Object == null && methodCallExpression.Arguments.Count == 2) { + if (methodCallExpression.Arguments[0].Type == typeof(object[]) + && methodCallExpression.Arguments[0] is NewArrayExpression) + { + return Visit(ConvertObjectArrayEqualityComparison( + methodCallExpression.Arguments[0], methodCallExpression.Arguments[1])); + } + var left = Visit(methodCallExpression.Arguments[0]); var right = Visit(methodCallExpression.Arguments[1]); @@ -1262,10 +1269,10 @@ private static bool CanEvaluate(Expression expression) } } - private static Expression ConvertObjectArrayEqualityComparison(BinaryExpression binaryExpression) + private static Expression ConvertObjectArrayEqualityComparison(Expression left, Expression right) { - var leftExpressions = ((NewArrayExpression)binaryExpression.Left).Expressions; - var rightExpressions = ((NewArrayExpression)binaryExpression.Right).Expressions; + var leftExpressions = ((NewArrayExpression)left).Expressions; + var rightExpressions = ((NewArrayExpression)right).Expressions; return leftExpressions.Zip( rightExpressions, diff --git a/src/EFCore.Relational/Query/Internal/EqualsTranslator.cs b/src/EFCore.Relational/Query/Internal/EqualsTranslator.cs index c3be6ba5f36..a94ef7703e1 100644 --- a/src/EFCore.Relational/Query/Internal/EqualsTranslator.cs +++ b/src/EFCore.Relational/Query/Internal/EqualsTranslator.cs @@ -64,12 +64,11 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method if (left != null && right != null) { - if (left.Type == right.Type) - { - return _sqlExpressionFactory.Equal(left, right); - } - - return _sqlExpressionFactory.Constant(false); + return left.Type == right.Type + || (right.Type == typeof(object) && right is SqlParameterExpression) + || (left.Type == typeof(object) && left is SqlParameterExpression) + ? _sqlExpressionFactory.Equal(left, right) + : (SqlExpression)_sqlExpressionFactory.Constant(false); } return null; diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index f6d15fd6bc7..804928facb8 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -335,7 +335,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) && binaryExpression.Left is NewArrayExpression && binaryExpression.NodeType == ExpressionType.Equal) { - return Visit(ConvertObjectArrayEqualityComparison(binaryExpression)); + return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right)); } var left = TryRemoveImplicitConvert(binaryExpression.Left); @@ -624,6 +624,13 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) && methodCallExpression.Object == null && methodCallExpression.Arguments.Count == 2) { + if (methodCallExpression.Arguments[0].Type == typeof(object[]) + && methodCallExpression.Arguments[0] is NewArrayExpression) + { + return Visit(ConvertObjectArrayEqualityComparison( + methodCallExpression.Arguments[0], methodCallExpression.Arguments[1])); + } + var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0])); var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[1])); @@ -1000,10 +1007,10 @@ private static Expression RemoveObjectConvert(Expression expression) ? unaryExpression.Operand : expression; - private static Expression ConvertObjectArrayEqualityComparison(BinaryExpression binaryExpression) + private static Expression ConvertObjectArrayEqualityComparison(Expression left, Expression right) { - var leftExpressions = ((NewArrayExpression)binaryExpression.Left).Expressions; - var rightExpressions = ((NewArrayExpression)binaryExpression.Right).Expressions; + var leftExpressions = ((NewArrayExpression)left).Expressions; + var rightExpressions = ((NewArrayExpression)right).Expressions; return leftExpressions.Zip( rightExpressions, diff --git a/src/EFCore/Internal/EntityFinder.cs b/src/EFCore/Internal/EntityFinder.cs index d1c757a8f08..40d7c3ce3bd 100644 --- a/src/EFCore/Internal/EntityFinder.cs +++ b/src/EFCore/Internal/EntityFinder.cs @@ -26,6 +26,9 @@ namespace Microsoft.EntityFrameworkCore.Internal public class EntityFinder : IEntityFinder where TEntity : class { + private static readonly MethodInfo _objectEqualsMethodInfo + = typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) }); + private readonly IStateManager _stateManager; private readonly IDbSetSource _setSource; private readonly IDbSetCache _setCache; @@ -354,34 +357,50 @@ private static IQueryable Select( parameter)); } - private static BinaryExpression BuildPredicate( + private static Expression BuildPredicate( IReadOnlyList keyProperties, ValueBuffer keyValues, ParameterExpression entityParameter) { var keyValuesConstant = Expression.Constant(keyValues); - var predicate = GenerateEqualExpression(keyProperties[0], 0); + var predicate = GenerateEqualExpression(entityParameter, keyValuesConstant, keyProperties[0], 0); for (var i = 1; i < keyProperties.Count; i++) { - predicate = Expression.AndAlso(predicate, GenerateEqualExpression(keyProperties[i], i)); + predicate = Expression.AndAlso(predicate, GenerateEqualExpression(entityParameter, keyValuesConstant, keyProperties[i], i)); } return predicate; - BinaryExpression GenerateEqualExpression(IProperty property, int i) => - Expression.Equal( - Expression.Call( - EF.PropertyMethod.MakeGenericMethod(property.ClrType), - entityParameter, - Expression.Constant(property.Name, typeof(string))), - Expression.Convert( + static Expression GenerateEqualExpression( + Expression entityParameterExpression, Expression keyValuesConstantExpression, IProperty property, int i) + => property.ClrType.IsValueType + && property.ClrType.UnwrapNullableType() is Type nonNullableType + && !(nonNullableType == typeof(bool) || nonNullableType.IsNumeric() || nonNullableType.IsEnum) + ? Expression.Call( + _objectEqualsMethodInfo, Expression.Call( - keyValuesConstant, - ValueBuffer.GetValueMethod, - Expression.Constant(i)), - property.ClrType)); + EF.PropertyMethod.MakeGenericMethod(typeof(object)), + entityParameterExpression, + Expression.Constant(property.Name, typeof(string))), + Expression.Convert( + Expression.Call( + keyValuesConstantExpression, + ValueBuffer.GetValueMethod, + Expression.Constant(i)), + typeof(object))) + : (Expression)Expression.Equal( + Expression.Call( + EF.PropertyMethod.MakeGenericMethod(property.ClrType), + entityParameterExpression, + Expression.Constant(property.Name, typeof(string))), + Expression.Convert( + Expression.Call( + keyValuesConstantExpression, + ValueBuffer.GetValueMethod, + Expression.Constant(i)), + property.ClrType)); } private static Expression> BuildProjection(IEntityType entityType) diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs index 8b3b17f7604..85f707d2e4f 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs @@ -23,6 +23,9 @@ public partial class NavigationExpandingExpressionVisitor /// private class ExpandingExpressionVisitor : ExpressionVisitor { + private static readonly MethodInfo _objectEqualsMethodInfo + = typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) }); + private readonly NavigationExpandingExpressionVisitor _navigationExpandingExpressionVisitor; private readonly NavigationExpansionExpression _source; @@ -393,7 +396,7 @@ outerKey is NewArrayExpression newArrayExpression }) .Aggregate((l, r) => Expression.AndAlso(l, r)) : Expression.NotEqual(outerKey, Expression.Constant(null, outerKey.Type)), - Expression.Equal(outerKey, innerKey)); + Expression.Call(_objectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey))); // Caller should take care of wrapping MaterializeCollectionNavigation return Expression.Call( @@ -455,6 +458,11 @@ outerKey is NewArrayExpression newArrayExpression return innerSource.PendingSelector; } + + static Expression AddConvertToObject(Expression expression) + => expression.Type.IsValueType + ? Expression.Convert(expression, typeof(object)) + : expression; } /// diff --git a/test/EFCore.Specification.Tests/KeysWithConvertersTestBase.cs b/test/EFCore.Specification.Tests/KeysWithConvertersTestBase.cs index 054a9aba66d..2f807d843fc 100644 --- a/test/EFCore.Specification.Tests/KeysWithConvertersTestBase.cs +++ b/test/EFCore.Specification.Tests/KeysWithConvertersTestBase.cs @@ -86,9 +86,9 @@ void RunQueries( context.Set().Single(e => e.Id.Equals(new IntStructKey { Id = 101 })), context.Set().Single(e => e.Id.Equals(new IntStructKey { Id = oneOhTwo })), context.Set().Single(e => e.Id.Equals(oneOhThree)), - context.Set().Single(e => e.Id == new IntStructKey { Id = 104 }), - context.Set().Single(e => e.Id == new IntStructKey { Id = oneOhFive }), - context.Set().Single(e => e.Id == oneOhSix) + context.Set().Single(e => e.Id.Equals(new IntStructKey { Id = 104 })), + context.Set().Single(e => e.Id.Equals(new IntStructKey { Id = oneOhFive })), + context.Set().Single(e => e.Id.Equals(oneOhSix)) }; Assert.Same(dependents[0], context.Set().Find(new IntStructKey { Id = 101 })); @@ -180,9 +180,9 @@ void RunQueries( context.Set().Single(e => e.Id.Equals(new ComparableIntStructKey { Id = 101 })), context.Set().Single(e => e.Id.Equals(new ComparableIntStructKey { Id = oneOhTwo })), context.Set().Single(e => e.Id.Equals(oneOhThree)), - context.Set().Single(e => e.Id == new ComparableIntStructKey { Id = 104 }), - context.Set().Single(e => e.Id == new ComparableIntStructKey { Id = oneOhFive }), - context.Set().Single(e => e.Id == oneOhSix) + context.Set().Single(e => e.Id.Equals(new ComparableIntStructKey { Id = 104 })), + context.Set().Single(e => e.Id.Equals(new ComparableIntStructKey { Id = oneOhFive })), + context.Set().Single(e => e.Id.Equals(oneOhSix)) }; Assert.Same(dependents[0], context.Set().Find(new ComparableIntStructKey { Id = 101 })); @@ -274,9 +274,9 @@ void RunQueries( context.Set().Single(e => e.Id.Equals(new GenericComparableIntStructKey { Id = 101 })), context.Set().Single(e => e.Id.Equals(new GenericComparableIntStructKey { Id = oneOhTwo })), context.Set().Single(e => e.Id.Equals(oneOhThree)), - context.Set().Single(e => e.Id == new GenericComparableIntStructKey { Id = 104 }), - context.Set().Single(e => e.Id == new GenericComparableIntStructKey { Id = oneOhFive }), - context.Set().Single(e => e.Id == oneOhSix) + context.Set().Single(e => e.Id.Equals(new GenericComparableIntStructKey { Id = 104 })), + context.Set().Single(e => e.Id.Equals(new GenericComparableIntStructKey { Id = oneOhFive })), + context.Set().Single(e => e.Id.Equals(oneOhSix)) }; Assert.Same(dependents[0], context.Set().Find(new GenericComparableIntStructKey { Id = 101 })); @@ -368,9 +368,9 @@ void RunQueries( context.Set().FirstOrDefault(e => e.Id.Equals(new IntStructKey { Id = 111 })), context.Set().FirstOrDefault(e => e.Id.Equals(new IntStructKey { Id = oneTwelve })), context.Set().FirstOrDefault(e => e.Id.Equals(oneThirteen)), - context.Set().FirstOrDefault(e => e.Id == new IntStructKey { Id = 114 }), - context.Set().FirstOrDefault(e => e.Id == new IntStructKey { Id = oneFifteeen }), - context.Set().FirstOrDefault(e => e.Id == oneSixteen) + context.Set().FirstOrDefault(e => e.Id.Equals(new IntStructKey { Id = 114 })), + context.Set().FirstOrDefault(e => e.Id.Equals(new IntStructKey { Id = oneFifteeen })), + context.Set().FirstOrDefault(e => e.Id.Equals(oneSixteen)) }; Assert.Same(dependents[0], context.Set().Find(new IntStructKey { Id = 111 })); @@ -462,9 +462,9 @@ void RunQueries( context.Set().FirstOrDefault(e => e.Id.Equals(new ComparableIntStructKey { Id = 111 })), context.Set().FirstOrDefault(e => e.Id.Equals(new ComparableIntStructKey { Id = oneTwelve })), context.Set().FirstOrDefault(e => e.Id.Equals(oneThirteen)), - context.Set().FirstOrDefault(e => e.Id == new ComparableIntStructKey { Id = 114 }), - context.Set().FirstOrDefault(e => e.Id == new ComparableIntStructKey { Id = oneFifteeen }), - context.Set().FirstOrDefault(e => e.Id == oneSixteen) + context.Set().FirstOrDefault(e => e.Id.Equals(new ComparableIntStructKey { Id = 114 })), + context.Set().FirstOrDefault(e => e.Id.Equals(new ComparableIntStructKey { Id = oneFifteeen })), + context.Set().FirstOrDefault(e => e.Id.Equals(oneSixteen)) }; Assert.Same(dependents[0], context.Set().Find(new ComparableIntStructKey { Id = 111 })); @@ -556,9 +556,9 @@ void RunQueries( context.Set().FirstOrDefault(e => e.Id.Equals(new GenericComparableIntStructKey { Id = 111 })), context.Set().FirstOrDefault(e => e.Id.Equals(new GenericComparableIntStructKey { Id = oneTwelve })), context.Set().FirstOrDefault(e => e.Id.Equals(oneThirteen)), - context.Set().FirstOrDefault(e => e.Id == new GenericComparableIntStructKey { Id = 114 }), - context.Set().FirstOrDefault(e => e.Id == new GenericComparableIntStructKey { Id = oneFifteeen }), - context.Set().FirstOrDefault(e => e.Id == oneSixteen) + context.Set().FirstOrDefault(e => e.Id.Equals(new GenericComparableIntStructKey { Id = 114 })), + context.Set().FirstOrDefault(e => e.Id.Equals(new GenericComparableIntStructKey { Id = oneFifteeen })), + context.Set().FirstOrDefault(e => e.Id.Equals(oneSixteen)) }; Assert.Same(dependents[0], context.Set().Find(new GenericComparableIntStructKey { Id = 111 })); @@ -844,9 +844,9 @@ void RunQueries( context.Set().Single(e => e.Id.Equals(new BytesStructKey { Id = new byte[] { 101 } })), context.Set().Single(e => e.Id.Equals(new BytesStructKey { Id = oneOhTwo })), context.Set().Single(e => e.Id.Equals(oneOhThree)), - context.Set().Single(e => e.Id == new BytesStructKey { Id = new byte[] { 104 } }), - context.Set().Single(e => e.Id == new BytesStructKey { Id = oneOhFive }), - context.Set().Single(e => e.Id == oneOhSix) + context.Set().Single(e => e.Id.Equals(new BytesStructKey { Id = new byte[] { 104 } })), + context.Set().Single(e => e.Id.Equals(new BytesStructKey { Id = oneOhFive })), + context.Set().Single(e => e.Id.Equals(oneOhSix)) }; Assert.Same(dependents[0], context.Set().Find(new BytesStructKey { Id = new byte[] { 101 } })); @@ -938,9 +938,9 @@ void RunQueries( context.Set().Single(e => e.Id.Equals(new StructuralComparableBytesStructKey { Id = new byte[] { 101 } })), context.Set().Single(e => e.Id.Equals(new StructuralComparableBytesStructKey { Id = oneOhTwo })), context.Set().Single(e => e.Id.Equals(oneOhThree)), - context.Set().Single(e => e.Id == new StructuralComparableBytesStructKey { Id = new byte[] { 104 } }), - context.Set().Single(e => e.Id == new StructuralComparableBytesStructKey { Id = oneOhFive }), - context.Set().Single(e => e.Id == oneOhSix) + context.Set().Single(e => e.Id.Equals(new StructuralComparableBytesStructKey { Id = new byte[] { 104 } })), + context.Set().Single(e => e.Id.Equals(new StructuralComparableBytesStructKey { Id = oneOhFive })), + context.Set().Single(e => e.Id.Equals(oneOhSix)) }; Assert.Same(dependents[0], context.Set().Find(new StructuralComparableBytesStructKey { Id = new byte[] { 101 } })); @@ -1032,9 +1032,9 @@ void RunQueries( context.Set().Single(e => e.Id.Equals(new ComparableBytesStructKey { Id = new byte[] { 101 } })), context.Set().Single(e => e.Id.Equals(new ComparableBytesStructKey { Id = oneOhTwo })), context.Set().Single(e => e.Id.Equals(oneOhThree)), - context.Set().Single(e => e.Id == new ComparableBytesStructKey { Id = new byte[] { 104 } }), - context.Set().Single(e => e.Id == new ComparableBytesStructKey { Id = oneOhFive }), - context.Set().Single(e => e.Id == oneOhSix) + context.Set().Single(e => e.Id.Equals(new ComparableBytesStructKey { Id = new byte[] { 104 } })), + context.Set().Single(e => e.Id.Equals(new ComparableBytesStructKey { Id = oneOhFive })), + context.Set().Single(e => e.Id.Equals(oneOhSix)) }; Assert.Same(dependents[0], context.Set().Find(new ComparableBytesStructKey { Id = new byte[] { 101 } })); @@ -1126,9 +1126,9 @@ void RunQueries( context.Set().Single(e => e.Id.Equals(new GenericComparableBytesStructKey { Id = new byte[] { 101 } })), context.Set().Single(e => e.Id.Equals(new GenericComparableBytesStructKey { Id = oneOhTwo })), context.Set().Single(e => e.Id.Equals(oneOhThree)), - context.Set().Single(e => e.Id == new GenericComparableBytesStructKey { Id = new byte[] { 104 } }), - context.Set().Single(e => e.Id == new GenericComparableBytesStructKey { Id = oneOhFive }), - context.Set().Single(e => e.Id == oneOhSix) + context.Set().Single(e => e.Id.Equals(new GenericComparableBytesStructKey { Id = new byte[] { 104 } })), + context.Set().Single(e => e.Id.Equals(new GenericComparableBytesStructKey { Id = oneOhFive })), + context.Set().Single(e => e.Id.Equals(oneOhSix)) }; Assert.Same(dependents[0], context.Set().Find(new GenericComparableBytesStructKey { Id = new byte[] { 101 } })); @@ -1220,9 +1220,9 @@ void RunQueries( context.Set().FirstOrDefault(e => e.Id.Equals(new BytesStructKey { Id = new byte[] { 111 } })), context.Set().FirstOrDefault(e => e.Id.Equals(new BytesStructKey { Id = oneTwelve })), context.Set().FirstOrDefault(e => e.Id.Equals(oneThirteen)), - context.Set().FirstOrDefault(e => e.Id == new BytesStructKey { Id = new byte[] { 114 } }), - context.Set().FirstOrDefault(e => e.Id == new BytesStructKey { Id = oneFifteeen }), - context.Set().FirstOrDefault(e => e.Id == oneSixteen) + context.Set().FirstOrDefault(e => e.Id.Equals(new BytesStructKey { Id = new byte[] { 114 } })), + context.Set().FirstOrDefault(e => e.Id.Equals(new BytesStructKey { Id = oneFifteeen })), + context.Set().FirstOrDefault(e => e.Id.Equals(oneSixteen)) }; Assert.Same(dependents[0], context.Set().Find(new BytesStructKey { Id = new byte[] { 111 } })); @@ -1314,9 +1314,9 @@ void RunQueries( context.Set().FirstOrDefault(e => e.Id.Equals(new ComparableBytesStructKey { Id = new byte[] { 111 } })), context.Set().FirstOrDefault(e => e.Id.Equals(new ComparableBytesStructKey { Id = oneTwelve })), context.Set().FirstOrDefault(e => e.Id.Equals(oneThirteen)), - context.Set().FirstOrDefault(e => e.Id == new ComparableBytesStructKey { Id = new byte[] { 114 } }), - context.Set().FirstOrDefault(e => e.Id == new ComparableBytesStructKey { Id = oneFifteeen }), - context.Set().FirstOrDefault(e => e.Id == oneSixteen) + context.Set().FirstOrDefault(e => e.Id.Equals(new ComparableBytesStructKey { Id = new byte[] { 114 } })), + context.Set().FirstOrDefault(e => e.Id.Equals(new ComparableBytesStructKey { Id = oneFifteeen })), + context.Set().FirstOrDefault(e => e.Id.Equals(oneSixteen)) }; Assert.Same(dependents[0], context.Set().Find(new ComparableBytesStructKey { Id = new byte[] { 111 } })); @@ -1408,9 +1408,9 @@ void RunQueries( context.Set().FirstOrDefault(e => e.Id.Equals(new StructuralComparableBytesStructKey { Id = new byte[] { 111 } })), context.Set().FirstOrDefault(e => e.Id.Equals(new StructuralComparableBytesStructKey { Id = oneTwelve })), context.Set().FirstOrDefault(e => e.Id.Equals(oneThirteen)), - context.Set().FirstOrDefault(e => e.Id == new StructuralComparableBytesStructKey { Id = new byte[] { 114 } }), - context.Set().FirstOrDefault(e => e.Id == new StructuralComparableBytesStructKey { Id = oneFifteeen }), - context.Set().FirstOrDefault(e => e.Id == oneSixteen) + context.Set().FirstOrDefault(e => e.Id.Equals(new StructuralComparableBytesStructKey { Id = new byte[] { 114 } })), + context.Set().FirstOrDefault(e => e.Id.Equals(new StructuralComparableBytesStructKey { Id = oneFifteeen })), + context.Set().FirstOrDefault(e => e.Id.Equals(oneSixteen)) }; Assert.Same(dependents[0], context.Set().Find(new StructuralComparableBytesStructKey { Id = new byte[] { 111 } })); @@ -1502,9 +1502,9 @@ void RunQueries( context.Set().FirstOrDefault(e => e.Id.Equals(new GenericComparableBytesStructKey { Id = new byte[] { 111 } })), context.Set().FirstOrDefault(e => e.Id.Equals(new GenericComparableBytesStructKey { Id = oneTwelve })), context.Set().FirstOrDefault(e => e.Id.Equals(oneThirteen)), - context.Set().FirstOrDefault(e => e.Id == new GenericComparableBytesStructKey { Id = new byte[] { 114 } }), - context.Set().FirstOrDefault(e => e.Id == new GenericComparableBytesStructKey { Id = oneFifteeen }), - context.Set().FirstOrDefault(e => e.Id == oneSixteen) + context.Set().FirstOrDefault(e => e.Id.Equals(new GenericComparableBytesStructKey { Id = new byte[] { 114 } })), + context.Set().FirstOrDefault(e => e.Id.Equals(new GenericComparableBytesStructKey { Id = oneFifteeen })), + context.Set().FirstOrDefault(e => e.Id.Equals(oneSixteen)) }; Assert.Same(dependents[0], context.Set().Find(new GenericComparableBytesStructKey { Id = new byte[] { 111 } })); @@ -1535,48 +1535,46 @@ private void InsertOptionalGraph() where TPrincipal : class, IIntPrincipal, new() where TDependent : class, IIntOptionalDependent, new() { - using (var context = CreateContext()) - { - context.Set().AddRange( - new TPrincipal { BackingId = 1, Foo = "X1" }, - new TPrincipal { BackingId = 2, Foo = "X2" }, - new TPrincipal { BackingId = 3, Foo = "X3" }, - new TPrincipal { BackingId = 4, Foo = "X4" }); + using var context = CreateContext(); - context.Set().AddRange( - new TDependent { BackingId = 101, BackingPrincipalId = 1 }, - new TDependent { BackingId = 102, BackingPrincipalId = 2 }, - new TDependent { BackingId = 103, BackingPrincipalId = 3 }, - new TDependent { BackingId = 104, BackingPrincipalId = 3 }, - new TDependent { BackingId = 105, BackingPrincipalId = 3 }, - new TDependent { BackingId = 106 }); + context.Set().AddRange( + new TPrincipal { BackingId = 1, Foo = "X1" }, + new TPrincipal { BackingId = 2, Foo = "X2" }, + new TPrincipal { BackingId = 3, Foo = "X3" }, + new TPrincipal { BackingId = 4, Foo = "X4" }); - Assert.Equal(10, context.SaveChanges()); - } + context.Set().AddRange( + new TDependent { BackingId = 101, BackingPrincipalId = 1 }, + new TDependent { BackingId = 102, BackingPrincipalId = 2 }, + new TDependent { BackingId = 103, BackingPrincipalId = 3 }, + new TDependent { BackingId = 104, BackingPrincipalId = 3 }, + new TDependent { BackingId = 105, BackingPrincipalId = 3 }, + new TDependent { BackingId = 106 }); + + Assert.Equal(10, context.SaveChanges()); } private void InsertRequiredGraph() where TPrincipal : class, IIntPrincipal, new() where TDependent : class, IIntRequiredDependent, new() { - using (var context = CreateContext()) - { - context.Set().AddRange( - new TPrincipal { BackingId = 11, Foo = "X1" }, - new TPrincipal { BackingId = 12, Foo = "X2" }, - new TPrincipal { BackingId = 13, Foo = "X3" }, - new TPrincipal { BackingId = 14, Foo = "X4" }); + using var context = CreateContext(); - context.Set().AddRange( - new TDependent { BackingId = 111, BackingPrincipalId = 11 }, - new TDependent { BackingId = 112, BackingPrincipalId = 12 }, - new TDependent { BackingId = 113, BackingPrincipalId = 13 }, - new TDependent { BackingId = 114, BackingPrincipalId = 13 }, - new TDependent { BackingId = 115, BackingPrincipalId = 13 }, - new TDependent { BackingId = 116, BackingPrincipalId = 13 }); + context.Set().AddRange( + new TPrincipal { BackingId = 11, Foo = "X1" }, + new TPrincipal { BackingId = 12, Foo = "X2" }, + new TPrincipal { BackingId = 13, Foo = "X3" }, + new TPrincipal { BackingId = 14, Foo = "X4" }); - Assert.Equal(10, context.SaveChanges()); - } + context.Set().AddRange( + new TDependent { BackingId = 111, BackingPrincipalId = 11 }, + new TDependent { BackingId = 112, BackingPrincipalId = 12 }, + new TDependent { BackingId = 113, BackingPrincipalId = 13 }, + new TDependent { BackingId = 114, BackingPrincipalId = 13 }, + new TDependent { BackingId = 115, BackingPrincipalId = 13 }, + new TDependent { BackingId = 116, BackingPrincipalId = 13 }); + + Assert.Equal(10, context.SaveChanges()); } protected void ValidateOptional( @@ -1668,48 +1666,44 @@ private void InsertOptionalBytesGraph() where TPrincipal : class, IBytesPrincipal, new() where TDependent : class, IBytesOptionalDependent, new() { - using (var context = CreateContext()) - { - context.Set().AddRange( - new TPrincipal { BackingId = new byte[]{ 1 }, Foo = "X1" }, - new TPrincipal { BackingId = new byte[]{ 2, 2 }, Foo = "X2" }, - new TPrincipal { BackingId = new byte[]{ 3, 3, 3 }, Foo = "X3" }, - new TPrincipal { BackingId = new byte[]{ 4, 4, 4, 4 }, Foo = "X4" }); + using var context = CreateContext(); + context.Set().AddRange( + new TPrincipal { BackingId = new byte[] { 1 }, Foo = "X1" }, + new TPrincipal { BackingId = new byte[] { 2, 2 }, Foo = "X2" }, + new TPrincipal { BackingId = new byte[] { 3, 3, 3 }, Foo = "X3" }, + new TPrincipal { BackingId = new byte[] { 4, 4, 4, 4 }, Foo = "X4" }); - context.Set().AddRange( - new TDependent { BackingId = new byte[]{ 101 }, BackingPrincipalId = new byte[]{ 1 } }, - new TDependent { BackingId = new byte[]{ 102 }, BackingPrincipalId = new byte[]{ 2, 2 } }, - new TDependent { BackingId = new byte[]{ 103 }, BackingPrincipalId = new byte[]{ 3, 3, 3 } }, - new TDependent { BackingId = new byte[]{ 104 }, BackingPrincipalId = new byte[]{ 3, 3, 3 } }, - new TDependent { BackingId = new byte[]{ 105 }, BackingPrincipalId = new byte[]{ 3, 3, 3 } }, - new TDependent { BackingId = new byte[]{ 106 } }); + context.Set().AddRange( + new TDependent { BackingId = new byte[] { 101 }, BackingPrincipalId = new byte[] { 1 } }, + new TDependent { BackingId = new byte[] { 102 }, BackingPrincipalId = new byte[] { 2, 2 } }, + new TDependent { BackingId = new byte[] { 103 }, BackingPrincipalId = new byte[] { 3, 3, 3 } }, + new TDependent { BackingId = new byte[] { 104 }, BackingPrincipalId = new byte[] { 3, 3, 3 } }, + new TDependent { BackingId = new byte[] { 105 }, BackingPrincipalId = new byte[] { 3, 3, 3 } }, + new TDependent { BackingId = new byte[] { 106 } }); - Assert.Equal(10, context.SaveChanges()); - } + Assert.Equal(10, context.SaveChanges()); } private void InsertRequiredBytesGraph() where TPrincipal : class, IBytesPrincipal, new() where TDependent : class, IBytesRequiredDependent, new() { - using (var context = CreateContext()) - { - context.Set().AddRange( - new TPrincipal { BackingId = new byte[]{ 11 }, Foo = "X1" }, - new TPrincipal { BackingId = new byte[]{ 12, 12 }, Foo = "X2" }, - new TPrincipal { BackingId = new byte[]{ 13, 13, 13 }, Foo = "X3" }, - new TPrincipal { BackingId = new byte[]{ 14, 14, 14, 14 }, Foo = "X4" }); + using var context = CreateContext(); + context.Set().AddRange( + new TPrincipal { BackingId = new byte[] { 11 }, Foo = "X1" }, + new TPrincipal { BackingId = new byte[] { 12, 12 }, Foo = "X2" }, + new TPrincipal { BackingId = new byte[] { 13, 13, 13 }, Foo = "X3" }, + new TPrincipal { BackingId = new byte[] { 14, 14, 14, 14 }, Foo = "X4" }); - context.Set().AddRange( - new TDependent { BackingId = new byte[]{ 111 }, BackingPrincipalId = new byte[]{ 11 } }, - new TDependent { BackingId = new byte[]{ 112 }, BackingPrincipalId = new byte[]{ 12, 12 } }, - new TDependent { BackingId = new byte[]{ 113 }, BackingPrincipalId = new byte[]{ 13, 13, 13 } }, - new TDependent { BackingId = new byte[]{ 114 }, BackingPrincipalId = new byte[]{ 13, 13, 13 } }, - new TDependent { BackingId = new byte[]{ 115 }, BackingPrincipalId = new byte[]{ 13, 13, 13 } }, - new TDependent { BackingId = new byte[]{ 116 }, BackingPrincipalId = new byte[]{ 13, 13, 13 } }); + context.Set().AddRange( + new TDependent { BackingId = new byte[] { 111 }, BackingPrincipalId = new byte[] { 11 } }, + new TDependent { BackingId = new byte[] { 112 }, BackingPrincipalId = new byte[] { 12, 12 } }, + new TDependent { BackingId = new byte[] { 113 }, BackingPrincipalId = new byte[] { 13, 13, 13 } }, + new TDependent { BackingId = new byte[] { 114 }, BackingPrincipalId = new byte[] { 13, 13, 13 } }, + new TDependent { BackingId = new byte[] { 115 }, BackingPrincipalId = new byte[] { 13, 13, 13 } }, + new TDependent { BackingId = new byte[] { 116 }, BackingPrincipalId = new byte[] { 13, 13, 13 } }); - Assert.Equal(10, context.SaveChanges()); - } + Assert.Equal(10, context.SaveChanges()); } protected void ValidateOptionalBytes( @@ -1804,25 +1798,15 @@ protected void ValidateRequiredBytes( } } -#pragma warning disable 660,661 // Issue #19407 protected struct IntStructKey -#pragma warning restore 660,661 { public static ValueConverter Converter = new ValueConverter(v => v.Id, v => new IntStructKey { Id = v }); public int Id { get; set; } - - public static bool operator ==(IntStructKey left, IntStructKey right) - => left.Id == right.Id; - - public static bool operator !=(IntStructKey left, IntStructKey right) - => left.Id != right.Id; } -#pragma warning disable 660,661 // Issue #19407 protected struct BytesStructKey -#pragma warning restore 660,661 { public static ValueConverter Converter = new ValueConverter(v => v.Id, v => new BytesStructKey { Id = v }); @@ -1848,36 +1832,20 @@ public override int GetHashCode() return code.ToHashCode(); } - - public static bool operator ==(BytesStructKey left, BytesStructKey right) - => left.Equals(right); - - public static bool operator !=(BytesStructKey left, BytesStructKey right) - => !left.Equals(right); } -#pragma warning disable 660,661 // Issue #19407 protected struct ComparableIntStructKey : IComparable -#pragma warning restore 660,661 { public static ValueConverter Converter = new ValueConverter(v => v.Id, v => new ComparableIntStructKey { Id = v }); public int Id { get; set; } - public static bool operator ==(ComparableIntStructKey left, ComparableIntStructKey right) - => left.Id == right.Id; - - public static bool operator !=(ComparableIntStructKey left, ComparableIntStructKey right) - => left.Id != right.Id; - public int CompareTo(object other) => Id - ((ComparableIntStructKey)other).Id; } -#pragma warning disable 660,661 // Issue #19407 protected struct ComparableBytesStructKey : IComparable -#pragma warning restore 660,661 { public static ValueConverter Converter = new ValueConverter(v => v.Id, v => new ComparableBytesStructKey { Id = v }); @@ -1904,46 +1872,28 @@ public override int GetHashCode() return code.ToHashCode(); } - public static bool operator ==(ComparableBytesStructKey left, ComparableBytesStructKey right) - => left.Equals(right); - - public static bool operator !=(ComparableBytesStructKey left, ComparableBytesStructKey right) - => !left.Equals(right); - public int CompareTo(object other) { var result = Id.Length - ((ComparableBytesStructKey)other).Id.Length; - if (result != 0) - { - return result; - } - return StructuralComparisons.StructuralComparer.Compare(Id, ((ComparableBytesStructKey)other).Id); + return result != 0 + ? result + : StructuralComparisons.StructuralComparer.Compare(Id, ((ComparableBytesStructKey)other).Id); } } -#pragma warning disable 660,661 // Issue #19407 protected struct GenericComparableIntStructKey : IComparable -#pragma warning restore 660,661 { public static ValueConverter Converter = new ValueConverter(v => v.Id, v => new GenericComparableIntStructKey { Id = v }); public int Id { get; set; } - public static bool operator ==(GenericComparableIntStructKey left, GenericComparableIntStructKey right) - => left.Id == right.Id; - - public static bool operator !=(GenericComparableIntStructKey left, GenericComparableIntStructKey right) - => !(left == right); - public int CompareTo(GenericComparableIntStructKey other) => Id - other.Id; } -#pragma warning disable 660,661 // Issue #19407 protected struct GenericComparableBytesStructKey : IComparable -#pragma warning restore 660,661 { public static ValueConverter Converter = new ValueConverter(v => v.Id, v => new GenericComparableBytesStructKey { Id = v }); @@ -1970,27 +1920,17 @@ public override int GetHashCode() return code.ToHashCode(); } - public static bool operator ==(GenericComparableBytesStructKey left, GenericComparableBytesStructKey right) - => left.Equals(right); - - public static bool operator !=(GenericComparableBytesStructKey left, GenericComparableBytesStructKey right) - => !left.Equals(right); - public int CompareTo(GenericComparableBytesStructKey other) { var result = Id.Length - other.Id.Length; - if (result != 0) - { - return result; - } - return StructuralComparisons.StructuralComparer.Compare(Id, other.Id); + return result != 0 + ? result + : StructuralComparisons.StructuralComparer.Compare(Id, other.Id); } } -#pragma warning disable 660,661 // Issue #19407 protected struct StructuralComparableBytesStructKey : IStructuralComparable -#pragma warning restore 660,661 { public static ValueConverter Converter = new ValueConverter(v => v.Id, v => new StructuralComparableBytesStructKey { Id = v }); @@ -2017,12 +1957,6 @@ public override int GetHashCode() return code.ToHashCode(); } - public static bool operator ==(StructuralComparableBytesStructKey left, StructuralComparableBytesStructKey right) - => left.Equals(right); - - public static bool operator !=(StructuralComparableBytesStructKey left, StructuralComparableBytesStructKey right) - => !left.Equals(right); - public int CompareTo(object other, IComparer comparer) { var typedOther = ((StructuralComparableBytesStructKey)other); diff --git a/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs index 18a1b68ac71..d5766d59a2e 100644 --- a/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs @@ -1532,14 +1532,22 @@ public virtual async Task Select_navigation_with_concat_and_count(bool async) @"MaterializeCollectionNavigation( Navigation: Gear.Weapons, subquery: DbSet() - .Where(w => EF.Property(g, ""FullName"") != null && EF.Property(g, ""FullName"") == EF.Property(w, ""OwnerFullName"")) - .Where(i => EF.Property(g, ""FullName"") != null && EF.Property(g, ""FullName"") == EF.Property(i, ""OwnerFullName"")) + .Where(w => EF.Property(g, ""FullName"") != null && object.Equals( + objA: EF.Property(g, ""FullName""), + objB: EF.Property(w, ""OwnerFullName""))) + .Where(i => EF.Property(g, ""FullName"") != null && object.Equals( + objA: EF.Property(g, ""FullName""), + objB: EF.Property(i, ""OwnerFullName""))) .AsQueryable() .Concat(MaterializeCollectionNavigation( Navigation: Gear.Weapons, subquery: DbSet() - .Where(w0 => EF.Property(g, ""FullName"") != null && EF.Property(g, ""FullName"") == EF.Property(w0, ""OwnerFullName"")) - .Where(i => EF.Property(g, ""FullName"") != null && EF.Property(g, ""FullName"") == EF.Property(i, ""OwnerFullName"")))"), + .Where(w0 => EF.Property(g, ""FullName"") != null && object.Equals( + objA: EF.Property(g, ""FullName""), + objB: EF.Property(w0, ""OwnerFullName""))) + .Where(i => EF.Property(g, ""FullName"") != null && object.Equals( + objA: EF.Property(g, ""FullName""), + objB: EF.Property(i, ""OwnerFullName""))))"), message, ignoreLineEndingDifferences: true); } @@ -1557,14 +1565,22 @@ public virtual async Task Concat_with_collection_navigations(bool async) @"MaterializeCollectionNavigation( Navigation: Gear.Weapons, subquery: DbSet() - .Where(w => EF.Property(g, ""FullName"") != null && EF.Property(g, ""FullName"") == EF.Property(w, ""OwnerFullName"")) - .Where(i => EF.Property(g, ""FullName"") != null && EF.Property(g, ""FullName"") == EF.Property(i, ""OwnerFullName"")) + .Where(w => EF.Property(g, ""FullName"") != null && object.Equals( + objA: EF.Property(g, ""FullName""), + objB: EF.Property(w, ""OwnerFullName""))) + .Where(i => EF.Property(g, ""FullName"") != null && object.Equals( + objA: EF.Property(g, ""FullName""), + objB: EF.Property(i, ""OwnerFullName""))) .AsQueryable() .Union(MaterializeCollectionNavigation( Navigation: Gear.Weapons, subquery: DbSet() - .Where(w0 => EF.Property(g, ""FullName"") != null && EF.Property(g, ""FullName"") == EF.Property(w0, ""OwnerFullName"")) - .Where(i => EF.Property(g, ""FullName"") != null && EF.Property(g, ""FullName"") == EF.Property(i, ""OwnerFullName"")))"), + .Where(w0 => EF.Property(g, ""FullName"") != null && object.Equals( + objA: EF.Property(g, ""FullName""), + objB: EF.Property(w0, ""OwnerFullName""))) + .Where(i => EF.Property(g, ""FullName"") != null && object.Equals( + objA: EF.Property(g, ""FullName""), + objB: EF.Property(i, ""OwnerFullName""))))"), message, ignoreLineEndingDifferences: true); }