Skip to content

Commit

Permalink
Don't use service discovery for DbContext types if corresponding IDes…
Browse files Browse the repository at this point in the history
…ignTimeDbContextFactory implementations are found. (#34082)

Throw for MSBuild-based execution if no DbContext types are found.

Fixes #27322
  • Loading branch information
AndriySvyryd authored Jul 8, 2024
1 parent a4aa68d commit 5f0887d
Show file tree
Hide file tree
Showing 11 changed files with 406 additions and 117 deletions.
244 changes: 146 additions & 98 deletions src/EFCore.Design/Design/Internal/DbContextOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,9 @@ public virtual IReadOnlyList<string> Optimize(
var optimizeAllInAssembly = contextTypeName == "*";
var contexts = optimizeAllInAssembly ? CreateAllContexts() : [CreateContext(contextTypeName)];

MSBuildLocator.RegisterDefaults();

List<string> generatedFiles = [];
HashSet<string> generatedFileNames = [];
var contextOptimized = false;
foreach (var context in contexts)
{
using (context)
Expand All @@ -158,6 +157,20 @@ public virtual IReadOnlyList<string> Optimize(
optimizeAllInAssembly,
generatedFiles,
generatedFileNames);
contextOptimized = true;
}
}

if (optimizeAllInAssembly)
{
if (!contextOptimized)
{
throw new OperationException(DesignStrings.NoContextsToOptimize);
}

if (generatedFiles.Count == 0)
{
_reporter.WriteWarning(DesignStrings.OptimizeNoFilesGenerated);
}
}

Expand Down Expand Up @@ -269,6 +282,10 @@ private IReadOnlyList<string> PrecompileQueries(string? outputDir, DbContext con
{
outputDir = Path.GetFullPath(Path.Combine(_projectDir, outputDir ?? "Generated"));

if (!MSBuildLocator.IsRegistered)
{
MSBuildLocator.RegisterDefaults();
}
// TODO: pass through properties
var workspace = MSBuildWorkspace.Create();
workspace.LoadMetadataForReferencedProjects = true;
Expand Down Expand Up @@ -373,6 +390,44 @@ static async Task<object> FormatCode(Project project, ScaffoldedFile generatedFi
: null;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ContextInfo GetContextInfo(string? contextType)
{
using var context = CreateContext(contextType);
var info = new ContextInfo { Type = context.GetType().FullName! };

var provider = context.GetService<IDatabaseProvider>();
info.ProviderName = provider.Name;

if (((IDatabaseFacadeDependenciesAccessor)context.Database).Dependencies is IRelationalDatabaseFacadeDependencies)
{
try
{
var connection = context.Database.GetDbConnection();
info.DataSource = connection.DataSource;
info.DatabaseName = connection.Database;
}
catch (Exception exception)
{
info.DataSource = info.DatabaseName = DesignStrings.BadConnection(exception.Message);
}
}
else
{
info.DataSource = info.DatabaseName = DesignStrings.NoRelationalConnection;
}

var options = context.GetService<IDbContextOptions>();
info.Options = options.BuildOptionsFragment().Trim();

return info;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -445,11 +500,11 @@ public virtual IEnumerable<Type> GetContextTypes()
public virtual Type GetContextType(string? name)
=> FindContextType(name).Key;

private IDictionary<Type, Func<DbContext>> FindContextTypes()
private IDictionary<Type, Func<DbContext>> FindContextTypes(string? name = null)
{
_reporter.WriteVerbose(DesignStrings.FindingContexts);

var contexts = new Dictionary<Type, Func<DbContext>>();
var contexts = new Dictionary<Type, Func<DbContext>?>();

try
{
Expand All @@ -475,28 +530,18 @@ where i.IsGenericType
}

// Look for DbContextAttribute on the assembly
var appServices = _appServicesFactory.Create(_args);
foreach (var contextAttribute in _startupAssembly.GetCustomAttributes<DbContextAttribute>())
{
var context = contextAttribute.ContextType;
_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context)
?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context)));
}
if (contexts.ContainsKey(context))
{
continue;
}

// Look for DbContext classes registered in the service provider
var registeredContexts = appServices.GetServices<DbContextOptions>()
.Select(o => o.ContextType);
foreach (var context in registeredContexts.Where(c => !contexts.ContainsKey(c)))
{
_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context)
?? FindContextFromRuntimeDbContextFactory(appServices, context)
?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context)));
FindContextFactory(context));
}

// Look for DbContext classes in assemblies
Expand All @@ -507,87 +552,93 @@ where i.IsGenericType

var contextTypes = types.Where(t => typeof(DbContext).IsAssignableFrom(t)).Select(
t => t.AsType())
.Concat(
.Concat<Type>(
types.Where(t => typeof(Migration).IsAssignableFrom(t))
.Select(t => t.GetCustomAttribute<DbContextAttribute>()?.ContextType)
.Where(t => t != null)
.Cast<Type>())
.Where(t => t != null)!)
.Distinct();

foreach (var context in contextTypes.Where(c => !contexts.ContainsKey(c)))
foreach (var context in contextTypes)
{
if (contexts.ContainsKey(context))
{
continue;
}

_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context)
?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context)));
FindContextFactory(context));
}
}
catch (Exception ex)
{
if (ex is OperationException)

if (!string.IsNullOrEmpty(name))
{
throw;
contexts = FilterTypes(contexts, name, throwOnEmpty: false);
}

if (ex is TargetInvocationException)
if (contexts.Values.All(f => f != null)
&& (string.IsNullOrEmpty(name) || contexts.Count == 1))
{
ex = ex.InnerException!;
return contexts!;
}

throw new OperationException(DesignStrings.CannotFindDbContextTypes(ex.Message), ex);
}

return contexts;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ContextInfo GetContextInfo(string? contextType)
{
using var context = CreateContext(contextType);
var info = new ContextInfo { Type = context.GetType().FullName! };
// Look for DbContext classes registered in the service provider
var appServices = _appServicesFactory.Create(_args);
foreach (var options in appServices.GetServices<DbContextOptions>())
{
var context = options.ContextType;
if (contexts.ContainsKey(context))
{
continue;
}

var provider = context.GetService<IDatabaseProvider>();
info.ProviderName = provider.Name;
_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context));
}

if (((IDatabaseFacadeDependenciesAccessor)context.Database).Dependencies is IRelationalDatabaseFacadeDependencies)
{
try
if (!string.IsNullOrEmpty(name))
{
var connection = context.Database.GetDbConnection();
info.DataSource = connection.DataSource;
info.DatabaseName = connection.Database;
contexts = FilterTypes(contexts, name, throwOnEmpty: true);
}
catch (Exception exception)

foreach (var contextPair in contexts)
{
info.DataSource = info.DatabaseName = DesignStrings.BadConnection(exception.Message);
if (contextPair.Value == null)
{
var context = contextPair.Key;
contexts[context] = CreateContextFromServiceProvider(appServices, context);
}
}
}
else
catch (Exception ex)
{
info.DataSource = info.DatabaseName = DesignStrings.NoRelationalConnection;
}
if (ex is OperationException)
{
throw;
}

var options = context.GetService<IDbContextOptions>();
info.Options = options.BuildOptionsFragment().Trim();
if (ex is TargetInvocationException)
{
ex = ex.InnerException!;
}

return info;
throw new OperationException(DesignStrings.CannotFindDbContextTypes(ex.Message), ex);
}

return contexts!;
}

private static Func<DbContext>? FindContextFromRuntimeDbContextFactory(IServiceProvider appServices, Type contextType)
private static Func<DbContext> CreateContextFromServiceProvider(IServiceProvider appServices, Type contextType)
{
var factoryInterface = typeof(IDbContextFactory<>).MakeGenericType(contextType);
var service = appServices.GetService(factoryInterface);
return service == null
? null
var factoryService = appServices.GetService(factoryInterface);
return factoryService == null
? () => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, contextType)
: () => (DbContext)factoryInterface
.GetMethod(nameof(IDbContextFactory<DbContext>.CreateDbContext))
!.Invoke(service, null)!;
!.Invoke(factoryService, null)!;
}

private Func<DbContext>? FindContextFactory(Type contextType)
Expand All @@ -609,44 +660,45 @@ private DbContext CreateContextFromFactory(Type factory, Type contextType)

private KeyValuePair<Type, Func<DbContext>> FindContextType(string? name)
{
var types = FindContextTypes();

if (string.IsNullOrEmpty(name))
{
if (types.Count == 0)
var types = FindContextTypes(name);
return !string.IsNullOrEmpty(name)
? types.First()
: types.Count switch
{
throw new OperationException(DesignStrings.NoContext(_assembly.GetName().Name));
}

if (types.Count == 1)
{
return types.First();
}

throw new OperationException(DesignStrings.MultipleContexts);
}
0 => throw new OperationException(DesignStrings.NoContext(_assembly.GetName().Name)),
1 => types.First(),
_ => throw new OperationException(DesignStrings.MultipleContexts)
};
}

var candidates = FilterTypes(types, name, ignoreCase: true);
private Dictionary<Type, Func<DbContext>?> FilterTypes(
Dictionary<Type, Func<DbContext>?> types,
string name,
bool throwOnEmpty)
{
var candidates = FilterTypes(types, name, StringComparison.OrdinalIgnoreCase);
if (candidates.Count == 0)
{
throw new OperationException(DesignStrings.NoContextWithName(name));
return throwOnEmpty
? throw new OperationException(DesignStrings.NoContextWithName(name))
: candidates;
}

if (candidates.Count == 1)
{
return candidates.First();
return candidates;
}

// Disambiguate using case
candidates = FilterTypes(candidates, name);
candidates = FilterTypes(candidates, name, StringComparison.Ordinal);
if (candidates.Count == 0)
{
throw new OperationException(DesignStrings.MultipleContextsWithName(name));
}

if (candidates.Count == 1)
{
return candidates.First();
return candidates;
}

// Allow selecting types in the default namespace
Expand All @@ -658,21 +710,17 @@ private KeyValuePair<Type, Func<DbContext>> FindContextType(string? name)

Check.DebugAssert(candidates.Count == 1, $"candidates.Count is {candidates.Count}");

return candidates.First();
return candidates;
}

private static IDictionary<Type, Func<DbContext>> FilterTypes(
IDictionary<Type, Func<DbContext>> types,
private static Dictionary<Type, Func<DbContext>?> FilterTypes(
Dictionary<Type, Func<DbContext>?> types,
string name,
bool ignoreCase = false)
{
var comparisonType = ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal;

return types
StringComparison comparisonType)
=> types
.Where(
t => string.Equals(t.Key.Name, name, comparisonType)
|| string.Equals(t.Key.FullName, name, comparisonType)
|| string.Equals(t.Key.AssemblyQualifiedName, name, comparisonType))
.ToDictionary();
}
}
12 changes: 12 additions & 0 deletions src/EFCore.Design/Properties/DesignStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 5f0887d

Please sign in to comment.