Skip to content

Commit

Permalink
Fixing issue #91 where where disposable objects are not disposed by D…
Browse files Browse the repository at this point in the history
…I when decorated.
  • Loading branch information
DanHarltey authored and khellang committed May 23, 2022
1 parent 45b6cef commit d2ca182
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 41 deletions.
9 changes: 4 additions & 5 deletions src/Scrutor/Decoration/ClosedTypeDecorationStrategy.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.Extensions.DependencyInjection;
using System;
using System;

namespace Scrutor.Decoration
{
Expand All @@ -20,16 +19,16 @@ public ClosedTypeDecorationStrategy(Type serviceType, Type? decoratorType, Func<

public bool CanDecorate(Type serviceType) => _serviceType == serviceType;

public Func<IServiceProvider, object> CreateDecorator(ServiceDescriptor descriptor)
public Func<IServiceProvider, object> CreateDecorator(Type serviceType)
{
if (_decoratorType is not null)
{
return DecoratorInstanceFactory.Default(descriptor, _decoratorType);
return DecoratorInstanceFactory.Default(serviceType, _decoratorType);
}

if (_decoratorFactory is not null)
{
return DecoratorInstanceFactory.Custom(descriptor, _decoratorFactory);
return DecoratorInstanceFactory.Custom(serviceType, _decoratorFactory);
}

throw new InvalidOperationException($"Both serviceType and decoratorFactory can not be null.");
Expand Down
184 changes: 184 additions & 0 deletions src/Scrutor/Decoration/DecoratedType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Reflection;
using System.Runtime.InteropServices;

namespace Scrutor.Decoration
{
internal class DecoratedType : Type
{
private readonly Type _proxiedType;

public DecoratedType(Type type) => _proxiedType = type;

// We use object reference equality here to ensure that only the decorating object can match.

public override bool Equals(Type? o) => ReferenceEquals(this, o);

public override bool Equals(object? o) => ReferenceEquals(this, o);

public override int GetHashCode() => _proxiedType.GetHashCode();

public override string? Namespace => _proxiedType.Namespace;
public override string? AssemblyQualifiedName => _proxiedType.AssemblyQualifiedName;
public override string? FullName => _proxiedType.FullName;


public override Assembly Assembly => _proxiedType.Assembly;
public override Module Module => _proxiedType.Module;

public override Type? DeclaringType => _proxiedType.DeclaringType;
public override MethodBase? DeclaringMethod => _proxiedType.DeclaringMethod;

public override Type? ReflectedType => _proxiedType.ReflectedType;
public override Type UnderlyingSystemType => _proxiedType.UnderlyingSystemType;

#if NETCOREAPP3_1_OR_GREATER
public override bool IsTypeDefinition => _proxiedType.IsTypeDefinition;
#endif
protected override bool IsArrayImpl() => _proxiedType.HasElementType;
protected override bool IsByRefImpl() => _proxiedType.IsByRef;
protected override bool IsPointerImpl() => _proxiedType.IsPointer;

public override bool IsConstructedGenericType => _proxiedType.IsConstructedGenericType;
public override bool IsGenericParameter => _proxiedType.IsGenericParameter;
#if NETCOREAPP3_1_OR_GREATER
public override bool IsGenericTypeParameter => _proxiedType.IsGenericTypeParameter;
public override bool IsGenericMethodParameter => _proxiedType.IsGenericMethodParameter;
#endif
public override bool IsGenericType => _proxiedType.IsGenericType;
public override bool IsGenericTypeDefinition => _proxiedType.IsGenericTypeDefinition;

#if NETCOREAPP3_1_OR_GREATER
public override bool IsSZArray => _proxiedType.IsSZArray;
public override bool IsVariableBoundArray => _proxiedType.IsVariableBoundArray;

public override bool IsByRefLike => _proxiedType.IsByRefLike;
#endif
protected override bool HasElementTypeImpl() => _proxiedType.HasElementType;
public override Type? GetElementType() => _proxiedType.GetElementType();

public override int GetArrayRank() => _proxiedType.GetArrayRank();

public override Type GetGenericTypeDefinition() => _proxiedType.GetGenericTypeDefinition();
public override Type[] GetGenericArguments() => _proxiedType.GetGenericArguments();

public override int GenericParameterPosition => _proxiedType.GenericParameterPosition;
public override GenericParameterAttributes GenericParameterAttributes => _proxiedType.GenericParameterAttributes;
public override Type[] GetGenericParameterConstraints() => _proxiedType.GetGenericParameterConstraints();

protected override TypeAttributes GetAttributeFlagsImpl() => _proxiedType.Attributes;

protected override bool IsCOMObjectImpl() => _proxiedType.IsCOMObject;
protected override bool IsContextfulImpl() => _proxiedType.IsContextful;

public override bool IsEnum => _proxiedType.IsEnum;
protected override bool IsMarshalByRefImpl() => _proxiedType.IsMarshalByRef;
protected override bool IsPrimitiveImpl() => _proxiedType.IsPrimitive;

protected override bool IsValueTypeImpl() => _proxiedType.IsValueType;
#if NETCOREAPP3_1_OR_GREATER
public override bool IsSignatureType =>_proxiedType.IsSignatureType;
#endif
public override bool IsSecurityCritical => _proxiedType.IsSecurityCritical;
public override bool IsSecuritySafeCritical => _proxiedType.IsSecuritySafeCritical;
public override bool IsSecurityTransparent => _proxiedType.IsSecurityTransparent;

public override StructLayoutAttribute? StructLayoutAttribute => _proxiedType.StructLayoutAttribute;

protected override ConstructorInfo? GetConstructorImpl(BindingFlags bindingAttr, Binder? binder, CallingConventions callConvention, Type[] types, ParameterModifier[]? modifiers)
=> _proxiedType.GetConstructor(bindingAttr, binder, callConvention, types, modifiers);

public override ConstructorInfo[] GetConstructors(BindingFlags bindingAttr) => _proxiedType.GetConstructors(bindingAttr);

public override EventInfo? GetEvent(string name, BindingFlags bindingAttr) => _proxiedType.GetEvent(name, bindingAttr);

public override EventInfo[] GetEvents() => _proxiedType.GetEvents();

public override EventInfo[] GetEvents(BindingFlags bindingAttr) => _proxiedType.GetEvents(bindingAttr);

public override FieldInfo? GetField(string name, BindingFlags bindingAttr) => _proxiedType.GetField(name, bindingAttr);

public override FieldInfo[] GetFields(BindingFlags bindingAttr) => _proxiedType.GetFields(bindingAttr);

public override MemberInfo[] GetMember(string name, BindingFlags bindingAttr) => _proxiedType.GetMember(name, bindingAttr);

public override MemberInfo[] GetMember(string name, MemberTypes type, BindingFlags bindingAttr) => _proxiedType.GetMember(name, type, bindingAttr);

#if NET6_0
public override MemberInfo GetMemberWithSameMetadataDefinitionAs(MemberInfo member) => _proxiedType.GetMemberWithSameMetadataDefinitionAs(member);
#endif
public override MemberInfo[] GetMembers(BindingFlags bindingAttr) => _proxiedType.GetMembers(bindingAttr);

protected override MethodInfo? GetMethodImpl(string name, BindingFlags bindingAttr, Binder? binder, CallingConventions callConvention, Type[]? types, ParameterModifier[]? modifiers)
=> _proxiedType.GetMethod(name, bindingAttr, binder, callConvention, types!, modifiers);

public override MethodInfo[] GetMethods(BindingFlags bindingAttr) => _proxiedType.GetMethods(bindingAttr);

public override Type? GetNestedType(string name, BindingFlags bindingAttr) => _proxiedType.GetNestedType(name, bindingAttr);

public override Type[] GetNestedTypes(BindingFlags bindingAttr) => _proxiedType.GetNestedTypes(bindingAttr);

protected override PropertyInfo? GetPropertyImpl(string name, BindingFlags bindingAttr, Binder? binder, Type? returnType, Type[]? types, ParameterModifier[]? modifiers)
=> _proxiedType.GetProperty(name, bindingAttr, binder, returnType, types!, modifiers);

public override PropertyInfo[] GetProperties(BindingFlags bindingAttr) => _proxiedType.GetProperties(bindingAttr);

public override MemberInfo[] GetDefaultMembers() => _proxiedType.GetDefaultMembers();

public override RuntimeTypeHandle TypeHandle => _proxiedType.TypeHandle;

protected override TypeCode GetTypeCodeImpl() => Type.GetTypeCode(_proxiedType);

public override Guid GUID => _proxiedType.GUID;

public override Type? BaseType => _proxiedType.BaseType;

public override object? InvokeMember(string name, BindingFlags invokeAttr, Binder? binder, object? target, object?[]? args, ParameterModifier[]? modifiers, CultureInfo? culture, string[]? namedParameters) =>
_proxiedType.InvokeMember(name, invokeAttr, binder, target, args, modifiers, culture, namedParameters);

public override Type? GetInterface(string name, bool ignoreCase) => _proxiedType.GetInterface(name, ignoreCase);
public override Type[] GetInterfaces() => _proxiedType.GetInterfaces();

public override InterfaceMapping GetInterfaceMap(Type interfaceType) => _proxiedType.GetInterfaceMap(interfaceType);

public override bool IsInstanceOfType(object? o) => _proxiedType.IsInstanceOfType(o);

public override bool IsEquivalentTo(Type? other) => _proxiedType.IsEquivalentTo(other);

public override Type GetEnumUnderlyingType() => _proxiedType.GetEnumUnderlyingType();

public override Array GetEnumValues() => _proxiedType.GetEnumValues();

public override Type MakeArrayType() => _proxiedType.MakeArrayType();
public override Type MakeArrayType(int rank) => _proxiedType.MakeArrayType(rank);
public override Type MakeByRefType() => _proxiedType.MakeByRefType();

public override Type MakeGenericType(params Type[] typeArguments) => _proxiedType.MakeGenericType(typeArguments);

public override Type MakePointerType() => _proxiedType.MakePointerType();

public override string ToString() => "Type: " + Name;

#region MemberInfo overrides

public override MemberTypes MemberType => _proxiedType.MemberType;

public override string Name => "Decorated " + _proxiedType.Name;

public override IEnumerable<CustomAttributeData> CustomAttributes => _proxiedType.CustomAttributes;

public override int MetadataToken => _proxiedType.MetadataToken;

public override object[] GetCustomAttributes(bool inherit) => _proxiedType.GetCustomAttributes(inherit);

public override object[] GetCustomAttributes(Type attributeType, bool inherit) => _proxiedType.GetCustomAttributes(attributeType, inherit);

public override bool IsDefined(Type attributeType, bool inherit) => _proxiedType.IsDefined(attributeType, inherit);

public override IList<CustomAttributeData> GetCustomAttributesData() => _proxiedType.GetCustomAttributesData();

#endregion
}
}
21 changes: 19 additions & 2 deletions src/Scrutor/Decoration/Decoration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ private int DecorateServices(IServiceCollection services)
{
var serviceDescriptor = services[i];

if (_decorationStrategy.CanDecorate(serviceDescriptor.ServiceType))
if (IsNotAlreadyDecorated(serviceDescriptor)
&& _decorationStrategy.CanDecorate(serviceDescriptor.ServiceType))
{
var decoratorFactory = _decorationStrategy.CreateDecorator(serviceDescriptor);
var decoratedType = new DecoratedType(serviceDescriptor.ServiceType);

var decoratorFactory = _decorationStrategy.CreateDecorator(decoratedType);

// insert decorated
var decoratedServiceDescriptor = CreateDecoratedServiceDescriptor(serviceDescriptor, decoratedType);
services.Add(decoratedServiceDescriptor);

// replace decorator
services[i] = new ServiceDescriptor(serviceDescriptor.ServiceType, decoratorFactory, serviceDescriptor.Lifetime);
Expand All @@ -49,5 +56,15 @@ private int DecorateServices(IServiceCollection services)

return decorated;
}

private static bool IsNotAlreadyDecorated(ServiceDescriptor serviceDescriptor) => serviceDescriptor.ServiceType is not DecoratedType;

private static ServiceDescriptor CreateDecoratedServiceDescriptor(ServiceDescriptor serviceDescriptor, Type decoratedType) => serviceDescriptor switch
{
{ ImplementationType: not null } => new ServiceDescriptor(decoratedType, serviceDescriptor.ImplementationType, serviceDescriptor.Lifetime),
{ ImplementationFactory: not null } => new ServiceDescriptor(decoratedType, serviceDescriptor.ImplementationFactory, serviceDescriptor.Lifetime),
{ ImplementationInstance: not null } => new ServiceDescriptor(decoratedType, serviceDescriptor.ImplementationInstance),
_ => throw new ArgumentException($"No implementation factory or instance or type found for {serviceDescriptor.ServiceType}.", nameof(serviceDescriptor))
};
}
}
29 changes: 4 additions & 25 deletions src/Scrutor/Decoration/DecoratorInstanceFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,18 @@ namespace Scrutor.Decoration
{
internal static class DecoratorInstanceFactory
{
internal static Func<IServiceProvider, object> Default(ServiceDescriptor descriptor, Type decorator) =>
internal static Func<IServiceProvider, object> Default(Type decorated, Type decorator) =>
(serviceProvider) =>
{
var instanceToDecorate = GetInstance(serviceProvider, descriptor);
var instanceToDecorate = serviceProvider.GetRequiredService(decorated);
return ActivatorUtilities.CreateInstance(serviceProvider, decorator, instanceToDecorate);
};

internal static Func<IServiceProvider, object> Custom(ServiceDescriptor descriptor, Func<object, IServiceProvider, object> creationFactory) =>
internal static Func<IServiceProvider, object> Custom(Type decorated, Func<object, IServiceProvider, object> creationFactory) =>
(serviceProvider) =>
{
var instanceToDecorate = GetInstance(serviceProvider, descriptor);
var instanceToDecorate = serviceProvider.GetRequiredService(decorated);
return creationFactory(instanceToDecorate, serviceProvider);
};

private static object GetInstance(IServiceProvider provider, ServiceDescriptor descriptor)
{
if (descriptor.ImplementationInstance != null)
{
return descriptor.ImplementationInstance;
}

var implementationType = descriptor.ImplementationType;
if (implementationType != null)
{
return ActivatorUtilities.CreateInstance(provider, implementationType);
}

if (descriptor.ImplementationFactory != null)
{
return descriptor.ImplementationFactory(provider);
}

throw new InvalidOperationException($"No implementation factory or instance or type found for {descriptor.ServiceType}.");
}
}
}
5 changes: 2 additions & 3 deletions src/Scrutor/Decoration/IDecorationStrategy.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
using Microsoft.Extensions.DependencyInjection;
using System;
using System;

namespace Scrutor.Decoration
{
internal interface IDecorationStrategy
{
public Type ServiceType { get; }
public bool CanDecorate(Type serviceType);
public Func<IServiceProvider, object> CreateDecorator(ServiceDescriptor descriptor);
public Func<IServiceProvider, object> CreateDecorator(Type serviceType);
}
}
11 changes: 5 additions & 6 deletions src/Scrutor/Decoration/OpenGenericDecorationStrategy.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.Extensions.DependencyInjection;
using System;
using System;

namespace Scrutor.Decoration
{
Expand Down Expand Up @@ -28,19 +27,19 @@ public bool CanDecorate(Type serviceType)
return canHandle;
}

public Func<IServiceProvider, object> CreateDecorator(ServiceDescriptor descriptor)
public Func<IServiceProvider, object> CreateDecorator(Type serviceType)
{
if (_decoratorType is not null)
{
var genericArguments = descriptor.ServiceType.GetGenericArguments();
var genericArguments = serviceType.GetGenericArguments();
var closedDecorator = _decoratorType.MakeGenericType(genericArguments);

return DecoratorInstanceFactory.Default(descriptor, closedDecorator);
return DecoratorInstanceFactory.Default(serviceType, closedDecorator);
}

if (_decoratorFactory is not null)
{
return DecoratorInstanceFactory.Custom(descriptor, _decoratorFactory);
return DecoratorInstanceFactory.Custom(serviceType, _decoratorFactory);
}

throw new InvalidOperationException($"Both serviceType and decoratorFactory can not be null.");
Expand Down

0 comments on commit d2ca182

Please sign in to comment.