From 1e29900734bf942bdbdd0faf469f2e882d7a4626 Mon Sep 17 00:00:00 2001 From: Paul Middleton Date: Mon, 31 Jul 2017 10:19:36 -0500 Subject: [PATCH] Default Sql Server db functions to DBO schema --- .../Northwind/NorthwindRelationalContext.cs | 11 +-- .../Internal/SqlServerDbFunctionConvention.cs | 31 ++++++++ .../SqlServerConventionSetBuilder.cs | 2 + .../SqlServerMetadataTests.cs | 72 +++++++++++++++++++ 4 files changed, 111 insertions(+), 5 deletions(-) create mode 100644 src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs create mode 100644 test/EFCore.SqlServer.FunctionalTests/SqlServerMetadataTests.cs diff --git a/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs b/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs index 62d6310de31..24077831d7a 100644 --- a/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs +++ b/src/EFCore.Relational.Specification.Tests/TestModels/Northwind/NorthwindRelationalContext.cs @@ -34,7 +34,8 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) .HasTranslation(args => new SqlFunctionExpression("len", methodInfo.ReturnType, args)); modelBuilder.HasDbFunction(typeof(NorthwindRelationalContext) - .GetRuntimeMethod(nameof(IsDate), new[] { typeof(string) })); + .GetRuntimeMethod(nameof(IsDate), new[] { typeof(string) })) + .HasSchema(""); } public enum ReportingPeriod @@ -55,13 +56,13 @@ public static bool IsDate(string date) throw new Exception(); } - [DbFunction(Schema = "dbo", FunctionName = "EmployeeOrderCount")] + [DbFunction(FunctionName = "EmployeeOrderCount")] public static int EmployeeOrderCount(int employeeId) { throw new NotImplementedException(); } - [DbFunction(Schema = "dbo", FunctionName = "EmployeeOrderCount")] + [DbFunction(FunctionName = "EmployeeOrderCount")] public static int EmployeeOrderCountWithClient(int employeeId) { switch (employeeId) @@ -73,13 +74,13 @@ public static int EmployeeOrderCountWithClient(int employeeId) } } - [DbFunction(Schema = "dbo")] + [DbFunction] public static bool IsTopEmployee(int employeeId) { throw new NotImplementedException(); } - [DbFunction(Schema = "dbo")] + [DbFunction] public static int GetEmployeeWithMostOrdersAfterDate(DateTime? startDate) { throw new NotImplementedException(); diff --git a/src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs b/src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs new file mode 100644 index 00000000000..5ec5591da34 --- /dev/null +++ b/src/EFCore.SqlServer/Metadata/Conventions/Internal/SqlServerDbFunctionConvention.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Reflection; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal +{ + public class SqlServerDbFunctionConvention : IModelBuiltConvention + { + public InternalModelBuilder Apply(InternalModelBuilder modelBuilder) + { + Check.NotNull(modelBuilder, nameof(modelBuilder)); + + if (!string.IsNullOrWhiteSpace(modelBuilder.Metadata.Relational().DefaultSchema)) + return modelBuilder; + + foreach (var dbFunction in modelBuilder.Metadata.Relational().DbFunctions) + { + if (dbFunction.Schema == null) + modelBuilder.Metadata.Relational().GetOrAddDbFunction(dbFunction.MethodInfo).Schema = "dbo"; + } + + return modelBuilder; + } + } +} diff --git a/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs b/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs index 49f66629007..89deeb2a7f6 100644 --- a/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs +++ b/src/EFCore.SqlServer/Metadata/Conventions/SqlServerConventionSetBuilder.cs @@ -57,6 +57,8 @@ public override ConventionSet AddConventions(ConventionSet conventionSet) conventionSet.PropertyAnnotationChangedConventions.Add(sqlServerIndexConvention); conventionSet.PropertyAnnotationChangedConventions.Add((SqlServerValueGeneratorConvention)valueGeneratorConvention); + conventionSet.ModelBuiltConventions.Add(new SqlServerDbFunctionConvention()); + return conventionSet; } diff --git a/test/EFCore.SqlServer.FunctionalTests/SqlServerMetadataTests.cs b/test/EFCore.SqlServer.FunctionalTests/SqlServerMetadataTests.cs new file mode 100644 index 00000000000..64087534e6e --- /dev/null +++ b/test/EFCore.SqlServer.FunctionalTests/SqlServerMetadataTests.cs @@ -0,0 +1,72 @@ +using System; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata.Conventions; +using Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal; +using Microsoft.EntityFrameworkCore.Metadata.Internal ; +using Xunit; + +namespace Microsoft.EntityFrameworkCore +{ + public class SqlServerMetadataTests + { + public class TestMethods + { + public static int Foo() + { + throw new Exception(); + } + } + + public static MethodInfo MethodFoo = typeof(TestMethods).GetRuntimeMethod(nameof(TestMethods.Foo), new Type[] { }); + + + [Fact] + public virtual void DbFuction_defaults_schema_to_dbo() + { + var modelBuilder = GetModelBuilder(); + + var dbFunction = modelBuilder.HasDbFunction(MethodFoo); + + ((Model)modelBuilder.Model).Validate(); + + Assert.Equal("dbo", dbFunction.Metadata.Schema); + } + + [Fact] + public virtual void DbFuction_set_schmea_is_not_overridden() + { + var modelBuilder = GetModelBuilder(); + + var dbFunction = modelBuilder.HasDbFunction(MethodFoo).HasSchema("abc"); + + ((Model)modelBuilder.Model).Validate(); + + Assert.Equal("abc", dbFunction.Metadata.Schema); + } + + [Fact] + public virtual void DbFuction__dbo_not_set_if_default_schema_is_set() + { + var modelBuilder = GetModelBuilder(); + + modelBuilder.HasDefaultSchema("qwerty"); + + var dbFunction = modelBuilder.HasDbFunction(MethodFoo); + + ((Model)modelBuilder.Model).Validate(); + + Assert.Equal(null, dbFunction.Metadata.Schema); + } + + private ModelBuilder GetModelBuilder() + { + var conventionset = new ConventionSet(); + + conventionset.ModelBuiltConventions.Add(new SqlServerDbFunctionConvention()); + + return new ModelBuilder(conventionset); + } + } +}