Skip to content

Commit

Permalink
Merge pull request #5807 from Youssef1313/compare-symbols-correctly-fp
Browse files Browse the repository at this point in the history
Fix 'Compare symbols correctly' false positive with custom comparers
  • Loading branch information
mavasani authored Jan 31, 2022
2 parents 50d35c8 + a70ef8e commit 3405da8
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,29 @@ public override void Initialize(AnalysisContext context)
// Check that the EqualityComparer exists and can be used, otherwise the Roslyn version
// being used it too low to need the change for method references
var symbolEqualityComparerType = compilation.GetOrCreateTypeByMetadataName(SymbolEqualityComparerName);
var hasSymbolEqualityComparer = UseSymbolEqualityComparer(compilation);
context.RegisterOperationAction(
context => HandleBinaryOperator(in context, symbolType),
OperationKind.BinaryOperator);
var equalityComparerMethods = GetEqualityComparerMethodsToCheck(compilation);
var systemHashCode = compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemHashCode);
var iEqualityComparer = compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsGenericIEqualityComparer1);
context.RegisterOperationAction(
context => HandleInvocationOperation(in context, symbolType, symbolEqualityComparerType, equalityComparerMethods, systemHashCode),
context => HandleInvocationOperation(in context, symbolType, hasSymbolEqualityComparer, equalityComparerMethods, systemHashCode, iEqualityComparer),
OperationKind.Invocation);
if (symbolEqualityComparerType != null)
if (hasSymbolEqualityComparer && iEqualityComparer is not null)
{
var collectionTypesBuilder = ImmutableHashSet.CreateBuilder<INamedTypeSymbol>(SymbolEqualityComparer.Default);
collectionTypesBuilder.AddIfNotNull(compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsGenericDictionary2));
collectionTypesBuilder.AddIfNotNull(compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsGenericHashSet1));
collectionTypesBuilder.AddIfNotNull(compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemCollectionsConcurrentConcurrentDictionary2));
context.RegisterOperationAction(
context => HandleObjectCreation(in context, symbolType, symbolEqualityComparerType, collectionTypesBuilder.ToImmutable()),
context => HandleObjectCreation(in context, symbolType, iEqualityComparer, collectionTypesBuilder.ToImmutable()),
OperationKind.ObjectCreation);
}
});
Expand Down Expand Up @@ -144,9 +145,10 @@ private static void HandleBinaryOperator(in OperationAnalysisContext context, IN
private static void HandleInvocationOperation(
in OperationAnalysisContext context,
INamedTypeSymbol symbolType,
INamedTypeSymbol? symbolEqualityComparerType,
bool hasSymbolEqualityComparer,
ImmutableDictionary<string, ImmutableHashSet<INamedTypeSymbol>> equalityComparerMethods,
INamedTypeSymbol? systemHashCodeType)
INamedTypeSymbol? systemHashCodeType,
INamedTypeSymbol? iEqualityComparer)
{
var invocationOperation = (IInvocationOperation)context.Operation;
var method = invocationOperation.TargetMethod;
Expand All @@ -163,7 +165,7 @@ private static void HandleInvocationOperation(
break;

case s_symbolEqualsName:
if (symbolEqualityComparerType is not null && IsNotInstanceInvocationOrNotOnSymbol(invocationOperation, symbolType))
if (hasSymbolEqualityComparer && IsNotInstanceInvocationOrNotOnSymbol(invocationOperation, symbolType))
{
var parameters = invocationOperation.Arguments;
if (parameters.All(p => IsSymbolType(p.Value, symbolType)))
Expand All @@ -186,10 +188,10 @@ invocationOperation.Instance is null &&

default:
if (equalityComparerMethods.TryGetValue(method.Name, out var possibleMethodTypes) &&
symbolEqualityComparerType is not null &&
hasSymbolEqualityComparer &&
possibleMethodTypes.Contains(method.ContainingType.OriginalDefinition) &&
IsBehavingOnSymbolType(method, symbolType) &&
!invocationOperation.Arguments.Any(arg => IsSymbolType(arg.Value, symbolEqualityComparerType)))
!invocationOperation.Arguments.Any(arg => IsSymbolType(arg.Value, iEqualityComparer)))
{
context.ReportDiagnostic(invocationOperation.CreateDiagnostic(CollectionRule));
}
Expand Down Expand Up @@ -230,23 +232,23 @@ static bool IsBehavingOnSymbolType(IMethodSymbol? method, INamedTypeSymbol symbo
}

private static void HandleObjectCreation(in OperationAnalysisContext context, INamedTypeSymbol symbolType,
INamedTypeSymbol symbolEqualityComparerType, ImmutableHashSet<INamedTypeSymbol> collectionTypes)
INamedTypeSymbol iEqualityComparerType, ImmutableHashSet<INamedTypeSymbol> collectionTypes)
{
var objectCreation = (IObjectCreationOperation)context.Operation;

if (objectCreation.Type is INamedTypeSymbol createdType &&
collectionTypes.Contains(createdType.OriginalDefinition) &&
!createdType.TypeArguments.IsEmpty &&
IsSymbolType(createdType.TypeArguments[0], symbolType) &&
!objectCreation.Arguments.Any(arg => IsSymbolType(arg.Value, symbolEqualityComparerType)))
!objectCreation.Arguments.Any(arg => IsSymbolType(arg.Value, iEqualityComparerType)))
{
context.ReportDiagnostic(objectCreation.CreateDiagnostic(CollectionRule));
}
}

private static bool IsSymbolType(IOperation? operation, INamedTypeSymbol symbolType)
private static bool IsSymbolType(IOperation? operation, INamedTypeSymbol? symbolType)
{
if (operation?.Type is object && IsSymbolType(operation.Type, symbolType))
if (operation?.Type is object && IsSymbolType(operation.Type.OriginalDefinition, symbolType))
{
return true;
}
Expand All @@ -259,7 +261,7 @@ private static bool IsSymbolType(IOperation? operation, INamedTypeSymbol symbolT
return false;
}

private static bool IsSymbolType(ITypeSymbol typeSymbol, INamedTypeSymbol symbolType)
private static bool IsSymbolType(ITypeSymbol typeSymbol, INamedTypeSymbol? symbolType)
=> typeSymbol != null
&& (SymbolEqualityComparer.Default.Equals(typeSymbol, symbolType)
|| typeSymbol.AllInterfaces.Contains(symbolType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1248,9 +1248,168 @@ public int GetHashCode(object o1, object o2)
}.RunAsync();
}

[Fact]
[WorkItem(5715, "https://github.com/dotnet/roslyn-analyzers/issues/5715")]
public async Task RS1024_CustomComparer_Instance_Is_InterfaceAsync()
{
var csCode = @"
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
public class C
{
public void M(IEnumerable<ITypeSymbol> symbols)
{
_ = new HashSet<ISymbol>(SymbolNameComparer.Instance);
_ = symbols.ToDictionary(s => s, s => s.ToDisplayString(), SymbolNameComparer.Instance);
_ = symbols.ToDictionary(s => s, s => s.ToDisplayString(), SymbolEqualityComparer.Default);
}
}
internal sealed class SymbolNameComparer : EqualityComparer<ISymbol>
{
private SymbolNameComparer() { }
internal static IEqualityComparer<ISymbol> Instance { get; } = new SymbolNameComparer();
public override bool Equals(ISymbol x, ISymbol y) => true;
public override int GetHashCode(ISymbol obj) => 0;
}
";

await new VerifyCS.Test
{
TestCode = csCode,
FixedCode = csCode,
ReferenceAssemblies = CreateNetCoreReferenceAssemblies(),

}.RunAsync();

var vbCode = @"
Imports System.Collections.Generic
Imports System.Linq
Imports Microsoft.CodeAnalysis
Public Class C
Public Sub M(symbols As IEnumerable(Of ITypeSymbol))
Dim x As New HashSet(Of ISymbol)(SymbolNameComparer.Instance)
Dim y = symbols.ToDictionary(Function(s) s, Function(s) s.ToDisplayString(), SymbolNameComparer.Instance)
Dim z = symbols.ToDictionary(Function(s) s, Function(s) s.ToDisplayString(), SymbolEqualityComparer.Default)
End Sub
End Class
Class SymbolNameComparer
Inherits EqualityComparer(Of ISymbol)
Private Sub New()
End Sub
Friend Shared Property Instance As IEqualityComparer(Of ISymbol) = New SymbolNameComparer()
Public Overrides Function Equals(x As ISymbol, y As ISymbol) As Boolean
Return True
End Function
Public Overrides Function GetHashCode(obj As ISymbol) As Integer
Return 0
End Function
End Class
";

await new VerifyVB.Test
{
TestCode = vbCode,
FixedCode = vbCode,
ReferenceAssemblies = CreateNetCoreReferenceAssemblies(),

}.RunAsync();
}

[Fact]
[WorkItem(5715, "https://github.com/dotnet/roslyn-analyzers/issues/5715")]
public async Task RS1024_CustomComparer_Instance_Is_TypeAsync()
{
var csCode = @"
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
public class C
{
public void M(IEnumerable<ITypeSymbol> symbols)
{
_ = new HashSet<ISymbol>(SymbolNameComparer.Instance);
_ = symbols.ToDictionary(s => s, s => s.ToDisplayString(), SymbolNameComparer.Instance);
_ = symbols.ToDictionary(s => s, s => s.ToDisplayString(), SymbolEqualityComparer.Default);
}
}
internal sealed class SymbolNameComparer : EqualityComparer<ISymbol>
{
private SymbolNameComparer() { }
internal static SymbolNameComparer Instance { get; } = new SymbolNameComparer();
public override bool Equals(ISymbol x, ISymbol y) => true;
public override int GetHashCode(ISymbol obj) => 0;
}
";

await new VerifyCS.Test
{
TestCode = csCode,
FixedCode = csCode,
ReferenceAssemblies = CreateNetCoreReferenceAssemblies(),

}.RunAsync();

var vbCode = @"
Imports System.Collections.Generic
Imports System.Linq
Imports Microsoft.CodeAnalysis
Public Class C
Public Sub M(symbols As IEnumerable(Of ITypeSymbol))
Dim x As New HashSet(Of ISymbol)(SymbolNameComparer.Instance)
Dim y = symbols.ToDictionary(Function(s) s, Function(s) s.ToDisplayString(), SymbolNameComparer.Instance)
Dim z = symbols.ToDictionary(Function(s) s, Function(s) s.ToDisplayString(), SymbolEqualityComparer.Default)
End Sub
End Class
Class SymbolNameComparer
Inherits EqualityComparer(Of ISymbol)
Private Sub New()
End Sub
Friend Shared Property Instance As SymbolNameComparer = New SymbolNameComparer()
Public Overrides Function Equals(x As ISymbol, y As ISymbol) As Boolean
Return True
End Function
Public Overrides Function GetHashCode(obj As ISymbol) As Integer
Return 0
End Function
End Class
";

await new VerifyVB.Test
{
TestCode = vbCode,
FixedCode = vbCode,
ReferenceAssemblies = CreateNetCoreReferenceAssemblies(),

}.RunAsync();
}

private static ReferenceAssemblies CreateNetCoreReferenceAssemblies()
=> ReferenceAssemblies.NetCore.NetCoreApp31.AddPackages(ImmutableArray.Create(
new PackageIdentity("Microsoft.CodeAnalysis", "3.0.0"),
new PackageIdentity("Microsoft.CodeAnalysis", "4.0.1"),
new PackageIdentity("System.Runtime.Serialization.Formatters", "4.3.0"),
new PackageIdentity("System.Configuration.ConfigurationManager", "4.7.0"),
new PackageIdentity("System.Security.Cryptography.Cng", "4.7.0"),
Expand Down

0 comments on commit 3405da8

Please sign in to comment.