Skip to content

Commit

Permalink
Cosmos: Fixes around array projection
Browse files Browse the repository at this point in the history
Closes #33797
  • Loading branch information
roji committed Jun 22, 2024
1 parent ecd6104 commit 3fe7200
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ public static bool TryConvertToArray(
{
subquery.ApplyProjection();

// TODO: Should the type be an array, or enumerable/queryable?
var arrayClrType = projection.Type.MakeArrayType();
var arrayClrType = typeof(IEnumerable<>).MakeGenericType(projection.Type);

switch (projection)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Cosmos.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
Expand All @@ -21,7 +23,9 @@ private static readonly MethodInfo GetParameterValueMethodInfo
= typeof(CosmosProjectionBindingExpressionVisitor)
.GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue))!;

private readonly CosmosQueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor;
private readonly CosmosSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly ITypeMappingSource _typeMappingSource;
private readonly IModel _model;
private SelectExpression _selectExpression;
private bool _clientEval;
Expand All @@ -39,10 +43,14 @@ private static readonly MethodInfo GetParameterValueMethodInfo
/// </summary>
public CosmosProjectionBindingExpressionVisitor(
IModel model,
CosmosSqlTranslatingExpressionVisitor sqlTranslator)
CosmosQueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor,
CosmosSqlTranslatingExpressionVisitor sqlTranslator,
ITypeMappingSource typeMappingSource)
{
_model = model;
_queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor;
_sqlTranslator = sqlTranslator;
_typeMappingSource = typeMappingSource;
_selectExpression = null!;
}

Expand Down Expand Up @@ -570,6 +578,50 @@ UnaryExpression unaryExpression
lambda);
}
}
else if (method is { Name: nameof(Enumerable.ToList), IsGenericMethod: true }
&& method.DeclaringType == typeof(Enumerable)
&& methodCallExpression.Arguments is [var argument]
&& argument.Type.TryGetElementType(typeof(IQueryable<>)) != null)
{
if (_queryableMethodTranslatingExpressionVisitor.TranslateSubquery(argument) is not ShapedQueryExpression subquery
|| !subquery.TryConvertToArray(_typeMappingSource, out var array))
{
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

// If ToList() was composed over a subquery with operators, the result here is an ArrayExpression (ARRAY(SELECT ...)), whose
// CLR Type is IEnumerable<T>. This can be directly used in the resulting ProjectingBindingExpression - the shaper will
// simply read the JSON results out successfully.
// But if ToList() is composed directly over an array property, that property could have type e.g. T[], which will be read
// in the shaper, and then the cast from T[] to List<T> will fail. As a result, wrap the array in an additional
// "reprojection" subquery, effectively to change the CLR type.
if (array is SqlExpression scalarArray
&& !(array.Type.IsGenericType && array.Type.GetGenericTypeDefinition() == typeof(IEnumerable<>)))
{
Check.DebugAssert(
array is not ScalarArrayExpression and not ObjectArrayExpression, "ArrayExpression should be IEnumerable");

if (scalarArray is not { TypeMapping.ElementTypeMapping: CosmosTypeMapping elementTypeMapping })
{
throw new UnreachableException("Scalar array with no element type mapping");
}

// TODO: Proper alias management (#33894).
var arrayReprojectionSubquery = SelectExpression.CreateForCollection(
array, "i", new ScalarReferenceExpression("i", elementTypeMapping.ClrType, elementTypeMapping));
arrayReprojectionSubquery.ApplyProjection();

array = new ScalarArrayExpression(
arrayReprojectionSubquery,
methodCallExpression.Type, // List<>
_typeMappingSource.FindMapping(methodCallExpression.Type, _model, elementTypeMapping));
}

return new ProjectionBindingExpression(
_selectExpression,
_selectExpression.AddToProjection(array),
methodCallExpression.Type);
}
}

var @object = Visit(methodCallExpression.Object);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Cosmos.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;
using Microsoft.EntityFrameworkCore.Internal;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
Expand Down Expand Up @@ -54,7 +55,7 @@ public CosmosQueryableMethodTranslatingExpressionVisitor(
_methodCallTranslatorProvider,
this);
_projectionBindingExpressionVisitor =
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, _sqlTranslator);
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, this, _sqlTranslator, _typeMappingSource);
_subquery = false;
}

Expand All @@ -81,7 +82,7 @@ protected CosmosQueryableMethodTranslatingExpressionVisitor(
_methodCallTranslatorProvider,
parentVisitor);
_projectionBindingExpressionVisitor =
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, _sqlTranslator);
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, this, _sqlTranslator, _typeMappingSource);
_subquery = true;
}

Expand Down Expand Up @@ -1125,8 +1126,10 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
// ElementAtOrDefault over an array of scalars
case SqlExpression scalarArray when projection is SqlExpression element:
{
var slice = _sqlExpressionFactory.Function(
"ARRAY_SLICE", [scalarArray, translatedCount], scalarArray.Type, scalarArray.TypeMapping);
var arrayType = typeof(IEnumerable<>).MakeGenericType(projection.Type);
var arrayTypeMapping = _typeMappingSource.FindMapping(arrayType, _queryCompilationContext.Model, element.TypeMapping);

var slice = _sqlExpressionFactory.Function("ARRAY_SLICE", [scalarArray, translatedCount], arrayType, arrayTypeMapping);

// TODO: Proper alias management (#33894). Ideally reach into the source of the original SelectExpression and use that alias.
var translatedSelect = SelectExpression.CreateForCollection(
Expand All @@ -1139,8 +1142,10 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
// ElementAtOrDefault over an array os structural types
case not null when projectedStructuralTypeShaper is not null:
{
var arrayType = typeof(IEnumerable<>).MakeGenericType(projectedStructuralTypeShaper.Type);

// TODO: Proper alias management (#33894).
var slice = new ObjectFunctionExpression("ARRAY_SLICE", [array, translatedCount], projectedStructuralTypeShaper.Type);
var slice = new ObjectFunctionExpression("ARRAY_SLICE", [array, translatedCount], arrayType);
var translatedSelect = SelectExpression.CreateForCollection(
slice,
"i",
Expand Down Expand Up @@ -1585,7 +1590,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
// value conversion). #34026.
var elementClrType = inlineQueryRootExpression.ElementType;
var elementTypeMapping = _typeMappingSource.FindMapping(elementClrType)!;
var arrayTypeMapping = _typeMappingSource.FindMapping(elementClrType.MakeArrayType()); // TODO: IEnumerable?
var arrayTypeMapping = _typeMappingSource.FindMapping(typeof(IEnumerable<>).MakeGenericType(elementClrType));
var inlineArray = new ArrayConstantExpression(elementClrType, translatedItems, arrayTypeMapping);

// TODO: Do proper alias management: #33894
Expand Down Expand Up @@ -1614,7 +1619,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
// TODO: Temporary hack - need to perform proper derivation of the array type mapping from the element (e.g. for
// value conversion). #34026.
var elementClrType = parameterQueryRootExpression.ElementType;
var arrayTypeMapping = _typeMappingSource.FindMapping(elementClrType.MakeArrayType()); // TODO: IEnumerable?
var arrayTypeMapping = _typeMappingSource.FindMapping(typeof(IEnumerable<>).MakeGenericType(elementClrType));
var elementTypeMapping = _typeMappingSource.FindMapping(elementClrType)!;
var sqlParameterExpression = new SqlParameterExpression(parameterQueryRootExpression.ParameterExpression, arrayTypeMapping);

Expand Down Expand Up @@ -1683,13 +1688,17 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
&& source2.TryConvertToArray(_typeMappingSource, out var array2, out var projection2, ignoreOrderings)
&& projection1.Type == projection2.Type)
{
var arrayType = typeof(IEnumerable<>).MakeGenericType(projection1.Type);

// Set operation over arrays of scalars
if (projection1 is SqlExpression sqlProjection1
&& projection2 is SqlExpression sqlProjection2
&& (sqlProjection1.TypeMapping ?? sqlProjection2.TypeMapping) is CoreTypeMapping typeMapping)
&& (sqlProjection1.TypeMapping ?? sqlProjection2.TypeMapping) is CosmosTypeMapping typeMapping)
{
var arrayTypeMapping = _typeMappingSource.FindMapping(arrayType, _queryCompilationContext.Model, typeMapping);

// TODO: Proper alias management (#33894).
var translation = _sqlExpressionFactory.Function(functionName, [array1, array2], projection1.Type, typeMapping);
var translation = _sqlExpressionFactory.Function(functionName, [array1, array2], arrayType, arrayTypeMapping);
var select = SelectExpression.CreateForCollection(
translation, "i", new ScalarReferenceExpression("i", projection1.Type, typeMapping));
return source1.UpdateQueryExpression(select);
Expand All @@ -1701,7 +1710,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
&& structuralType1 == structuralType2)
{
// TODO: Proper alias management (#33894).
var translation = new ObjectFunctionExpression(functionName, [array1, array2], projection1.Type);
var translation = new ObjectFunctionExpression(functionName, [array1, array2], arrayType);
var select = SelectExpression.CreateForCollection(
translation, "i", new ObjectReferenceExpression((IEntityType)structuralType1, "i"));
return CreateShapedQueryExpression(select, structuralType1.ClrType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </remarks>
[DebuggerDisplay("{Microsoft.EntityFrameworkCore.Query.ExpressionPrinter.Print(this), nq}")]
public class ScalarAccessExpression(Expression @object, string propertyName, Type clrType, CoreTypeMapping? typeMapping)
: SqlExpression(clrType, typeMapping), IAccessExpression
{
Expand Down
20 changes: 1 addition & 19 deletions src/EFCore.Cosmos/Query/Internal/Expressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -322,27 +322,9 @@ public virtual void ReplaceProjectionMapping(IDictionary<ProjectionMember, Expre
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual int AddToProjection(SqlExpression sqlExpression)
public virtual int AddToProjection(Expression sqlExpression)
=> AddToProjection(sqlExpression, null);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual int AddToProjection(EntityProjectionExpression entityProjection)
=> AddToProjection(entityProjection, null);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual int AddToProjection(ObjectArrayAccessExpression objectArrayAccess)
=> AddToProjection(objectArrayAccess, null);

private int AddToProjection(Expression expression, string? alias)
{
var existingIndex = _projection.FindIndex(pe => pe.Expression.Equals(expression));
Expand Down
4 changes: 2 additions & 2 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
// TODO: This infers based on the CLR type; need to properly infer based on the element type mapping
// TODO: being applied here (e.g. WHERE @p[1] = c.PropertyWithValueConverter). #34026
var arrayTypeMapping = left.TypeMapping
?? (typeMapping is null ? null : typeMappingSource.FindMapping(typeMapping.ClrType.MakeArrayType()));
?? (typeMapping is null ? null : typeMappingSource.FindMapping(typeof(IEnumerable<>).MakeGenericType(typeMapping.ClrType)));
return new SqlBinaryExpression(
ExpressionType.ArrayIndex,
ApplyTypeMapping(left, arrayTypeMapping),
Expand Down Expand Up @@ -290,7 +290,7 @@ private InExpression ApplyTypeMappingOnIn(InExpression inExpression)
var arrayClrType = arrayExpression.Type switch
{
var t when t.TryGetSequenceType() != typeof(object) => t,
{ IsArray: true } => itemExpression.Type.MakeArrayType(),
{ IsArray: true } => typeof(IEnumerable<>).MakeGenericType(itemExpression.Type),
{ IsConstructedGenericType: true, GenericTypeArguments.Length: 1 } t
=> t.GetGenericTypeDefinition().MakeGenericType(itemExpression.Type),
_ => throw new InvalidOperationException(
Expand Down
20 changes: 10 additions & 10 deletions test/EFCore.Cosmos.FunctionalTests/Query/OwnedQueryCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ public override Task Where_owned_collection_navigation_ToList_Count(bool async)
async, async a =>
{
// TODO: #34011
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets persisted
// as null instead of [].
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets
// persisted as null instead of [] when there are no Details. So we change the Count we check to 1.
await AssertQuery(
a,
ss => ss.Set<OwnedPerson>()
Expand All @@ -323,7 +323,7 @@ await AssertQuery(
AssertSql(
"""
SELECT a
SELECT a["Details"]
FROM root c
JOIN a IN c["Orders"]
WHERE (c["Discriminator"] IN ("OwnedPerson", "Branch", "LeafB", "LeafA") AND (ARRAY_LENGTH(a["Details"]) = 1))
Expand All @@ -336,8 +336,8 @@ public override Task Where_collection_navigation_ToArray_Count(bool async)
async, async a =>
{
// TODO: #34011
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets persisted
// as null instead of [].
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets
// persisted as null instead of [] when there are no Details. So we change the Count we check to 1.
await AssertQuery(
a,
ss => ss.Set<OwnedPerson>()
Expand All @@ -363,8 +363,8 @@ public override Task Where_collection_navigation_AsEnumerable_Count(bool async)
async, async a =>
{
// TODO: #34011
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets persisted
// as null instead of [].
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets
// persisted as null instead of [] when there are no Details. So we change the Count we check to 1.
await AssertQuery(
a,
ss => ss.Set<OwnedPerson>()
Expand All @@ -390,8 +390,8 @@ public override Task Where_collection_navigation_ToList_Count_member(bool async)
async, async a =>
{
// TODO: #34011
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets persisted
// as null instead of [].
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets
// persisted as null instead of [] when there are no Details. So we change the Count we check to 1.
await AssertQuery(
a,
ss => ss.Set<OwnedPerson>()
Expand All @@ -404,7 +404,7 @@ await AssertQuery(
AssertSql(
"""
SELECT a
SELECT a["Details"]
FROM root c
JOIN a IN c["Orders"]
WHERE (c["Discriminator"] IN ("OwnedPerson", "Branch", "LeafB", "LeafA") AND (ARRAY_LENGTH(a["Details"]) = 1))
Expand Down
Loading

0 comments on commit 3fe7200

Please sign in to comment.