From fe50f0db7f487f05db2c5ed092eaeda379c7dc90 Mon Sep 17 00:00:00 2001 From: Andriy Svyryd Date: Mon, 24 Jun 2024 17:15:14 -0700 Subject: [PATCH] Don't use service discovery for DbContext types if corresponding IDesignTimeDbContextFactory implementations are found. Throw for MSBuild-based execution if no DbContext types are found. Fixes #27322 --- .../Design/Internal/DbContextOperations.cs | 238 +++++++++++------- .../Properties/DesignStrings.Designer.cs | 12 + .../Properties/DesignStrings.resx | 6 + .../Properties/Resources.Designer.cs | 2 +- src/dotnet-ef/Properties/Resources.resx | 2 +- src/ef/Properties/Resources.Designer.cs | 2 +- src/ef/Properties/Resources.resx | 2 +- .../Design/Internal/DatabaseOperationsTest.cs | 2 +- .../Internal/DbContextOperationsTest.cs | 221 +++++++++++++++- .../TestAppServiceProviderFactory.cs | 26 +- ...SharedPrimitiveCollectionsQueryTestBase.cs | 4 +- 11 files changed, 402 insertions(+), 115 deletions(-) diff --git a/src/EFCore.Design/Design/Internal/DbContextOperations.cs b/src/EFCore.Design/Design/Internal/DbContextOperations.cs index 71b61ca58f4..5651eadc9c1 100644 --- a/src/EFCore.Design/Design/Internal/DbContextOperations.cs +++ b/src/EFCore.Design/Design/Internal/DbContextOperations.cs @@ -144,6 +144,7 @@ public virtual IReadOnlyList Optimize( List generatedFiles = []; HashSet generatedFileNames = []; + var contextOptimized = false; foreach (var context in contexts) { using (context) @@ -158,6 +159,20 @@ public virtual IReadOnlyList Optimize( optimizeAllInAssembly, generatedFiles, generatedFileNames); + contextOptimized = true; + } + } + + if (optimizeAllInAssembly) + { + if (!contextOptimized) + { + throw new OperationException(DesignStrings.NoContextsToOptimize); + } + + if (generatedFiles.Count == 0) + { + _reporter.WriteWarning(DesignStrings.OptimizeNoFilesGenerated); } } @@ -373,6 +388,44 @@ static async Task FormatCode(Project project, ScaffoldedFile generatedFi : null; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual ContextInfo GetContextInfo(string? contextType) + { + using var context = CreateContext(contextType); + var info = new ContextInfo { Type = context.GetType().FullName! }; + + var provider = context.GetService(); + info.ProviderName = provider.Name; + + if (((IDatabaseFacadeDependenciesAccessor)context.Database).Dependencies is IRelationalDatabaseFacadeDependencies) + { + try + { + var connection = context.Database.GetDbConnection(); + info.DataSource = connection.DataSource; + info.DatabaseName = connection.Database; + } + catch (Exception exception) + { + info.DataSource = info.DatabaseName = DesignStrings.BadConnection(exception.Message); + } + } + else + { + info.DataSource = info.DatabaseName = DesignStrings.NoRelationalConnection; + } + + var options = context.GetService(); + info.Options = options.BuildOptionsFragment().Trim(); + + return info; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -445,11 +498,11 @@ public virtual IEnumerable GetContextTypes() public virtual Type GetContextType(string? name) => FindContextType(name).Key; - private IDictionary> FindContextTypes() + private IDictionary> FindContextTypes(string? name = null) { _reporter.WriteVerbose(DesignStrings.FindingContexts); - var contexts = new Dictionary>(); + var contexts = new Dictionary?>(); try { @@ -475,28 +528,18 @@ where i.IsGenericType } // Look for DbContextAttribute on the assembly - var appServices = _appServicesFactory.Create(_args); foreach (var contextAttribute in _startupAssembly.GetCustomAttributes()) { var context = contextAttribute.ContextType; - _reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName())); - contexts.Add( - context, - FindContextFactory(context) - ?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context))); - } + if (contexts.ContainsKey(context)) + { + continue; + } - // Look for DbContext classes registered in the service provider - var registeredContexts = appServices.GetServices() - .Select(o => o.ContextType); - foreach (var context in registeredContexts.Where(c => !contexts.ContainsKey(c))) - { _reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName())); contexts.Add( context, - FindContextFactory(context) - ?? FindContextFromRuntimeDbContextFactory(appServices, context) - ?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context))); + FindContextFactory(context)); } // Look for DbContext classes in assemblies @@ -507,87 +550,93 @@ where i.IsGenericType var contextTypes = types.Where(t => typeof(DbContext).IsAssignableFrom(t)).Select( t => t.AsType()) - .Concat( + .Concat( types.Where(t => typeof(Migration).IsAssignableFrom(t)) .Select(t => t.GetCustomAttribute()?.ContextType) - .Where(t => t != null) - .Cast()) + .Where(t => t != null)!) .Distinct(); - foreach (var context in contextTypes.Where(c => !contexts.ContainsKey(c))) + foreach (var context in contextTypes) { + if (contexts.ContainsKey(context)) + { + continue; + } + _reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName())); contexts.Add( context, - FindContextFactory(context) - ?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context))); + FindContextFactory(context)); } - } - catch (Exception ex) - { - if (ex is OperationException) + + if (!string.IsNullOrEmpty(name)) { - throw; + contexts = FilterTypes(contexts, name, throwOnEmpty: false); } - if (ex is TargetInvocationException) + if (contexts.Values.All(f => f != null) + && (string.IsNullOrEmpty(name) || contexts.Count == 1)) { - ex = ex.InnerException!; + return contexts!; } - throw new OperationException(DesignStrings.CannotFindDbContextTypes(ex.Message), ex); - } - - return contexts; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual ContextInfo GetContextInfo(string? contextType) - { - using var context = CreateContext(contextType); - var info = new ContextInfo { Type = context.GetType().FullName! }; + // Look for DbContext classes registered in the service provider + var appServices = _appServicesFactory.Create(_args); + foreach (var options in appServices.GetServices()) + { + var context = options.ContextType; + if (contexts.ContainsKey(context)) + { + continue; + } - var provider = context.GetService(); - info.ProviderName = provider.Name; + _reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName())); + contexts.Add( + context, + FindContextFactory(context)); + } - if (((IDatabaseFacadeDependenciesAccessor)context.Database).Dependencies is IRelationalDatabaseFacadeDependencies) - { - try + if (!string.IsNullOrEmpty(name)) { - var connection = context.Database.GetDbConnection(); - info.DataSource = connection.DataSource; - info.DatabaseName = connection.Database; + contexts = FilterTypes(contexts, name, throwOnEmpty: true); } - catch (Exception exception) + + foreach (var contextPair in contexts) { - info.DataSource = info.DatabaseName = DesignStrings.BadConnection(exception.Message); + if (contextPair.Value == null) + { + var context = contextPair.Key; + contexts[context] = CreateContextFromServiceProvider(appServices, context); + } } } - else + catch (Exception ex) { - info.DataSource = info.DatabaseName = DesignStrings.NoRelationalConnection; - } + if (ex is OperationException) + { + throw; + } - var options = context.GetService(); - info.Options = options.BuildOptionsFragment().Trim(); + if (ex is TargetInvocationException) + { + ex = ex.InnerException!; + } - return info; + throw new OperationException(DesignStrings.CannotFindDbContextTypes(ex.Message), ex); + } + + return contexts!; } - private static Func? FindContextFromRuntimeDbContextFactory(IServiceProvider appServices, Type contextType) + private static Func CreateContextFromServiceProvider(IServiceProvider appServices, Type contextType) { var factoryInterface = typeof(IDbContextFactory<>).MakeGenericType(contextType); - var service = appServices.GetService(factoryInterface); - return service == null - ? null + var factoryService = appServices.GetService(factoryInterface); + return factoryService == null + ? () => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, contextType) : () => (DbContext)factoryInterface .GetMethod(nameof(IDbContextFactory.CreateDbContext)) - !.Invoke(service, null)!; + !.Invoke(factoryService, null)!; } private Func? FindContextFactory(Type contextType) @@ -609,36 +658,37 @@ private DbContext CreateContextFromFactory(Type factory, Type contextType) private KeyValuePair> FindContextType(string? name) { - var types = FindContextTypes(); - - if (string.IsNullOrEmpty(name)) - { - if (types.Count == 0) + var types = FindContextTypes(name); + return !string.IsNullOrEmpty(name) + ? types.First() + : types.Count switch { - throw new OperationException(DesignStrings.NoContext(_assembly.GetName().Name)); - } - - if (types.Count == 1) - { - return types.First(); - } - - throw new OperationException(DesignStrings.MultipleContexts); - } + 0 => throw new OperationException(DesignStrings.NoContext(_assembly.GetName().Name)), + 1 => types.First(), + _ => throw new OperationException(DesignStrings.MultipleContexts) + }; + } - var candidates = FilterTypes(types, name, ignoreCase: true); + private Dictionary?> FilterTypes( + Dictionary?> types, + string name, + bool throwOnEmpty) + { + var candidates = FilterTypes(types, name, StringComparison.OrdinalIgnoreCase); if (candidates.Count == 0) { - throw new OperationException(DesignStrings.NoContextWithName(name)); + return throwOnEmpty + ? throw new OperationException(DesignStrings.NoContextWithName(name)) + : candidates; } if (candidates.Count == 1) { - return candidates.First(); + return candidates; } // Disambiguate using case - candidates = FilterTypes(candidates, name); + candidates = FilterTypes(candidates, name, StringComparison.Ordinal); if (candidates.Count == 0) { throw new OperationException(DesignStrings.MultipleContextsWithName(name)); @@ -646,7 +696,7 @@ private KeyValuePair> FindContextType(string? name) if (candidates.Count == 1) { - return candidates.First(); + return candidates; } // Allow selecting types in the default namespace @@ -658,21 +708,17 @@ private KeyValuePair> FindContextType(string? name) Check.DebugAssert(candidates.Count == 1, $"candidates.Count is {candidates.Count}"); - return candidates.First(); + return candidates; } - private static IDictionary> FilterTypes( - IDictionary> types, + private static Dictionary?> FilterTypes( + Dictionary?> types, string name, - bool ignoreCase = false) - { - var comparisonType = ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal; - - return types + StringComparison comparisonType) + => types .Where( t => string.Equals(t.Key.Name, name, comparisonType) || string.Equals(t.Key.FullName, name, comparisonType) || string.Equals(t.Key.AssemblyQualifiedName, name, comparisonType)) .ToDictionary(); - } } diff --git a/src/EFCore.Design/Properties/DesignStrings.Designer.cs b/src/EFCore.Design/Properties/DesignStrings.Designer.cs index 52142c48ce1..3a107cb980d 100644 --- a/src/EFCore.Design/Properties/DesignStrings.Designer.cs +++ b/src/EFCore.Design/Properties/DesignStrings.Designer.cs @@ -496,6 +496,12 @@ public static string NoContext(object? assembly) GetString("NoContext", nameof(assembly)), assembly); + /// + /// No type deriving from DbContext was found. Add [assembly: DbContext(typeof(*))] attribute for every context type used in this project. + /// + public static string NoContextsToOptimize + => GetString("NoContextsToOptimize"); + /// /// You must provide a DbContext.t4 file in order to scaffold using custom templates. /// @@ -614,6 +620,12 @@ public static string NotExistDatabase(object? name) GetString("NotExistDatabase", nameof(name)), name); + /// + /// No files were generated during the DbContext optimization. Ensure that the target project has code that uses DbContext and that the supplied options are correct. + /// + public static string OptimizeNoFilesGenerated + => GetString("OptimizeNoFilesGenerated"); + /// /// Changes have been made to the model since the last migration. Add a new migration. /// diff --git a/src/EFCore.Design/Properties/DesignStrings.resx b/src/EFCore.Design/Properties/DesignStrings.resx index 5029ecbe9cb..f4d95755480 100644 --- a/src/EFCore.Design/Properties/DesignStrings.resx +++ b/src/EFCore.Design/Properties/DesignStrings.resx @@ -312,6 +312,9 @@ Change your target project to the migrations project by using the Package Manage No DbContext was found in assembly '{assembly}'. Ensure that you're using the correct assembly and that the type is neither abstract nor generic. + + No type deriving from DbContext was found. Add [assembly: DbContext(typeof(*))] attribute for every context type used in this project. + You must provide a DbContext.t4 file in order to scaffold using custom templates. @@ -363,6 +366,9 @@ Change your target project to the migrations project by using the Package Manage Database '{name}' did not exist, no action was taken. + + No files were generated during the DbContext optimization. Ensure that the target project has code that uses DbContext and that the supplied options are correct. + Changes have been made to the model since the last migration. Add a new migration. diff --git a/src/dotnet-ef/Properties/Resources.Designer.cs b/src/dotnet-ef/Properties/Resources.Designer.cs index 91660680362..e11c02b8777 100644 --- a/src/dotnet-ef/Properties/Resources.Designer.cs +++ b/src/dotnet-ef/Properties/Resources.Designer.cs @@ -50,7 +50,7 @@ public static string ConnectionDescription => GetString("ConnectionDescription"); /// - /// The DbContext to use. + /// The DbContext to use. "*" can be used to run the command for all contexts found. This will also disable service discovery through the startup project if a corresponding IDesignTimeDbContextFactory implementation is found. /// public static string ContextDescription => GetString("ContextDescription"); diff --git a/src/dotnet-ef/Properties/Resources.resx b/src/dotnet-ef/Properties/Resources.resx index 7b9f580a4c5..4182eb6d82a 100644 --- a/src/dotnet-ef/Properties/Resources.resx +++ b/src/dotnet-ef/Properties/Resources.resx @@ -133,7 +133,7 @@ The connection string to the database. - The DbContext to use. + The DbContext to use. "*" can be used to run the command for all contexts found. This will also disable service discovery through the startup project if a corresponding IDesignTimeDbContextFactory implementation is found. The directory to put the DbContext file in. Paths are relative to the project directory. diff --git a/src/ef/Properties/Resources.Designer.cs b/src/ef/Properties/Resources.Designer.cs index 94c955124e0..32f9b4f25c4 100644 --- a/src/ef/Properties/Resources.Designer.cs +++ b/src/ef/Properties/Resources.Designer.cs @@ -64,7 +64,7 @@ public static string ConnectionDescription => GetString("ConnectionDescription"); /// - /// The DbContext to use. + /// The DbContext to use. "*" can be used to run the command for all contexts found. This will also disable service discovery through the startup project if a corresponding IDesignTimeDbContextFactory implementation is found. /// public static string ContextDescription => GetString("ContextDescription"); diff --git a/src/ef/Properties/Resources.resx b/src/ef/Properties/Resources.resx index f0b17981907..f8f33bb0175 100644 --- a/src/ef/Properties/Resources.resx +++ b/src/ef/Properties/Resources.resx @@ -139,7 +139,7 @@ The connection string to the database. - The DbContext to use. + The DbContext to use. "*" can be used to run the command for all contexts found. This will also disable service discovery through the startup project if a corresponding IDesignTimeDbContextFactory implementation is found. The directory to put the DbContext file in. Paths are relative to the project directory. diff --git a/test/EFCore.Design.Tests/Design/Internal/DatabaseOperationsTest.cs b/test/EFCore.Design.Tests/Design/Internal/DatabaseOperationsTest.cs index 2d94efd84e8..d5c7a100554 100644 --- a/test/EFCore.Design.Tests/Design/Internal/DatabaseOperationsTest.cs +++ b/test/EFCore.Design.Tests/Design/Internal/DatabaseOperationsTest.cs @@ -22,5 +22,5 @@ public void Can_pass_null_args() args: null); } - private class TestContext : DbContext; + public class TestContext : DbContext; } diff --git a/test/EFCore.Design.Tests/Design/Internal/DbContextOperationsTest.cs b/test/EFCore.Design.Tests/Design/Internal/DbContextOperationsTest.cs index 5a086d6127e..26b8a37c323 100644 --- a/test/EFCore.Design.Tests/Design/Internal/DbContextOperationsTest.cs +++ b/test/EFCore.Design.Tests/Design/Internal/DbContextOperationsTest.cs @@ -10,7 +10,7 @@ public class DbContextOperationsTest { [ConditionalFact] public void CreateContext_gets_service() - => CreateOperations(typeof(TestProgram)).CreateContext(typeof(TestContext).FullName); + => CreateOperations(typeof(TestProgram)).CreateContext(typeof(TestContext).FullName.ToLower()); [ConditionalFact] public void CreateContext_gets_service_without_AddDbContext() @@ -20,14 +20,125 @@ public void CreateContext_gets_service_without_AddDbContext() public void CreateContext_gets_service_when_context_factory_used() => CreateOperations(typeof(TestProgramWithContextFactory)).CreateContext(typeof(TestContextFromFactory).FullName); + [ConditionalFact] + public void CreateContext_throws_if_context_type_not_found() + => Assert.Equal( + DesignStrings.NoContextWithName(typeof(TestContextFromFactory).FullName), + Assert.Throws(() => CreateOperations(typeof(TestProgramRelationalBad)).CreateContext(typeof(TestContextFromFactory).FullName)).Message); + + [ConditionalFact] + public void CreateContext_throws_if_ambiguous_context_type_by_case() + { + var assembly = MockAssembly.Create(typeof(TestContext), typeof(Testcontext)); + var reporter = new TestOperationReporter(); + var operations = new TestDbContextOperations( + reporter, + assembly, + assembly, + project: "", + projectDir: "", + rootNamespace: null, + language: "C#", + nullable: false, + /* args: */ [], + new TestAppServiceProviderFactory(assembly, reporter)); + + Assert.Equal( + DesignStrings.MultipleContextsWithName(typeof(TestContext).FullName.ToLower()), + Assert.Throws(() => operations.CreateContext(typeof(TestContext).FullName.ToLower())).Message); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Warning)); + } + + [ConditionalFact] + public void CreateContext_throws_if_ambiguous_context_type_by_namespace() + { + var assembly = MockAssembly.Create(typeof(TestContext), typeof(DatabaseOperationsTest.TestContext)); + var reporter = new TestOperationReporter(); + var operations = new TestDbContextOperations( + reporter, + assembly, + assembly, + project: "", + projectDir: "", + rootNamespace: null, + language: "C#", + nullable: false, + /* args: */ [], + new TestAppServiceProviderFactory(assembly, reporter)); + + Assert.Equal( + DesignStrings.MultipleContextsWithQualifiedName(nameof(TestContext)), + Assert.Throws(() => operations.CreateContext(nameof(TestContext))).Message); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Warning)); + } + + [ConditionalFact] + public void CreateContext_throws_if_ambiguous_context_type() + { + var assembly = MockAssembly.Create(typeof(TestContext), typeof(Testcontext)); + var reporter = new TestOperationReporter(); + var operations = new TestDbContextOperations( + reporter, + assembly, + assembly, + project: "", + projectDir: "", + rootNamespace: null, + language: "C#", + nullable: false, + /* args: */ [], + new TestAppServiceProviderFactory(assembly, reporter)); + + Assert.Equal( + DesignStrings.MultipleContexts, + Assert.Throws(() => operations.CreateContext(null)).Message); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Warning)); + } + + [ConditionalFact] + public void CreateContext_throws_if_no_context_type() + { + var assembly = MockAssembly.Create(); + var reporter = new TestOperationReporter(); + var operations = new TestDbContextOperations( + reporter, + assembly, + assembly, + project: "", + projectDir: "", + rootNamespace: null, + language: "C#", + nullable: false, + /* args: */ [], + new TestAppServiceProviderFactory(assembly, reporter)); + + Assert.Equal( + DesignStrings.NoContext(nameof(MockAssembly)), + Assert.Throws(() => operations.CreateContext(null)).Message); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Warning)); + } + [ConditionalFact] public void Can_pass_null_args() { // Even though newer versions of the tools will pass an empty array // older versions of the tools can pass null args. var assembly = MockAssembly.Create(typeof(TestContext)); - _ = new TestDbContextOperations( - new TestOperationReporter(), + var reporter = new TestOperationReporter(); + var operations = new TestDbContextOperations( + reporter, assembly, assembly, project: "", @@ -36,15 +147,20 @@ public void Can_pass_null_args() language: "C#", nullable: false, args: null, - new TestAppServiceProviderFactory(assembly)); + new TestAppServiceProviderFactory(assembly, reporter)); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Warning)); } [ConditionalFact] public void CreateContext_uses_exact_factory_method() { var assembly = MockAssembly.Create(typeof(BaseContext), typeof(DerivedContext), typeof(HierarchyContextFactory)); + var reporter = new TestOperationReporter(); var operations = new TestDbContextOperations( - new TestOperationReporter(), + reporter, assembly, assembly, project: "", @@ -53,21 +169,26 @@ public void CreateContext_uses_exact_factory_method() language: "C#", nullable: false, args: [], - new TestAppServiceProviderFactory(assembly)); + new TestAppServiceProviderFactory(assembly, reporter, throwOnCreate: true)); var baseContext = Assert.IsType(operations.CreateContext(nameof(BaseContext))); Assert.Equal(nameof(BaseContext), baseContext.FactoryUsed); var derivedContext = Assert.IsType(operations.CreateContext(nameof(DerivedContext))); Assert.Equal(nameof(DerivedContext), derivedContext.FactoryUsed); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Warning)); } [ConditionalFact] public void CreateAllContexts_creates_all_contexts() { var assembly = MockAssembly.Create(typeof(BaseContext), typeof(DerivedContext), typeof(HierarchyContextFactory)); + var reporter = new TestOperationReporter(); var operations = new TestDbContextOperations( - new TestOperationReporter(), + reporter, assembly, assembly, project: "", @@ -76,12 +197,70 @@ public void CreateAllContexts_creates_all_contexts() language: "C#", nullable: false, args: [], - new TestAppServiceProviderFactory(assembly)); + new TestAppServiceProviderFactory(assembly, reporter, throwOnCreate: true)); var contexts = operations.CreateAllContexts().ToList(); Assert.Collection(contexts, c => Assert.Equal(nameof(BaseContext), Assert.IsType(c).FactoryUsed), c => Assert.Equal(nameof(DerivedContext), Assert.IsType(c).FactoryUsed)); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Warning)); + } + + [ConditionalFact] + public void Optimize_throws_when_no_contexts() + { + var assembly = MockAssembly.Create(); + var reporter = new TestOperationReporter(); + var operations = new TestDbContextOperations( + reporter, + assembly, + assembly, + project: "", + projectDir: "", + rootNamespace: null, + language: "C#", + nullable: false, + args: [], + new TestAppServiceProviderFactory(assembly, reporter, throwOnCreate: true)); + + Assert.Equal( + DesignStrings.NoContextsToOptimize, + Assert.Throws(() => + operations.Optimize(null, null, contextTypeName: "*", null, scaffoldModel: true, precompileQueries: false)).Message); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Warning)); + } + + [ConditionalFact] + public void Optimize_shows_warning_when_nothing_was_generated() + { + var assembly = MockAssembly.Create(typeof(DerivedContext)); + var reporter = new TestOperationReporter(); + var operations = new TestDbContextOperations( + reporter, + assembly, + assembly, + project: "", + projectDir: "", + rootNamespace: null, + language: "C#", + nullable: false, + args: [], + new TestAppServiceProviderFactory(assembly, reporter, throwOnCreate: true)); + + operations.Optimize(null, null, contextTypeName: "*", null, scaffoldModel: true, precompileQueries: false); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + + Assert.Equal( + DesignStrings.OptimizeNoFilesGenerated, + Assert.Single(reporter.Messages.Where(m => m.Level == LogLevel.Warning)).Message); } [ConditionalFact] @@ -184,8 +363,9 @@ private static TestWebHost BuildWebHost(string[] args) private static TestDbContextOperations CreateOperations(Type testProgramType) { var assembly = MockAssembly.Create(testProgramType, typeof(TestContext)); - return new TestDbContextOperations( - new TestOperationReporter(), + var reporter = new TestOperationReporter(); + var operations = new TestDbContextOperations( + reporter, assembly, assembly, project: "", @@ -194,7 +374,13 @@ private static TestDbContextOperations CreateOperations(Type testProgramType) language: "C#", nullable: false, /* args: */ [], - new TestAppServiceProviderFactory(assembly)); + new TestAppServiceProviderFactory(assembly, reporter)); + + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Critical)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Error)); + Assert.Empty(reporter.Messages.Where(m => m.Level == LogLevel.Warning)); + + return operations; } private static TestWebHost CreateWebHost(Func configureProvider) @@ -230,6 +416,19 @@ public TestContextFromFactory(DbContextOptions options) } } + private class Testcontext : DbContext + { + public Testcontext() + { + throw new Exception("This isn't the constructor you're looking for."); + } + + public Testcontext(DbContextOptions options) + : base(options) + { + } + } + private class BaseContext(string factoryUsed) : DbContext { protected override void OnConfiguring(DbContextOptionsBuilder options) diff --git a/test/EFCore.Design.Tests/TestUtilities/TestAppServiceProviderFactory.cs b/test/EFCore.Design.Tests/TestUtilities/TestAppServiceProviderFactory.cs index 8c8cc5b037c..ed1f019b17d 100644 --- a/test/EFCore.Design.Tests/TestUtilities/TestAppServiceProviderFactory.cs +++ b/test/EFCore.Design.Tests/TestUtilities/TestAppServiceProviderFactory.cs @@ -5,4 +5,28 @@ namespace Microsoft.EntityFrameworkCore.TestUtilities; -public class TestAppServiceProviderFactory(Assembly startupAssembly, IOperationReporter reporter = null) : AppServiceProviderFactory(startupAssembly, reporter ?? new TestOperationReporter()); +public class TestAppServiceProviderFactory : AppServiceProviderFactory +{ + private readonly bool _throwOnCreate; + + public TestAppServiceProviderFactory(Assembly startupAssembly, bool throwOnCreate = false) + : this(startupAssembly, new TestOperationReporter(), throwOnCreate) + { + } + + public TestAppServiceProviderFactory(Assembly startupAssembly, TestOperationReporter reporter, bool throwOnCreate = false) + : base(startupAssembly, reporter) + { + TestOperationReporter = reporter; + _throwOnCreate = throwOnCreate; + } + + public TestOperationReporter TestOperationReporter { get; } + + public override IServiceProvider Create(string[] args) + { + Assert.False(_throwOnCreate, "Service provider shouldn't be used in this case."); + + return base.Create(args); + } +} diff --git a/test/EFCore.Specification.Tests/Query/NonSharedPrimitiveCollectionsQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NonSharedPrimitiveCollectionsQueryTestBase.cs index efe6255a46d..bdbccf70ad0 100644 --- a/test/EFCore.Specification.Tests/Query/NonSharedPrimitiveCollectionsQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NonSharedPrimitiveCollectionsQueryTestBase.cs @@ -231,8 +231,8 @@ public virtual async Task Project_collection_from_entity_type_with_owned() await using var context = contextFactory.CreateContext(); var results = await context.Set().Select(t => t.Ints).ToListAsync(); - Assert.True(results.Any(r => r?.SequenceEqual([1, 2]) == true)); - Assert.True(results.Any(r => r?.SequenceEqual([3, 4]) == true)); + Assert.Contains(results, r => r?.SequenceEqual([1, 2]) ?? false); + Assert.Contains(results, r => r?.SequenceEqual([3, 4]) ?? false); } private class TestEntityWithOwned