diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs index 0dd9d582..34870256 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs @@ -20,13 +20,15 @@ public sealed partial class ObservableValidatorValidateAllPropertiesGenerator : /// public void Initialize(IncrementalGeneratorInitializationContext context) { - // Get all class declarations + // Get all class declarations. We intentionally skip generating code for abstract types, as that would never be used. + // The methods that are generated by this generator are retrieved through reflection using the type of the invoking + // instance as discriminator, which means a type that is abstract could never be used (since it couldn't be instantiated). IncrementalValuesProvider typeSymbols = context.SyntaxProvider .CreateSyntaxProvider( static (node, _) => node is ClassDeclarationSyntax, static (context, _) => (context.Node, Symbol: (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!)) - .Where(static item => item.Node.IsFirstSyntaxDeclarationForSymbol(item.Symbol)) + .Where(static item => !item.Symbol.IsAbstract && item.Node.IsFirstSyntaxDeclarationForSymbol(item.Symbol)) .Select(static (item, _) => item.Symbol); // Get the types that inherit from ObservableValidator and gather their info diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs index 8fc53859..4dae6550 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs @@ -25,12 +25,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // definitions (it might happen if a recipient has partial declarations). To do this, all pairs // of class declarations and associated symbols are gathered, and then only the pair where the // class declaration is the first syntax reference for the associated symbol is kept. + // Just like with the ObservableValidator generator, we also intentionally skip abstract types. IncrementalValuesProvider typeSymbols = context.SyntaxProvider .CreateSyntaxProvider( static (node, _) => node is ClassDeclarationSyntax, static (context, _) => (context.Node, Symbol: (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!)) - .Where(static item => item.Node.IsFirstSyntaxDeclarationForSymbol(item.Symbol)) + .Where(static item => !item.Symbol.IsAbstract && item.Node.IsFirstSyntaxDeclarationForSymbol(item.Symbol)) .Select(static (item, _) => item.Symbol); // Get the target IRecipient interfaces and filter out other types diff --git a/tests/CommunityToolkit.Mvvm.UnitTests/Test_IRecipientGenerator.cs b/tests/CommunityToolkit.Mvvm.UnitTests/Test_IRecipientGenerator.cs index b03fbde7..782e4145 100644 --- a/tests/CommunityToolkit.Mvvm.UnitTests/Test_IRecipientGenerator.cs +++ b/tests/CommunityToolkit.Mvvm.UnitTests/Test_IRecipientGenerator.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Linq; +using System.Reflection; using CommunityToolkit.Mvvm.Messaging; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -52,9 +54,20 @@ public void Test_IRecipientGenerator_TypeWithMultipleClassDeclarations() _ = Messaging.__Internals.__IMessengerExtensions.CreateAllMessagesRegistratorWithToken(recipient); } - public sealed class RecipientWithSomeMessages : - IRecipient, - IRecipient + [TestMethod] + public void Test_IRecipientGenerator_AbstractTypesDoNotTriggerCodeGeneration() + { + MethodInfo? createAllPropertiesValidatorMethod = typeof(Messaging.__Internals.__IMessengerExtensions) + .GetMethods(BindingFlags.Static | BindingFlags.Public) + .Where(static m => m.Name == "CreateAllMessagesRegistratorWithToken") + .Where(static m => m.GetParameters() is { Length: 1 } parameters && parameters[0].ParameterType == typeof(AbstractModelWithValidatablePropertyIRecipientInterfaces)) + .FirstOrDefault(); + + // We need to validate that no methods are generated for abstract types, so we just check this method doesn't exist + Assert.IsNull(createAllPropertiesValidatorMethod); + } + + public sealed class RecipientWithSomeMessages : IRecipient, IRecipient { public MessageA? A { get; private set; } @@ -91,4 +104,14 @@ public void Receive(MessageA message) partial class RecipientWithMultipleClassDeclarations { } + + public abstract class AbstractModelWithValidatablePropertyIRecipientInterfaces : IRecipient, IRecipient + { + public abstract void Receive(MessageA message); + + public void Receive(MessageB message) + { + + } + } } diff --git a/tests/CommunityToolkit.Mvvm.UnitTests/Test_ObservableValidator.cs b/tests/CommunityToolkit.Mvvm.UnitTests/Test_ObservableValidator.cs index faade404..778be31d 100644 --- a/tests/CommunityToolkit.Mvvm.UnitTests/Test_ObservableValidator.cs +++ b/tests/CommunityToolkit.Mvvm.UnitTests/Test_ObservableValidator.cs @@ -526,6 +526,19 @@ public void Test_ObservableValidator_VerifyTrimmingAnnotation() #endif } + [TestMethod] + public void Test_ObservableRecipient_AbstractTypesDoNotTriggerCodeGeneration() + { + MethodInfo? createAllPropertiesValidatorMethod = typeof(ComponentModel.__Internals.__ObservableValidatorExtensions) + .GetMethods(BindingFlags.Static | BindingFlags.Public) + .Where(static m => m.Name == "CreateAllPropertiesValidator") + .Where(static m => m.GetParameters() is { Length: 1 } parameters && parameters[0].ParameterType == typeof(AbstractModelWithValidatableProperty)) + .FirstOrDefault(); + + // We need to validate that no methods are generated for abstract types, so we just check this method doesn't exist + Assert.IsNull(createAllPropertiesValidatorMethod); + } + public class Person : ObservableValidator { private string? name; @@ -777,4 +790,11 @@ public partial class PersonWithPartialDeclaration [Range(10, 1000)] public int Number { get; set; } } + + public abstract class AbstractModelWithValidatableProperty : ObservableValidator + { + [Required] + [MinLength(2)] + public string? Name { get; set; } + } }