Skip to content

Commit

Permalink
Allow DI for pooled DbContexts (#30739)
Browse files Browse the repository at this point in the history
Co-authored-by: Steven.Darby <[email protected]>
  • Loading branch information
stevendarby and Steven.Darby authored Apr 21, 2023
1 parent 0adaa62 commit b35364d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 18 deletions.
22 changes: 14 additions & 8 deletions src/EFCore/Internal/DbContextPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class DbContextPool<TContext> : IDbContextPool<TContext>, IDisposable, IA
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public DbContextPool(DbContextOptions<TContext> options)
public DbContextPool(DbContextOptions<TContext> options, IServiceProvider? serviceProvider = default)
{
_maxSize = options.FindExtension<CoreOptionsExtension>()?.MaxPoolSize ?? DefaultPoolSize;

Expand All @@ -46,23 +46,29 @@ public DbContextPool(DbContextOptions<TContext> options)

options.Freeze();

_activator = CreateActivator(options);
_activator = CreateActivator(options, serviceProvider);
}

private static Func<DbContext> CreateActivator(DbContextOptions<TContext> options)
private static Func<DbContext> CreateActivator(DbContextOptions<TContext> options, IServiceProvider? serviceProvider)
{
var constructors = typeof(TContext).GetTypeInfo().DeclaredConstructors
.Where(c => !c.IsStatic && c.IsPublic && c.GetParameters().Length > 0).ToArray();

if (constructors.Length == 1)
if (constructors.Length == 1
&& constructors[0].GetParameters() is { } parameters
&& parameters.Any(p => p.ParameterType == typeof(DbContextOptions) || p.ParameterType == typeof(DbContextOptions<TContext>)))
{
var parameters = constructors[0].GetParameters();
if (parameters.Length == 1
&& (parameters[0].ParameterType == typeof(DbContextOptions)
|| parameters[0].ParameterType == typeof(DbContextOptions<TContext>)))
if (parameters.Length == 1)
{
return Expression.Lambda<Func<TContext>>(Expression.New(constructors[0], Expression.Constant(options))).Compile();
}

if (serviceProvider is not null)
{
var factory = ActivatorUtilities.CreateFactory<TContext>(new[] { typeof(DbContextOptions<TContext>) });
var activatorParameters = new object[] { options };
return () => factory.Invoke(serviceProvider, activatorParameters);
}
}

throw new InvalidOperationException(CoreStrings.PoolingContextCtorError(typeof(TContext).ShortDisplayName()));
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore/Properties/CoreStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 36 additions & 9 deletions test/EFCore.SqlServer.FunctionalTests/DbContextPoolingTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -454,40 +454,38 @@ public void Throws_when_used_with_parameterless_constructor_context()
}

[ConditionalFact]
public void Throws_when_pooled_context_constructor_has_more_than_one_parameter()
public void Throws_when_pooled_context_constructor_has_second_parameter_that_cannot_be_resolved_from_service_provider()
{
var serviceProvider
= new ServiceCollection().AddDbContextPool<TwoParameterConstructorContext>(_ => { })
.BuildServiceProvider(validateScopes: true);

using var scope = serviceProvider.CreateScope();

Assert.Equal(
CoreStrings.PoolingContextCtorError(nameof(TwoParameterConstructorContext)),
Assert.Throws<InvalidOperationException>(() => scope.ServiceProvider.GetService<TwoParameterConstructorContext>()).Message);
Assert.Throws<InvalidOperationException>(() => scope.ServiceProvider.GetService<TwoParameterConstructorContext>());
}

private class TwoParameterConstructorContext : DbContext
{
public string StringParameter { get; }

public TwoParameterConstructorContext(DbContextOptions options, string x)
: base(options)
{
StringParameter = x;
}
}

[ConditionalFact]
public void Throws_when_pooled_context_constructor_wrong_parameter()
public void Throws_when_pooled_context_constructor_has_single_parameter_that_cannot_be_resolved_from_service_provider()
{
var serviceProvider
= new ServiceCollection().AddDbContextPool<WrongParameterConstructorContext>(_ => { })
.BuildServiceProvider(validateScopes: true);

using var scope = serviceProvider.CreateScope();

Assert.Equal(
CoreStrings.PoolingContextCtorError(nameof(WrongParameterConstructorContext)),
Assert.Throws<InvalidOperationException>(() => scope.ServiceProvider.GetService<WrongParameterConstructorContext>())
.Message);
Assert.Throws<InvalidOperationException>(() => scope.ServiceProvider.GetService<WrongParameterConstructorContext>());
}

private class WrongParameterConstructorContext : DbContext
Expand All @@ -498,6 +496,35 @@ public WrongParameterConstructorContext(string x)
}
}

[ConditionalFact]
public void Throws_when_pooled_context_constructor_has_scoped_service()
{
var serviceProvider
= new ServiceCollection()
.AddDbContextPool<TwoParameterConstructorContext>(_ => { })
.AddScoped(sp => "string")
.BuildServiceProvider(validateScopes: true);

using var scope = serviceProvider.CreateScope();

Assert.Throws<InvalidOperationException>(() => scope.ServiceProvider.GetService<TwoParameterConstructorContext>());
}

[ConditionalFact]
public void Does_not_throw_when_pooled_context_constructor_has_singleton_service()
{
var serviceProvider
= new ServiceCollection()
.AddDbContextPool<TwoParameterConstructorContext>(_ => { })
.AddSingleton("string")
.BuildServiceProvider(validateScopes: true);

using var scope = serviceProvider.CreateScope();
var context = scope.ServiceProvider.GetService<TwoParameterConstructorContext>();

Assert.Equal("string", context.StringParameter);
}

[ConditionalFact]
public void Does_not_throw_when_parameterless_and_correct_constructor()
{
Expand Down

0 comments on commit b35364d

Please sign in to comment.