Skip to content

Commit

Permalink
Array translation improvements
Browse files Browse the repository at this point in the history
And redo array/list tests
  • Loading branch information
roji committed Oct 8, 2021
1 parent 92e5412 commit 8370f16
Show file tree
Hide file tree
Showing 14 changed files with 2,203 additions and 1,343 deletions.
1 change: 1 addition & 0 deletions EFCore.PG.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
<s:Boolean x:Key="/Default/UserDictionary/Words/=datetimeoffset/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=doesnt/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=fallbacks/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=ilike/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=initializers/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=keyless/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=materializer/@EntryIndexedValue">True</s:Boolean>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,57 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte
/// </remarks>
public class NpgsqlArrayTranslator : IMethodCallTranslator, IMemberTranslator
{
private static readonly MethodInfo SequenceEqual =
#region Methods

private static readonly MethodInfo Array_IndexOf1 =
typeof(Array).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Array.IndexOf) && m.IsGenericMethod && m.GetParameters().Length == 2);

private static readonly MethodInfo Array_IndexOf2 =
typeof(Array).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Array.IndexOf) && m.IsGenericMethod && m.GetParameters().Length == 3);

private static readonly MethodInfo Enumerable_Append =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.SequenceEqual) && m.GetParameters().Length == 2);
.Single(m => m.Name == nameof(Enumerable.Append) && m.GetParameters().Length == 2);

private static readonly MethodInfo Enumerable_AnyWithoutPredicate =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 1);

private static readonly MethodInfo EnumerableContains =
private static readonly MethodInfo Enumerable_Concat =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.Concat) && m.GetParameters().Length == 2);

private static readonly MethodInfo Enumerable_Contains =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.Contains) && m.GetParameters().Length == 2);

private static readonly MethodInfo EnumerableAnyWithoutPredicate =
private static readonly MethodInfo Enumerable_SequenceEqual =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 1);
.Single(m => m.Name == nameof(Enumerable.SequenceEqual) && m.GetParameters().Length == 2);

private static readonly MethodInfo String_Join1 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(string), typeof(object[]) })!;

private static readonly MethodInfo String_Join2 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(string), typeof(string[]) })!;

private static readonly MethodInfo String_Join3 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(char), typeof(object[]) })!;

private static readonly MethodInfo String_Join4 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(char), typeof(string[]) })!;

private static readonly MethodInfo String_Join_generic1 =
typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(string.Join) && m.IsGenericMethod && m.GetParameters().Length == 2 && m.GetParameters()[0].ParameterType == typeof(string));

private static readonly MethodInfo String_Join_generic2 =
typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(string.Join) && m.IsGenericMethod && m.GetParameters().Length == 2 && m.GetParameters()[0].ParameterType == typeof(char));

#endregion Methods

private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
private readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator;
Expand Down Expand Up @@ -72,14 +112,31 @@ public NpgsqlArrayTranslator(
if (instance is null && arguments.Count > 0 && arguments[0].Type.IsArrayOrGenericList() && !IsMappedToNonArray(arguments[0]))
{
// Extension method over an array or list
if (method.IsClosedFormOf(SequenceEqual) && arguments[1].Type.IsArray)
if (method.IsClosedFormOf(Enumerable_SequenceEqual) && arguments[1].Type.IsArray)
{
return _sqlExpressionFactory.Equal(arguments[0], arguments[1]);
}

return TranslateCommon(arguments[0], arguments.Slice(1));
}

if (method.DeclaringType == typeof(string)
&& (method == String_Join1
|| method == String_Join2
|| method == String_Join3
|| method == String_Join4
|| method.IsClosedFormOf(String_Join_generic1)
|| method.IsClosedFormOf(String_Join_generic2))
&& !IsMappedToNonArray(arguments[0]))
{
return _sqlExpressionFactory.Function(
"array_to_string",
new[] { arguments[1], arguments[0], _sqlExpressionFactory.Constant("") },
nullable: true,
argumentsPropagateNullability: TrueArrays[3],
typeof(string));
}

// Not an array/list
return null;

Expand All @@ -92,7 +149,7 @@ static bool IsMappedToNonArray(SqlExpression arrayOrList)
SqlExpression? TranslateCommon(SqlExpression arrayOrList, IReadOnlyList<SqlExpression> arguments)
{
// Predicate-less Any - translate to a simple length check.
if (method.IsClosedFormOf(EnumerableAnyWithoutPredicate))
if (method.IsClosedFormOf(Enumerable_AnyWithoutPredicate))
{
return _sqlExpressionFactory.GreaterThan(
_jsonPocoTranslator.TranslateArrayLength(arrayOrList)
Expand All @@ -109,7 +166,7 @@ static bool IsMappedToNonArray(SqlExpression arrayOrList)
// is pattern-matched in AllAnyToContainsRewritingExpressionVisitor, which transforms it to
// new[] { "a", "b", "c" }.Contains(e.Some Text).

if ((method.IsClosedFormOf(EnumerableContains)
if ((method.IsClosedFormOf(Enumerable_Contains)
||
method.Name == nameof(List<int>.Contains)
&& method.DeclaringType.IsGenericList()
Expand Down Expand Up @@ -176,6 +233,77 @@ arrayOrList.TypeMapping is NpgsqlArrayTypeMapping or null
// Note: we also translate .Where(e => new[] { "a", "b", "c" }.Any(p => EF.Functions.Like(e.SomeText, p)))
// to LIKE ANY (...). See NpgsqlSqlTranslatingExpressionVisitor.VisitArrayMethodCall.

if (method.IsClosedFormOf(Enumerable_Append))
{
var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList);

return _sqlExpressionFactory.Function(
"array_append",
new[] { array, item },
nullable: true,
TrueArrays[2],
arrayOrList.Type,
arrayOrList.TypeMapping);
}

if (method.IsClosedFormOf(Enumerable_Concat))
{
var inferredMapping = ExpressionExtensions.InferTypeMapping(arrayOrList, arguments[0]);

return _sqlExpressionFactory.Function(
"array_cat",
new[]
{
_sqlExpressionFactory.ApplyTypeMapping(arrayOrList, inferredMapping),
_sqlExpressionFactory.ApplyTypeMapping(arguments[0], inferredMapping)
},
nullable: true,
TrueArrays[2],
arrayOrList.Type,
inferredMapping);
}

if (method.IsClosedFormOf(Array_IndexOf1)
||
method.Name == nameof(List<int>.IndexOf)
&& method.DeclaringType.IsGenericList()
&& method.GetParameters().Length == 1)
{
var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList);

return _sqlExpressionFactory.Coalesce(
_sqlExpressionFactory.Subtract(
_sqlExpressionFactory.Function(
"array_position",
new[] { array, item },
nullable: true,
TrueArrays[2],
arrayOrList.Type),
_sqlExpressionFactory.Constant(1)),
_sqlExpressionFactory.Constant(-1));
}

if (method.IsClosedFormOf(Array_IndexOf2)
||
method.Name == nameof(List<int>.IndexOf)
&& method.DeclaringType.IsGenericList()
&& method.GetParameters().Length == 2)
{
var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList);
var startIndex = _sqlExpressionFactory.GenerateOneBasedIndexExpression(arguments[1]);

return _sqlExpressionFactory.Coalesce(
_sqlExpressionFactory.Subtract(
_sqlExpressionFactory.Function(
"array_position",
new[] { array, item, startIndex },
nullable: true,
TrueArrays[3],
arrayOrList.Type),
_sqlExpressionFactory.Constant(1)),
_sqlExpressionFactory.Constant(-1));
}

return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class NpgsqlStringMethodTranslator : IMethodCallTranslator
private static readonly MethodInfo Contains = typeof(string).GetRuntimeMethod(nameof(string.Contains), new[] { typeof(string) })!;
private static readonly MethodInfo DbFunctionsReverse = typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod(nameof(NpgsqlDbFunctionsExtensions.Reverse), new[] { typeof(DbFunctions), typeof(string) })!;
private static readonly MethodInfo EndsWith = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), new[] { typeof(string) })!;
private static readonly MethodInfo Indexer = typeof(string).GetRuntimeMethod("get_Item", new[] { typeof(int) })!;
private static readonly MethodInfo IndexOfChar = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), new[] { typeof(char) })!;
private static readonly MethodInfo IndexOfString = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), new[] { typeof(string) })!;
private static readonly MethodInfo IsNullOrWhiteSpace = typeof(string).GetRuntimeMethod(nameof(string.IsNullOrWhiteSpace), new[] { typeof(string) })!;
Expand Down Expand Up @@ -87,6 +88,11 @@ public NpgsqlStringMethodTranslator(
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method == Indexer)
{

}

if (method == IndexOfString || method == IndexOfChar)
{
var argument = arguments[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
// Try translating ArrayIndex inside json column
_jsonPocoTranslator.TranslateMemberAccess(sqlLeft!, sqlRight!, binaryExpression.Type) ??
// Other types should be subscriptable - but PostgreSQL arrays are 1-based, so adjust the index.
_sqlExpressionFactory.ArrayIndex(sqlLeft!, GenerateOneBasedIndexExpression(sqlRight!));
_sqlExpressionFactory.ArrayIndex(sqlLeft!, _sqlExpressionFactory.GenerateOneBasedIndexExpression(sqlRight!));
}

return base.VisitBinary(binaryExpression);
Expand Down Expand Up @@ -509,15 +509,6 @@ bool TryTranslateArguments(out SqlExpression[] sqlArguments)
}
}

/// <summary>
/// PostgreSQL array indexing is 1-based. If the index happens to be a constant,
/// just increment it. Otherwise, append a +1 in the SQL.
/// </summary>
private SqlExpression GenerateOneBasedIndexExpression(SqlExpression expression)
=> expression is SqlConstantExpression constant
? _sqlExpressionFactory.Constant(Convert.ToInt32(constant.Value) + 1, constant.TypeMapping)
: _sqlExpressionFactory.Add(expression, _sqlExpressionFactory.Constant(1));

#region Copied from RelationalSqlTranslatingExpressionVisitor

private static Expression TryRemoveImplicitConvert(Expression expression)
Expand Down
24 changes: 20 additions & 4 deletions src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ private SqlExpression ApplyTypeMappingOnAll(PostgresAllExpression postgresAllExp
return new PostgresAllExpression(item, array, postgresAllExpression.OperatorType, _boolTypeMapping);
}

private (SqlExpression, SqlExpression) ApplyTypeMappingsOnItemAndArray(SqlExpression itemExpression, SqlExpression arrayExpression)
public (SqlExpression, SqlExpression) ApplyTypeMappingsOnItemAndArray(SqlExpression itemExpression, SqlExpression arrayExpression)
{
// Attempt type inference either from the operand to the array or the other way around
var arrayMapping = (NpgsqlArrayTypeMapping?)arrayExpression.TypeMapping;
Expand Down Expand Up @@ -464,16 +464,23 @@ private SqlExpression ApplyTypeMappingOnAll(PostgresAllExpression postgresAllExp
private SqlExpression ApplyTypeMappingOnArrayIndex(
PostgresArrayIndexExpression postgresArrayIndexExpression,
RelationalTypeMapping? typeMapping)
=> new PostgresArrayIndexExpression(
// TODO: Infer the array's mapping from the element
ApplyDefaultTypeMapping(postgresArrayIndexExpression.Array),
{
// If a (non-null) type mapping is being applied, it's to the element being indexed.
// Infer the array's mapping from that.
var (_, array) = typeMapping is not null
? ApplyTypeMappingsOnItemAndArray(Constant(null, typeMapping), postgresArrayIndexExpression.Array)
: (null, ApplyDefaultTypeMapping(postgresArrayIndexExpression.Array));

return new PostgresArrayIndexExpression(
array,
ApplyDefaultTypeMapping(postgresArrayIndexExpression.Index),
postgresArrayIndexExpression.Type,
// If the array has a type mapping (i.e. column), prefer that just like we prefer column mappings in general
postgresArrayIndexExpression.Array.TypeMapping is NpgsqlArrayTypeMapping arrayMapping
? arrayMapping.ElementMapping
: typeMapping
?? (RelationalTypeMapping?)_typeMappingSource.FindMapping(postgresArrayIndexExpression.Type, Dependencies.Model));
}

private SqlExpression ApplyTypeMappingOnILike(PostgresILikeExpression ilikeExpression)
{
Expand Down Expand Up @@ -749,5 +756,14 @@ private SqlExpression ApplyTypeMappingOnPostgresNewArray(
newExpressions ?? postgresNewArrayExpression.Expressions,
postgresNewArrayExpression.Type, arrayTypeMapping);
}

/// <summary>
/// PostgreSQL array indexing is 1-based. If the index happens to be a constant,
/// just increment it. Otherwise, append a +1 in the SQL.
/// </summary>
public SqlExpression GenerateOneBasedIndexExpression(SqlExpression expression)
=> expression is SqlConstantExpression constant
? Constant(System.Convert.ToInt32(constant.Value) + 1, constant.TypeMapping)
: Add(expression, Constant(1));
}
}
Loading

0 comments on commit 8370f16

Please sign in to comment.