From b58ec4049f45a1427fc8ee8d03074b1cba9045a6 Mon Sep 17 00:00:00 2001 From: Paul Middleton Date: Mon, 31 Jul 2017 10:19:36 -0500 Subject: [PATCH] Default Sql Server db functions to DBO schema --- .../Northwind/NorthwindRelationalContext.cs | 11 +- .../RelationalDbFunctionConvention.cs | 2 +- .../Metadata/Internal/DbFunction.cs | 12 +- .../Internal/SqlServerDbFunctionConvention.cs | 30 +++++ .../SqlServerConventionSetBuilder.cs | 2 + .../Metadata/DbFunctionMetadataTests.cs | 16 +++ .../Query/QueryBugsTest.cs | 112 ++++++++++++++++++ .../SqlServerDbFunctionMetadataTests.cs | 73 ++++++++++++ 8 files changed, 250 insertions(+), 8 deletions(-) create mode 100644 src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs create mode 100644 test/EFCore.SqlServer.FunctionalTests/SqlServerDbFunctionMetadataTests.cs diff --git a/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs b/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs index 62d6310de31..24077831d7a 100644 --- a/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs +++ b/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs @@ -34,7 +34,8 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) .HasTranslation(args => new SqlFunctionExpression("len", methodInfo.ReturnType, args)); modelBuilder.HasDbFunction(typeof(NorthwindRelationalContext) - .GetRuntimeMethod(nameof(IsDate), new[] { typeof(string) })); + .GetRuntimeMethod(nameof(IsDate), new[] { typeof(string) })) + .HasSchema(""); } public enum ReportingPeriod @@ -55,13 +56,13 @@ public static bool IsDate(string date) throw new Exception(); } - [DbFunction(Schema = "dbo", FunctionName = "EmployeeOrderCount")] + [DbFunction(FunctionName = "EmployeeOrderCount")] public static int EmployeeOrderCount(int employeeId) { throw new NotImplementedException(); } - [DbFunction(Schema = "dbo", FunctionName = "EmployeeOrderCount")] + [DbFunction(FunctionName = "EmployeeOrderCount")] public static int EmployeeOrderCountWithClient(int employeeId) { switch (employeeId) @@ -73,13 +74,13 @@ public static int EmployeeOrderCountWithClient(int employeeId) } } - [DbFunction(Schema = "dbo")] + [DbFunction] public static bool IsTopEmployee(int employeeId) { throw new NotImplementedException(); } - [DbFunction(Schema = "dbo")] + [DbFunction] public static int GetEmployeeWithMostOrdersAfterDate(DateTime? startDate) { throw new NotImplementedException(); diff --git a/src/EFCore.Relational/Metadata/Conventions/Internal/RelationalDbFunctionConvention.cs b/src/EFCore.Relational/Metadata/Conventions/Internal/RelationalDbFunctionConvention.cs index e9be2812b9b..03697c9bab4 100644 --- a/src/EFCore.Relational/Metadata/Conventions/Internal/RelationalDbFunctionConvention.cs +++ b/src/EFCore.Relational/Metadata/Conventions/Internal/RelationalDbFunctionConvention.cs @@ -34,7 +34,7 @@ public virtual Annotation Apply(InternalModelBuilder modelBuilder, string name, var dbFunctionAttribute = methodInfo.GetCustomAttributes().SingleOrDefault(); dbFunctionBuilder.HasName(dbFunctionAttribute?.FunctionName ?? methodInfo.Name, ConfigurationSource.Convention); - dbFunctionBuilder.HasSchema(dbFunctionAttribute?.Schema ?? modelBuilder.Metadata.Relational().DefaultSchema, ConfigurationSource.Convention); + dbFunctionBuilder.HasSchema(dbFunctionAttribute?.Schema, ConfigurationSource.Convention); } return annotation; diff --git a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs index 9885c82b305..e144d10bac9 100644 --- a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs +++ b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs @@ -47,6 +47,7 @@ public static DbFunction FindDbFunction( return model[BuildAnnotationName(annotationPrefix, methodInfo)] as DbFunction; } + private readonly IMutableModel _model; private string _schema; private string _functionName; @@ -81,7 +82,8 @@ private DbFunction( MethodInfo = methodInfo; - model[BuildAnnotationName(annotationPrefix, methodInfo)] = this; + _model = model; + _model[BuildAnnotationName(annotationPrefix, methodInfo)] = this; } /// @@ -102,13 +104,19 @@ public static IEnumerable GetDbFunctions([NotNull] IModel model, [N private static string BuildAnnotationName(string annotationPrefix, MethodBase methodBase) => $@"{annotationPrefix}{methodBase.Name}({string.Join(",", methodBase.GetParameters().Select(p => p.ParameterType.Name))})"; + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public virtual string DefaultSchema { get; [param: CanBeNull] set;} + /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. /// public virtual string Schema { - get => _schema; + get => _schema ?? _model.Relational().DefaultSchema ?? DefaultSchema; set => SetSchema(value, ConfigurationSource.Explicit); } diff --git a/src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs b/src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs new file mode 100644 index 00000000000..f0948acce75 --- /dev/null +++ b/src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs @@ -0,0 +1,30 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Reflection; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal +{ + public class SqlServerDbFunctionConvention : IModelAnnotationChangedConvention + { + public virtual Annotation Apply(InternalModelBuilder modelBuilder, string name, Annotation annotation, Annotation oldAnnotation) + { + Check.NotNull(modelBuilder, nameof(modelBuilder)); + Check.NotNull(name, nameof(name)); + + if (name.StartsWith(RelationalAnnotationNames.DbFunction, StringComparison.OrdinalIgnoreCase) + && annotation != null + && oldAnnotation == null) + { + ((DbFunction)annotation.Value).DefaultSchema = "dbo"; + } + + return annotation; + } + } +} diff --git a/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs b/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs index 49f66629007..29d39fef12c 100644 --- a/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs +++ b/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs @@ -57,6 +57,8 @@ public override ConventionSet AddConventions(ConventionSet conventionSet) conventionSet.PropertyAnnotationChangedConventions.Add(sqlServerIndexConvention); conventionSet.PropertyAnnotationChangedConventions.Add((SqlServerValueGeneratorConvention)valueGeneratorConvention); + conventionSet.ModelAnnotationChangedConventions.Add(new SqlServerDbFunctionConvention()); + return conventionSet; } diff --git a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs index 3aa3d569c81..752eda78b3a 100644 --- a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs +++ b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs @@ -275,6 +275,22 @@ public void Adding_method_with_relational_schema_attribute_overrides() Assert.Equal("bar", dbFuncBuilder.Metadata.Schema); } + [Fact] + public void Changing_default_schema_is_detected_by_dbfunction() + { + var modelBuilder = GetModelBuilder(); + + modelBuilder.HasDefaultSchema("abc"); + + var dbFuncBuilder = modelBuilder.HasDbFunction(MethodAmi); + + Assert.Equal("abc", dbFuncBuilder.Metadata.Schema); + + modelBuilder.HasDefaultSchema("xyz"); + + Assert.Equal("xyz", dbFuncBuilder.Metadata.Schema); + } + [Fact] public void Add_method_generic_not_supported_throws() { diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index 3762d9b9d2f..94f53281e45 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -2141,6 +2141,118 @@ public class Details9202 public string Info { get; set; } } + #endregion + + #region Bug9214 + + [Fact] + public void Default_schema_applied_when_no_function_schema() + { + using (CreateDatabase9214()) + { + using (var context = new MyContext9214(_options)) + { + var result = context.Widgets.Where(w => w.Val == 1).Select(w => MyContext9214.AddOne(w.Val)).Single(); + + Assert.Equal(2, result); + + AssertSql( + @"SELECT TOP(2) [foo].AddOne([w].[Val]) +FROM [foo].[Widgets] AS [w] +WHERE [w].[Val] = 1"); + } + } + } + + [Fact] + public void Default_schema_function_schema_overrides() + { + using (CreateDatabase9214()) + { + using (var context = new MyContext9214(_options)) + { + var result = context.Widgets.Where(w => w.Val == 1).Select(w => MyContext9214.AddTwo(w.Val)).Single(); + + Assert.Equal(3, result); + + AssertSql( + @"SELECT TOP(2) [dbo].AddTwo([w].[Val]) +FROM [foo].[Widgets] AS [w] +WHERE [w].[Val] = 1"); + } + } + } + + private SqlServerTestStore CreateDatabase9214() + => CreateTestStore( + () => new MyContext9214(_options), + context => + { + var w1 = new Widget9214 { Val = 1 }; + var w2 = new Widget9214 { Val = 2 }; + var w3 = new Widget9214 { Val = 3 }; + context.Widgets.AddRange(w1, w2, w3); + context.SaveChanges(); + + context.Database.ExecuteSqlCommand(@"CREATE FUNCTION foo.AddOne (@num int) + RETURNS int + AS + BEGIN + return @num + 1 ; + END"); + + context.Database.ExecuteSqlCommand(@"CREATE FUNCTION dbo.AddTwo (@num int) + RETURNS int + AS + BEGIN + return @num + 2 ; + END"); + + ClearLog(); + }); + + public class MyContext9214 : DbContext + { + public DbSet Widgets { get; set; } + + public static int AddOne(int num) + { + throw new Exception(); + } + + public static int AddTwo(int num) + { + throw new Exception(); + } + + public static int AddThree(int num) + { + throw new Exception(); + } + + public MyContext9214(DbContextOptions options) + : base(options) + { + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.HasDefaultSchema("foo"); + + modelBuilder.Entity().ToTable("Widgets", "foo"); + + modelBuilder.HasDbFunction(typeof(MyContext9214).GetMethod(nameof(AddOne))); + modelBuilder.HasDbFunction(typeof(MyContext9214).GetMethod(nameof(AddTwo))).HasSchema("dbo"); + } + } + + public class Widget9214 + { + public int Id { get; set; } + public int Val { get; set; } + } + + #endregion #region Bug9277 diff --git a/test/EFCore.SqlServer.FunctionalTests/SqlServerDbFunctionMetadataTests.cs b/test/EFCore.SqlServer.FunctionalTests/SqlServerDbFunctionMetadataTests.cs new file mode 100644 index 00000000000..e047a3c8a6e --- /dev/null +++ b/test/EFCore.SqlServer.FunctionalTests/SqlServerDbFunctionMetadataTests.cs @@ -0,0 +1,73 @@ +using System; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata.Conventions; +using Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal; +using Microsoft.EntityFrameworkCore.Metadata.Internal ; +using Xunit; + +namespace Microsoft.EntityFrameworkCore +{ + public class SqlServerDbFunctionMetadataTests + { + public class TestMethods + { + public static int Foo() + { + throw new Exception(); + } + } + + public static MethodInfo MethodFoo = typeof(TestMethods).GetRuntimeMethod(nameof(TestMethods.Foo), new Type[] { }); + + [Fact] + public virtual void DbFuction_defaults_schema_to_dbo_if_no_default_schema_or_set_schema() + { + var modelBuilder = GetModelBuilder(); + + var dbFunction = modelBuilder.HasDbFunction(MethodFoo); + + ((Model)modelBuilder.Model).Validate(); + + Assert.Equal("dbo", dbFunction.Metadata.Schema); + } + + [Fact] + public virtual void DbFuction_set_schmea_is_not_overridden_by_default_or_dbo() + { + var modelBuilder = GetModelBuilder(); + + modelBuilder.HasDefaultSchema("qwerty"); + + var dbFunction = modelBuilder.HasDbFunction(MethodFoo).HasSchema("abc"); + + ((Model)modelBuilder.Model).Validate(); + + Assert.Equal("abc", dbFunction.Metadata.Schema); + } + + [Fact] + public virtual void DbFuction_default_schema_not_overridden_by_dbo() + { + var modelBuilder = GetModelBuilder(); + + modelBuilder.HasDefaultSchema("qwerty"); + + var dbFunction = modelBuilder.HasDbFunction(MethodFoo); + + ((Model)modelBuilder.Model).Validate(); + + Assert.Equal("qwerty", dbFunction.Metadata.Schema); + } + + private ModelBuilder GetModelBuilder() + { + var conventionset = new ConventionSet(); + + conventionset.ModelAnnotationChangedConventions.Add(new SqlServerDbFunctionConvention()); + + return new ModelBuilder(conventionset); + } + } +}