Skip to content

Commit

Permalink
Don't use service discovery for DbContext types if corresponding IDes…
Browse files Browse the repository at this point in the history
…ignTimeDbContextFactory implementations are found.

Throw for MSBuild-based execution if no DbContext types are found.

Fixes #27322
  • Loading branch information
AndriySvyryd committed Jun 25, 2024
1 parent 5ac6eca commit fe50f0d
Show file tree
Hide file tree
Showing 11 changed files with 402 additions and 115 deletions.
238 changes: 142 additions & 96 deletions src/EFCore.Design/Design/Internal/DbContextOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ public virtual IReadOnlyList<string> Optimize(

List<string> generatedFiles = [];
HashSet<string> generatedFileNames = [];
var contextOptimized = false;
foreach (var context in contexts)
{
using (context)
Expand All @@ -158,6 +159,20 @@ public virtual IReadOnlyList<string> Optimize(
optimizeAllInAssembly,
generatedFiles,
generatedFileNames);
contextOptimized = true;
}
}

if (optimizeAllInAssembly)
{
if (!contextOptimized)
{
throw new OperationException(DesignStrings.NoContextsToOptimize);
}

if (generatedFiles.Count == 0)
{
_reporter.WriteWarning(DesignStrings.OptimizeNoFilesGenerated);
}
}

Expand Down Expand Up @@ -373,6 +388,44 @@ static async Task<object> FormatCode(Project project, ScaffoldedFile generatedFi
: null;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ContextInfo GetContextInfo(string? contextType)
{
using var context = CreateContext(contextType);
var info = new ContextInfo { Type = context.GetType().FullName! };

var provider = context.GetService<IDatabaseProvider>();
info.ProviderName = provider.Name;

if (((IDatabaseFacadeDependenciesAccessor)context.Database).Dependencies is IRelationalDatabaseFacadeDependencies)
{
try
{
var connection = context.Database.GetDbConnection();
info.DataSource = connection.DataSource;
info.DatabaseName = connection.Database;
}
catch (Exception exception)
{
info.DataSource = info.DatabaseName = DesignStrings.BadConnection(exception.Message);
}
}
else
{
info.DataSource = info.DatabaseName = DesignStrings.NoRelationalConnection;
}

var options = context.GetService<IDbContextOptions>();
info.Options = options.BuildOptionsFragment().Trim();

return info;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -445,11 +498,11 @@ public virtual IEnumerable<Type> GetContextTypes()
public virtual Type GetContextType(string? name)
=> FindContextType(name).Key;

private IDictionary<Type, Func<DbContext>> FindContextTypes()
private IDictionary<Type, Func<DbContext>> FindContextTypes(string? name = null)
{
_reporter.WriteVerbose(DesignStrings.FindingContexts);

var contexts = new Dictionary<Type, Func<DbContext>>();
var contexts = new Dictionary<Type, Func<DbContext>?>();

try
{
Expand All @@ -475,28 +528,18 @@ where i.IsGenericType
}

// Look for DbContextAttribute on the assembly
var appServices = _appServicesFactory.Create(_args);
foreach (var contextAttribute in _startupAssembly.GetCustomAttributes<DbContextAttribute>())
{
var context = contextAttribute.ContextType;
_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context)
?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context)));
}
if (contexts.ContainsKey(context))
{
continue;
}

// Look for DbContext classes registered in the service provider
var registeredContexts = appServices.GetServices<DbContextOptions>()
.Select(o => o.ContextType);
foreach (var context in registeredContexts.Where(c => !contexts.ContainsKey(c)))
{
_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context)
?? FindContextFromRuntimeDbContextFactory(appServices, context)
?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context)));
FindContextFactory(context));
}

// Look for DbContext classes in assemblies
Expand All @@ -507,87 +550,93 @@ where i.IsGenericType

var contextTypes = types.Where(t => typeof(DbContext).IsAssignableFrom(t)).Select(
t => t.AsType())
.Concat(
.Concat<Type>(
types.Where(t => typeof(Migration).IsAssignableFrom(t))
.Select(t => t.GetCustomAttribute<DbContextAttribute>()?.ContextType)
.Where(t => t != null)
.Cast<Type>())
.Where(t => t != null)!)
.Distinct();

foreach (var context in contextTypes.Where(c => !contexts.ContainsKey(c)))
foreach (var context in contextTypes)
{
if (contexts.ContainsKey(context))
{
continue;
}

_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context)
?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context)));
FindContextFactory(context));
}
}
catch (Exception ex)
{
if (ex is OperationException)

if (!string.IsNullOrEmpty(name))
{
throw;
contexts = FilterTypes(contexts, name, throwOnEmpty: false);
}

if (ex is TargetInvocationException)
if (contexts.Values.All(f => f != null)
&& (string.IsNullOrEmpty(name) || contexts.Count == 1))
{
ex = ex.InnerException!;
return contexts!;
}

throw new OperationException(DesignStrings.CannotFindDbContextTypes(ex.Message), ex);
}

return contexts;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ContextInfo GetContextInfo(string? contextType)
{
using var context = CreateContext(contextType);
var info = new ContextInfo { Type = context.GetType().FullName! };
// Look for DbContext classes registered in the service provider
var appServices = _appServicesFactory.Create(_args);
foreach (var options in appServices.GetServices<DbContextOptions>())
{
var context = options.ContextType;
if (contexts.ContainsKey(context))
{
continue;
}

var provider = context.GetService<IDatabaseProvider>();
info.ProviderName = provider.Name;
_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context));
}

if (((IDatabaseFacadeDependenciesAccessor)context.Database).Dependencies is IRelationalDatabaseFacadeDependencies)
{
try
if (!string.IsNullOrEmpty(name))
{
var connection = context.Database.GetDbConnection();
info.DataSource = connection.DataSource;
info.DatabaseName = connection.Database;
contexts = FilterTypes(contexts, name, throwOnEmpty: true);
}
catch (Exception exception)

foreach (var contextPair in contexts)
{
info.DataSource = info.DatabaseName = DesignStrings.BadConnection(exception.Message);
if (contextPair.Value == null)
{
var context = contextPair.Key;
contexts[context] = CreateContextFromServiceProvider(appServices, context);
}
}
}
else
catch (Exception ex)
{
info.DataSource = info.DatabaseName = DesignStrings.NoRelationalConnection;
}
if (ex is OperationException)
{
throw;
}

var options = context.GetService<IDbContextOptions>();
info.Options = options.BuildOptionsFragment().Trim();
if (ex is TargetInvocationException)
{
ex = ex.InnerException!;
}

return info;
throw new OperationException(DesignStrings.CannotFindDbContextTypes(ex.Message), ex);
}

return contexts!;
}

private static Func<DbContext>? FindContextFromRuntimeDbContextFactory(IServiceProvider appServices, Type contextType)
private static Func<DbContext> CreateContextFromServiceProvider(IServiceProvider appServices, Type contextType)
{
var factoryInterface = typeof(IDbContextFactory<>).MakeGenericType(contextType);
var service = appServices.GetService(factoryInterface);
return service == null
? null
var factoryService = appServices.GetService(factoryInterface);
return factoryService == null
? () => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, contextType)
: () => (DbContext)factoryInterface
.GetMethod(nameof(IDbContextFactory<DbContext>.CreateDbContext))
!.Invoke(service, null)!;
!.Invoke(factoryService, null)!;
}

private Func<DbContext>? FindContextFactory(Type contextType)
Expand All @@ -609,44 +658,45 @@ private DbContext CreateContextFromFactory(Type factory, Type contextType)

private KeyValuePair<Type, Func<DbContext>> FindContextType(string? name)
{
var types = FindContextTypes();

if (string.IsNullOrEmpty(name))
{
if (types.Count == 0)
var types = FindContextTypes(name);
return !string.IsNullOrEmpty(name)
? types.First()
: types.Count switch
{
throw new OperationException(DesignStrings.NoContext(_assembly.GetName().Name));
}

if (types.Count == 1)
{
return types.First();
}

throw new OperationException(DesignStrings.MultipleContexts);
}
0 => throw new OperationException(DesignStrings.NoContext(_assembly.GetName().Name)),
1 => types.First(),
_ => throw new OperationException(DesignStrings.MultipleContexts)
};
}

var candidates = FilterTypes(types, name, ignoreCase: true);
private Dictionary<Type, Func<DbContext>?> FilterTypes(
Dictionary<Type, Func<DbContext>?> types,
string name,
bool throwOnEmpty)
{
var candidates = FilterTypes(types, name, StringComparison.OrdinalIgnoreCase);
if (candidates.Count == 0)
{
throw new OperationException(DesignStrings.NoContextWithName(name));
return throwOnEmpty
? throw new OperationException(DesignStrings.NoContextWithName(name))
: candidates;
}

if (candidates.Count == 1)
{
return candidates.First();
return candidates;
}

// Disambiguate using case
candidates = FilterTypes(candidates, name);
candidates = FilterTypes(candidates, name, StringComparison.Ordinal);
if (candidates.Count == 0)
{
throw new OperationException(DesignStrings.MultipleContextsWithName(name));
}

if (candidates.Count == 1)
{
return candidates.First();
return candidates;
}

// Allow selecting types in the default namespace
Expand All @@ -658,21 +708,17 @@ private KeyValuePair<Type, Func<DbContext>> FindContextType(string? name)

Check.DebugAssert(candidates.Count == 1, $"candidates.Count is {candidates.Count}");

return candidates.First();
return candidates;
}

private static IDictionary<Type, Func<DbContext>> FilterTypes(
IDictionary<Type, Func<DbContext>> types,
private static Dictionary<Type, Func<DbContext>?> FilterTypes(
Dictionary<Type, Func<DbContext>?> types,
string name,
bool ignoreCase = false)
{
var comparisonType = ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal;

return types
StringComparison comparisonType)
=> types
.Where(
t => string.Equals(t.Key.Name, name, comparisonType)
|| string.Equals(t.Key.FullName, name, comparisonType)
|| string.Equals(t.Key.AssemblyQualifiedName, name, comparisonType))
.ToDictionary();
}
}
12 changes: 12 additions & 0 deletions src/EFCore.Design/Properties/DesignStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit fe50f0d

Please sign in to comment.