Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TINY] Fix to #15763 - "collection selector was not NavigationExpansionExpression" when joining two FromSql()s #15963

Merged
merged 1 commit into from
Jun 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions src/EFCore/Extensions/Internal/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
using System.Runtime.Versioning;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Extensions.Internal;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Utilities;
Expand Down Expand Up @@ -363,5 +361,74 @@ private static readonly Type _assignBinaryExpressionType

private static readonly MethodInfo _fieldInfoSetValueMethod
= typeof(FieldInfo).GetRuntimeMethod(nameof(FieldInfo.SetValue), new[] { typeof(object), typeof(object) });

public static LambdaExpression GetLambdaOrNull(this Expression expression)
=> expression is LambdaExpression lambda
? lambda
: expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote
? (LambdaExpression)unary.Operand
: null;

public static LambdaExpression UnwrapQuote(this Expression expression)
=> expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote
? (LambdaExpression)unary.Operand
: (LambdaExpression)expression;

public static bool IsIncludeMethod(this MethodCallExpression methodCallExpression)
=> methodCallExpression.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
&& methodCallExpression.Method.Name == nameof(EntityFrameworkQueryableExtensions.Include);

public static Expression BuildPropertyAccess(this Expression root, List<string> path)
{
var result = root;
foreach (var pathElement in path)
{
result = Expression.PropertyOrField(result, pathElement);
}

return result;
}

public static Expression CombineAndRemap(
Expression source,
ParameterExpression sourceParameter,
Expression replaceWith)
=> new ExpressionCombiningVisitor(sourceParameter, replaceWith).Visit(source);

public class ExpressionCombiningVisitor : ExpressionVisitor
{
private ParameterExpression _sourceParameter;
private Expression _replaceWith;

public ExpressionCombiningVisitor(
ParameterExpression sourceParameter,
Expression replaceWith)
{
_sourceParameter = sourceParameter;
_replaceWith = replaceWith;
}

protected override Expression VisitParameter(ParameterExpression parameterExpression)
=> parameterExpression == _sourceParameter
? _replaceWith
: base.VisitParameter(parameterExpression);

protected override Expression VisitMember(MemberExpression memberExpression)
{
var newSource = Visit(memberExpression.Expression);
if (newSource is NewExpression newExpression)
{
var matchingMemberIndex = newExpression.Members.Select((m, i) => new { index = i, match = m == memberExpression.Member }).Where(r => r.match).SingleOrDefault()?.index;
if (matchingMemberIndex.HasValue)
{
return newExpression.Arguments[matchingMemberIndex.Value];
}
}

return newSource != memberExpression.Expression
? memberExpression.Update(newSource)
: memberExpression;
}
}
}
}
75 changes: 0 additions & 75 deletions src/EFCore/Query/NavigationExpansion/ExpressionExtensions.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Extensions.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Extensions.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Internal;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
case nameof(Queryable.DefaultIfEmpty):
return ProcessDefaultIfEmpty(methodCallExpression);

case "AsTracking":
case "AsNoTracking":
return ProcessBasicTerminatingOperation(methodCallExpression);

case nameof(Queryable.First):
case nameof(Queryable.FirstOrDefault):
case nameof(Queryable.Single):
Expand All @@ -99,18 +95,48 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
case "ThenInclude":
return ProcessInclude(methodCallExpression);

//TODO: should we have relational version of this? - probably
case "FromSqlRaw":
return ProcessFromRawSql(methodCallExpression);

case nameof(EntityFrameworkQueryableExtensions.TagWith):
return ProcessWithTag(methodCallExpression);

default:
return base.VisitMethodCall(methodCallExpression);
return ProcessUnknownMethod(methodCallExpression);
}
}

private Expression ProcessUnknownMethod(MethodCallExpression methodCallExpression)
{
var resultSequenceType = TryGetNonPrimitiveSequenceType(methodCallExpression.Type);

// result is a sequence, no lambda arguments, exactly one generic argument corresponding to result sequence type
if (methodCallExpression.Object == null
&& resultSequenceType != null
&& methodCallExpression.Arguments.All(a => a.GetLambdaOrNull() == null)
&& methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericArguments().Length == 1
&& methodCallExpression.Method.GetGenericArguments()[0] == resultSequenceType)
{
var argumentSequenceTypes = methodCallExpression.Arguments.Select(a => TryGetNonPrimitiveSequenceType(a.Type)).ToList();
if (argumentSequenceTypes.FirstOrDefault() == resultSequenceType
&& argumentSequenceTypes.Count(t => t != null) == 1)
{
var source = VisitSourceExpression(methodCallExpression.Arguments[0]);
var preProcessResult = PreProcessTerminatingOperation(source);
var newArguments = methodCallExpression.Arguments.Skip(1).Select(Visit).ToList();
newArguments.Insert(0, preProcessResult.source);

var methodInfo = methodCallExpression.Method.GetGenericMethodDefinition().MakeGenericMethod(preProcessResult.state.CurrentParameter.Type);
var rewritten = Expression.Call(methodInfo, newArguments);

return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type);
}
}

return base.VisitMethodCall(methodCallExpression);
}

private Type TryGetNonPrimitiveSequenceType(Type type)
=> type == typeof(string) || type.IsArray ? null : type.TryGetSequenceType();

private NavigationExpansionExpression VisitSourceExpression(Expression sourceExpression)
{
var result = Visit(sourceExpression);
Expand Down Expand Up @@ -845,17 +871,6 @@ private Expression ProcessSkipTake(MethodCallExpression methodCallExpression)
return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type);
}

private Expression ProcessBasicTerminatingOperation(MethodCallExpression methodCallExpression)
{
var source = VisitSourceExpression(methodCallExpression.Arguments[0]);
var preProcessResult = PreProcessTerminatingOperation(source);
var newArguments = methodCallExpression.Arguments.Skip(1).ToList();
newArguments.Insert(0, preProcessResult.source);
var rewritten = methodCallExpression.Update(methodCallExpression.Object, newArguments);

return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type);
}

private (Expression source, NavigationExpansionExpressionState state) PreProcessTerminatingOperation(NavigationExpansionExpression source)
{
var applyOrderingsResult = ApplyPendingOrderings(source.Operand, source.State);
Expand Down Expand Up @@ -1172,7 +1187,7 @@ private Expression ProcessCardinalityReducingOperation(MethodCallExpression meth
return new NavigationExpansionExpression(applyOrderingsResult.source, applyOrderingsResult.state, methodCallExpression.Type);
}

private Expression ProcessFromRawSql(MethodCallExpression methodCallExpression)
private Expression ProcessFromSql(MethodCallExpression methodCallExpression)
{
var source = VisitSourceExpression(methodCallExpression.Arguments[0]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Extensions.Internal;
using Microsoft.EntityFrameworkCore.Internal;

namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;

namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Internal;

namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// Methods with a typed first argument (source), and with no lambda arguments or a single lambda
// argument that has one parameter are rewritten automatically (e.g. Where(), FromSql(), Average()
var newArguments = new Expression[arguments.Count];
var lambdaArgs = arguments.Select(GetLambdaOrNull).Where(l => l != null).ToArray();
var lambdaArgs = arguments.Select(a => a.GetLambdaOrNull()).Where(l => l != null).ToArray();
newSource = Visit(arguments[0]);
newArguments[0] = Unwrap(newSource);
if (methodCallExpression.Object == null
Expand All @@ -194,7 +194,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
for (var i = 1; i < arguments.Count; i++)
{
// Visit all arguments, rewriting the single lambda to replace its parameter expression
newArguments[i] = GetLambdaOrNull(arguments[i]) is LambdaExpression lambda
newArguments[i] = arguments[i].GetLambdaOrNull() is LambdaExpression lambda
? Unwrap(RewriteAndVisitLambda(lambda, newSourceWrapper))
: Unwrap(Visit(arguments[i]));
}
Expand Down Expand Up @@ -593,13 +593,6 @@ protected static Expression UnwrapLastNavigation(Expression expression)
? methodCallExpression.Arguments[0]
: null);

protected static LambdaExpression GetLambdaOrNull(Expression expression)
=> expression is LambdaExpression lambda
? lambda
: expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote
? (LambdaExpression)unary.Operand
: null;

protected static Expression Unwrap(Expression expression)
=> expression switch {
EntityReferenceExpression wrapper => wrapper.Underlying,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public virtual async Task FromSqlRaw_queryable_composed()
}
}

[Fact(Skip = "#15763")]
[Fact]
public virtual async Task FromSqlRaw_queryable_multiple_composed()
{
using (var context = CreateContext())
Expand All @@ -101,7 +101,7 @@ from o in context.Set<Order>().FromSqlRaw(NormalizeDelimetersInRawString("SELECT
}
}

[Fact(Skip = "Issue#15763")]
[Fact]
public virtual async Task FromSqlRaw_queryable_multiple_composed_with_closure_parameters()
{
var startDate = new DateTime(1997, 1, 1);
Expand All @@ -125,7 +125,7 @@ from o in context.Set<Order>().FromSqlRaw(
}
}

[Fact(Skip = "Issue#15763")]
[Fact]
public virtual async Task FromSqlRaw_queryable_multiple_composed_with_parameters_and_closure_parameters()
{
var city = "London";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ var actual
}
}

[Fact(Skip = "#15763")]
[Fact]
public virtual void FromSqlRaw_queryable_multiple_composed()
{
using (var context = CreateContext())
Expand All @@ -325,7 +325,7 @@ from o in context.Set<Order>().FromSqlRaw(NormalizeDelimetersInRawString("SELECT
}
}

[Fact(Skip = "Issue#15763")]
[Fact]
public virtual void FromSqlRaw_queryable_multiple_composed_with_closure_parameters()
{
var startDate = new DateTime(1997, 1, 1);
Expand All @@ -351,7 +351,7 @@ from o in context.Set<Order>().FromSqlRaw(
}
}

[Fact(Skip = "Issue#15763")]
[Fact]
public virtual void FromSqlRaw_queryable_multiple_composed_with_parameters_and_closure_parameters()
{
var city = "London";
Expand Down Expand Up @@ -503,7 +503,7 @@ public virtual void FromSqlInterpolated_queryable_with_parameters_inline_interpo
}
}

[Fact(Skip = "Issue#15763")]
[Fact]
public virtual void FromSqlInterpolated_queryable_multiple_composed_with_parameters_and_closure_parameters_interpolated()
{
var city = "London";
Expand Down
Loading