Skip to content

Commit

Permalink
Merge pull request #1013 from zachpainter77/master
Browse files Browse the repository at this point in the history
Add Support for Generic Handlers
  • Loading branch information
jbogard authored Jun 6, 2024
2 parents b3147c8 + 1552d92 commit 3b8bf44
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 124 deletions.
100 changes: 89 additions & 11 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,37 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
bool addIfAlreadyExists,
MediatRServiceConfiguration configuration)
{
var concretions = new List<Type>();
var concretions = new List<Type>();
var interfaces = new List<Type>();
foreach (var type in assembliesToScan.SelectMany(a => a.DefinedTypes).Where(t => !t.IsOpenGeneric()).Where(configuration.TypeEvaluator))
var genericConcretions = new List<Type>();
var genericInterfaces = new List<Type>();

var types = assembliesToScan
.SelectMany(a => a.DefinedTypes)
.Where(t => t.IsConcrete() && t.FindInterfacesThatClose(openRequestInterface).Any())
.Where(configuration.TypeEvaluator)
.ToList();

foreach (var type in types)
{
var interfaceTypes = type.FindInterfacesThatClose(openRequestInterface).ToArray();
if (!interfaceTypes.Any()) continue;

if (type.IsConcrete())
if (!type.IsOpenGeneric())
{
concretions.Add(type);
}

foreach (var interfaceType in interfaceTypes)
foreach (var interfaceType in interfaceTypes)
{
interfaces.Fill(interfaceType);
}
}
else
{
interfaces.Fill(interfaceType);
genericConcretions.Add(type);
foreach (var interfaceType in interfaceTypes)
{
genericInterfaces.Fill(interfaceType);
}
}
}

Expand Down Expand Up @@ -111,6 +127,12 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
AddConcretionsThatCouldBeClosed(@interface, concretions, services);
}
}

foreach (var @interface in genericInterfaces)
{
var exactMatches = genericConcretions.Where(x => x.CanBeCastTo(@interface)).ToList();
AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan);
}
}

private static bool IsMatchingWithInterface(Type? handlerType, Type handlerInterface)
Expand Down Expand Up @@ -150,6 +172,62 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List<Type>
}
}

private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(Type openRequestHandlerInterface, Type concreteGenericTRequest, Type openRequestHandlerImplementation)
{
var closingType = concreteGenericTRequest.GetGenericArguments().First();

var concreteTResponse = concreteGenericTRequest.GetInterfaces()
.FirstOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IRequest<>))
?.GetGenericArguments()
.FirstOrDefault();

var typeDefinition = openRequestHandlerInterface.GetGenericTypeDefinition();

var serviceType = concreteTResponse != null ?
typeDefinition.MakeGenericType(concreteGenericTRequest, concreteTResponse) :
typeDefinition.MakeGenericType(concreteGenericTRequest);

return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingType));
}

private static List<Type>? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable<Assembly> assembliesToScan)
{
var constraints = openRequestHandlerImplementation.GetGenericArguments().First().GetGenericParameterConstraints();

var typesThatCanClose = assembliesToScan
.SelectMany(assembly => assembly.GetTypes())
.Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type)))
.ToList();

var requestType = openRequestHandlerInterface.GenericTypeArguments.First();

if (requestType.IsGenericParameter)
return null;

var requestGenericTypeDefinition = requestType.GetGenericTypeDefinition();

return typesThatCanClose.Select(type => requestGenericTypeDefinition.MakeGenericType(type)).ToList();
}

private static void AddAllConcretionsThatClose(Type openRequestInterface, List<Type> concretions, IServiceCollection services, IEnumerable<Assembly> assembliesToScan)
{
foreach (var concretion in concretions)
{
var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan);

if (concreteRequests is null)
continue;

var registrationTypes = concreteRequests
.Select(concreteRequest => GetConcreteRegistrationTypes(openRequestInterface, concreteRequest, concretion));

foreach (var (Service, Implementation) in registrationTypes)
{
services.AddTransient(Service, Implementation);
}
}
}

internal static bool CouldCloseTo(this Type openConcretion, Type closedInterface)
{
var openInterface = closedInterface.GetGenericTypeDefinition();
Expand Down Expand Up @@ -259,8 +337,8 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
foreach (var serviceDescriptor in serviceConfiguration.BehaviorsToRegister)
{
services.TryAddEnumerable(serviceDescriptor);
}
}

foreach (var serviceDescriptor in serviceConfiguration.StreamBehaviorsToRegister)
{
services.TryAddEnumerable(serviceDescriptor);
Expand All @@ -270,7 +348,7 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
private static void RegisterBehaviorIfImplementationsExist(IServiceCollection services, Type behaviorType, Type subBehaviorType)
{
var hasAnyRegistrationsOfSubBehaviorType = services
.Where(service => !service.IsKeyedService)
.Where(service => !service.IsKeyedService)
.Select(service => service.ImplementationType)
.OfType<Type>()
.SelectMany(type => type.GetInterfaces())
Expand All @@ -283,4 +361,4 @@ private static void RegisterBehaviorIfImplementationsExist(IServiceCollection se
services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), behaviorType, ServiceLifetime.Transient));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
namespace MediatR.Extensions.Microsoft.DependencyInjection.Tests;

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Shouldly;
Expand Down Expand Up @@ -55,8 +56,32 @@ public void ShouldRequireAtLeastOneAssembly()
{
var services = new ServiceCollection();

Action registration = () => services.AddMediatR(_ => {});
Action registration = () => services.AddMediatR(_ => { });

registration.ShouldThrow<ArgumentException>();
}

[Fact]
public void ShouldResolveGenericVoidRequestHandler()
{
_provider.GetService<IRequestHandler<OpenGenericVoidRequest<ConcreteTypeArgument>>>().ShouldNotBeNull();
}

[Fact]
public void ShouldResolveGenericReturnTypeRequestHandler()
{
_provider.GetService<IRequestHandler<OpenGenericReturnTypeRequest<ConcreteTypeArgument>, string>>().ShouldNotBeNull();
}

[Fact]
public void ShouldResolveGenericPingRequestHandler()
{
_provider.GetService<IRequestHandler<GenericPing<Pong>, Pong>>().ShouldNotBeNull();
}

[Fact]
public void ShouldResolveVoidGenericPingRequestHandler()
{
_provider.GetService<IRequestHandler<VoidGenericPing<Pong>>>().ShouldNotBeNull();
}
}
42 changes: 42 additions & 0 deletions test/MediatR.Tests/MicrosoftExtensionsDI/Handlers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,48 @@ public Task Send<TRequest>(TRequest request, CancellationToken cancellationToken
throw new System.NotImplementedException();
}
}

interface ITypeArgument { }
class ConcreteTypeArgument : ITypeArgument { }
class OpenGenericVoidRequest<T> : IRequest
where T : class, ITypeArgument
{ }
class OpenGenericVoidRequestHandler<T> : IRequestHandler<OpenGenericVoidRequest<T>>
where T : class, ITypeArgument
{
public Task Handle(OpenGenericVoidRequest<T> request, CancellationToken cancellationToken) => Task.CompletedTask;
}
class OpenGenericReturnTypeRequest<T> : IRequest<string>
where T : class, ITypeArgument
{ }
class OpenGenericReturnTypeRequestHandler<T> : IRequestHandler<OpenGenericReturnTypeRequest<T>, string>
where T : class, ITypeArgument
{
public Task<string> Handle(OpenGenericReturnTypeRequest<T> request, CancellationToken cancellationToken) => Task.FromResult(nameof(request));
}

public class GenericPing<T> : IRequest<T>
where T : Pong
{
public T? Pong { get; set; }
}

public class GenericPingHandler<T> : IRequestHandler<GenericPing<T>, T>
where T : Pong
{
public Task<T> Handle(GenericPing<T> request, CancellationToken cancellationToken) => Task.FromResult(request.Pong!);
}

public class VoidGenericPing<T> : IRequest
where T : Pong
{ }

public class VoidGenericPingHandler<T> : IRequestHandler<VoidGenericPing<T>>
where T : Pong
{
public Task Handle(VoidGenericPing<T> request, CancellationToken cancellationToken) => Task.CompletedTask;
}

}

namespace MediatR.Extensions.Microsoft.DependencyInjection.Tests.Included
Expand Down
Loading

0 comments on commit 3b8bf44

Please sign in to comment.