Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring back enumerable Concat/Append translations for ExecuteUpdate #3005

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Npgsql.EntityFrameworkCore.PostgreSQL.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics;
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;

Expand Down Expand Up @@ -33,6 +34,17 @@ public class NpgsqlArrayMethodTranslator : IMethodCallTranslator
private static readonly MethodInfo Enumerable_SequenceEqual =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.SequenceEqual) && m.GetParameters().Length == 2);

// TODO: Enumerable Append and Concat are only here because primitive collections aren't handled in ExecuteUpdate,
// https://github.com/dotnet/efcore/issues/32494
private static readonly MethodInfo Enumerable_Append =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.Append) && m.GetParameters().Length == 2);

private static readonly MethodInfo Enumerable_Concat =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.Concat) && m.GetParameters().Length == 2);

// ReSharper restore InconsistentNaming

#endregion Methods
Expand Down Expand Up @@ -155,6 +167,38 @@ static bool IsMappedToNonArray(SqlExpression arrayOrList)
_sqlExpressionFactory.Constant(-1));
}

// TODO: Enumerable Append and Concat are only here because primitive collections aren't handled in ExecuteUpdate,
// https://github.com/dotnet/efcore/issues/32494
if (method.IsClosedFormOf(Enumerable_Append))
{
var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList);

return _sqlExpressionFactory.Function(
"array_append",
new[] { array, item },
nullable: true,
TrueArrays[2],
arrayOrList.Type,
arrayOrList.TypeMapping);
}

if (method.IsClosedFormOf(Enumerable_Concat))
{
var inferredMapping = ExpressionExtensions.InferTypeMapping(arrayOrList, arguments[0]);

return _sqlExpressionFactory.Function(
"array_cat",
new[]
{
_sqlExpressionFactory.ApplyTypeMapping(arrayOrList, inferredMapping),
_sqlExpressionFactory.ApplyTypeMapping(arguments[0], inferredMapping)
},
nullable: true,
TrueArrays[2],
arrayOrList.Type,
inferredMapping);
}

return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,41 @@ SELECT COALESCE(sum(o0."Amount"), 0)::int
""");
}

[ConditionalTheory] // #3001
[MemberData(nameof(IsAsyncData))]
public virtual async Task Update_with_primitive_collection_in_value_selector(bool async)
{
var contextFactory = await InitializeAsync<Context3001>(
seed: ctx =>
{
ctx.AddRange(new EntityWithPrimitiveCollection { Tags = new List<string> { "tag1", "tag2" }});
ctx.SaveChanges();
});

await AssertUpdate(
async,
contextFactory.CreateContext,
ss => ss.EntitiesWithPrimitiveCollection,
s => s.SetProperty(x => x.Tags, x => x.Tags.Append("another_tag")),
rowsAffectedCount: 1);
}

protected class Context3001 : DbContext
{
public Context3001(DbContextOptions options)
: base(options)
{
}

public DbSet<EntityWithPrimitiveCollection> EntitiesWithPrimitiveCollection { get; set; }
}

protected class EntityWithPrimitiveCollection
{
public int Id { get; set; }
public List<string> Tags { get; set; }
}

private void AssertSql(params string[] expected)
=> TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down