diff --git a/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs b/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs index dd33f35e..b8ea67be 100644 --- a/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs +++ b/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs @@ -13,6 +13,7 @@ using Microsoft.CodeAnalysis.CodeActions; using Microsoft.CodeAnalysis.CodeFixes; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Editing; using Microsoft.CodeAnalysis.Text; using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; @@ -63,7 +64,7 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) context.RegisterCodeFix( CodeAction.Create( title: "Inherit from ObservableObject", - createChangedDocument: token => UpdateReference(context.Document, classDeclaration, attributeTypeName, token), + createChangedDocument: token => UpdateReference(context.Document, root, classDeclaration, attributeTypeName), equivalenceKey: "Inherit from ObservableObject"), diagnostic); @@ -76,21 +77,16 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) /// Applies the code fix to a target class declaration and returns an updated document. /// /// The original document being fixed. + /// The original tree root belonging to the current document. /// The to update. /// The name of the attribute that should be removed. - /// The cancellation token for the operation. /// An updated document with the applied code fix, and inheriting from ObservableObject. - private static async Task UpdateReference(Document document, ClassDeclarationSyntax classDeclaration, string attributeTypeName, CancellationToken cancellationToken) + private static Task UpdateReference(Document document, SyntaxNode root, ClassDeclarationSyntax classDeclaration, string attributeTypeName) { // Insert ObservableObject always in first position in the base list. The type might have // some interfaces in the base list, so we just copy them back after ObservableObject. - ClassDeclarationSyntax updatedClassDeclaration = - classDeclaration.WithBaseList(BaseList(SingletonSeparatedList( - (BaseTypeSyntax)SimpleBaseType(IdentifierName("ObservableObject")))) - .AddTypes(classDeclaration.BaseList?.Types.ToArray() ?? Array.Empty())); - - AttributeListSyntax? targetAttributeList = null; - AttributeSyntax? targetAttribute = null; + SyntaxGenerator generator = SyntaxGenerator.GetGenerator(document); + ClassDeclarationSyntax updatedClassDeclaration = (ClassDeclarationSyntax)generator.AddBaseType(classDeclaration, IdentifierName("ObservableObject")); // Find the attribute list and attribute to remove foreach (AttributeListSyntax attributeList in updatedClassDeclaration.AttributeLists) @@ -101,35 +97,13 @@ private static async Task UpdateReference(Document document, ClassDecl (identifierName == attributeTypeName || (identifierName + "Attribute") == attributeTypeName)) { // We found the attribute to remove and the list to update - targetAttributeList = attributeList; - targetAttribute = attribute; + updatedClassDeclaration = (ClassDeclarationSyntax)generator.RemoveNode(updatedClassDeclaration, attribute); break; } } } - // If we found an attribute to remove, do that - if (targetAttribute is not null) - { - // If the target list has more than one attribute, keep it and just remove the target one - if (targetAttributeList!.Attributes.Count > 1) - { - updatedClassDeclaration = - updatedClassDeclaration.ReplaceNode( - targetAttributeList, - targetAttributeList.RemoveNode(targetAttribute, SyntaxRemoveOptions.KeepNoTrivia)!); - } - else - { - // Otherwise, remove the entire attribute list - updatedClassDeclaration = updatedClassDeclaration.RemoveNode(targetAttributeList, SyntaxRemoveOptions.KeepExteriorTrivia)!; - } - } - - SyntaxNode originalRoot = await classDeclaration.SyntaxTree.GetRootAsync(cancellationToken).ConfigureAwait(false); - SyntaxTree updatedTree = originalRoot.ReplaceNode(classDeclaration, updatedClassDeclaration).SyntaxTree; - - return document.WithSyntaxRoot(await updatedTree.GetRootAsync(cancellationToken).ConfigureAwait(false)); + return Task.FromResult(document.WithSyntaxRoot(root.ReplaceNode(classDeclaration, updatedClassDeclaration))); } }