diff --git a/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs b/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs index 62d6310de31..24077831d7a 100644 --- a/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs +++ b/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs @@ -34,7 +34,8 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) .HasTranslation(args => new SqlFunctionExpression("len", methodInfo.ReturnType, args)); modelBuilder.HasDbFunction(typeof(NorthwindRelationalContext) - .GetRuntimeMethod(nameof(IsDate), new[] { typeof(string) })); + .GetRuntimeMethod(nameof(IsDate), new[] { typeof(string) })) + .HasSchema(""); } public enum ReportingPeriod @@ -55,13 +56,13 @@ public static bool IsDate(string date) throw new Exception(); } - [DbFunction(Schema = "dbo", FunctionName = "EmployeeOrderCount")] + [DbFunction(FunctionName = "EmployeeOrderCount")] public static int EmployeeOrderCount(int employeeId) { throw new NotImplementedException(); } - [DbFunction(Schema = "dbo", FunctionName = "EmployeeOrderCount")] + [DbFunction(FunctionName = "EmployeeOrderCount")] public static int EmployeeOrderCountWithClient(int employeeId) { switch (employeeId) @@ -73,13 +74,13 @@ public static int EmployeeOrderCountWithClient(int employeeId) } } - [DbFunction(Schema = "dbo")] + [DbFunction] public static bool IsTopEmployee(int employeeId) { throw new NotImplementedException(); } - [DbFunction(Schema = "dbo")] + [DbFunction] public static int GetEmployeeWithMostOrdersAfterDate(DateTime? startDate) { throw new NotImplementedException(); diff --git a/src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs b/src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs new file mode 100644 index 00000000000..549c6d6d8f4 --- /dev/null +++ b/src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Reflection; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal +{ + public class SqlServerDbFunctionConvention : IModelBuiltConvention + { + public virtual InternalModelBuilder Apply(InternalModelBuilder modelBuilder) + { + Check.NotNull(modelBuilder, nameof(modelBuilder)); + + if (!string.IsNullOrWhiteSpace(modelBuilder.Metadata.Relational().DefaultSchema)) + return modelBuilder; + + foreach (var dbFunction in modelBuilder.Metadata.Relational().DbFunctions) + { + if (dbFunction.Schema == null) + modelBuilder.Metadata.Relational().GetOrAddDbFunction(dbFunction.MethodInfo).Schema = "dbo"; + } + + return modelBuilder; + } + } +} diff --git a/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs b/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs index 49f66629007..89deeb2a7f6 100644 --- a/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs +++ b/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs @@ -57,6 +57,8 @@ public override ConventionSet AddConventions(ConventionSet conventionSet) conventionSet.PropertyAnnotationChangedConventions.Add(sqlServerIndexConvention); conventionSet.PropertyAnnotationChangedConventions.Add((SqlServerValueGeneratorConvention)valueGeneratorConvention); + conventionSet.ModelBuiltConventions.Add(new SqlServerDbFunctionConvention()); + return conventionSet; } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index f7b5ad336c8..a21b23a36d4 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -2141,6 +2141,118 @@ public class Details9202 public string Info { get; set; } } + #endregion + + #region Bug9214 + + [Fact] + public void Default_schema_applied_when_no_function_schema() + { + using (CreateDatabase9214()) + { + using (var context = new MyContext9214(_options)) + { + var result = context.Widgets.Where(w => w.Val == 1).Select(w => MyContext9214.AddOne(w.Val)).Single(); + + Assert.Equal(2, result); + + AssertSql( + @"SELECT TOP(2) [foo].AddOne([w].[Val]) +FROM [foo].[Widgets] AS [w] +WHERE [w].[Val] = 1"); + } + } + } + + [Fact] + public void Default_schema_function_schema_overrides() + { + using (CreateDatabase9214()) + { + using (var context = new MyContext9214(_options)) + { + var result = context.Widgets.Where(w => w.Val == 1).Select(w => MyContext9214.AddTwo(w.Val)).Single(); + + Assert.Equal(3, result); + + AssertSql( + @"SELECT TOP(2) [dbo].AddTwo([w].[Val]) +FROM [foo].[Widgets] AS [w] +WHERE [w].[Val] = 1"); + } + } + } + + private SqlServerTestStore CreateDatabase9214() + => CreateTestStore( + () => new MyContext9214(_options), + context => + { + var w1 = new Widget9214 { Val = 1 }; + var w2 = new Widget9214 { Val = 2 }; + var w3 = new Widget9214 { Val = 3 }; + context.Widgets.AddRange(w1, w2, w3); + context.SaveChanges(); + + context.Database.ExecuteSqlCommand(@"CREATE FUNCTION foo.AddOne (@num int) + RETURNS int + AS + BEGIN + return @num + 1 ; + END"); + + context.Database.ExecuteSqlCommand(@"CREATE FUNCTION dbo.AddTwo (@num int) + RETURNS int + AS + BEGIN + return @num + 2 ; + END"); + + ClearLog(); + }); + + public class MyContext9214 : DbContext + { + public DbSet Widgets { get; set; } + + public static int AddOne(int num) + { + throw new Exception(); + } + + public static int AddTwo(int num) + { + throw new Exception(); + } + + public static int AddThree(int num) + { + throw new Exception(); + } + + public MyContext9214(DbContextOptions options) + : base(options) + { + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.HasDefaultSchema("foo"); + + modelBuilder.Entity().ToTable("Widgets", "foo"); + + modelBuilder.HasDbFunction(typeof(MyContext9214).GetMethod(nameof(AddOne))); + modelBuilder.HasDbFunction(typeof(MyContext9214).GetMethod(nameof(AddTwo))).HasSchema("dbo"); + } + } + + public class Widget9214 + { + public int Id { get; set; } + public int Val { get; set; } + } + + #endregion #region Bug9277 diff --git a/test/EFCore.SqlServer.FunctionalTests/SqlServerDbFunctionMetadataTests.cs b/test/EFCore.SqlServer.FunctionalTests/SqlServerDbFunctionMetadataTests.cs new file mode 100644 index 00000000000..47945c1a0d3 --- /dev/null +++ b/test/EFCore.SqlServer.FunctionalTests/SqlServerDbFunctionMetadataTests.cs @@ -0,0 +1,72 @@ +using System; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata.Conventions; +using Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal; +using Microsoft.EntityFrameworkCore.Metadata.Internal ; +using Xunit; + +namespace Microsoft.EntityFrameworkCore +{ + public class SqlServerDbFunctionMetadataTests + { + public class TestMethods + { + public static int Foo() + { + throw new Exception(); + } + } + + public static MethodInfo MethodFoo = typeof(TestMethods).GetRuntimeMethod(nameof(TestMethods.Foo), new Type[] { }); + + + [Fact] + public virtual void DbFuction_defaults_schema_to_dbo() + { + var modelBuilder = GetModelBuilder(); + + var dbFunction = modelBuilder.HasDbFunction(MethodFoo); + + ((Model)modelBuilder.Model).Validate(); + + Assert.Equal("dbo", dbFunction.Metadata.Schema); + } + + [Fact] + public virtual void DbFuction_set_schmea_is_not_overridden() + { + var modelBuilder = GetModelBuilder(); + + var dbFunction = modelBuilder.HasDbFunction(MethodFoo).HasSchema("abc"); + + ((Model)modelBuilder.Model).Validate(); + + Assert.Equal("abc", dbFunction.Metadata.Schema); + } + + [Fact] + public virtual void DbFuction_dbo_not_set_if_default_schema_is_set() + { + var modelBuilder = GetModelBuilder(); + + modelBuilder.HasDefaultSchema("qwerty"); + + var dbFunction = modelBuilder.HasDbFunction(MethodFoo); + + ((Model)modelBuilder.Model).Validate(); + + Assert.Equal(null, dbFunction.Metadata.Schema); + } + + private ModelBuilder GetModelBuilder() + { + var conventionset = new ConventionSet(); + + conventionset.ModelBuiltConventions.Add(new SqlServerDbFunctionConvention()); + + return new ModelBuilder(conventionset); + } + } +}