diff --git a/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs b/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
index dd33f35e..b8ea67be 100644
--- a/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
+++ b/src/CommunityToolkit.Mvvm.CodeFixers/ClassUsingAttributeInsteadOfInheritanceCodeFixer.cs
@@ -13,6 +13,7 @@
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Text;
using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
@@ -63,7 +64,7 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
context.RegisterCodeFix(
CodeAction.Create(
title: "Inherit from ObservableObject",
- createChangedDocument: token => UpdateReference(context.Document, classDeclaration, attributeTypeName, token),
+ createChangedDocument: token => UpdateReference(context.Document, root, classDeclaration, attributeTypeName),
equivalenceKey: "Inherit from ObservableObject"),
diagnostic);
@@ -76,21 +77,16 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
/// Applies the code fix to a target class declaration and returns an updated document.
///
/// The original document being fixed.
+ /// The original tree root belonging to the current document.
/// The to update.
/// The name of the attribute that should be removed.
- /// The cancellation token for the operation.
/// An updated document with the applied code fix, and inheriting from ObservableObject.
- private static async Task UpdateReference(Document document, ClassDeclarationSyntax classDeclaration, string attributeTypeName, CancellationToken cancellationToken)
+ private static Task UpdateReference(Document document, SyntaxNode root, ClassDeclarationSyntax classDeclaration, string attributeTypeName)
{
// Insert ObservableObject always in first position in the base list. The type might have
// some interfaces in the base list, so we just copy them back after ObservableObject.
- ClassDeclarationSyntax updatedClassDeclaration =
- classDeclaration.WithBaseList(BaseList(SingletonSeparatedList(
- (BaseTypeSyntax)SimpleBaseType(IdentifierName("ObservableObject"))))
- .AddTypes(classDeclaration.BaseList?.Types.ToArray() ?? Array.Empty()));
-
- AttributeListSyntax? targetAttributeList = null;
- AttributeSyntax? targetAttribute = null;
+ SyntaxGenerator generator = SyntaxGenerator.GetGenerator(document);
+ ClassDeclarationSyntax updatedClassDeclaration = (ClassDeclarationSyntax)generator.AddBaseType(classDeclaration, IdentifierName("ObservableObject"));
// Find the attribute list and attribute to remove
foreach (AttributeListSyntax attributeList in updatedClassDeclaration.AttributeLists)
@@ -101,35 +97,13 @@ private static async Task UpdateReference(Document document, ClassDecl
(identifierName == attributeTypeName || (identifierName + "Attribute") == attributeTypeName))
{
// We found the attribute to remove and the list to update
- targetAttributeList = attributeList;
- targetAttribute = attribute;
+ updatedClassDeclaration = (ClassDeclarationSyntax)generator.RemoveNode(updatedClassDeclaration, attribute);
break;
}
}
}
- // If we found an attribute to remove, do that
- if (targetAttribute is not null)
- {
- // If the target list has more than one attribute, keep it and just remove the target one
- if (targetAttributeList!.Attributes.Count > 1)
- {
- updatedClassDeclaration =
- updatedClassDeclaration.ReplaceNode(
- targetAttributeList,
- targetAttributeList.RemoveNode(targetAttribute, SyntaxRemoveOptions.KeepNoTrivia)!);
- }
- else
- {
- // Otherwise, remove the entire attribute list
- updatedClassDeclaration = updatedClassDeclaration.RemoveNode(targetAttributeList, SyntaxRemoveOptions.KeepExteriorTrivia)!;
- }
- }
-
- SyntaxNode originalRoot = await classDeclaration.SyntaxTree.GetRootAsync(cancellationToken).ConfigureAwait(false);
- SyntaxTree updatedTree = originalRoot.ReplaceNode(classDeclaration, updatedClassDeclaration).SyntaxTree;
-
- return document.WithSyntaxRoot(await updatedTree.GetRootAsync(cancellationToken).ConfigureAwait(false));
+ return Task.FromResult(document.WithSyntaxRoot(root.ReplaceNode(classDeclaration, updatedClassDeclaration)));
}
}