diff --git a/src/MediatR/Registration/ServiceRegistrar.cs b/src/MediatR/Registration/ServiceRegistrar.cs index e4bba2b5..8fd5bf96 100644 --- a/src/MediatR/Registration/ServiceRegistrar.cs +++ b/src/MediatR/Registration/ServiceRegistrar.cs @@ -65,21 +65,37 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa bool addIfAlreadyExists, MediatRServiceConfiguration configuration) { - var concretions = new List(); + var concretions = new List(); var interfaces = new List(); - foreach (var type in assembliesToScan.SelectMany(a => a.DefinedTypes).Where(t => !t.IsOpenGeneric()).Where(configuration.TypeEvaluator)) + var genericConcretions = new List(); + var genericInterfaces = new List(); + + var types = assembliesToScan + .SelectMany(a => a.DefinedTypes) + .Where(t => t.IsConcrete() && t.FindInterfacesThatClose(openRequestInterface).Any()) + .Where(configuration.TypeEvaluator) + .ToList(); + + foreach (var type in types) { var interfaceTypes = type.FindInterfacesThatClose(openRequestInterface).ToArray(); - if (!interfaceTypes.Any()) continue; - if (type.IsConcrete()) + if (!type.IsOpenGeneric()) { concretions.Add(type); - } - foreach (var interfaceType in interfaceTypes) + foreach (var interfaceType in interfaceTypes) + { + interfaces.Fill(interfaceType); + } + } + else { - interfaces.Fill(interfaceType); + genericConcretions.Add(type); + foreach (var interfaceType in interfaceTypes) + { + genericInterfaces.Fill(interfaceType); + } } } @@ -111,6 +127,12 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa AddConcretionsThatCouldBeClosed(@interface, concretions, services); } } + + foreach (var @interface in genericInterfaces) + { + var exactMatches = genericConcretions.Where(x => x.CanBeCastTo(@interface)).ToList(); + AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan); + } } private static bool IsMatchingWithInterface(Type? handlerType, Type handlerInterface) @@ -150,6 +172,62 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List } } + private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(Type openRequestHandlerInterface, Type concreteGenericTRequest, Type openRequestHandlerImplementation) + { + var closingType = concreteGenericTRequest.GetGenericArguments().First(); + + var concreteTResponse = concreteGenericTRequest.GetInterfaces() + .FirstOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IRequest<>)) + ?.GetGenericArguments() + .FirstOrDefault(); + + var typeDefinition = openRequestHandlerInterface.GetGenericTypeDefinition(); + + var serviceType = concreteTResponse != null ? + typeDefinition.MakeGenericType(concreteGenericTRequest, concreteTResponse) : + typeDefinition.MakeGenericType(concreteGenericTRequest); + + return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingType)); + } + + private static List? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable assembliesToScan) + { + var constraints = openRequestHandlerImplementation.GetGenericArguments().First().GetGenericParameterConstraints(); + + var typesThatCanClose = assembliesToScan + .SelectMany(assembly => assembly.GetTypes()) + .Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type))) + .ToList(); + + var requestType = openRequestHandlerInterface.GenericTypeArguments.First(); + + if (requestType.IsGenericParameter) + return null; + + var requestGenericTypeDefinition = requestType.GetGenericTypeDefinition(); + + return typesThatCanClose.Select(type => requestGenericTypeDefinition.MakeGenericType(type)).ToList(); + } + + private static void AddAllConcretionsThatClose(Type openRequestInterface, List concretions, IServiceCollection services, IEnumerable assembliesToScan) + { + foreach (var concretion in concretions) + { + var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan); + + if (concreteRequests is null) + continue; + + var registrationTypes = concreteRequests + .Select(concreteRequest => GetConcreteRegistrationTypes(openRequestInterface, concreteRequest, concretion)); + + foreach (var (Service, Implementation) in registrationTypes) + { + services.AddTransient(Service, Implementation); + } + } + } + internal static bool CouldCloseTo(this Type openConcretion, Type closedInterface) { var openInterface = closedInterface.GetGenericTypeDefinition(); @@ -259,8 +337,8 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi foreach (var serviceDescriptor in serviceConfiguration.BehaviorsToRegister) { services.TryAddEnumerable(serviceDescriptor); - } - + } + foreach (var serviceDescriptor in serviceConfiguration.StreamBehaviorsToRegister) { services.TryAddEnumerable(serviceDescriptor); @@ -270,7 +348,7 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi private static void RegisterBehaviorIfImplementationsExist(IServiceCollection services, Type behaviorType, Type subBehaviorType) { var hasAnyRegistrationsOfSubBehaviorType = services - .Where(service => !service.IsKeyedService) + .Where(service => !service.IsKeyedService) .Select(service => service.ImplementationType) .OfType() .SelectMany(type => type.GetInterfaces()) @@ -283,4 +361,4 @@ private static void RegisterBehaviorIfImplementationsExist(IServiceCollection se services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), behaviorType, ServiceLifetime.Transient)); } } -} \ No newline at end of file +} diff --git a/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs b/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs index f5175b91..dbbcaefc 100644 --- a/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs +++ b/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs @@ -3,6 +3,7 @@ namespace MediatR.Extensions.Microsoft.DependencyInjection.Tests; using System; +using System.Collections.Generic; using System.Linq; using System.Reflection; using Shouldly; @@ -55,8 +56,32 @@ public void ShouldRequireAtLeastOneAssembly() { var services = new ServiceCollection(); - Action registration = () => services.AddMediatR(_ => {}); + Action registration = () => services.AddMediatR(_ => { }); registration.ShouldThrow(); } + + [Fact] + public void ShouldResolveGenericVoidRequestHandler() + { + _provider.GetService>>().ShouldNotBeNull(); + } + + [Fact] + public void ShouldResolveGenericReturnTypeRequestHandler() + { + _provider.GetService, string>>().ShouldNotBeNull(); + } + + [Fact] + public void ShouldResolveGenericPingRequestHandler() + { + _provider.GetService, Pong>>().ShouldNotBeNull(); + } + + [Fact] + public void ShouldResolveVoidGenericPingRequestHandler() + { + _provider.GetService>>().ShouldNotBeNull(); + } } \ No newline at end of file diff --git a/test/MediatR.Tests/MicrosoftExtensionsDI/Handlers.cs b/test/MediatR.Tests/MicrosoftExtensionsDI/Handlers.cs index f71c407e..5dbe0c1e 100644 --- a/test/MediatR.Tests/MicrosoftExtensionsDI/Handlers.cs +++ b/test/MediatR.Tests/MicrosoftExtensionsDI/Handlers.cs @@ -213,6 +213,48 @@ public Task Send(TRequest request, CancellationToken cancellationToken throw new System.NotImplementedException(); } } + + interface ITypeArgument { } + class ConcreteTypeArgument : ITypeArgument { } + class OpenGenericVoidRequest : IRequest + where T : class, ITypeArgument + { } + class OpenGenericVoidRequestHandler : IRequestHandler> + where T : class, ITypeArgument + { + public Task Handle(OpenGenericVoidRequest request, CancellationToken cancellationToken) => Task.CompletedTask; + } + class OpenGenericReturnTypeRequest : IRequest + where T : class, ITypeArgument + { } + class OpenGenericReturnTypeRequestHandler : IRequestHandler, string> + where T : class, ITypeArgument + { + public Task Handle(OpenGenericReturnTypeRequest request, CancellationToken cancellationToken) => Task.FromResult(nameof(request)); + } + + public class GenericPing : IRequest + where T : Pong + { + public T? Pong { get; set; } + } + + public class GenericPingHandler : IRequestHandler, T> + where T : Pong + { + public Task Handle(GenericPing request, CancellationToken cancellationToken) => Task.FromResult(request.Pong!); + } + + public class VoidGenericPing : IRequest + where T : Pong + { } + + public class VoidGenericPingHandler : IRequestHandler> + where T : Pong + { + public Task Handle(VoidGenericPing request, CancellationToken cancellationToken) => Task.CompletedTask; + } + } namespace MediatR.Extensions.Microsoft.DependencyInjection.Tests.Included diff --git a/test/MediatR.Tests/SendTests.cs b/test/MediatR.Tests/SendTests.cs index 7ca41fe7..b489a808 100644 --- a/test/MediatR.Tests/SendTests.cs +++ b/test/MediatR.Tests/SendTests.cs @@ -5,11 +5,25 @@ namespace MediatR.Tests; using System; using System.Threading.Tasks; using Shouldly; -using Lamar; using Xunit; - +using Microsoft.Extensions.DependencyInjection; + public class SendTests { + private readonly IServiceProvider _serviceProvider; + private Dependency _dependency; + private readonly IMediator _mediator; + + public SendTests() + { + _dependency = new Dependency(); + var services = new ServiceCollection(); + services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblies(typeof(Ping).Assembly)); + services.AddSingleton(_dependency); + _serviceProvider = services.BuildServiceProvider(); + _mediator = _serviceProvider.GetService()!; + + } public class Ping : IRequest { @@ -52,73 +66,66 @@ public Task Handle(VoidPing request, CancellationToken cancellationToken) } } - [Fact] - public async Task Should_resolve_main_handler() + public class GenericPing : IRequest + where T : Pong { - var container = new Container(cfg => - { - cfg.Scan(scanner => - { - scanner.AssemblyContainingType(typeof(PublishTests)); - scanner.IncludeNamespaceContainingType(); - scanner.WithDefaultConventions(); - scanner.AddAllTypesOf(typeof(IRequestHandler<,>)); - }); - cfg.For().Use(); - }); + public T? Pong { get; set; } + } - var mediator = container.GetInstance(); + public class GenericPingHandler : IRequestHandler, T> + where T : Pong + { + private readonly Dependency _dependency; - var response = await mediator.Send(new Ping { Message = "Ping" }); + public GenericPingHandler(Dependency dependency) => _dependency = dependency; - response.Message.ShouldBe("Ping Pong"); + public Task Handle(GenericPing request, CancellationToken cancellationToken) + { + _dependency.Called = true; + request.Pong!.Message += " Pong"; + return Task.FromResult(request.Pong!); + } } - [Fact] - public async Task Should_resolve_main_void_handler() + public class VoidGenericPing : IRequest + where T : Pong + { } + + public class VoidGenericPingHandler : IRequestHandler> + where T : Pong { - var dependency = new Dependency(); + private readonly Dependency _dependency; + public VoidGenericPingHandler(Dependency dependency) => _dependency = dependency; - var container = new Container(cfg => + public Task Handle(VoidGenericPing request, CancellationToken cancellationToken) { - cfg.Scan(scanner => - { - scanner.AssemblyContainingType(typeof(PublishTests)); - scanner.IncludeNamespaceContainingType(); - scanner.WithDefaultConventions(); - scanner.AddAllTypesOf(typeof(IRequestHandler<>)); - scanner.AddAllTypesOf(typeof(IRequestHandler<,>)); - }); - cfg.ForSingletonOf().Use(dependency); - cfg.For().Use(); - }); + _dependency.Called = true; - var mediator = container.GetInstance(); + return Task.CompletedTask; + } + } - await mediator.Send(new VoidPing()); + [Fact] + public async Task Should_resolve_main_handler() + { + var response = await _mediator.Send(new Ping { Message = "Ping" }); - dependency.Called.ShouldBeTrue(); + response.Message.ShouldBe("Ping Pong"); } [Fact] - public async Task Should_resolve_main_handler_via_dynamic_dispatch() + public async Task Should_resolve_main_void_handler() { - var container = new Container(cfg => - { - cfg.Scan(scanner => - { - scanner.AssemblyContainingType(typeof(PublishTests)); - scanner.IncludeNamespaceContainingType(); - scanner.WithDefaultConventions(); - scanner.AddAllTypesOf(typeof(IRequestHandler<,>)); - }); - cfg.For().Use(); - }); - - var mediator = container.GetInstance(); + await _mediator.Send(new VoidPing()); + + _dependency.Called.ShouldBeTrue(); + } + [Fact] + public async Task Should_resolve_main_handler_via_dynamic_dispatch() + { object request = new Ping { Message = "Ping" }; - var response = await mediator.Send(request); + var response = await _mediator.Send(request); var pong = response.ShouldBeOfType(); pong.Message.ShouldBe("Ping Pong"); @@ -127,50 +134,18 @@ public async Task Should_resolve_main_handler_via_dynamic_dispatch() [Fact] public async Task Should_resolve_main_void_handler_via_dynamic_dispatch() { - var dependency = new Dependency(); - - var container = new Container(cfg => - { - cfg.Scan(scanner => - { - scanner.AssemblyContainingType(typeof(PublishTests)); - scanner.IncludeNamespaceContainingType(); - scanner.WithDefaultConventions(); - scanner.AddAllTypesOf(typeof(IRequestHandler<>)); - scanner.AddAllTypesOf(typeof(IRequestHandler<,>)); - }); - cfg.ForSingletonOf().Use(dependency); - cfg.For().Use(); - }); - - var mediator = container.GetInstance(); - object request = new VoidPing(); - var response = await mediator.Send(request); + var response = await _mediator.Send(request); response.ShouldBeOfType(); - dependency.Called.ShouldBeTrue(); + _dependency.Called.ShouldBeTrue(); } [Fact] public async Task Should_resolve_main_handler_by_specific_interface() { - var container = new Container(cfg => - { - cfg.Scan(scanner => - { - scanner.AssemblyContainingType(typeof(PublishTests)); - scanner.IncludeNamespaceContainingType(); - scanner.WithDefaultConventions(); - scanner.AddAllTypesOf(typeof(IRequestHandler<,>)); - }); - cfg.For().Use(); - }); - - var mediator = container.GetInstance(); - - var response = await mediator.Send(new Ping { Message = "Ping" }); + var response = await _mediator.Send(new Ping { Message = "Ping" }); response.Message.ShouldBe("Ping Pong"); } @@ -178,39 +153,34 @@ public async Task Should_resolve_main_handler_by_specific_interface() [Fact] public async Task Should_resolve_main_handler_by_given_interface() { - var dependency = new Dependency(); - var container = new Container(cfg => - { - cfg.Scan(scanner => - { - scanner.AssemblyContainingType(typeof(PublishTests)); - scanner.IncludeNamespaceContainingType(); - scanner.WithDefaultConventions(); - scanner.AddAllTypesOf(typeof(IRequestHandler<>)); - }); - cfg.ForSingletonOf().Use(dependency); - cfg.For().Use(); - }); - - var mediator = container.GetInstance(); - // wrap requests in an array, so this test won't break on a 'replace with var' refactoring var requests = new IRequest[] { new VoidPing() }; - await mediator.Send(requests[0]); + await _mediator.Send(requests[0]); - dependency.Called.ShouldBeTrue(); + _dependency.Called.ShouldBeTrue(); } [Fact] - public async Task Should_raise_execption_on_null_request() - { - var container = new Container(cfg => - { - cfg.For().Use(); - }); + public Task Should_raise_execption_on_null_request() => Should.ThrowAsync(async () => await _mediator.Send(default!)); + + [Fact] + public async Task Should_resolve_generic_handler() + { + var request = new GenericPing { Pong = new Pong { Message = "Ping" } }; + var result = await _mediator.Send(request); - var mediator = container.GetInstance(); + var pong = result.ShouldBeOfType(); + pong.Message.ShouldBe("Ping Pong"); + + _dependency.Called.ShouldBeTrue(); + } + + [Fact] + public async Task Should_resolve_generic_void_handler() + { + var request = new VoidGenericPing(); + await _mediator.Send(request); - await Should.ThrowAsync(async () => await mediator.Send(default!)); + _dependency.Called.ShouldBeTrue(); } } \ No newline at end of file