Skip to content

Commit

Permalink
Default Sql Server db functions to DBO schema
Browse files Browse the repository at this point in the history
  • Loading branch information
pmiddleton committed Aug 11, 2017
1 parent ead9056 commit b58ec40
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ protected override void OnModelCreating(ModelBuilder modelBuilder)
.HasTranslation(args => new SqlFunctionExpression("len", methodInfo.ReturnType, args));

modelBuilder.HasDbFunction(typeof(NorthwindRelationalContext)
.GetRuntimeMethod(nameof(IsDate), new[] { typeof(string) }));
.GetRuntimeMethod(nameof(IsDate), new[] { typeof(string) }))
.HasSchema("");
}

public enum ReportingPeriod
Expand All @@ -55,13 +56,13 @@ public static bool IsDate(string date)
throw new Exception();
}

[DbFunction(Schema = "dbo", FunctionName = "EmployeeOrderCount")]
[DbFunction(FunctionName = "EmployeeOrderCount")]
public static int EmployeeOrderCount(int employeeId)
{
throw new NotImplementedException();
}

[DbFunction(Schema = "dbo", FunctionName = "EmployeeOrderCount")]
[DbFunction(FunctionName = "EmployeeOrderCount")]
public static int EmployeeOrderCountWithClient(int employeeId)
{
switch (employeeId)
Expand All @@ -73,13 +74,13 @@ public static int EmployeeOrderCountWithClient(int employeeId)
}
}

[DbFunction(Schema = "dbo")]
[DbFunction]
public static bool IsTopEmployee(int employeeId)
{
throw new NotImplementedException();
}

[DbFunction(Schema = "dbo")]
[DbFunction]
public static int GetEmployeeWithMostOrdersAfterDate(DateTime? startDate)
{
throw new NotImplementedException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public virtual Annotation Apply(InternalModelBuilder modelBuilder, string name,
var dbFunctionAttribute = methodInfo.GetCustomAttributes<DbFunctionAttribute>().SingleOrDefault();

dbFunctionBuilder.HasName(dbFunctionAttribute?.FunctionName ?? methodInfo.Name, ConfigurationSource.Convention);
dbFunctionBuilder.HasSchema(dbFunctionAttribute?.Schema ?? modelBuilder.Metadata.Relational().DefaultSchema, ConfigurationSource.Convention);
dbFunctionBuilder.HasSchema(dbFunctionAttribute?.Schema, ConfigurationSource.Convention);
}

return annotation;
Expand Down
12 changes: 10 additions & 2 deletions src/EFCore.Relational/Metadata/Internal/DbFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public static DbFunction FindDbFunction(
return model[BuildAnnotationName(annotationPrefix, methodInfo)] as DbFunction;
}

private readonly IMutableModel _model;
private string _schema;
private string _functionName;

Expand Down Expand Up @@ -81,7 +82,8 @@ private DbFunction(

MethodInfo = methodInfo;

model[BuildAnnotationName(annotationPrefix, methodInfo)] = this;
_model = model;
_model[BuildAnnotationName(annotationPrefix, methodInfo)] = this;
}

/// <summary>
Expand All @@ -102,13 +104,19 @@ public static IEnumerable<IDbFunction> GetDbFunctions([NotNull] IModel model, [N
private static string BuildAnnotationName(string annotationPrefix, MethodBase methodBase)
=> $@"{annotationPrefix}{methodBase.Name}({string.Join(",", methodBase.GetParameters().Select(p => p.ParameterType.Name))})";

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual string DefaultSchema { get; [param: CanBeNull] set;}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual string Schema
{
get => _schema;
get => _schema ?? _model.Relational().DefaultSchema ?? DefaultSchema;
set => SetSchema(value, ConfigurationSource.Explicit);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal
{
public class SqlServerDbFunctionConvention : IModelAnnotationChangedConvention
{
public virtual Annotation Apply(InternalModelBuilder modelBuilder, string name, Annotation annotation, Annotation oldAnnotation)
{
Check.NotNull(modelBuilder, nameof(modelBuilder));
Check.NotNull(name, nameof(name));

if (name.StartsWith(RelationalAnnotationNames.DbFunction, StringComparison.OrdinalIgnoreCase)
&& annotation != null
&& oldAnnotation == null)
{
((DbFunction)annotation.Value).DefaultSchema = "dbo";
}

return annotation;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public override ConventionSet AddConventions(ConventionSet conventionSet)
conventionSet.PropertyAnnotationChangedConventions.Add(sqlServerIndexConvention);
conventionSet.PropertyAnnotationChangedConventions.Add((SqlServerValueGeneratorConvention)valueGeneratorConvention);

conventionSet.ModelAnnotationChangedConventions.Add(new SqlServerDbFunctionConvention());

return conventionSet;
}

Expand Down
16 changes: 16 additions & 0 deletions test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,22 @@ public void Adding_method_with_relational_schema_attribute_overrides()
Assert.Equal("bar", dbFuncBuilder.Metadata.Schema);
}

[Fact]
public void Changing_default_schema_is_detected_by_dbfunction()
{
var modelBuilder = GetModelBuilder();

modelBuilder.HasDefaultSchema("abc");

var dbFuncBuilder = modelBuilder.HasDbFunction(MethodAmi);

Assert.Equal("abc", dbFuncBuilder.Metadata.Schema);

modelBuilder.HasDefaultSchema("xyz");

Assert.Equal("xyz", dbFuncBuilder.Metadata.Schema);
}

[Fact]
public void Add_method_generic_not_supported_throws()
{
Expand Down
112 changes: 112 additions & 0 deletions test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2141,6 +2141,118 @@ public class Details9202
public string Info { get; set; }
}

#endregion

#region Bug9214

[Fact]
public void Default_schema_applied_when_no_function_schema()
{
using (CreateDatabase9214())
{
using (var context = new MyContext9214(_options))
{
var result = context.Widgets.Where(w => w.Val == 1).Select(w => MyContext9214.AddOne(w.Val)).Single();

Assert.Equal(2, result);

AssertSql(
@"SELECT TOP(2) [foo].AddOne([w].[Val])
FROM [foo].[Widgets] AS [w]
WHERE [w].[Val] = 1");
}
}
}

[Fact]
public void Default_schema_function_schema_overrides()
{
using (CreateDatabase9214())
{
using (var context = new MyContext9214(_options))
{
var result = context.Widgets.Where(w => w.Val == 1).Select(w => MyContext9214.AddTwo(w.Val)).Single();

Assert.Equal(3, result);

AssertSql(
@"SELECT TOP(2) [dbo].AddTwo([w].[Val])
FROM [foo].[Widgets] AS [w]
WHERE [w].[Val] = 1");
}
}
}

private SqlServerTestStore CreateDatabase9214()
=> CreateTestStore(
() => new MyContext9214(_options),
context =>
{
var w1 = new Widget9214 { Val = 1 };
var w2 = new Widget9214 { Val = 2 };
var w3 = new Widget9214 { Val = 3 };
context.Widgets.AddRange(w1, w2, w3);
context.SaveChanges();

context.Database.ExecuteSqlCommand(@"CREATE FUNCTION foo.AddOne (@num int)
RETURNS int
AS
BEGIN
return @num + 1 ;
END");

context.Database.ExecuteSqlCommand(@"CREATE FUNCTION dbo.AddTwo (@num int)
RETURNS int
AS
BEGIN
return @num + 2 ;
END");

ClearLog();
});

public class MyContext9214 : DbContext
{
public DbSet<Widget9214> Widgets { get; set; }

public static int AddOne(int num)
{
throw new Exception();
}

public static int AddTwo(int num)
{
throw new Exception();
}

public static int AddThree(int num)
{
throw new Exception();
}

public MyContext9214(DbContextOptions options)
: base(options)
{
}

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder.HasDefaultSchema("foo");

modelBuilder.Entity<Widget9214>().ToTable("Widgets", "foo");

modelBuilder.HasDbFunction(typeof(MyContext9214).GetMethod(nameof(AddOne)));
modelBuilder.HasDbFunction(typeof(MyContext9214).GetMethod(nameof(AddTwo))).HasSchema("dbo");
}
}

public class Widget9214
{
public int Id { get; set; }
public int Val { get; set; }
}


#endregion

#region Bug9277
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using System;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal;
using Microsoft.EntityFrameworkCore.Metadata.Internal ;
using Xunit;

namespace Microsoft.EntityFrameworkCore
{
public class SqlServerDbFunctionMetadataTests
{
public class TestMethods
{
public static int Foo()
{
throw new Exception();
}
}

public static MethodInfo MethodFoo = typeof(TestMethods).GetRuntimeMethod(nameof(TestMethods.Foo), new Type[] { });

[Fact]
public virtual void DbFuction_defaults_schema_to_dbo_if_no_default_schema_or_set_schema()
{
var modelBuilder = GetModelBuilder();

var dbFunction = modelBuilder.HasDbFunction(MethodFoo);

((Model)modelBuilder.Model).Validate();

Assert.Equal("dbo", dbFunction.Metadata.Schema);
}

[Fact]
public virtual void DbFuction_set_schmea_is_not_overridden_by_default_or_dbo()
{
var modelBuilder = GetModelBuilder();

modelBuilder.HasDefaultSchema("qwerty");

var dbFunction = modelBuilder.HasDbFunction(MethodFoo).HasSchema("abc");

((Model)modelBuilder.Model).Validate();

Assert.Equal("abc", dbFunction.Metadata.Schema);
}

[Fact]
public virtual void DbFuction_default_schema_not_overridden_by_dbo()
{
var modelBuilder = GetModelBuilder();

modelBuilder.HasDefaultSchema("qwerty");

var dbFunction = modelBuilder.HasDbFunction(MethodFoo);

((Model)modelBuilder.Model).Validate();

Assert.Equal("qwerty", dbFunction.Metadata.Schema);
}

private ModelBuilder GetModelBuilder()
{
var conventionset = new ConventionSet();

conventionset.ModelAnnotationChangedConventions.Add(new SqlServerDbFunctionConvention());

return new ModelBuilder(conventionset);
}
}
}

0 comments on commit b58ec40

Please sign in to comment.