From 05cf17c06178d439939ccacd4897b98a4098b425 Mon Sep 17 00:00:00 2001 From: "Steven.Darby" Date: Tue, 18 Apr 2023 23:42:59 +0100 Subject: [PATCH] Fixes #27752 - Allow root provider DI for pooled DbContexts --- src/EFCore/Internal/DbContextPool.cs | 22 +++++---- src/EFCore/Properties/CoreStrings.Designer.cs | 2 +- .../DbContextPoolingTest.cs | 45 +++++++++++++++---- 3 files changed, 51 insertions(+), 18 deletions(-) diff --git a/src/EFCore/Internal/DbContextPool.cs b/src/EFCore/Internal/DbContextPool.cs index 83cc7fecd25..60fa35b627a 100644 --- a/src/EFCore/Internal/DbContextPool.cs +++ b/src/EFCore/Internal/DbContextPool.cs @@ -35,7 +35,7 @@ public class DbContextPool : IDbContextPool, IDisposable, IA /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public DbContextPool(DbContextOptions options) + public DbContextPool(DbContextOptions options, IServiceProvider? serviceProvider = default) { _maxSize = options.FindExtension()?.MaxPoolSize ?? DefaultPoolSize; @@ -46,23 +46,29 @@ public DbContextPool(DbContextOptions options) options.Freeze(); - _activator = CreateActivator(options); + _activator = CreateActivator(options, serviceProvider); } - private static Func CreateActivator(DbContextOptions options) + private static Func CreateActivator(DbContextOptions options, IServiceProvider? serviceProvider) { var constructors = typeof(TContext).GetTypeInfo().DeclaredConstructors .Where(c => !c.IsStatic && c.IsPublic && c.GetParameters().Length > 0).ToArray(); - if (constructors.Length == 1) + if (constructors.Length == 1 + && constructors[0].GetParameters() is { } parameters + && parameters.Any(p => p.ParameterType == typeof(DbContextOptions) || p.ParameterType == typeof(DbContextOptions))) { - var parameters = constructors[0].GetParameters(); - if (parameters.Length == 1 - && (parameters[0].ParameterType == typeof(DbContextOptions) - || parameters[0].ParameterType == typeof(DbContextOptions))) + if (parameters.Length == 1) { return Expression.Lambda>(Expression.New(constructors[0], Expression.Constant(options))).Compile(); } + + if (serviceProvider is not null) + { + var factory = ActivatorUtilities.CreateFactory(new[] { typeof(DbContextOptions) }); + var activatorParameters = new object[] { options }; + return () => factory.Invoke(serviceProvider, activatorParameters); + } } throw new InvalidOperationException(CoreStrings.PoolingContextCtorError(typeof(TContext).ShortDisplayName())); diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index cda5e38a44b..356730b18cd 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -2018,7 +2018,7 @@ public static string OwnershipToDependent(object? navigation, object? principalE navigation, principalEntityType, dependentEntityType); /// - /// The DbContext of type '{contextType}' cannot be pooled because it does not have a public constructor accepting a single parameter of type DbContextOptions or has more than one constructor. + /// The DbContext of type '{contextType}' cannot be pooled because it does not have a public constructor accepting a parameter of type DbContextOptions or has more than one constructor. /// public static string PoolingContextCtorError(object? contextType) => string.Format( diff --git a/test/EFCore.SqlServer.FunctionalTests/DbContextPoolingTest.cs b/test/EFCore.SqlServer.FunctionalTests/DbContextPoolingTest.cs index e3b774b6d18..6b0c90aa798 100644 --- a/test/EFCore.SqlServer.FunctionalTests/DbContextPoolingTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/DbContextPoolingTest.cs @@ -454,7 +454,7 @@ public void Throws_when_used_with_parameterless_constructor_context() } [ConditionalFact] - public void Throws_when_pooled_context_constructor_has_more_than_one_parameter() + public void Throws_when_pooled_context_constructor_has_second_parameter_that_cannot_be_resolved_from_service_provider() { var serviceProvider = new ServiceCollection().AddDbContextPool(_ => { }) @@ -462,21 +462,22 @@ var serviceProvider using var scope = serviceProvider.CreateScope(); - Assert.Equal( - CoreStrings.PoolingContextCtorError(nameof(TwoParameterConstructorContext)), - Assert.Throws(() => scope.ServiceProvider.GetService()).Message); + Assert.Throws(() => scope.ServiceProvider.GetService()); } private class TwoParameterConstructorContext : DbContext { + public string StringParameter { get; } + public TwoParameterConstructorContext(DbContextOptions options, string x) : base(options) { + StringParameter = x; } } [ConditionalFact] - public void Throws_when_pooled_context_constructor_wrong_parameter() + public void Throws_when_pooled_context_constructor_has_single_parameter_that_cannot_be_resolved_from_service_provider() { var serviceProvider = new ServiceCollection().AddDbContextPool(_ => { }) @@ -484,10 +485,7 @@ var serviceProvider using var scope = serviceProvider.CreateScope(); - Assert.Equal( - CoreStrings.PoolingContextCtorError(nameof(WrongParameterConstructorContext)), - Assert.Throws(() => scope.ServiceProvider.GetService()) - .Message); + Assert.Throws(() => scope.ServiceProvider.GetService()); } private class WrongParameterConstructorContext : DbContext @@ -498,6 +496,35 @@ public WrongParameterConstructorContext(string x) } } + [ConditionalFact] + public void Throws_when_pooled_context_constructor_has_scoped_service() + { + var serviceProvider + = new ServiceCollection() + .AddDbContextPool(_ => { }) + .AddScoped(sp => "string") + .BuildServiceProvider(validateScopes: true); + + using var scope = serviceProvider.CreateScope(); + + Assert.Throws(() => scope.ServiceProvider.GetService()); + } + + [ConditionalFact] + public void Does_not_throw_when_pooled_context_constructor_has_singleton_service() + { + var serviceProvider + = new ServiceCollection() + .AddDbContextPool(_ => { }) + .AddSingleton("string") + .BuildServiceProvider(validateScopes: true); + + using var scope = serviceProvider.CreateScope(); + var context = scope.ServiceProvider.GetService(); + + Assert.Equal("string", context.StringParameter); + } + [ConditionalFact] public void Does_not_throw_when_parameterless_and_correct_constructor() {