Skip to content

Commit

Permalink
Query: Convert FromSql methods to custom query roots
Browse files Browse the repository at this point in the history
- Avoids any specific processing in core provider
- Takes care of parameterization correctly
- Also improved parameter extraction by creating parameter for any expression outside of lambda even when not a method call argument

Part of #20146
  • Loading branch information
smitpatel committed Mar 4, 2020
1 parent 1ecd859 commit 1303423
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 98 deletions.
44 changes: 21 additions & 23 deletions src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ public static DbCommand CreateDbCommand([NotNull] this IQueryable source)
throw new NotSupportedException(RelationalStrings.NoDbCommand);
}

internal static readonly MethodInfo FromSqlOnQueryableMethodInfo
= typeof(RelationalQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(FromSqlOnQueryable))
.Single();

/// <summary>
/// <para>
/// Creates a LINQ query based on a raw SQL query.
Expand Down Expand Up @@ -100,12 +95,10 @@ public static IQueryable<TEntity> FromSqlRaw<TEntity>(

var queryableSource = (IQueryable)source;
return queryableSource.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
queryableSource.Expression,
Expression.Constant(sql),
Expression.Constant(parameters)));
GenerateFromSqlQueryRoot(
queryableSource,
sql,
parameters));
}

/// <summary>
Expand Down Expand Up @@ -140,19 +133,24 @@ public static IQueryable<TEntity> FromSqlInterpolated<TEntity>(

var queryableSource = (IQueryable)source;
return queryableSource.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
queryableSource.Expression,
Expression.Constant(sql.Format),
Expression.Constant(sql.GetArguments())));
GenerateFromSqlQueryRoot(
queryableSource,
sql.Format,
sql.GetArguments()));
}

internal static IQueryable<TEntity> FromSqlOnQueryable<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotParameterized] string sql,
[NotNull] params object[] parameters)
where TEntity : class
=> throw new NotImplementedException();
private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot(
IQueryable source,
string sql,
object[] arguments)
{
var queryRootExpression = (QueryRootExpression)source.Expression;

return new FromSqlQueryRootExpression(
queryRootExpression.QueryProvider,
queryRootExpression.EntityType,
sql,
Expression.Constant(arguments));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
Expand Down Expand Up @@ -40,30 +42,26 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
&& (methodName == nameof(RelationalQueryableExtensions.FromSqlRaw)
|| methodName == nameof(RelationalQueryableExtensions.FromSqlInterpolated)))
{
var newSource = Visit(methodCallExpression.Arguments[0]);
var fromSqlOnQueryableMethod =
RelationalQueryableExtensions.FromSqlOnQueryableMethodInfo.MakeGenericMethod(
newSource.Type.GetGenericArguments()[0]);
var newSource = (QueryRootExpression)Visit(methodCallExpression.Arguments[0]);

string sql;
Expression argument;

if (methodName == nameof(RelationalQueryableExtensions.FromSqlRaw))
{
return Expression.Call(
null,
fromSqlOnQueryableMethod,
newSource,
methodCallExpression.Arguments[1],
methodCallExpression.Arguments[2]);
sql = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value;
argument = methodCallExpression.Arguments[2];
}
else
{
var formattableString = Expression.Lambda<Func<FormattableString>>(
Expression.Convert(methodCallExpression.Arguments[1], typeof(FormattableString))).Compile().Invoke();

var formattableString = Expression.Lambda<Func<FormattableString>>(
Expression.Convert(methodCallExpression.Arguments[1], typeof(FormattableString))).Compile().Invoke();
sql = formattableString.Format;
argument = Expression.Constant(formattableString.GetArguments());
}

return Expression.Call(
null,
fromSqlOnQueryableMethod,
newSource,
Expression.Constant(formattableString.Format),
Expression.Constant(formattableString.GetArguments()));
return new FromSqlQueryRootExpression(newSource.EntityType, sql, argument);
}

return base.VisitMethodCall(methodCallExpression);
Expand Down
72 changes: 72 additions & 0 deletions src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
public class FromSqlQueryRootExpression : QueryRootExpression
{
public FromSqlQueryRootExpression(
[NotNull] IAsyncQueryProvider queryProvider, [NotNull] IEntityType entityType, [NotNull] string sql, [NotNull] Expression argument)
: base(queryProvider, entityType)
{
Check.NotEmpty(sql, nameof(sql));
Check.NotNull(argument, nameof(argument));

Sql = sql;
Argument = argument;
}

public FromSqlQueryRootExpression(
[NotNull] IEntityType entityType, [NotNull] string sql, [NotNull] Expression argument)
: base(entityType)
{
Check.NotEmpty(sql, nameof(sql));
Check.NotNull(argument, nameof(argument));

Sql = sql;
Argument = argument;
}

public virtual string Sql { get; }
public virtual Expression Argument { get; }

public override Expression DetachQueryProvider() => new FromSqlQueryRootExpression(EntityType, Sql, Argument);

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var argument = visitor.Visit(Argument);

return argument != Argument
? new FromSqlQueryRootExpression(EntityType, Sql, argument)
: this;
}

public override void Print(ExpressionPrinter expressionPrinter)
{
base.Print(expressionPrinter);
expressionPrinter.Append($".FromSql({Sql}, ");
expressionPrinter.Visit(Argument);
expressionPrinter.AppendLine(")");
}

public override bool Equals(object obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is FromSqlQueryRootExpression queryRootExpression
&& Equals(queryRootExpression));

private bool Equals(FromSqlQueryRootExpression queryRootExpression)
=> base.Equals(queryRootExpression)
&& string.Equals(Sql, queryRootExpression.Sql, StringComparison.OrdinalIgnoreCase)
&& ExpressionEqualityComparer.Instance.Equals(Argument, queryRootExpression.Argument);

public override int GetHashCode()
=> HashCode.Combine(base.GetHashCode(), Sql, ExpressionEqualityComparer.Instance.GetHashCode(Argument));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,26 @@ protected RelationalQueryableMethodTranslatingExpressionVisitor(
_subquery = true;
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
protected override Expression VisitExtension(Expression extensionExpression)
{
Check.NotNull(methodCallExpression, nameof(methodCallExpression));

if (methodCallExpression.Method.DeclaringType == typeof(RelationalQueryableExtensions)
&& methodCallExpression.Method.Name == nameof(RelationalQueryableExtensions.FromSqlOnQueryable))
if (extensionExpression is FromSqlQueryRootExpression fromSqlQueryRootExpression)
{
var sql = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value;
var entityType = ((QueryRootExpression)methodCallExpression.Arguments[0]).EntityType;

return CreateShapedQueryExpression(
entityType, _sqlExpressionFactory.Select(entityType, sql, methodCallExpression.Arguments[2]));
fromSqlQueryRootExpression.EntityType,
_sqlExpressionFactory.Select(
fromSqlQueryRootExpression.EntityType,
fromSqlQueryRootExpression.Sql,
fromSqlQueryRootExpression.Argument));
}

var dbFunction = this._model.FindDbFunction(methodCallExpression.Method);
return base.VisitExtension(extensionExpression);
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
Check.NotNull(methodCallExpression, nameof(methodCallExpression));

var dbFunction = _model.FindDbFunction(methodCallExpression.Method);
if (dbFunction != null && dbFunction.IsIQueryable)
{
return CreateShapedQueryExpression(methodCallExpression);
Expand All @@ -94,7 +99,7 @@ protected virtual ShapedQueryExpression CreateShapedQueryExpression([NotNull] Me
var sqlFuncExpression = _sqlTranslator.TranslateMethodCall(methodCallExpression) as SqlFunctionExpression;

var elementType = methodCallExpression.Method.ReturnType.GetGenericArguments()[0];
var entityType =_model.FindEntityType(elementType);
var entityType = _model.FindEntityType(elementType);
var queryExpression = _sqlExpressionFactory.Select(entityType, sqlFuncExpression);

return CreateShapedQueryExpression(entityType, queryExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -932,14 +932,14 @@ public Expression ApplyCollectionJoin(
var (parentIdentifier, parentIdentifierValueComparers) = GetIdentifierAccessor(_identifier);
var (outerIdentifier, outerIdentifierValueComparers) = GetIdentifierAccessor(_identifier.Concat(_childIdentifiers));
innerSelectExpression.ApplyProjection();
if (innerSelectExpression._identifier.Count == 0 && innerSelectExpression.Tables.FirstOrDefault(

if (innerSelectExpression._identifier.Count == 0 && innerSelectExpression.Tables.FirstOrDefault(
t => t is QueryableSqlFunctionExpression expression && expression.SqlFunctionExpression.Arguments.Count != 0) is QueryableSqlFunctionExpression queryableFunctionExpression)
{
throw new InvalidOperationException(RelationalStrings.DbFunctionProjectedCollectionMustHavePK(queryableFunctionExpression.SqlFunctionExpression.Name));
}

var (selfIdentifier, selfIdentifierValueComparers) = innerSelectExpression.GetIdentifierAccessor(innerSelectExpression._identifier);
var (selfIdentifier, selfIdentifierValueComparers) = innerSelectExpression.GetIdentifierAccessor(innerSelectExpression._identifier);

if (collectionIndex == 0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

// Methods with a typed first argument (source), and with no lambda arguments or a single lambda
// argument that has one parameter are rewritten automatically (e.g. Where(), FromSql(), Average()
// argument that has one parameter are rewritten automatically (e.g. Where(), Average()
var newArguments = new Expression[arguments.Count];
var lambdaArgs = arguments.Select(a => a.GetLambdaOrNull()).Where(l => l != null).ToArray();
newSource = Visit(arguments[0]);
Expand Down
21 changes: 3 additions & 18 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
var entityType = queryRootExpression.EntityType;
var definingQuery = entityType.GetDefiningQuery();
NavigationExpansionExpression navigationExpansionExpression;
if (definingQuery != null)
if (definingQuery != null
// Apply defining query only when it is not custom query root
&& queryRootExpression.GetType() == typeof(QueryRootExpression))
{
var processedDefiningQueryBody = _parameterExtractingExpressionVisitor.ExtractParameters(definingQuery.Body);
processedDefiningQueryBody = _queryTranslationPreprocessor.NormalizeQueryableMethodCall(processedDefiningQueryBody);
Expand Down Expand Up @@ -518,23 +520,6 @@ when QueryableMethods.IsSumWithSelector(method):
return methodCallExpression.Update(null, new[] { argument });
}

if (method.IsGenericMethod
&& method.Name == "FromSqlOnQueryable"
&& methodCallExpression.Arguments.Count == 3
&& methodCallExpression.Arguments[0] is QueryRootExpression queryRootExpression
&& methodCallExpression.Arguments[1] is ConstantExpression
&& (methodCallExpression.Arguments[2] is ParameterExpression || methodCallExpression.Arguments[2] is ConstantExpression))
{
var entityType = queryRootExpression.EntityType;
var source = CreateNavigationExpansionExpression(queryRootExpression, entityType);
source.UpdateSource(
methodCallExpression.Update(
null,
new[] { source.Source, methodCallExpression.Arguments[1], methodCallExpression.Arguments[2] }));

return ApplyQueryFilter(source);
}

return ProcessUnknownMethod(methodCallExpression);
}

Expand Down
21 changes: 8 additions & 13 deletions src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ protected override Expression VisitExtension(Expression extensionExpression)
throw new InvalidOperationException(CoreStrings.ErrorInvalidQueryable);
}

return queryRootExpression.DetachQueryProvider();
// Visit after detaching query provider since custom query roots can have additional components
extensionExpression = queryRootExpression.DetachQueryProvider();
}

return base.VisitExtension(extensionExpression);
Expand Down Expand Up @@ -500,7 +501,8 @@ public override Expression Visit(Expression expression)

if (_evaluatable)
{
_evaluatableExpressions[expression] = _containsClosure;
// Force parameterization when not in lambda
_evaluatableExpressions[expression] = _containsClosure || !_inLambda;
}

_evaluatable = parentEvaluatable && _evaluatable;
Expand Down Expand Up @@ -575,18 +577,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

Visit(methodCallExpression.Arguments[i]);

if (_evaluatableExpressions.ContainsKey(methodCallExpression.Arguments[i]))
if (_evaluatableExpressions.ContainsKey(methodCallExpression.Arguments[i])
&& (parameterInfos[i].GetCustomAttribute<NotParameterizedAttribute>() != null
|| _model.IsIndexerMethod(methodCallExpression.Method)))
{
if (parameterInfos[i].GetCustomAttribute<NotParameterizedAttribute>() != null
|| _model.IsIndexerMethod(methodCallExpression.Method))
{
_evaluatableExpressions[methodCallExpression.Arguments[i]] = false;
}
else if (!_inLambda)
{
// Force parameterization when not in lambda
_evaluatableExpressions[methodCallExpression.Arguments[i]] = true;
}
_evaluatableExpressions[methodCallExpression.Arguments[i]] = false;
}
}

Expand Down
Loading

0 comments on commit 1303423

Please sign in to comment.