Skip to content

Commit

Permalink
Support top-level statements in regex analyzer (#72046)
Browse files Browse the repository at this point in the history
* Support top-level statements in regex analyzer

* Apply test suggestion

* Address feedback
  • Loading branch information
Youssef1313 authored Jul 13, 2022
1 parent a5f3676 commit 750157d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,6 @@ public override void Initialize(AnalysisContext context)
return;
}

// Validate that the project is not using top-level statements, since if it were, the code-fixer
// can't easily convert to the source generator without having to make the program not use top-level
// statements any longer.
if (ProjectUsesTopLevelStatements(compilation))
{
return;
}

// Pre-compute a hash with all of the method symbols that we want to analyze for possibly emitting
// a diagnostic.
HashSet<IMethodSymbol> staticMethodsToDetect = GetMethodSymbolHash(regexTypeSymbol,
Expand Down Expand Up @@ -250,15 +242,6 @@ private static bool TryValidateParametersAndExtractArgumentIndices(ImmutableArra
private static bool IsConstant(IArgumentOperation argument)
=> argument.Value.ConstantValue.HasValue;

/// <summary>
/// Detects whether or not the current project is using top-level statements.
/// </summary>
private static bool ProjectUsesTopLevelStatements(Compilation compilation)
{
INamedTypeSymbol? programType = compilation.GetTypeByMetadataName("Program");
return programType is not null && !programType.GetMembers("<Main>$").IsEmpty;
}

/// <summary>
/// Ensures that the compilation can find the Regex and RegexAttribute types, and also validates that the
/// LangVersion of the project is >= 10.0 (which is the current requirement for the Regex source generator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,20 @@ private static async Task<Document> ConvertToSourceGenerator(Document document,
}

// Get the parent type declaration so that we can inspect its methods as well as check if we need to add the partial keyword.
TypeDeclarationSyntax? typeDeclaration = nodeToFix.Ancestors().OfType<TypeDeclarationSyntax>().FirstOrDefault();
SyntaxNode? typeDeclarationOrCompilationUnit = nodeToFix.Ancestors().OfType<TypeDeclarationSyntax>().FirstOrDefault();

if (typeDeclaration is null)
if (typeDeclarationOrCompilationUnit is null)
{
return document;
typeDeclarationOrCompilationUnit = await nodeToFix.SyntaxTree.GetRootAsync(cancellationToken).ConfigureAwait(false);
}

// Calculate what name should be used for the generated static partial method
string methodName = DefaultRegexMethodName;
ITypeSymbol? typeSymbol = semanticModel.GetDeclaredSymbol(typeDeclaration, cancellationToken) as ITypeSymbol;

INamedTypeSymbol? typeSymbol = typeDeclarationOrCompilationUnit is TypeDeclarationSyntax typeDeclaration ?
semanticModel.GetDeclaredSymbol(typeDeclaration, cancellationToken) :
semanticModel.GetDeclaredSymbol((CompilationUnitSyntax)typeDeclarationOrCompilationUnit, cancellationToken)?.ContainingType;

if (typeSymbol is not null)
{
IEnumerable<ISymbol> members = GetAllMembers(typeSymbol);
Expand Down Expand Up @@ -147,9 +151,12 @@ private static async Task<Document> ConvertToSourceGenerator(Document document,
}

// We need to find the typeDeclaration again, but now using the new root.
typeDeclaration = nodeToFix.Ancestors().OfType<TypeDeclarationSyntax>().FirstOrDefault();
Debug.Assert(typeDeclaration is not null);
TypeDeclarationSyntax newTypeDeclaration = typeDeclaration;
typeDeclarationOrCompilationUnit = typeDeclarationOrCompilationUnit is TypeDeclarationSyntax ?
nodeToFix.Ancestors().OfType<TypeDeclarationSyntax>().FirstOrDefault() :
await nodeToFix.SyntaxTree.GetRootAsync(cancellationToken).ConfigureAwait(false);

Debug.Assert(typeDeclarationOrCompilationUnit is not null);
SyntaxNode newTypeDeclarationOrCompilationUnit = typeDeclarationOrCompilationUnit;

// We generate a new invocation node to call our new partial method, and use it to replace the nodeToFix.
DocumentEditor editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -177,12 +184,12 @@ private static async Task<Document> ConvertToSourceGenerator(Document document,
SyntaxNode createRegexMethod = generator.InvocationExpression(generator.IdentifierName(methodName));
SyntaxNode method = generator.InvocationExpression(generator.MemberAccessExpression(createRegexMethod, invocationOperation.TargetMethod.Name), arguments.Select(arg => arg.Syntax).ToArray());

newTypeDeclaration = newTypeDeclaration.ReplaceNode(nodeToFix, method);
newTypeDeclarationOrCompilationUnit = newTypeDeclarationOrCompilationUnit.ReplaceNode(nodeToFix, method);
}
else // When using a Regex constructor
{
SyntaxNode invokeMethod = generator.InvocationExpression(generator.IdentifierName(methodName));
newTypeDeclaration = newTypeDeclaration.ReplaceNode(nodeToFix, invokeMethod);
newTypeDeclarationOrCompilationUnit = newTypeDeclarationOrCompilationUnit.ReplaceNode(nodeToFix, invokeMethod);
}

// Initialize the inputs for the RegexGenerator attribute.
Expand Down Expand Up @@ -223,10 +230,12 @@ private static async Task<Document> ConvertToSourceGenerator(Document document,
newMethod = (MethodDeclarationSyntax)generator.AddAttributes(newMethod, attributes);

// Add the method to the type.
newTypeDeclaration = newTypeDeclaration.AddMembers(newMethod);
newTypeDeclarationOrCompilationUnit = newTypeDeclarationOrCompilationUnit is TypeDeclarationSyntax newTypeDeclaration ?
newTypeDeclaration.AddMembers(newMethod) :
((CompilationUnitSyntax)newTypeDeclarationOrCompilationUnit).AddMembers((ClassDeclarationSyntax)generator.ClassDeclaration("Program", modifiers: DeclarationModifiers.Partial, members: new[] { newMethod }));

// Replace the old type declaration with the new modified one, and return the document.
return document.WithSyntaxRoot(root.ReplaceNode(typeDeclaration, newTypeDeclaration));
return document.WithSyntaxRoot(root.ReplaceNode(typeDeclarationOrCompilationUnit, newTypeDeclarationOrCompilationUnit));

static IEnumerable<ISymbol> GetAllMembers(ITypeSymbol? symbol)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,22 @@ public async Task NoDiagnosticForConstructorWithTimeout(string test)
await VerifyCS.VerifyAnalyzerAsync(test);
}

[Fact]
public async Task NoDiagnosticForTopLevelStatements()
[Theory]
[MemberData(nameof(InvocationTypes))]
public async Task TopLevelStatements(InvocationType invocationType)
{
string isMatchInvocation = invocationType == InvocationType.Constructor ? @".IsMatch("""")" : string.Empty;
string test = @"using System.Text.RegularExpressions;
var isMatch = [|" + ConstructRegexInvocation(invocationType, pattern: "\"\"") + @"|]" + isMatchInvocation + ";";
string fixedCode = @"using System.Text.RegularExpressions;
var isMatch = MyRegex().IsMatch("""");
Regex r = new Regex("""");";

await VerifyCS.VerifyAnalyzerAsync(test);
partial class Program
{
[RegexGenerator("""")]
private static partial Regex MyRegex();
}";
await VerifyCS.VerifyCodeFixAsync(test, fixedCode);
}

public static IEnumerable<object[]> StaticInvocationWithTimeoutTestData()
Expand Down Expand Up @@ -737,17 +745,26 @@ static void Main(string[] args)
}

[Fact]
public async Task NoDiagnosticForTopLevelStatements_MultipleSourceFiles()
public async Task TopLevelStatements_MultipleSourceFiles()
{
await new VerifyCS.Test(references: null, usePreviewLanguageVersion: true, numberOfIterations: 1)
{
TestState =
{
Sources = { "public class C { }", @"var r = new System.Text.RegularExpressions.Regex("""");" },
Sources = { "public class C { }", @"var r = [|new System.Text.RegularExpressions.Regex("""")|];" },
},
FixedState =
{
Sources = { "public class C { }", @"var r = MyRegex();
partial class Program
{
[System.Text.RegularExpressions.RegexGenerator("""")]
private static partial System.Text.RegularExpressions.Regex MyRegex();
}" }
}
}.RunAsync();
}

#region Test helpers

private static string ConstructRegexInvocation(InvocationType invocationType, string pattern, string? options = null)
Expand Down

0 comments on commit 750157d

Please sign in to comment.