Skip to content

Commit

Permalink
Query: Removing Queryable conversion for sources which are not queryable
Browse files Browse the repository at this point in the history
Our EnumerableToQueryable converter works on all the methods. But during nav expansion, we actually inject query sources.
If item is not a query source (like postgre arrays or collection property mapped as scalar via value conversion), then we convert it to enumerable again to be handled by provider.

Resolves #17374
  • Loading branch information
smitpatel committed Oct 30, 2019
1 parent 0929120 commit 0aeae4e
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,44 +35,44 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

var enumerableMethod = methodCallExpression.Method;
var enumerableParameters = enumerableMethod.GetParameters();
Type[] genericArguments = null;
Type[] genericTypeArguments = null;
if (enumerableMethod.Name == nameof(Enumerable.Min)
|| enumerableMethod.Name == nameof(Enumerable.Max))
{
genericArguments = new Type[methodCallExpression.Arguments.Count];
genericTypeArguments = new Type[methodCallExpression.Arguments.Count];

if (!enumerableMethod.IsGenericMethod)
{
genericArguments[0] = enumerableMethod.ReturnType;
genericTypeArguments[0] = enumerableMethod.ReturnType;
}
else
{
var argumentTypes = enumerableMethod.GetGenericArguments();
if (argumentTypes.Length == genericArguments.Length)
if (argumentTypes.Length == genericTypeArguments.Length)
{
genericArguments = argumentTypes;
genericTypeArguments = argumentTypes;
}
else
{
genericArguments[0] = argumentTypes[0];
genericArguments[1] = enumerableMethod.ReturnType;
genericTypeArguments[0] = argumentTypes[0];
genericTypeArguments[1] = enumerableMethod.ReturnType;
}
}
}
else if (enumerableMethod.IsGenericMethod)
{
genericArguments = enumerableMethod.GetGenericArguments();
genericTypeArguments = enumerableMethod.GetGenericArguments();
}

foreach (var method in typeof(Queryable).GetTypeInfo().GetDeclaredMethods(methodCallExpression.Method.Name))
{
var queryableMethod = method;
if (queryableMethod.IsGenericMethod)
{
if (genericArguments != null
&& queryableMethod.GetGenericArguments().Length == genericArguments.Length)
if (genericTypeArguments != null
&& queryableMethod.GetGenericArguments().Length == genericTypeArguments.Length)
{
queryableMethod = queryableMethod.MakeGenericMethod(genericArguments);
queryableMethod = queryableMethod.MakeGenericMethod(genericTypeArguments);
}
else
{
Expand Down
175 changes: 156 additions & 19 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,10 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
|| method.DeclaringType == typeof(QueryableExtensions)
|| method.DeclaringType == typeof(EntityFrameworkQueryableExtensions))
{
var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null;
var firstArgument = Visit(methodCallExpression.Arguments[0]);
if (firstArgument is NavigationExpansionExpression source)
{
var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null;

if (source.PendingOrderings.Any()
&& genericMethod != QueryableMethods.ThenBy
&& genericMethod != QueryableMethods.ThenByDescending)
Expand Down Expand Up @@ -559,27 +558,43 @@ when QueryableMethods.IsSumWithSelector(method):
throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name));
}
}
else if (firstArgument is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression
&& method.Name == nameof(Queryable.AsQueryable))

if (genericMethod == QueryableMethods.AsQueryable)
{
var subquery = materializeCollectionNavigationExpression.Subquery;
return subquery is OwnedNavigationReference ownedNavigationReference
&& ownedNavigationReference.Navigation.IsCollection()
? Visit(
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(subquery.Type.TryGetSequenceType()),
subquery))
: subquery;
if (firstArgument is MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression)
{
var subquery = materializeCollectionNavigationExpression.Subquery;

return subquery is OwnedNavigationReference innerOwnedNavigationReference
&& innerOwnedNavigationReference.Navigation.IsCollection()
? Visit(
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(subquery.Type.TryGetSequenceType()),
subquery))
: subquery;
}

if (firstArgument is OwnedNavigationReference ownedNavigationReference
&& ownedNavigationReference.Navigation.IsCollection())
{
var parameterName = GetParameterName("o");
var entityReference = ownedNavigationReference.EntityReference;
var currentTree = new NavigationTreeExpression(entityReference);

return new NavigationExpansionExpression(methodCallExpression, currentTree, currentTree, parameterName);
}

return firstArgument;
}
else if (firstArgument is OwnedNavigationReference ownedNavigationReference
&& ownedNavigationReference.Navigation.IsCollection()
&& method.Name == nameof(Queryable.AsQueryable))

if (firstArgument.Type.TryGetElementType(typeof(IQueryable<>)) == null)
{
var parameterName = GetParameterName("o");
var entityReference = ownedNavigationReference.EntityReference;
var currentTree = new NavigationTreeExpression(entityReference);
// firstArgument was not an queryable
var visitedArguments = new[] { firstArgument }
.Concat(methodCallExpression.Arguments.Skip(1).Select(Visit))
.ToList();

return new NavigationExpansionExpression(methodCallExpression, currentTree, currentTree, parameterName);
return ConvertToEnumerable(method, visitedArguments);
}

throw new InvalidOperationException(CoreStrings.QueryFailed(methodCallExpression.Print(), GetType().Name));
Expand Down Expand Up @@ -618,6 +633,128 @@ when QueryableMethods.IsSumWithSelector(method):
return ProcessUnknownMethod(methodCallExpression);
}

private MethodCallExpression ConvertToEnumerable(MethodInfo queryableMethod, List<Expression> arguments)
{
var genericTypeArguments = queryableMethod.IsGenericMethod
? queryableMethod.GetGenericArguments()
: null;
var enumerableArguments = arguments.Select(
arg => arg is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Quote
&& unaryExpression.Operand is LambdaExpression
? unaryExpression.Operand
: arg)
.ToList();

if (queryableMethod.Name == nameof(Enumerable.Min))
{
if (genericTypeArguments.Length == 1)
{
var resultType = genericTypeArguments[0];
var enumerableMethod = EnumerableMethods.GetMinWithoutSelector(resultType);

if (!IsNumericType(resultType))
{
enumerableMethod = enumerableMethod.MakeGenericMethod(resultType);
}

return Expression.Call(enumerableMethod, enumerableArguments);
}

if (genericTypeArguments.Length == 2)
{
var resultType = genericTypeArguments[1];
var enumerableMethod = EnumerableMethods.GetMinWithSelector(resultType);

enumerableMethod = IsNumericType(resultType)
? enumerableMethod.MakeGenericMethod(resultType)
: enumerableMethod.MakeGenericMethod(genericTypeArguments);

return Expression.Call(enumerableMethod, enumerableArguments);
}
}

if (queryableMethod.Name == nameof(Enumerable.Max))
{
if (genericTypeArguments.Length == 1)
{
var resultType = genericTypeArguments[0];
var enumerableMethod = EnumerableMethods.GetMaxWithoutSelector(resultType);

if (!IsNumericType(resultType))
{
enumerableMethod = enumerableMethod.MakeGenericMethod(resultType);
}

return Expression.Call(enumerableMethod, enumerableArguments);
}

if (genericTypeArguments.Length == 2)
{
var resultType = genericTypeArguments[1];
var enumerableMethod = EnumerableMethods.GetMaxWithSelector(resultType);

enumerableMethod = IsNumericType(resultType)
? enumerableMethod.MakeGenericMethod(resultType)
: enumerableMethod.MakeGenericMethod(genericTypeArguments);

return Expression.Call(enumerableMethod, enumerableArguments);
}
}


foreach (var method in typeof(Enumerable).GetTypeInfo().GetDeclaredMethods(queryableMethod.Name))
{
var enumerableMethod = method;
if (enumerableMethod.IsGenericMethod)
{
if (genericTypeArguments != null
&& enumerableMethod.GetGenericArguments().Length == genericTypeArguments.Length)
{
enumerableMethod = enumerableMethod.MakeGenericMethod(genericTypeArguments);
}
else
{
continue;
}
}

var enumerableMethodParameters = enumerableMethod.GetParameters();
if (enumerableMethodParameters.Length != enumerableArguments.Count)
{
continue;
}

var validMapping = true;
for (var i = 0; i < enumerableMethodParameters.Length; i++)
{
if (!enumerableMethodParameters[i].ParameterType.IsAssignableFrom(enumerableArguments[i].Type))
{
validMapping = false;
break;
}
}

if (validMapping)
{
return Expression.Call(enumerableMethod, enumerableArguments);
}
}

throw new InvalidOperationException("Unable to convert queryable method to enumerable method.");

static bool IsNumericType(Type type)
{
type = type.UnwrapNullableType();

return type == typeof(int)
|| type == typeof(long)
|| type == typeof(float)
|| type == typeof(double)
|| type == typeof(decimal);
}
}

private Expression ProcessDefaultIfEmpty(NavigationExpansionExpression source)
{
source.UpdateSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1301,9 +1301,9 @@ FROM root c
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task List_Contains_with_parameter_HashSet(bool isAsync)
public override async Task HashSet_Contains_with_parameter(bool isAsync)
{
await base.List_Contains_with_parameter_HashSet(isAsync);
await base.HashSet_Contains_with_parameter(isAsync);

AssertSql(
@"SELECT c
Expand All @@ -1312,9 +1312,9 @@ FROM root c
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task List_Contains_with_parameter_ImmutableHashSet(bool isAsync)
public override async Task ImmutableHashSet_Contains_with_parameter(bool isAsync)
{
await base.List_Contains_with_parameter_ImmutableHashSet(isAsync);
await base.ImmutableHashSet_Contains_with_parameter(isAsync);

AssertSql(
@"SELECT c
Expand Down Expand Up @@ -1344,16 +1344,11 @@ FROM root c
WHERE ((c[""Discriminator""] = ""OrderDetail"") AND ((c[""OrderID""] = 10248) AND (c[""ProductID""] = 42)))");
}

public override void Paging_operation_on_string_doesnt_issue_warning()
public override async Task String_FirstOrDefault_in_projection_does_client_eval(bool isAsync)
{
base.Paging_operation_on_string_doesnt_issue_warning();
await base.String_FirstOrDefault_in_projection_does_client_eval(isAsync);

Assert.DoesNotContain(
#pragma warning disable CS0612 // Type or member is obsolete
CoreResources.LogFirstWithoutOrderByAndFilter(new TestLogger<TestLoggingDefinitions>()).GenerateMessage(
#pragma warning restore CS0612 // Type or member is obsolete
"(from char <generated>_1 in [c].CustomerID select [<generated>_1]).FirstOrDefault()"),
Fixture.TestSqlLoggerFactory.Log.Select(l => l.Message));
AssertSql(" ");
}

[ConditionalTheory(Skip = "Issue #17246")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ public override void Value_conversion_with_property_named_value()
{
}

[ConditionalFact(Skip = "Issue#17050")]
public override void Collection_property_as_scalar()
{
base.Collection_property_as_scalar();
}

public class CustomConvertersInMemoryFixture : CustomConvertersFixtureBase
{
public override bool StrictEquality => true;
Expand Down
42 changes: 41 additions & 1 deletion test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,23 @@ protected class ValueWrapper
public string Value { get; set; }
}

[ConditionalFact]
public virtual void Collection_property_as_scalar()
{
using var context = CreateContext();
Assert.Equal(
@"The LINQ expression 'DbSet<CollectionScalar> .Where(c => c.Tags .Any())' could not be translated. Either rewrite the query in a form that can be translated, or switch to client evaluation explicitly by inserting a call to either AsEnumerable(), AsAsyncEnumerable(), ToList(), or ToListAsync(). See https://go.microsoft.com/fwlink/?linkid=2101038 for more information.",
Assert.Throws<InvalidOperationException>(
() => context.Set<CollectionScalar>().Where(e => e.Tags.Any()).ToList())
.Message.Replace("\r","").Replace("\n",""));
}

protected class CollectionScalar
{
public int Id { get; set; }
public List<string> Tags { get; set; }
}

public abstract class CustomConvertersFixtureBase : BuiltInDataTypesFixtureBase
{
protected override string StoreName { get; } = "CustomConverters";
Expand Down Expand Up @@ -954,9 +971,32 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
);
e.HasData(new EntityWithValueWrapper { Id = 1, Wrapper = new ValueWrapper { Value = "foo" } });
});

modelBuilder.Entity<CollectionScalar>(
b =>
{
b.Property(e => e.Tags).HasConversion(
c => string.Join(",", c),
s => s.Split(',', StringSplitOptions.None).ToList()).Metadata
.SetValueComparer(new ListOfStringComparer());

b.HasData(new CollectionScalar
{
Id = 1,
Tags = new List<string> { "A", "B", "C" }
});
});
}

private class ListOfStringComparer : ValueComparer<List<string>>
{
public ListOfStringComparer()
: base(favorStructuralComparisons: true)
{
}
}

public static class StringToDictionarySerializer
private static class StringToDictionarySerializer
{
public static string Serialize(IDictionary<string, string> dictionary)
{
Expand Down
Loading

0 comments on commit 0aeae4e

Please sign in to comment.