Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collection expressions: nullable analysis of spread element expression #74686

Merged
merged 11 commits into from
Oct 18, 2024
7 changes: 4 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -976,9 +976,10 @@ BoundNode bindSpreadElement(BoundCollectionExpressionSpreadElement element, Type
Debug.Assert(enumeratorInfo is { });
Debug.Assert(enumeratorInfo.ElementType is { }); // ElementType is set always, even for IEnumerable.

var elementPlaceholder = new BoundValuePlaceholder(syntax, enumeratorInfo.ElementType) { WasCompilerGenerated = true };
var expressionSyntax = element.Expression.Syntax;
var elementPlaceholder = new BoundValuePlaceholder(expressionSyntax, enumeratorInfo.ElementType) { WasCompilerGenerated = true };
var convertElement = CreateConversion(
element.Syntax,
expressionSyntax,
elementPlaceholder,
elementConversion,
isCast: false,
Expand All @@ -991,7 +992,7 @@ BoundNode bindSpreadElement(BoundCollectionExpressionSpreadElement element, Type
conversion: element.Conversion,
enumeratorInfo,
elementPlaceholder: elementPlaceholder,
iteratorBody: new BoundExpressionStatement(syntax, convertElement) { WasCompilerGenerated = true },
iteratorBody: new BoundExpressionStatement(expressionSyntax, convertElement) { WasCompilerGenerated = true },
lengthOrCount: element.LengthOrCount);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,30 @@ private void VerifyExpression(BoundExpression expression, bool overrideSkippedEx
return null;
}

return base.VisitCollectionExpression(node);
// See NullableWalker.VisitCollectionExpression.getCollectionDetails() which
// does not have an element type for the ImplementsIEnumerable case.
bool hasElementType = node.CollectionTypeKind is not (CollectionExpressionTypeKind.None or CollectionExpressionTypeKind.ImplementsIEnumerable);
foreach (var element in node.Elements)
{
if (element is BoundCollectionExpressionSpreadElement spread)
{
Visit(spread.Expression);
Visit(spread.Conversion);
if (spread.EnumeratorInfoOpt != null)
{
VisitForEachEnumeratorInfo(spread.EnumeratorInfoOpt);
}
if (hasElementType)
{
Visit(((BoundExpressionStatement?)spread.IteratorBody)?.Expression);
Copy link
Contributor

@RikkiGibson RikkiGibson Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to dig into the expression here, instead of just visiting the statement, for the verifier to pick it up? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NullableWalker only visits the expression within the BoundExpressionStatement of spread.IteratorBody. (The containing BoundExpressionStatement is only there to allow sharing code in binding between foreach and .. since the foreach infrastructure expects the body to be a statement.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds like if we visited the statement here, then we would fail verification because NullableWalker did not also visit that statement.

}
}
else
{
Visit(element);
}
}
return null;
}

public override BoundNode? VisitDeconstructionAssignmentOperator(BoundDeconstructionAssignmentOperator node)
Expand Down Expand Up @@ -196,14 +219,7 @@ private void VerifyExpression(BoundExpression expression, bool overrideSkippedEx
Visit(node.AwaitOpt);
if (node.EnumeratorInfoOpt != null)
{
Visit(node.EnumeratorInfoOpt.DisposeAwaitableInfo);
if (node.EnumeratorInfoOpt.GetEnumeratorInfo.Method.IsExtensionMethod)
{
foreach (var arg in node.EnumeratorInfoOpt.GetEnumeratorInfo.Arguments)
{
Visit(arg);
}
}
VisitForEachEnumeratorInfo(node.EnumeratorInfoOpt);
}
Visit(node.Expression);
// https://github.com/dotnet/roslyn/issues/35010: handle the deconstruction
Expand All @@ -212,6 +228,18 @@ private void VerifyExpression(BoundExpression expression, bool overrideSkippedEx
return null;
}

private void VisitForEachEnumeratorInfo(ForEachEnumeratorInfo enumeratorInfo)
{
Visit(enumeratorInfo.DisposeAwaitableInfo);
if (enumeratorInfo.GetEnumeratorInfo.Method.IsExtensionMethod)
{
foreach (var arg in enumeratorInfo.GetEnumeratorInfo.Arguments)
{
Visit(arg);
}
}
}

public override BoundNode? VisitGotoStatement(BoundGotoStatement node)
{
// There's no need to verify the label children. They do not have types or nullabilities
Expand Down
109 changes: 89 additions & 20 deletions src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3664,8 +3664,20 @@ protected override void VisitStatement(BoundStatement statement)
VisitRvalue(initializer.Arguments[0]);
break;
case BoundCollectionExpressionSpreadElement spread:
// https://github.com/dotnet/roslyn/issues/68786: We should check the spread
Visit(spread);
if (elementType.HasType &&
spread.ElementPlaceholder is { } elementPlaceholder &&
spread.IteratorBody is { })
{
var itemResult = spread.EnumeratorInfoOpt == null ? default : _visitResult;
var iteratorBody = ((BoundExpressionStatement)spread.IteratorBody).Expression;
AddPlaceholderReplacement(elementPlaceholder, expression: elementPlaceholder, itemResult);
var completion = VisitOptionalImplicitConversion(iteratorBody, elementType,
Copy link
Contributor

@RikkiGibson RikkiGibson Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we are assuming the iteratorBody is something that would be convertible to the elementType? But I thought it could represent a call to a void-returning Add method, for example. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Add() cases are not handled in this PR. For those cases, elementType.HasType == false so we shouldn't reach this code path.

Copy link
Contributor

@RikkiGibson RikkiGibson Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't find the connection between elementType.HasType and the shape of the spread IteratorBody straightforward to follow. It seems like there are several places where we dig into the IteratorBody and expect to get an expression which meets certain assumptions. Perhaps it would be helpful to add some extension(s) to BoundCollectionExpressionSpreadElement which implies and verifies the assumption the developer is making.

Also, it feels like using a name like targetElementType instead of elementType would make this code a little easier to follow, to be clear we are not talking about the element type of the spread operand itself.

That said, none of the above changes need to happen in this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I've renamed elementType to targetElementType. We can consider adding methods to BoundCollectionExpressionSpreadElement as needed in separate PRs.

useLegacyWarnings: false, trackMembers: false, AssignmentKind.Assignment, delayCompletionForTargetType: true).completion;
Debug.Assert(completion is not null);
elementConversionCompletions.Add(completion);
RemovePlaceholderReplacement(elementPlaceholder);
}
break;
default:
var elementExpr = (BoundExpression)element;
Expand Down Expand Up @@ -3767,6 +3779,34 @@ static NullableFlowState getResultState(BoundCollectionExpression node, Collecti
}
}

public override BoundNode? VisitCollectionExpressionSpreadElement(BoundCollectionExpressionSpreadElement node)
{
VisitRvalue(node.Expression);
Copy link
Contributor

@RikkiGibson RikkiGibson Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a case like int[] x = [..(int[]?)null] is handled by the VisitForEachExpression itself? Since we would be visiting a loop like foreach (var elem in (int[]?)null) { ... }? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the top-level nullability of the expression being spread is handled in VisitForEachExpression(). For an example test, see CollectionExpressionTests.SpreadNullability.


if (node.Conversion is BoundConversion { Conversion: var conversion })
{
Debug.Assert(node.ExpressionPlaceholder is { });
Debug.Assert(node.EnumeratorInfoOpt is { });
AddPlaceholderReplacement(node.ExpressionPlaceholder, node.Expression, _visitResult);
VisitForEachExpression(
node,
node.Conversion,
conversion,
node.ExpressionPlaceholder,
node.EnumeratorInfoOpt,
awaitOpt: null);
RikkiGibson marked this conversation as resolved.
Show resolved Hide resolved
RemovePlaceholderReplacement(node.ExpressionPlaceholder);
}
else
{
Debug.Assert(node.HasErrors);
Debug.Assert(node.Conversion is null);
Debug.Assert(node.EnumeratorInfoOpt is null);
}

return null;
}

private void VisitObjectCreationExpressionBase(BoundObjectCreationExpressionBase node)
{
Debug.Assert(!IsConditionalState);
Expand Down Expand Up @@ -10612,6 +10652,18 @@ private TypeWithAnnotations GetDeclaredParameterResult(ParameterSymbol parameter
return null;
}

public override BoundNode? VisitCollectionExpressionSpreadExpressionPlaceholder(BoundCollectionExpressionSpreadExpressionPlaceholder node)
{
VisitPlaceholderWithReplacement(node);
return null;
}

public override BoundNode? VisitValuePlaceholder(BoundValuePlaceholder node)
{
VisitPlaceholderWithReplacement(node);
return null;
}

public override BoundNode? VisitEventAccess(BoundEventAccess node)
{
var updatedSymbol = VisitMemberAccess(node, node.ReceiverOpt, node.EventSymbol);
Expand Down Expand Up @@ -10711,6 +10763,23 @@ protected override void VisitForEachExpression(BoundForEachStatement node)
var (expr, conversion) = RemoveConversion(node.Expression, includeExplicitConversions: false);
SnapshotWalkerThroughConversionGroup(node.Expression, expr);

VisitForEachExpression(
node,
node.Expression,
conversion,
expr,
node.EnumeratorInfoOpt,
node.AwaitOpt);
}

private void VisitForEachExpression(
BoundNode node,
BoundExpression collectionExpression,
Conversion conversion,
BoundExpression expr,
ForEachEnumeratorInfo? enumeratorInfoOpt,
BoundAwaitableInfo? awaitOpt)
{
// There are 7 ways that a foreach can be created:
// 1. The collection type is an array type. For this, initial binding will generate an implicit reference conversion to
// IEnumerable, and we do not need to do any reinferring of enumerators here.
Expand Down Expand Up @@ -10744,7 +10813,7 @@ protected override void VisitForEachExpression(BoundForEachStatement node)

MethodSymbol? reinferredGetEnumeratorMethod = null;

if (node.EnumeratorInfoOpt?.GetEnumeratorInfo is { Method: { IsExtensionMethod: true, Parameters: var parameters } } enumeratorMethodInfo)
if (enumeratorInfoOpt?.GetEnumeratorInfo is { Method: { IsExtensionMethod: true, Parameters: var parameters } } enumeratorMethodInfo)
{
// this is case 7
// We do not need to do this same analysis for non-extension methods because they do not have generic parameters that
Expand Down Expand Up @@ -10772,13 +10841,13 @@ protected override void VisitForEachExpression(BoundForEachStatement node)
}
else if (conversion.IsImplicit)
{
bool isAsync = node.AwaitOpt != null;
if (node.Expression.Type!.SpecialType == SpecialType.System_Collections_IEnumerable)
bool isAsync = awaitOpt != null;
if (collectionExpression.Type!.SpecialType == SpecialType.System_Collections_IEnumerable)
{
// If this is a conversion to IEnumerable (non-generic), nothing to do. This is cases 1, 2, and 5.
targetTypeWithAnnotations = TypeWithAnnotations.Create(node.Expression.Type);
targetTypeWithAnnotations = TypeWithAnnotations.Create(collectionExpression.Type);
}
else if (ForEachLoopBinder.IsIEnumerableT(node.Expression.Type.OriginalDefinition, isAsync, compilation))
else if (ForEachLoopBinder.IsIEnumerableT(collectionExpression.Type.OriginalDefinition, isAsync, compilation))
{
// This is case 4. We need to look for the IEnumerable<T> that this reinferred expression implements,
// so that we pick up any nested type substitutions that could have occurred.
Expand All @@ -10801,7 +10870,7 @@ protected override void VisitForEachExpression(BoundForEachStatement node)
}

var convertedResult = VisitConversion(
GetConversionIfApplicable(node.Expression, expr),
GetConversionIfApplicable(collectionExpression, expr),
expr,
conversion,
targetTypeWithAnnotations,
Expand All @@ -10811,15 +10880,15 @@ protected override void VisitForEachExpression(BoundForEachStatement node)
useLegacyWarnings: false,
AssignmentKind.Assignment);

bool reportedDiagnostic = node.EnumeratorInfoOpt?.GetEnumeratorInfo.Method is { IsExtensionMethod: true }
bool reportedDiagnostic = enumeratorInfoOpt?.GetEnumeratorInfo.Method is { IsExtensionMethod: true }
? false
: CheckPossibleNullReceiver(expr);

SetAnalyzedNullability(node.Expression, new VisitResult(convertedResult, convertedResult.ToTypeWithAnnotations(compilation)));
SetAnalyzedNullability(collectionExpression, new VisitResult(convertedResult, convertedResult.ToTypeWithAnnotations(compilation)));

TypeWithState currentPropertyGetterTypeWithState;

if (node.EnumeratorInfoOpt is null)
if (enumeratorInfoOpt is null)
{
currentPropertyGetterTypeWithState = default;
}
Expand All @@ -10834,17 +10903,17 @@ protected override void VisitForEachExpression(BoundForEachStatement node)
// There are frameworks where System.String does not implement IEnumerable, but we still lower it to a for loop
// using the indexer over the individual characters anyway. So the type must be not annotated char.
currentPropertyGetterTypeWithState =
TypeWithAnnotations.Create(node.EnumeratorInfoOpt.ElementType, NullableAnnotation.NotAnnotated).ToTypeWithState();
TypeWithAnnotations.Create(enumeratorInfoOpt.ElementType, NullableAnnotation.NotAnnotated).ToTypeWithState();
}
else
{
// Reinfer the return type of the node.Expression.GetEnumerator().Current property, so that if
// Reinfer the return type of the collectionExpression.GetEnumerator().Current property, so that if
// the collection changed nested generic types we pick up those changes.
if (reinferredGetEnumeratorMethod is null)
{
TypeSymbol? getEnumeratorType;

if (node.EnumeratorInfoOpt is { InlineArraySpanType: not WellKnownType.Unknown and var wellKnownSpan })
if (enumeratorInfoOpt is { InlineArraySpanType: not WellKnownType.Unknown and var wellKnownSpan })
{
Debug.Assert(wellKnownSpan is WellKnownType.System_Span_T or WellKnownType.System_ReadOnlySpan_T);
NamedTypeSymbol spanType = compilation.GetWellKnownType(wellKnownSpan);
Expand All @@ -10855,29 +10924,29 @@ protected override void VisitForEachExpression(BoundForEachStatement node)
getEnumeratorType = convertedResult.Type;
}

reinferredGetEnumeratorMethod = (MethodSymbol)AsMemberOfType(getEnumeratorType, node.EnumeratorInfoOpt.GetEnumeratorInfo.Method);
reinferredGetEnumeratorMethod = (MethodSymbol)AsMemberOfType(getEnumeratorType, enumeratorInfoOpt.GetEnumeratorInfo.Method);
}

var enumeratorReturnType = GetReturnTypeWithState(reinferredGetEnumeratorMethod);

if (enumeratorReturnType.State != NullableFlowState.NotNull)
{
if (!reportedDiagnostic && !(node.Expression is BoundConversion { Operand: { IsSuppressed: true } }))
if (!reportedDiagnostic && !(collectionExpression is BoundConversion { Operand: { IsSuppressed: true } }))
{
ReportDiagnostic(ErrorCode.WRN_NullReferenceReceiver, expr.Syntax.GetLocation());
}
}

var currentPropertyGetter = (MethodSymbol)AsMemberOfType(enumeratorReturnType.Type, node.EnumeratorInfoOpt.CurrentPropertyGetter);
var currentPropertyGetter = (MethodSymbol)AsMemberOfType(enumeratorReturnType.Type, enumeratorInfoOpt.CurrentPropertyGetter);

currentPropertyGetterTypeWithState = ApplyUnconditionalAnnotations(
currentPropertyGetter.ReturnTypeWithAnnotations.ToTypeWithState(),
currentPropertyGetter.ReturnTypeFlowAnalysisAnnotations);

// Analyze `await MoveNextAsync()`
if (node.AwaitOpt is { AwaitableInstancePlaceholder: BoundAwaitableValuePlaceholder moveNextPlaceholder } awaitMoveNextInfo)
if (awaitOpt is { AwaitableInstancePlaceholder: BoundAwaitableValuePlaceholder moveNextPlaceholder } awaitMoveNextInfo)
{
var moveNextAsyncMethod = (MethodSymbol)AsMemberOfType(reinferredGetEnumeratorMethod.ReturnType, node.EnumeratorInfoOpt.MoveNextInfo.Method);
var moveNextAsyncMethod = (MethodSymbol)AsMemberOfType(reinferredGetEnumeratorMethod.ReturnType, enumeratorInfoOpt.MoveNextInfo.Method);

var result = new VisitResult(GetReturnTypeWithState(moveNextAsyncMethod), moveNextAsyncMethod.ReturnTypeWithAnnotations);
AddPlaceholderReplacement(moveNextPlaceholder, moveNextPlaceholder, result);
Expand All @@ -10886,11 +10955,11 @@ protected override void VisitForEachExpression(BoundForEachStatement node)
}

// Analyze `await DisposeAsync()`
if (node.EnumeratorInfoOpt is { NeedsDisposal: true, DisposeAwaitableInfo: BoundAwaitableInfo awaitDisposalInfo })
if (enumeratorInfoOpt is { NeedsDisposal: true, DisposeAwaitableInfo: BoundAwaitableInfo awaitDisposalInfo })
{
var disposalPlaceholder = awaitDisposalInfo.AwaitableInstancePlaceholder;
bool addedPlaceholder = false;
if (node.EnumeratorInfoOpt.PatternDisposeInfo is { Method: var originalDisposeMethod }) // no statically known Dispose method if doing a runtime check
if (enumeratorInfoOpt.PatternDisposeInfo is { Method: var originalDisposeMethod }) // no statically known Dispose method if doing a runtime check
{
Debug.Assert(disposalPlaceholder is not null);
var disposeAsyncMethod = (MethodSymbol)AsMemberOfType(reinferredGetEnumeratorMethod.ReturnType, originalDisposeMethod);
Expand Down
Loading
Loading