From 13034236ab488b12815d27998dbabaf4ebda6437 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Tue, 3 Mar 2020 16:36:37 -0800 Subject: [PATCH] Query: Convert FromSql methods to custom query roots - Avoids any specific processing in core provider - Takes care of parameterization correctly - Also improved parameter extraction by creating parameter for any expression outside of lambda even when not a method call argument Part of #20146 --- .../RelationalQueryableExtensions.cs | 44 ++++++------ ...yFilterDefiningQueryRewritingConvention.cs | 34 +++++---- .../Internal/FromSqlQueryRootExpression.cs | 72 +++++++++++++++++++ ...yableMethodTranslatingExpressionVisitor.cs | 27 ++++--- .../Query/SqlExpressions/SelectExpression.cs | 6 +- ...yFilterDefiningQueryRewritingConvention.cs | 1 - ...ntityEqualityRewritingExpressionVisitor.cs | 2 +- .../NavigationExpandingExpressionVisitor.cs | 21 +----- .../ParameterExtractingExpressionVisitor.cs | 21 +++--- .../Query/UdfDbFunctionSqlServerTests.cs | 23 +++--- 10 files changed, 153 insertions(+), 98 deletions(-) create mode 100644 src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs diff --git a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs index 0a949e20fd4..4b415757ef4 100644 --- a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs +++ b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs @@ -54,11 +54,6 @@ public static DbCommand CreateDbCommand([NotNull] this IQueryable source) throw new NotSupportedException(RelationalStrings.NoDbCommand); } - internal static readonly MethodInfo FromSqlOnQueryableMethodInfo - = typeof(RelationalQueryableExtensions) - .GetTypeInfo().GetDeclaredMethods(nameof(FromSqlOnQueryable)) - .Single(); - /// /// /// Creates a LINQ query based on a raw SQL query. @@ -100,12 +95,10 @@ public static IQueryable FromSqlRaw( var queryableSource = (IQueryable)source; return queryableSource.Provider.CreateQuery( - Expression.Call( - null, - FromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)), - queryableSource.Expression, - Expression.Constant(sql), - Expression.Constant(parameters))); + GenerateFromSqlQueryRoot( + queryableSource, + sql, + parameters)); } /// @@ -140,19 +133,24 @@ public static IQueryable FromSqlInterpolated( var queryableSource = (IQueryable)source; return queryableSource.Provider.CreateQuery( - Expression.Call( - null, - FromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)), - queryableSource.Expression, - Expression.Constant(sql.Format), - Expression.Constant(sql.GetArguments()))); + GenerateFromSqlQueryRoot( + queryableSource, + sql.Format, + sql.GetArguments())); } - internal static IQueryable FromSqlOnQueryable( - [NotNull] this IQueryable source, - [NotParameterized] string sql, - [NotNull] params object[] parameters) - where TEntity : class - => throw new NotImplementedException(); + private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot( + IQueryable source, + string sql, + object[] arguments) + { + var queryRootExpression = (QueryRootExpression)source.Expression; + + return new FromSqlQueryRootExpression( + queryRootExpression.QueryProvider, + queryRootExpression.EntityType, + sql, + Expression.Constant(arguments)); + } } } diff --git a/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterDefiningQueryRewritingConvention.cs b/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterDefiningQueryRewritingConvention.cs index b59737adf3f..3e7dd8ef3cf 100644 --- a/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterDefiningQueryRewritingConvention.cs +++ b/src/EFCore.Relational/Metadata/Conventions/RelationalQueryFilterDefiningQueryRewritingConvention.cs @@ -5,6 +5,8 @@ using System.Linq.Expressions; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Metadata.Conventions @@ -40,30 +42,26 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp && (methodName == nameof(RelationalQueryableExtensions.FromSqlRaw) || methodName == nameof(RelationalQueryableExtensions.FromSqlInterpolated))) { - var newSource = Visit(methodCallExpression.Arguments[0]); - var fromSqlOnQueryableMethod = - RelationalQueryableExtensions.FromSqlOnQueryableMethodInfo.MakeGenericMethod( - newSource.Type.GetGenericArguments()[0]); + var newSource = (QueryRootExpression)Visit(methodCallExpression.Arguments[0]); + + string sql; + Expression argument; if (methodName == nameof(RelationalQueryableExtensions.FromSqlRaw)) { - return Expression.Call( - null, - fromSqlOnQueryableMethod, - newSource, - methodCallExpression.Arguments[1], - methodCallExpression.Arguments[2]); + sql = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value; + argument = methodCallExpression.Arguments[2]; } + else + { + var formattableString = Expression.Lambda>( + Expression.Convert(methodCallExpression.Arguments[1], typeof(FormattableString))).Compile().Invoke(); - var formattableString = Expression.Lambda>( - Expression.Convert(methodCallExpression.Arguments[1], typeof(FormattableString))).Compile().Invoke(); + sql = formattableString.Format; + argument = Expression.Constant(formattableString.GetArguments()); + } - return Expression.Call( - null, - fromSqlOnQueryableMethod, - newSource, - Expression.Constant(formattableString.Format), - Expression.Constant(formattableString.GetArguments())); + return new FromSqlQueryRootExpression(newSource.EntityType, sql, argument); } return base.VisitMethodCall(methodCallExpression); diff --git a/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs b/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs new file mode 100644 index 00000000000..f10c59edab7 --- /dev/null +++ b/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs @@ -0,0 +1,72 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq.Expressions; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + public class FromSqlQueryRootExpression : QueryRootExpression + { + public FromSqlQueryRootExpression( + [NotNull] IAsyncQueryProvider queryProvider, [NotNull] IEntityType entityType, [NotNull] string sql, [NotNull] Expression argument) + : base(queryProvider, entityType) + { + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(argument, nameof(argument)); + + Sql = sql; + Argument = argument; + } + + public FromSqlQueryRootExpression( + [NotNull] IEntityType entityType, [NotNull] string sql, [NotNull] Expression argument) + : base(entityType) + { + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(argument, nameof(argument)); + + Sql = sql; + Argument = argument; + } + + public virtual string Sql { get; } + public virtual Expression Argument { get; } + + public override Expression DetachQueryProvider() => new FromSqlQueryRootExpression(EntityType, Sql, Argument); + + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var argument = visitor.Visit(Argument); + + return argument != Argument + ? new FromSqlQueryRootExpression(EntityType, Sql, argument) + : this; + } + + public override void Print(ExpressionPrinter expressionPrinter) + { + base.Print(expressionPrinter); + expressionPrinter.Append($".FromSql({Sql}, "); + expressionPrinter.Visit(Argument); + expressionPrinter.AppendLine(")"); + } + + public override bool Equals(object obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is FromSqlQueryRootExpression queryRootExpression + && Equals(queryRootExpression)); + + private bool Equals(FromSqlQueryRootExpression queryRootExpression) + => base.Equals(queryRootExpression) + && string.Equals(Sql, queryRootExpression.Sql, StringComparison.OrdinalIgnoreCase) + && ExpressionEqualityComparer.Instance.Equals(Argument, queryRootExpression.Argument); + + public override int GetHashCode() + => HashCode.Combine(base.GetHashCode(), Sql, ExpressionEqualityComparer.Instance.GetHashCode(Argument)); + } +} diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index e0758a0b3ce..eed6ebcb752 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -63,21 +63,26 @@ protected RelationalQueryableMethodTranslatingExpressionVisitor( _subquery = true; } - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + protected override Expression VisitExtension(Expression extensionExpression) { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - - if (methodCallExpression.Method.DeclaringType == typeof(RelationalQueryableExtensions) - && methodCallExpression.Method.Name == nameof(RelationalQueryableExtensions.FromSqlOnQueryable)) + if (extensionExpression is FromSqlQueryRootExpression fromSqlQueryRootExpression) { - var sql = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value; - var entityType = ((QueryRootExpression)methodCallExpression.Arguments[0]).EntityType; - return CreateShapedQueryExpression( - entityType, _sqlExpressionFactory.Select(entityType, sql, methodCallExpression.Arguments[2])); + fromSqlQueryRootExpression.EntityType, + _sqlExpressionFactory.Select( + fromSqlQueryRootExpression.EntityType, + fromSqlQueryRootExpression.Sql, + fromSqlQueryRootExpression.Argument)); } - var dbFunction = this._model.FindDbFunction(methodCallExpression.Method); + return base.VisitExtension(extensionExpression); + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + Check.NotNull(methodCallExpression, nameof(methodCallExpression)); + + var dbFunction = _model.FindDbFunction(methodCallExpression.Method); if (dbFunction != null && dbFunction.IsIQueryable) { return CreateShapedQueryExpression(methodCallExpression); @@ -94,7 +99,7 @@ protected virtual ShapedQueryExpression CreateShapedQueryExpression([NotNull] Me var sqlFuncExpression = _sqlTranslator.TranslateMethodCall(methodCallExpression) as SqlFunctionExpression; var elementType = methodCallExpression.Method.ReturnType.GetGenericArguments()[0]; - var entityType =_model.FindEntityType(elementType); + var entityType = _model.FindEntityType(elementType); var queryExpression = _sqlExpressionFactory.Select(entityType, sqlFuncExpression); return CreateShapedQueryExpression(entityType, queryExpression); diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 29d55e59b6e..21d002e89f3 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -932,14 +932,14 @@ public Expression ApplyCollectionJoin( var (parentIdentifier, parentIdentifierValueComparers) = GetIdentifierAccessor(_identifier); var (outerIdentifier, outerIdentifierValueComparers) = GetIdentifierAccessor(_identifier.Concat(_childIdentifiers)); innerSelectExpression.ApplyProjection(); - - if (innerSelectExpression._identifier.Count == 0 && innerSelectExpression.Tables.FirstOrDefault( + + if (innerSelectExpression._identifier.Count == 0 && innerSelectExpression.Tables.FirstOrDefault( t => t is QueryableSqlFunctionExpression expression && expression.SqlFunctionExpression.Arguments.Count != 0) is QueryableSqlFunctionExpression queryableFunctionExpression) { throw new InvalidOperationException(RelationalStrings.DbFunctionProjectedCollectionMustHavePK(queryableFunctionExpression.SqlFunctionExpression.Name)); } - var (selfIdentifier, selfIdentifierValueComparers) = innerSelectExpression.GetIdentifierAccessor(innerSelectExpression._identifier); + var (selfIdentifier, selfIdentifierValueComparers) = innerSelectExpression.GetIdentifierAccessor(innerSelectExpression._identifier); if (collectionIndex == 0) { diff --git a/src/EFCore/Metadata/Conventions/QueryFilterDefiningQueryRewritingConvention.cs b/src/EFCore/Metadata/Conventions/QueryFilterDefiningQueryRewritingConvention.cs index 107d0faaca9..9380de3c33d 100644 --- a/src/EFCore/Metadata/Conventions/QueryFilterDefiningQueryRewritingConvention.cs +++ b/src/EFCore/Metadata/Conventions/QueryFilterDefiningQueryRewritingConvention.cs @@ -7,7 +7,6 @@ using Microsoft.EntityFrameworkCore.Metadata.Builders; using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure; using Microsoft.EntityFrameworkCore.Query; -using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Metadata.Conventions diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs index 46932dbd94f..47f2a29bc9a 100644 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -348,7 +348,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } // Methods with a typed first argument (source), and with no lambda arguments or a single lambda - // argument that has one parameter are rewritten automatically (e.g. Where(), FromSql(), Average() + // argument that has one parameter are rewritten automatically (e.g. Where(), Average() var newArguments = new Expression[arguments.Count]; var lambdaArgs = arguments.Select(a => a.GetLambdaOrNull()).Where(l => l != null).ToArray(); newSource = Visit(arguments[0]); diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs index e7f3332c07c..6d5e6871b16 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs @@ -114,7 +114,9 @@ protected override Expression VisitExtension(Expression extensionExpression) var entityType = queryRootExpression.EntityType; var definingQuery = entityType.GetDefiningQuery(); NavigationExpansionExpression navigationExpansionExpression; - if (definingQuery != null) + if (definingQuery != null + // Apply defining query only when it is not custom query root + && queryRootExpression.GetType() == typeof(QueryRootExpression)) { var processedDefiningQueryBody = _parameterExtractingExpressionVisitor.ExtractParameters(definingQuery.Body); processedDefiningQueryBody = _queryTranslationPreprocessor.NormalizeQueryableMethodCall(processedDefiningQueryBody); @@ -518,23 +520,6 @@ when QueryableMethods.IsSumWithSelector(method): return methodCallExpression.Update(null, new[] { argument }); } - if (method.IsGenericMethod - && method.Name == "FromSqlOnQueryable" - && methodCallExpression.Arguments.Count == 3 - && methodCallExpression.Arguments[0] is QueryRootExpression queryRootExpression - && methodCallExpression.Arguments[1] is ConstantExpression - && (methodCallExpression.Arguments[2] is ParameterExpression || methodCallExpression.Arguments[2] is ConstantExpression)) - { - var entityType = queryRootExpression.EntityType; - var source = CreateNavigationExpansionExpression(queryRootExpression, entityType); - source.UpdateSource( - methodCallExpression.Update( - null, - new[] { source.Source, methodCallExpression.Arguments[1], methodCallExpression.Arguments[2] })); - - return ApplyQueryFilter(source); - } - return ProcessUnknownMethod(methodCallExpression); } diff --git a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs index 853516646e1..16439f5ba74 100644 --- a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs @@ -250,7 +250,8 @@ protected override Expression VisitExtension(Expression extensionExpression) throw new InvalidOperationException(CoreStrings.ErrorInvalidQueryable); } - return queryRootExpression.DetachQueryProvider(); + // Visit after detaching query provider since custom query roots can have additional components + extensionExpression = queryRootExpression.DetachQueryProvider(); } return base.VisitExtension(extensionExpression); @@ -500,7 +501,8 @@ public override Expression Visit(Expression expression) if (_evaluatable) { - _evaluatableExpressions[expression] = _containsClosure; + // Force parameterization when not in lambda + _evaluatableExpressions[expression] = _containsClosure || !_inLambda; } _evaluatable = parentEvaluatable && _evaluatable; @@ -575,18 +577,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp Visit(methodCallExpression.Arguments[i]); - if (_evaluatableExpressions.ContainsKey(methodCallExpression.Arguments[i])) + if (_evaluatableExpressions.ContainsKey(methodCallExpression.Arguments[i]) + && (parameterInfos[i].GetCustomAttribute() != null + || _model.IsIndexerMethod(methodCallExpression.Method))) { - if (parameterInfos[i].GetCustomAttribute() != null - || _model.IsIndexerMethod(methodCallExpression.Method)) - { - _evaluatableExpressions[methodCallExpression.Arguments[i]] = false; - } - else if (!_inLambda) - { - // Force parameterization when not in lambda - _evaluatableExpressions[methodCallExpression.Arguments[i]] = true; - } + _evaluatableExpressions[methodCallExpression.Arguments[i]] = false; } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs index 6bdbeca535f..8332d5de3ca 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs @@ -468,10 +468,11 @@ public override void QF_Stand_Alone_Parameter() { base.QF_Stand_Alone_Parameter(); - AssertSql(@"@__customerId_0='1' + AssertSql( + @"@__customerId_1='1' SELECT [o].[Count], [o].[CustomerId], [o].[Year] -FROM [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [o] +FROM [dbo].[GetCustomerOrderCountByYear](@__customerId_1) AS [o] ORDER BY [o].[Count] DESC"); } @@ -601,14 +602,14 @@ public override void QF_Select_NonCorrelated_Subquery_In_Anonymous_Parameter() base.QF_Select_NonCorrelated_Subquery_In_Anonymous_Parameter(); AssertSql( - @"@__amount_0='27' (Nullable = true) + @"@__amount_1='27' (Nullable = true) SELECT [c].[Id], [t0].[ProductId] FROM [Customers] AS [c] OUTER APPLY ( SELECT [t].[ProductId] FROM [dbo].[GetTopTwoSellingProducts]() AS [t] - WHERE [t].[AmountSold] = @__amount_0 + WHERE [t].[AmountSold] = @__amount_1 ) AS [t0] ORDER BY [c].[Id]"); } @@ -637,11 +638,12 @@ public override void QF_CrossJoin_Not_Correlated() { base.QF_CrossJoin_Not_Correlated(); - AssertSql(@"@__customerId_0='2' + AssertSql( + @"@__customerId_1='2' SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count] FROM [Customers] AS [c] -CROSS JOIN [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [o] +CROSS JOIN [dbo].[GetCustomerOrderCountByYear](@__customerId_1) AS [o] WHERE [c].[Id] = 2 ORDER BY [o].[Count]"); } @@ -650,13 +652,14 @@ public override void QF_CrossJoin_Parameter() { base.QF_CrossJoin_Parameter(); - AssertSql(@"@__customerId_0='2' -@__custId_1='2' + AssertSql( + @"@__customerId_1='2' +@__custId_2='2' SELECT [c].[Id], [c].[LastName], [o].[Year], [o].[Count] FROM [Customers] AS [c] -CROSS JOIN [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [o] -WHERE [c].[Id] = @__custId_1 +CROSS JOIN [dbo].[GetCustomerOrderCountByYear](@__customerId_1) AS [o] +WHERE [c].[Id] = @__custId_2 ORDER BY [o].[Count]"); }