diff --git a/src/Test.Utilities/CodeFixTestBase.cs b/src/Test.Utilities/CodeFixTestBase.cs index 8dcebd1eb2..cc9a3ed5ff 100644 --- a/src/Test.Utilities/CodeFixTestBase.cs +++ b/src/Test.Utilities/CodeFixTestBase.cs @@ -24,49 +24,52 @@ public abstract class CodeFixTestBase : DiagnosticAnalyzerTestBase protected abstract CodeFixProvider GetBasicCodeFixProvider(); - protected void VerifyCSharpUnsafeCodeFix(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false, bool onlyFixFirstFixableDiagnostic = false) + protected void VerifyCSharpUnsafeCodeFix(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false, bool onlyFixFirstFixableDiagnostic = false, bool testFixAll = true) { - VerifyFix(LanguageNames.CSharp, GetCSharpDiagnosticAnalyzer(), GetCSharpCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, DefaultTestValidationMode, true); + VerifyFix(LanguageNames.CSharp, GetCSharpDiagnosticAnalyzer(), GetCSharpCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, validationMode: DefaultTestValidationMode, testFixAll: testFixAll, allowUnsafeCode: true); } - protected void VerifyCSharpFix(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false, bool onlyFixFirstFixableDiagnostic = false, TestValidationMode validationMode = DefaultTestValidationMode) + protected void VerifyCSharpFix(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false, bool onlyFixFirstFixableDiagnostic = false, TestValidationMode validationMode = DefaultTestValidationMode, bool testFixAll = true) { - VerifyFix(LanguageNames.CSharp, GetCSharpDiagnosticAnalyzer(), GetCSharpCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, validationMode, false); + VerifyFix(LanguageNames.CSharp, GetCSharpDiagnosticAnalyzer(), GetCSharpCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, validationMode, testFixAll, allowUnsafeCode: false); } - protected void VerifyCSharpFixAll(string oldSource, string newSource, bool allowNewCompilerDiagnostics = false, TestValidationMode validationMode = DefaultTestValidationMode) + protected void VerifyCSharpFixAll(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false, TestValidationMode validationMode = DefaultTestValidationMode) { - VerifyFixAll(LanguageNames.CSharp, GetCSharpDiagnosticAnalyzer(), GetCSharpCodeFixProvider(), oldSource, newSource, allowNewCompilerDiagnostics, validationMode, false); + VerifyFixAll(LanguageNames.CSharp, GetCSharpDiagnosticAnalyzer(), GetCSharpCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics, validationMode, allowUnsafeCode: false); } - protected void VerifyBasicFix(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false, bool onlyFixFirstFixableDiagnostic = false, TestValidationMode validationMode = DefaultTestValidationMode) + protected void VerifyBasicFix(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false, bool onlyFixFirstFixableDiagnostic = false, bool testFixAll = true, TestValidationMode validationMode = DefaultTestValidationMode) { - VerifyFix(LanguageNames.VisualBasic, GetBasicDiagnosticAnalyzer(), GetBasicCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, validationMode, false); + VerifyFix(LanguageNames.VisualBasic, GetBasicDiagnosticAnalyzer(), GetBasicCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, validationMode, testFixAll, allowUnsafeCode: false); } - protected void VerifyBasicFixAll(string oldSource, string newSource, bool allowNewCompilerDiagnostics = false, TestValidationMode validationMode = DefaultTestValidationMode) + protected void VerifyBasicFixAll(string oldSource, string newSource, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false, TestValidationMode validationMode = DefaultTestValidationMode) { - VerifyFixAll(LanguageNames.VisualBasic, GetBasicDiagnosticAnalyzer(), GetBasicCodeFixProvider(), oldSource, newSource, allowNewCompilerDiagnostics, validationMode, false); + VerifyFixAll(LanguageNames.VisualBasic, GetBasicDiagnosticAnalyzer(), GetBasicCodeFixProvider(), oldSource, newSource, codeFixIndex, allowNewCompilerDiagnostics, validationMode, allowUnsafeCode: false); } - private void VerifyFix(string language, DiagnosticAnalyzer analyzerOpt, CodeFixProvider codeFixProvider, string oldSource, string newSource, int? codeFixIndex, bool allowNewCompilerDiagnostics, bool onlyFixFirstFixableDiagnostic, TestValidationMode validationMode, bool allowUnsafeCode) + private void VerifyFix(string language, DiagnosticAnalyzer analyzerOpt, CodeFixProvider codeFixProvider, string oldSource, string newSource, int? codeFixIndex, bool allowNewCompilerDiagnostics, bool onlyFixFirstFixableDiagnostic, TestValidationMode validationMode, bool testFixAll, bool allowUnsafeCode) { Document document = CreateDocument(oldSource, language, allowUnsafeCode: allowUnsafeCode); var newSourceFileName = document.Name; - VerifyFix(document, analyzerOpt, codeFixProvider, newSource, newSourceFileName, ImmutableArray.Empty, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, validationMode); + VerifyFix(document, analyzerOpt, codeFixProvider, newSource, newSourceFileName, ImmutableArray.Empty, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, validationMode, testFixAll); } - private void VerifyFixAll(string language, DiagnosticAnalyzer analyzerOpt, CodeFixProvider codeFixProvider, string oldSource, string newSource, bool allowNewCompilerDiagnostics, TestValidationMode validationMode, bool allowUnsafeCode) + private void VerifyFixAll(string language, DiagnosticAnalyzer analyzerOpt, CodeFixProvider codeFixProvider, string oldSource, string newSource, int? codeFixIndex, bool allowNewCompilerDiagnostics, TestValidationMode validationMode, bool allowUnsafeCode) { Document document = CreateDocument(oldSource, language, allowUnsafeCode: allowUnsafeCode); var newSourceFileName = document.Name; - - VerifyFixAll(document, analyzerOpt, codeFixProvider, newSource, newSourceFileName, ImmutableArray.Empty, allowNewCompilerDiagnostics, validationMode); + var additionalFiles = ImmutableArray.Empty; + + VerifyFixOrFixAllCore(document, analyzerOpt, codeFixProvider, newSource, newSourceFileName, additionalFiles, + codeFixIndex, allowNewCompilerDiagnostics, validationMode, testFixAll: true, applySingleFixOrFixAll: true); } protected void VerifyAdditionalFileFix(string language, DiagnosticAnalyzer analyzerOpt, CodeFixProvider codeFixProvider, string source, - IEnumerable additionalFiles, TestAdditionalDocument newAdditionalFileToVerify, int? codeFixIndex = null, bool allowNewCompilerDiagnostics = false, bool onlyFixFirstFixableDiagnostic = false) + IEnumerable additionalFiles, TestAdditionalDocument newAdditionalFileToVerify, int? codeFixIndex = null, + bool allowNewCompilerDiagnostics = false, bool onlyFixFirstFixableDiagnostic = false, bool testFixAll = true) { Document document = CreateDocument(source, language); if (additionalFiles != null) @@ -83,7 +86,7 @@ protected void VerifyAdditionalFileFix(string language, DiagnosticAnalyzer analy var additionalFileName = newAdditionalFileToVerify.Name; var additionalFileText = newAdditionalFileToVerify.GetText().ToString(); - VerifyFix(document, analyzerOpt, codeFixProvider, additionalFileText, additionalFileName, additionalFiles, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, DefaultTestValidationMode); + VerifyFix(document, analyzerOpt, codeFixProvider, additionalFileText, additionalFileName, additionalFiles, codeFixIndex, allowNewCompilerDiagnostics, onlyFixFirstFixableDiagnostic, DefaultTestValidationMode, testFixAll); } private void VerifyFix( @@ -96,7 +99,33 @@ private void VerifyFix( int? codeFixIndex, bool allowNewCompilerDiagnostics, bool onlyFixFirstFixableDiagnostic, - TestValidationMode validationMode) + TestValidationMode validationMode, + bool testFixAll) + { + // Verify code fix. + VerifyFixOrFixAllCore(document, analyzerOpt, codeFixProvider, newSource, newSourceFileName, additionalFiles, + codeFixIndex, allowNewCompilerDiagnostics, validationMode, testFixAll: false, applySingleFixOrFixAll: onlyFixFirstFixableDiagnostic); + + // Also verify FixAll. + if (testFixAll && codeFixProvider.GetFixAllProvider() != null) + { + VerifyFixOrFixAllCore(document, analyzerOpt, codeFixProvider, newSource, newSourceFileName, additionalFiles, + codeFixIndex, allowNewCompilerDiagnostics, validationMode, testFixAll: true, applySingleFixOrFixAll: true); + } + } + + private void VerifyFixOrFixAllCore( + Document document, + DiagnosticAnalyzer analyzerOpt, + CodeFixProvider codeFixProvider, + string newSource, + string newSourceFileName, + IEnumerable additionalFiles, + int? codeFixIndex, + bool allowNewCompilerDiagnostics, + TestValidationMode validationMode, + bool testFixAll, + bool applySingleFixOrFixAll) { var fixableDiagnosticIds = codeFixProvider.FixableDiagnosticIds.ToSet(); Func, ImmutableArray> getFixableDiagnostics = diags => @@ -110,8 +139,9 @@ private void VerifyFix( while (diagnosticIndexToFix < fixableDiagnostics.Length) { var actions = new List(); + Diagnostic triggerDiagnostic = fixableDiagnostics[diagnosticIndexToFix]; - var context = new CodeFixContext(document, fixableDiagnostics[diagnosticIndexToFix], (a, d) => actions.Add(a), CancellationToken.None); + var context = new CodeFixContext(document, triggerDiagnostic, (a, d) => actions.Add(a), CancellationToken.None); codeFixProvider.RegisterCodeFixesAsync(context).Wait(); if (!actions.Any()) { @@ -128,14 +158,25 @@ private void VerifyFix( throw new Exception($"Unable to invoke code fix at index '{codeFixIndex.Value}', only '{actions.Count}' code fixes were registered."); } - document = document.Apply(actions.ElementAt(codeFixIndex.Value)); - additionalFiles = document.Project.AdditionalDocuments.Select(a => new TestAdditionalDocument(a)); + var codeAction = actions.ElementAt(codeFixIndex.Value); + + if (!testFixAll) + { + document = document.Apply(codeAction); + } + else + { + string diagnosticIdToFix = triggerDiagnostic.Id; + var diagnosticsToFix = fixableDiagnostics.Where(d => d.Id == diagnosticIdToFix); + document = ApplyFixAll(document, codeAction, codeFixProvider, codeFixProvider.GetFixAllProvider(), diagnosticIdToFix, diagnosticsToFix); + } - if (onlyFixFirstFixableDiagnostic) + if (applySingleFixOrFixAll) { break; } + additionalFiles = document.Project.AdditionalDocuments.Select(a => new TestAdditionalDocument(a)); analyzerDiagnostics = GetSortedDiagnostics(analyzerOpt, new[] { document }, additionalFiles: additionalFiles, validationMode: validationMode); var updatedCompilerDiagnostics = document.GetSemanticModelAsync().Result.GetDiagnostics(); @@ -167,51 +208,19 @@ private void VerifyFix( Assert.Equal(newSource, actualText.ToString()); } - private void VerifyFixAll( + private Document ApplyFixAll( Document document, - DiagnosticAnalyzer analyzerOpt, + CodeAction codeAction, CodeFixProvider codeFixProvider, - string newSource, - string newSourceFileName, - IEnumerable additionalFiles, - bool allowNewCompilerDiagnostics, - TestValidationMode validationMode) + FixAllProvider fixAllProvider, + string diagnosticIdToFix, + IEnumerable fixableDiagnostics) { - var fixableDiagnosticIds = codeFixProvider.FixableDiagnosticIds.ToSet(); - Func, ImmutableArray> getFixableDiagnostics = diags => - diags.Where(d => fixableDiagnosticIds.Contains(d.Id)).ToImmutableArrayOrEmpty(); - - var analyzerDiagnostics = GetSortedDiagnostics(analyzerOpt, new[] { document }, additionalFiles: additionalFiles, validationMode: validationMode); - var compilerDiagnostics = document.GetSemanticModelAsync().Result.GetDiagnostics(); - var fixableDiagnostics = getFixableDiagnostics(analyzerDiagnostics.Concat(compilerDiagnostics)); - - var fixAllProvider = codeFixProvider.GetFixAllProvider(); - var diagnosticProvider = new FixAllDiagnosticProvider(analyzerOpt, additionalFiles, validationMode, getFixableDiagnostics); - var fixAllContext = new FixAllContext(document, codeFixProvider, FixAllScope.Document, string.Empty, fixableDiagnostics.Select(d => d.Id), diagnosticProvider, CancellationToken.None); - var codeAction = fixAllProvider.GetFixAsync(fixAllContext).Result; - document = document.Apply(codeAction); - additionalFiles = document.Project.AdditionalDocuments.Select(a => new TestAdditionalDocument(a)); - - additionalFiles = document.Project.AdditionalDocuments.Select(a => new TestAdditionalDocument(a)); - - analyzerDiagnostics = GetSortedDiagnostics(analyzerOpt, new[] { document }, additionalFiles: additionalFiles, validationMode: validationMode); - - var updatedCompilerDiagnostics = document.GetSemanticModelAsync().Result.GetDiagnostics(); - var newCompilerDiagnostics = GetNewDiagnostics(compilerDiagnostics, updatedCompilerDiagnostics); - if (!allowNewCompilerDiagnostics && newCompilerDiagnostics.Any()) - { - // Format and get the compiler diagnostics again so that the locations make sense in the output - document = document.WithSyntaxRoot(Formatter.Format(document.GetSyntaxRootAsync().Result, Formatter.Annotation, document.Project.Solution.Workspace)); - newCompilerDiagnostics = GetNewDiagnostics(compilerDiagnostics, document.GetSemanticModelAsync().Result.GetDiagnostics()); - - Assert.True(false, - string.Format("Fix introduced new compiler diagnostics:\r\n{0}\r\n\r\nNew document:\r\n{1}\r\n", - newCompilerDiagnostics.Select(d => d.ToString()).Join("\r\n"), - document.GetSyntaxRootAsync().Result.ToFullString())); - } - - var actualText = GetActualTextForNewDocument(document, newSourceFileName); - Assert.Equal(newSource, actualText.ToString()); + var diagnosticProvider = new FixAllDiagnosticProvider(fixableDiagnostics); + IEnumerable fixableDiagnosticIds = new string[] { diagnosticIdToFix }; + var fixAllContext = new FixAllContext(document, codeFixProvider, FixAllScope.Document, codeAction.EquivalenceKey, fixableDiagnosticIds, diagnosticProvider, CancellationToken.None); + var fixAllCodeAction = fixAllProvider.GetFixAsync(fixAllContext).Result; + return fixAllCodeAction != null ? document.Apply(fixAllCodeAction) : document; } private sealed class DiagnosticComparer : IEqualityComparer @@ -240,37 +249,21 @@ public int GetHashCode(Diagnostic obj) private class FixAllDiagnosticProvider : FixAllContext.DiagnosticProvider { - private DiagnosticAnalyzer _analyzerOpt; - private IEnumerable _additionalFiles; - private TestValidationMode _testValidationMode; - private Func, ImmutableArray> _getFixableDiagnostics; - - public FixAllDiagnosticProvider( - DiagnosticAnalyzer analyzerOpt, - IEnumerable additionalFiles, - TestValidationMode testValidationMode, - Func, ImmutableArray> getFixableDiagnostics) - { - _analyzerOpt = analyzerOpt; - _additionalFiles = additionalFiles; - _testValidationMode = testValidationMode; - _getFixableDiagnostics = getFixableDiagnostics; - } + private IEnumerable _fixableDiagnostics; - public override async Task> GetDocumentDiagnosticsAsync(Document document, CancellationToken cancellationToken) + public FixAllDiagnosticProvider(IEnumerable fixableDiagnostics) { - var analyzerDiagnostics = GetSortedDiagnostics(_analyzerOpt, new[] { document }, additionalFiles: _additionalFiles, validationMode: _testValidationMode); - var semanticModel = await document.GetSemanticModelAsync().ConfigureAwait(false); - var compilerDiagnostics = semanticModel.GetDiagnostics(); - var fixableDiagnostics = _getFixableDiagnostics(analyzerDiagnostics.Concat(compilerDiagnostics)); - return fixableDiagnostics; + _fixableDiagnostics = fixableDiagnostics; } + public override Task> GetDocumentDiagnosticsAsync(Document document, CancellationToken cancellationToken) + => Task.FromResult(_fixableDiagnostics); + public override Task> GetAllDiagnosticsAsync(Project project, CancellationToken cancellationToken) - => throw new NotImplementedException(); + => Task.FromResult(_fixableDiagnostics); public override Task> GetProjectDiagnosticsAsync(Project project, CancellationToken cancellationToken) - => throw new NotImplementedException(); + => Task.FromResult(_fixableDiagnostics); } private static SourceText GetActualTextForNewDocument(Document documentInNewWorkspace, string newSourceFileName)