Skip to content

Commit

Permalink
Translate List<T> operations to PG array
Browse files Browse the repository at this point in the history
* Match our List<T> translation capabilities to CLR array.
* Improve some mapping and inference aspects.

Closes #395
  • Loading branch information
roji committed Jan 10, 2020
1 parent 0dcdaa7 commit 6759e00
Show file tree
Hide file tree
Showing 15 changed files with 767 additions and 319 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,20 @@ RelationalTypeMapping FindArrayMapping(in RelationalTypeMappingInfo mappingInfo)
var elementStoreType = storeType.Substring(0, storeType.Length - 2);
var elementMapping = FindExistingMapping(new RelationalTypeMappingInfo(elementStoreType, elementStoreType,
mappingInfo.IsUnicode, mappingInfo.Size, mappingInfo.Precision, mappingInfo.Scale));

if (elementMapping != null)
return StoreTypeMappings.GetOrAdd(storeType,
new RelationalTypeMapping[] { new NpgsqlArrayTypeMapping(storeType, elementMapping) })[0];
{
var added = StoreTypeMappings.TryAdd(storeType,
new RelationalTypeMapping[]
{
new NpgsqlArrayArrayTypeMapping(storeType, elementMapping),
new NpgsqlArrayListTypeMapping(storeType, elementMapping)
});
Debug.Assert(added);
var mapping = FindExistingMapping(mappingInfo);
Debug.Assert(mapping != null);
return mapping;
}
}

var clrType = mappingInfo.ClrType;
Expand All @@ -178,7 +189,7 @@ RelationalTypeMapping FindArrayMapping(in RelationalTypeMappingInfo mappingInfo)
if (elementMapping is NpgsqlArrayTypeMapping)
return null;

return ClrTypeMappings.GetOrAdd(clrType, new NpgsqlArrayTypeMapping(elementMapping, clrType));
return ClrTypeMappings.GetOrAdd(clrType, new NpgsqlArrayArrayTypeMapping(elementMapping, clrType));
}

if (clrType.IsGenericType && clrType.GetGenericTypeDefinition() == typeof(List<>))
Expand All @@ -194,7 +205,7 @@ RelationalTypeMapping FindArrayMapping(in RelationalTypeMappingInfo mappingInfo)
if (elementMapping is NpgsqlArrayTypeMapping)
return null;

return ClrTypeMappings.GetOrAdd(clrType, new NpgsqlListTypeMapping(elementMapping, clrType));
return ClrTypeMappings.GetOrAdd(clrType, new NpgsqlArrayListTypeMapping(elementMapping, clrType));
}

return null;
Expand Down
15 changes: 15 additions & 0 deletions src/EFCore.PG/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;

// ReSharper disable once CheckNamespace
namespace Npgsql.EntityFrameworkCore.PostgreSQL
{
internal static class TypeExtensions
{
internal static bool IsGenericList(this Type type)
=> type.IsGenericType && type.GetGenericTypeDefinition() == typeof(List<>);

internal static bool IsArrayOrGenericList(this Type type)
=> type.IsArray || type.IsGenericList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal
{
/// <summary>
/// Translates functions on arrays into their corresponding PostgreSQL operations.
/// Translates method and property calls on arrays/lists into their corresponding PostgreSQL operations.
/// </summary>
/// <remarks>
/// https://www.postgresql.org/docs/current/static/functions-array.html
/// </remarks>
public class NpgsqlArrayMethodTranslator : IMethodCallTranslator
public class NpgsqlArrayTranslator : IMethodCallTranslator, IMemberTranslator
{
[NotNull] static readonly MethodInfo SequenceEqual =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
Expand All @@ -38,7 +38,7 @@ public class NpgsqlArrayMethodTranslator : IMethodCallTranslator
[NotNull]
readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator;

public NpgsqlArrayMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory, NpgsqlJsonPocoTranslator jsonPocoTranslator)
public NpgsqlArrayTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory, NpgsqlJsonPocoTranslator jsonPocoTranslator)
{
_sqlExpressionFactory = sqlExpressionFactory;
_jsonPocoTranslator = jsonPocoTranslator;
Expand All @@ -47,7 +47,14 @@ public NpgsqlArrayMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFacto
[CanBeNull]
public SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList<SqlExpression> arguments)
{
// TODO: Fully support List<>
if (instance != null && instance.Type.IsGenericList() && method.Name == "get_Item" && arguments.Count == 1)
{
return
// Try translating indexing inside json column
_jsonPocoTranslator.TranslateMemberAccess(instance, arguments[0], method.ReturnType) ??
// Other types should be subscriptable - but PostgreSQL arrays are 1-based, so adjust the index.
_sqlExpressionFactory.ArrayIndex(instance, GenerateOneBasedIndexExpression(arguments[0]));
}

if (arguments.Count == 0)
return null;
Expand All @@ -56,7 +63,7 @@ public SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadO

var operandElementType = operand.Type.IsArray
? operand.Type.GetElementType()
: operand.Type.IsGenericType && operand.Type.GetGenericTypeDefinition() == typeof(List<>)
: operand.Type.IsGenericList()
? operand.Type.GetGenericArguments()[0]
: null;

Expand Down Expand Up @@ -122,5 +129,25 @@ public SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadO

return null;
}

public SqlExpression Translate(SqlExpression instance, MemberInfo member, Type returnType)
{
if (instance != null && instance.Type.IsGenericList() && member.Name == nameof(List<object>.Count))
{
return _jsonPocoTranslator.TranslateArrayLength(instance) ??
_sqlExpressionFactory.Function("cardinality", new[] { instance }, typeof(int?));
}

return null;
}

/// <summary>
/// PostgreSQL array indexing is 1-based. If the index happens to be a constant,
/// just increment it. Otherwise, append a +1 in the SQL.
/// </summary>
SqlExpression GenerateOneBasedIndexExpression([NotNull] SqlExpression expression)
=> expression is SqlConstantExpression constant
? _sqlExpressionFactory.Constant(Convert.ToInt32(constant.Value) + 1, constant.TypeMapping)
: (SqlExpression)_sqlExpressionFactory.Add(expression, _sqlExpressionFactory.Constant(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public NpgsqlMemberTranslatorProvider(

AddTranslators(
new IMemberTranslator[] {
new NpgsqlArrayTranslator(npgsqlSqlExpressionFactory, JsonPocoTranslator),
new NpgsqlStringMemberTranslator(npgsqlSqlExpressionFactory),
new NpgsqlDateTimeMemberTranslator(npgsqlSqlExpressionFactory),
new NpgsqlRangeTranslator(npgsqlSqlExpressionFactory),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public NpgsqlMethodCallTranslatorProvider(

AddTranslators(new IMethodCallTranslator[]
{
new NpgsqlArrayMethodTranslator(npgsqlSqlExpressionFactory, jsonTranslator),
new NpgsqlArrayTranslator(npgsqlSqlExpressionFactory, jsonTranslator),
new NpgsqlConvertTranslator(npgsqlSqlExpressionFactory),
new NpgsqlDateTimeMethodTranslator(npgsqlSqlExpressionFactory, npgsqlTypeMappingSource),
new NpgsqlNewGuidTranslator(npgsqlSqlExpressionFactory),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query;
Expand Down Expand Up @@ -28,8 +29,7 @@ public ArrayIndexExpression(
Check.NotNull(array, nameof(array));
Check.NotNull(index, nameof(index));

// TODO: Support also List<>
if (!array.Type.IsArray)
if (!array.Type.IsArray && !array.Type.IsGenericList())
throw new ArgumentException("Array expression must of an array type", nameof(array));
if (index.Type != typeof(int))
throw new ArgumentException("Index expression must of type int", nameof(index));
Expand Down
20 changes: 13 additions & 7 deletions src/EFCore.PG/Query/Internal/NpgsqlSqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ public ArrayIndexExpression ArrayIndex(
SqlExpression index,
RelationalTypeMapping typeMapping = null)
{
// TODO: Support List<>
if (!array.Type.IsArray)
throw new ArgumentException("Array expression must of an array type", nameof(array));
Type elementType;
if (array.Type.IsArray)
elementType = array.Type.GetElementType();
else if (array.Type.IsGenericList())
elementType = array.Type.GetGenericArguments()[0];
else
throw new ArgumentException("Array expression must be of an array or List<> type", nameof(array));

var elementType = array.Type.GetElementType();
return (ArrayIndexExpression)ApplyTypeMapping(new ArrayIndexExpression(array, index, elementType, null), typeMapping);
}

Expand Down Expand Up @@ -105,7 +108,7 @@ public override SqlExpression ApplyTypeMapping(SqlExpression sqlExpression, Rela
// PostgreSQL-specific expression types
RegexMatchExpression e => ApplyTypeMappingOnRegexMatch(e),
ArrayAnyAllExpression e => ApplyTypeMappingOnArrayAnyAll(e),
ArrayIndexExpression e => ApplyTypeMappingOnArrayIndex(e),
ArrayIndexExpression e => ApplyTypeMappingOnArrayIndex(e, typeMapping),
ILikeExpression e => ApplyTypeMappingOnILike(e),
PgFunctionExpression e => e.ApplyTypeMapping(typeMapping),

Expand Down Expand Up @@ -194,14 +197,17 @@ SqlExpression ApplyTypeMappingOnArrayAnyAll(ArrayAnyAllExpression arrayAnyAllExp
_boolTypeMapping);
}

SqlExpression ApplyTypeMappingOnArrayIndex(ArrayIndexExpression arrayIndexExpression)
SqlExpression ApplyTypeMappingOnArrayIndex(
ArrayIndexExpression arrayIndexExpression, RelationalTypeMapping typeMapping)
=> new ArrayIndexExpression(
// TODO: Infer the array's mapping from the element
ApplyDefaultTypeMapping(arrayIndexExpression.Array),
ApplyDefaultTypeMapping(arrayIndexExpression.Index),
arrayIndexExpression.Type,
// If the array has a type mapping (i.e. column), prefer that just like we prefer column mappings in general
arrayIndexExpression.Array.TypeMapping is NpgsqlArrayTypeMapping arrayMapping
? arrayMapping.ElementMapping
: FindMapping(arrayIndexExpression.Type));
: typeMapping ?? FindMapping(arrayIndexExpression.Type));

SqlExpression ApplyTypeMappingOnILike(ILikeExpression ilikeExpression)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Linq;
Expand Down Expand Up @@ -127,9 +128,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCall)
if (visited != null)
return visited;

// TODO: Handle List<>
if (methodCall.Arguments.Count > 0 && methodCall.Arguments[0].Type.IsArray)
if (methodCall.Arguments.Count > 0 && (
methodCall.Arguments[0].Type.IsArray || methodCall.Arguments[0].Type.IsGenericList()))
{
return VisitArrayMethodCall(methodCall.Method, methodCall.Arguments);
}

return null;
}
Expand Down Expand Up @@ -183,7 +186,7 @@ arguments[1] is LambdaExpression wherePredicate &&
arguments[1] is LambdaExpression wherePredicate &&
wherePredicate.Body is MethodCallExpression wherePredicateMethodCall &&
wherePredicateMethodCall.Method.IsClosedFormOf(Contains) &&
wherePredicateMethodCall.Arguments[0].Type.IsArray &&
wherePredicateMethodCall.Arguments[0].Type.IsArrayOrGenericList() &&
wherePredicateMethodCall.Arguments[1] is ParameterExpression parameterExpression &&
parameterExpression == wherePredicate.Parameters[0])
{
Expand All @@ -207,7 +210,7 @@ wherePredicateMethodCall.Arguments[1] is ParameterExpression parameterExpression
arguments[1] is LambdaExpression wherePredicate &&
wherePredicate.Body is MethodCallExpression wherePredicateMethodCall &&
wherePredicateMethodCall.Method.IsClosedFormOf(Contains) &&
wherePredicateMethodCall.Arguments[0].Type.IsArray &&
wherePredicateMethodCall.Arguments[0].Type.IsArrayOrGenericList() &&
wherePredicateMethodCall.Arguments[1] is ParameterExpression parameterExpression &&
parameterExpression == wherePredicate.Parameters[0])
{
Expand Down
Loading

0 comments on commit 6759e00

Please sign in to comment.