diff --git a/src/Temporalio.Extensions.Hosting/IServiceProviderAccessor.cs b/src/Temporalio.Extensions.Hosting/IServiceProviderAccessor.cs new file mode 100644 index 00000000..ff0639c6 --- /dev/null +++ b/src/Temporalio.Extensions.Hosting/IServiceProviderAccessor.cs @@ -0,0 +1,16 @@ +using System; + +namespace Temporalio.Extensions.Hosting +{ + /// + /// Provides access to the current, scoped if + /// one is available. + /// + public interface IServiceProviderAccessor + { + /// + /// Gets or sets the current service provider. + /// + IServiceProvider? ServiceProvider { get; set; } + } +} \ No newline at end of file diff --git a/src/Temporalio.Extensions.Hosting/ServiceProviderAccessor.cs b/src/Temporalio.Extensions.Hosting/ServiceProviderAccessor.cs new file mode 100644 index 00000000..ad3f9802 --- /dev/null +++ b/src/Temporalio.Extensions.Hosting/ServiceProviderAccessor.cs @@ -0,0 +1,42 @@ +using System; +using System.Threading; + +namespace Temporalio.Extensions.Hosting +{ + /// + /// Provides an implementation of based on + /// the current execution context. + /// + public class ServiceProviderAccessor : IServiceProviderAccessor + { + private static readonly AsyncLocal ServiceProviderCurrent = new(); + + /// + public IServiceProvider? ServiceProvider + { + get => ServiceProviderCurrent.Value?.ServiceProvider; + + set + { + var holder = ServiceProviderCurrent.Value; + if (holder != null) + { + // Clear current IServiceProvider trapped in the AsyncLocals, as its done. + holder.ServiceProvider = null; + } + + if (value != null) + { + // Use an object indirection to hold the IServiceProvider in the AsyncLocal, + // so it can be cleared in all ExecutionContexts when its cleared. + ServiceProviderCurrent.Value = new ServiceProviderHolder { ServiceProvider = value }; + } + } + } + + private sealed class ServiceProviderHolder + { + public IServiceProvider? ServiceProvider { get; set; } + } + } +} \ No newline at end of file diff --git a/src/Temporalio.Extensions.Hosting/ServiceProviderExtensions.cs b/src/Temporalio.Extensions.Hosting/ServiceProviderExtensions.cs index 16a5ef7c..c6c0e113 100644 --- a/src/Temporalio.Extensions.Hosting/ServiceProviderExtensions.cs +++ b/src/Temporalio.Extensions.Hosting/ServiceProviderExtensions.cs @@ -68,6 +68,14 @@ public static ActivityDefinition CreateTemporalActivityDefinition( #else var scope = provider.CreateScope(); #endif + IServiceProviderAccessor? serviceProviderAccessor = + scope.ServiceProvider.GetService(); + + if (serviceProviderAccessor is not null) + { + serviceProviderAccessor.ServiceProvider = scope.ServiceProvider; + } + try { object? result; @@ -111,6 +119,10 @@ public static ActivityDefinition CreateTemporalActivityDefinition( } finally { + if (serviceProviderAccessor is not null) + { + serviceProviderAccessor.ServiceProvider = null; + } #if NET6_0_OR_GREATER await scope.DisposeAsync().ConfigureAwait(false); #else