Skip to content

Commit

Permalink
SE: Learn from bool collection methods (#9497)
Browse files Browse the repository at this point in the history
  • Loading branch information
mary-georgiou-sonarsource authored Jul 4, 2024
1 parent fd7a9be commit db48fa2
Show file tree
Hide file tree
Showing 4 changed files with 598 additions and 448 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,10 @@ public static bool HasThisReceiver(this IInvocationOperationWrapper invocation,
|| (invocation.TargetMethod.IsExtensionMethod
&& !invocation.Arguments.IsEmpty
&& state.ResolveCaptureAndUnwrapConversion(invocation.Arguments[0].ToArgument().Value).Kind == OperationKindEx.InstanceReference);

public static IOperation GetInstance(this IInvocationOperationWrapper invocation, ProgramState state) =>
invocation.Instance
?? (invocation.TargetMethod.IsExtensionMethod
? state.ResolveCaptureAndUnwrapConversion(invocation.Arguments[0].ToArgument().Value)
: null);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace SonarAnalyzer.SymbolicExecution.Roslyn.OperationProcessors;

internal sealed partial class Invocation
{
private static readonly HashSet<string> ReturningNotNull = new()
{
private static readonly HashSet<string> ReturningNotNull =
[
nameof(Enumerable.Append),
nameof(Enumerable.AsEnumerable),
nameof(Queryable.AsQueryable),
Expand Down Expand Up @@ -69,25 +69,65 @@ internal sealed partial class Invocation
"UnionBy",
nameof(Enumerable.Where),
nameof(Enumerable.Zip),
};
];

private static readonly HashSet<string> ElementExistsCheckMethods =
[
nameof(Enumerable.Contains),
nameof(Enumerable.Any),
nameof(List<int>.Exists)
];

private static ProgramState[] ProcessLinqEnumerableAndQueryable(ProgramState state, IInvocationOperationWrapper invocation)
{
var name = invocation.TargetMethod.Name;
var states = ProcessElementExistsCheckMethods(state, invocation);
if (ReturningNotNull.Contains(name))
{
return state.SetOperationConstraint(invocation, ObjectConstraint.NotNull).ToArray();
return states.Select(x => x.SetOperationConstraint(invocation, ObjectConstraint.NotNull)).ToArray();
}
// ElementAtOrDefault is intentionally not supported. It's causing many FPs
else if (name is nameof(Enumerable.FirstOrDefault) or nameof(Enumerable.LastOrDefault) or nameof(Enumerable.SingleOrDefault))
{
return invocation.TargetMethod.ReturnType.IsReferenceType
? [state.SetOperationConstraint(invocation, ObjectConstraint.Null), state.SetOperationConstraint(invocation, ObjectConstraint.NotNull)]
: state.ToArray();
return states.SelectMany(x => new List<ProgramState>
{
x.SetOperationConstraint(invocation, ObjectConstraint.Null),
x.SetOperationConstraint(invocation, ObjectConstraint.NotNull)
}).ToArray();
}
else
{
return state.ToArray();
return states;
}
}

private static ProgramState[] ProcessElementExistsCheckMethods(ProgramState state, IInvocationOperationWrapper invocation)
{
if (ElementExistsCheckMethods.Contains(invocation.TargetMethod.Name) && invocation.GetInstance(state).TrackedSymbol(state) is { } instanceSymbol)
{
return state[instanceSymbol]?.Constraint<CollectionConstraint>() switch
{
CollectionConstraint constraint when constraint == CollectionConstraint.Empty => state.SetOperationConstraint(invocation, BoolConstraint.False).ToArray(),
CollectionConstraint constraint when constraint == CollectionConstraint.NotEmpty =>
HasNoParameters(invocation.TargetMethod)
? state.SetOperationConstraint(invocation, BoolConstraint.True).ToArray()
: state.ToArray(),
_ when HasNoParameters(invocation.TargetMethod) =>
[
state.SetOperationConstraint(invocation, BoolConstraint.True).SetSymbolConstraint(instanceSymbol, CollectionConstraint.NotEmpty),
state.SetOperationConstraint(invocation, BoolConstraint.False).SetSymbolConstraint(instanceSymbol, CollectionConstraint.Empty)
],
_ =>
[
state.SetOperationConstraint(invocation, BoolConstraint.True).SetSymbolConstraint(instanceSymbol, CollectionConstraint.NotEmpty),
state
]
};
}
return state.ToArray();

static bool HasNoParameters(IMethodSymbol symbol) =>
(symbol.IsExtensionMethod && symbol.Parameters.Length == 1)
|| symbol.Parameters.IsEmpty;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ internal sealed partial class Invocation : MultiProcessor<IInvocationOperationWr
nameof(Enumerable.Zip),
];

private static readonly ImmutableArray<KnownType> CollectionTypes = ImmutableArray.Create(
KnownType.System_Linq_Enumerable, // Has all the extension methods for IEnumerable <T>
KnownType.System_Linq_Queryable, // Has it's own versions of several extension methods (like Any() with IQueryable<T> parameter)
KnownType.System_Collections_Generic_IList_T); // Has it's own implementation of certain methods (like Contains())

protected override IInvocationOperationWrapper Convert(IOperation operation) =>
IInvocationOperationWrapper.FromOperation(operation);

Expand Down Expand Up @@ -121,7 +126,7 @@ _ when invocation.TargetMethod.Is(KnownType.Microsoft_VisualBasic_Information, "
_ when invocation.TargetMethod.Is(KnownType.System_Diagnostics_Debug, nameof(Debug.Assert)) => ProcessDebugAssert(context, invocation),
_ when invocation.TargetMethod.Is(KnownType.System_Object, nameof(ReferenceEquals)) => ProcessReferenceEquals(context, invocation),
_ when invocation.TargetMethod.Is(KnownType.System_Nullable_T, "get_HasValue") => ProcessNullableHasValue(state, invocation),
_ when invocation.TargetMethod.ContainingType.IsAny(KnownType.System_Linq_Enumerable, KnownType.System_Linq_Queryable) => ProcessLinqEnumerableAndQueryable(state, invocation),
_ when invocation.TargetMethod.ContainingType.DerivesOrImplementsAny(CollectionTypes) => ProcessLinqEnumerableAndQueryable(state, invocation),
_ when invocation.TargetMethod.Name == nameof(Equals) => ProcessEquals(context, invocation),
_ when invocation.TargetMethod.ContainingType.Is(KnownType.System_String) => ProcessSystemStringInvocation(state, invocation),
_ => ProcessArgumentAttributes(state, invocation),
Expand Down Expand Up @@ -204,17 +209,15 @@ _ when argument.WrappedOperation.TrackedSymbol(state) is { } argumentSymbol =>

ProgramState[] ExplodeStates(ISymbol argumentSymbol) =>
learnNull
? new[]
{
state.SetOperationConstraint(invocation, whenBoolConstraint).SetSymbolConstraint(argumentSymbol, ObjectConstraint.NotNull),
state.SetOperationConstraint(invocation, whenBoolConstraint.Opposite).SetSymbolConstraint(argumentSymbol, ObjectConstraint.Null),
state.SetOperationConstraint(invocation, whenBoolConstraint.Opposite).SetSymbolConstraint(argumentSymbol, ObjectConstraint.NotNull),
}
: new[]
{
state.SetOperationConstraint(invocation, whenBoolConstraint).SetSymbolConstraint(argumentSymbol, ObjectConstraint.NotNull),
state.SetOperationConstraint(invocation, whenBoolConstraint.Opposite),
};
? [
state.SetOperationConstraint(invocation, whenBoolConstraint).SetSymbolConstraint(argumentSymbol, ObjectConstraint.NotNull),
state.SetOperationConstraint(invocation, whenBoolConstraint.Opposite).SetSymbolConstraint(argumentSymbol, ObjectConstraint.Null),
state.SetOperationConstraint(invocation, whenBoolConstraint.Opposite).SetSymbolConstraint(argumentSymbol, ObjectConstraint.NotNull),
]
: [
state.SetOperationConstraint(invocation, whenBoolConstraint).SetSymbolConstraint(argumentSymbol, ObjectConstraint.NotNull),
state.SetOperationConstraint(invocation, whenBoolConstraint.Opposite),
];
}
private static ProgramState[] ProcessDoesNotReturnIf(ProgramState state, IArgumentOperationWrapper argument, bool when) =>
Expand Down Expand Up @@ -316,11 +319,11 @@ private static ProgramState[] ProcessEqualsObject(SymbolicContext context, IOper
}
else if ((leftConstraint == ObjectConstraint.Null ? rightOperation : leftOperation).TrackedSymbol(context.State) is { } symbol)
{
return new[]
{
return
[
context.SetOperationConstraint(BoolConstraint.True).SetSymbolConstraint(symbol, ObjectConstraint.Null),
context.SetOperationConstraint(BoolConstraint.False).SetSymbolConstraint(symbol, ObjectConstraint.NotNull)
};
];
}
}
return context.State.ToArray();
Expand Down Expand Up @@ -352,11 +355,11 @@ private static ProgramState[] ProcessNullableHasValue(ProgramState state, IInvoc
}
else if (invocation.Instance.TrackedSymbol(state) is { } symbol)
{
return new[]
{
return
[
state.SetSymbolConstraint(symbol, ObjectConstraint.Null).SetOperationConstraint(invocation, BoolConstraint.False),
state.SetSymbolConstraint(symbol, ObjectConstraint.NotNull).SetOperationConstraint(invocation, BoolConstraint.True),
};
];
}
else
{
Expand All @@ -377,11 +380,11 @@ private static ProgramState[] ProcessInformationIsNothing(SymbolicContext contex
ObjectConstraint constraint when constraint == ObjectConstraint.Null => context.SetOperationConstraint(BoolConstraint.True).ToArray(),
ObjectConstraint constraint when constraint == ObjectConstraint.NotNull => context.SetOperationConstraint(BoolConstraint.False).ToArray(),
_ when invocation.Arguments[0].ToArgument().Value.UnwrapConversion().Type is { } type && !type.CanBeNull() => context.SetOperationConstraint(BoolConstraint.False).ToArray(),
_ when invocation.Arguments[0].TrackedSymbol(context.State) is { } argumentSymbol => new[]
{
_ when invocation.Arguments[0].TrackedSymbol(context.State) is { } argumentSymbol =>
[
context.SetOperationConstraint(BoolConstraint.True).SetSymbolConstraint(argumentSymbol, ObjectConstraint.Null),
context.SetOperationConstraint(BoolConstraint.False).SetSymbolConstraint(argumentSymbol, ObjectConstraint.NotNull),
},
],
_ => context.State.ToArray()
};

Expand All @@ -391,5 +394,5 @@ private static bool IsNullableGetValueOrDefault(IInvocationOperationWrapper invo
private static bool IsSupportedEqualsType(ITypeSymbol type) =>
type.IsNullableValueType() // int?.Equals
|| (type.IsStruct() && type.SpecialType != SpecialType.None) // int.Equals and similar build-in basic value types
|| type.SpecialType == SpecialType.System_ValueType; // struct.Equals that was not overriden
|| type.SpecialType == SpecialType.System_ValueType; // struct.Equals that was not overriden
}
Loading

0 comments on commit db48fa2

Please sign in to comment.