Skip to content

Commit

Permalink
Translate ToString() over enums
Browse files Browse the repository at this point in the history
Fixes #33635 and #20604
  • Loading branch information
Danevandy99 committed May 27, 2024
1 parent 5eec48b commit 56841ce
Show file tree
Hide file tree
Showing 15 changed files with 569 additions and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -882,10 +882,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

// if object is nullable, add null safeguard before calling the function
// we special-case Nullable<>.GetValueOrDefault, which doesn't need the safeguard
// we special-case Nullable<>.GetValueOrDefault, which doesn't need the safeguard,
// and Nullable<>.ToString when the object is a nullable value type.
if (methodCallExpression.Object != null
&& @object!.Type.IsNullableType()
&& methodCallExpression.Method.Name != nameof(Nullable<int>.GetValueOrDefault))
&& methodCallExpression.Method.Name != nameof(Nullable<int>.GetValueOrDefault)
&& (!@object!.Type.IsNullableValueType()
|| methodCallExpression.Method.Name != nameof(Nullable<int>.ToString)))
{
var result = (Expression)methodCallExpression.Update(
Expression.Convert(@object, methodCallExpression.Object.Type),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal;
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class EnumHasFlagTranslator : IMethodCallTranslator
public class EnumMethodTranslator : IMethodCallTranslator
{
private static readonly MethodInfo MethodInfo
= typeof(Enum).GetRuntimeMethod(nameof(Enum.HasFlag), [typeof(Enum)])!;
private static readonly MethodInfo HasFlagMethodInfo
= typeof(Enum).GetRuntimeMethod(nameof(Enum.HasFlag), new[] { typeof(Enum) })!;

private static readonly MethodInfo ToStringMethodInfo
= typeof(object).GetRuntimeMethod(nameof(ToString), new Type[] { })!;

private readonly ISqlExpressionFactory _sqlExpressionFactory;

Expand All @@ -25,7 +28,7 @@ private static readonly MethodInfo MethodInfo
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public EnumHasFlagTranslator(ISqlExpressionFactory sqlExpressionFactory)
public EnumMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
{
_sqlExpressionFactory = sqlExpressionFactory;
}
Expand All @@ -42,7 +45,7 @@ public EnumHasFlagTranslator(ISqlExpressionFactory sqlExpressionFactory)
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (Equals(method, MethodInfo)
if (Equals(method, HasFlagMethodInfo)
&& instance != null)
{
var argument = arguments[0];
Expand All @@ -51,6 +54,37 @@ public EnumHasFlagTranslator(ISqlExpressionFactory sqlExpressionFactory)
: _sqlExpressionFactory.Equal(_sqlExpressionFactory.And(instance, argument), argument);
}

if (Equals(method, ToStringMethodInfo)
&& instance != null
&& instance.Type.IsEnum)
{
var converterType = instance.TypeMapping?.Converter?.GetType();

if (converterType is not null
&& converterType.IsGenericType)
{
if (converterType.GetGenericTypeDefinition() == typeof(EnumToNumberConverter<,>)
&& converterType.GetGenericArguments().Length == 2
&& converterType.GetGenericArguments()[1] == typeof(int)
&& (instance is SqlParameterExpression || instance is ColumnExpression))
{
var cases = Enum.GetValues(instance.Type)
.Cast<object>()
.Select(value => new CaseWhenClause(
_sqlExpressionFactory.Equal(instance, _sqlExpressionFactory.Constant(value)),
_sqlExpressionFactory.Constant(value.ToString(), typeof(string))))
.ToArray();

return _sqlExpressionFactory.Case(cases, _sqlExpressionFactory.Constant(string.Empty, typeof(string)));
}
else if (converterType.GetGenericTypeDefinition() == typeof(EnumToStringConverter<>))
{
// TODO: Unnecessary cast to string, #33733
return _sqlExpressionFactory.MakeUnary(ExpressionType.Convert, instance, typeof(string));
}
}
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public RelationalMethodCallTranslatorProvider(RelationalMethodCallTranslatorProv
new CollateTranslator(),
new ContainsTranslator(sqlExpressionFactory),
new LikeTranslator(sqlExpressionFactory),
new EnumHasFlagTranslator(sqlExpressionFactory),
new EnumMethodTranslator(sqlExpressionFactory),
new GetValueOrDefaultTranslator(sqlExpressionFactory),
new ComparisonTranslator(sqlExpressionFactory),
new ByteArraySequenceEqualTranslator(sqlExpressionFactory),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ public SqlServerObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFact
_sqlExpressionFactory.Constant(true.ToString()));
}

// Enums are handled by EnumMethodTranslator

return TypeMapping.TryGetValue(instance.Type, out var storeType)
? _sqlExpressionFactory.Function(
"CONVERT",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ public SqliteObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFactory
_sqlExpressionFactory.Constant(true.ToString()));
}

// Enums are handled by EnumMethodTranslator

return TypeMapping.Contains(instance.Type)
? _sqlExpressionFactory.Convert(instance, typeof(string))
: null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,12 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
b.HasOne(w => w.Owner).WithMany(g => g.Weapons).HasForeignKey(w => w.OwnerFullName).HasPrincipalKey(g => g.FullName);
});

modelBuilder.Entity<Mission>().Property(m => m.Id).ValueGeneratedNever();
modelBuilder.Entity<Mission>(
b =>
{
b.Property(m => m.Id).ValueGeneratedNever();
b.Property(m => m.Difficulty).HasConversion<string>();
});

modelBuilder.Entity<SquadMission>(
b =>
Expand Down
38 changes: 28 additions & 10 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,34 @@ public virtual Task ToString_boolean_property_nullable(bool async)
async,
ss => ss.Set<LocustHorde>().Select(lh => lh.Eradicated.ToString()));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task ToString_enum_property_projection(bool async)
=> AssertQuery(
async,
ss => ss.Set<Gear>().Select(g => g.Rank.ToString()));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task ToString_nullable_enum_property_projection(bool async)
=> AssertQuery(
async,
ss => ss.Set<Weapon>().Select(w => w.AmmunitionType.ToString()));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task ToString_enum_contains(bool async)
=> AssertQuery(
async,
ss => ss.Set<Mission>().Where(g => g.Difficulty.ToString().Contains("Med")).Select(g => g.CodeName));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task ToString_nullable_enum_contains(bool async)
=> AssertQuery(
async,
ss => ss.Set<Weapon>().Where(w => w.AmmunitionType.ToString().Contains("Cart")).Select(g => g.Name));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Include_multiple_one_to_one_and_one_to_many_self_reference(bool async)
Expand Down Expand Up @@ -3121,16 +3149,6 @@ public virtual Task Projecting_nullable_bool_in_conditional_works(bool async)
new { Prop = cg.Gear != null ? cg.Gear.HasSoulPatch : false }),
e => e.Prop);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Enum_ToString_is_client_eval(bool async)
=> AssertQuery(
async,
ss =>
ss.Set<Gear>().OrderBy(g => g.SquadId)
.ThenBy(g => g.Nickname)
.Select(g => g.Rank.ToString()));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collections_naked_navigation_with_ToList(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ public static IReadOnlyList<Mission> CreateMissions()
Timeline = new DateTimeOffset(599898024001234567, new TimeSpan(1, 30, 0)),
Duration = new TimeSpan(1, 2, 3),
Date = new DateOnly(2020, 1, 1),
Time = new TimeOnly(15, 30, 10)
Time = new TimeOnly(15, 30, 10),
Difficulty = MissionDifficulty.Low
},
new()
{
Expand All @@ -143,7 +144,8 @@ public static IReadOnlyList<Mission> CreateMissions()
Timeline = new DateTimeOffset(2, 3, 1, 8, 0, 0, new TimeSpan(-5, 0, 0)),
Duration = new TimeSpan(0, 1, 2, 3, 456),
Date = new DateOnly(1990, 11, 10),
Time = new TimeOnly(10, 15, 50, 500)
Time = new TimeOnly(10, 15, 50, 500),
Difficulty = MissionDifficulty.Medium
},
new()
{
Expand All @@ -153,7 +155,8 @@ public static IReadOnlyList<Mission> CreateMissions()
Timeline = new DateTimeOffset(10, 5, 3, 12, 0, 0, new TimeSpan()),
Duration = new TimeSpan(0, 1, 0, 15, 456),
Date = new DateOnly(1, 1, 1),
Time = new TimeOnly(0, 0, 0)
Time = new TimeOnly(0, 0, 0),
Difficulty = MissionDifficulty.Unknown
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public class Mission
public TimeSpan Duration { get; set; }
public DateOnly Date { get; set; }
public TimeOnly Time { get; set; }
public MissionDifficulty Difficulty { get; set; }

public virtual ICollection<SquadMission> ParticipatingSquads { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore.TestModels.GearsOfWarModel;

public enum MissionDifficulty
{
Unknown = 0,
Low = 1,
Medium = 2,
High = 3,
Extreme = 4
}
Loading

0 comments on commit 56841ce

Please sign in to comment.