Skip to content

Commit

Permalink
Query: Add support for instance based SqlFunction
Browse files Browse the repository at this point in the history
Part of #10109
  • Loading branch information
smitpatel committed Nov 27, 2017
1 parent 82339be commit 3c57461
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/EFCore.Relational/Metadata/Internal/DbFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ private static string BuildAnnotationName(string annotationPrefix, MethodBase me
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual string DefaultSchema { get; [param: CanBeNull] set;}
public virtual string DefaultSchema { get; [param: CanBeNull] set; }

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
Expand Down
56 changes: 50 additions & 6 deletions src/EFCore.Relational/Query/Expressions/SqlFunctionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.EntityFrameworkCore.Query.Expressions
/// <summary>
/// Represents a SQL function call expression.
/// </summary>
[DebuggerDisplay("{this.FunctionName}({string.Join(\", \", this.Arguments)})")]
[DebuggerDisplay("{ToString()}")]
public class SqlFunctionExpression : Expression
{
private readonly ReadOnlyCollection<Expression> _arguments;
Expand All @@ -34,7 +34,7 @@ public SqlFunctionExpression(
}

/// <summary>
/// Initializes a new instance of the Microsoft.EntityFrameworkCore.Query.Expressions.SqlFunctionExpression class.
/// Initializes a new instance of the <see cref="SqlFunctionExpression" /> class.
/// </summary>
/// <param name="functionName"> Name of the function. </param>
/// <param name="returnType"> The return type. </param>
Expand All @@ -51,16 +51,42 @@ public SqlFunctionExpression(
/// Initializes a new instance of the <see cref="SqlFunctionExpression" /> class.
/// </summary>
/// <param name="functionName"> Name of the function. </param>
/// <param name="returnType"> The return type. </param>
/// ///
/// <param name="schema"> The schema this function exists in if any. </param>
/// <param name="returnType"> The return type. </param>
/// <param name="arguments"> The arguments. </param>
public SqlFunctionExpression(
[NotNull] string functionName,
[NotNull] Type returnType,
[CanBeNull] string schema,
[NotNull] IEnumerable<Expression> arguments)
: this(/*instance*/ null, functionName, schema, returnType, arguments)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="SqlFunctionExpression" /> class.
/// </summary>
/// <param name="instance"> The instance on which the function is called. </param>
/// <param name="functionName"> Name of the function. </param>
/// <param name="returnType"> The return type. </param>
/// <param name="arguments"> The arguments. </param>
public SqlFunctionExpression(
[NotNull] Expression instance,
[NotNull] string functionName,
[NotNull] Type returnType,
[NotNull] IEnumerable<Expression> arguments)
: this(instance, functionName, /*schema*/ null, returnType, arguments)
{
}

private SqlFunctionExpression(
[CanBeNull] Expression instance,
[NotNull] string functionName,
[CanBeNull] string schema,
[NotNull] Type returnType,
[NotNull] IEnumerable<Expression> arguments)
{
Instance = instance;
FunctionName = functionName;
Type = returnType;
Schema = schema;
Expand All @@ -83,6 +109,11 @@ public SqlFunctionExpression(
/// </value>
public virtual string Schema { get; }

/// <summary>
/// The instance.
/// </summary>
public virtual Expression Instance { get; }

/// <summary>
/// The arguments.
/// </summary>
Expand Down Expand Up @@ -129,10 +160,11 @@ protected override Expression Accept(ExpressionVisitor visitor)
/// </remarks>
protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var newInstance = Instance != null ? visitor.Visit(Instance) : null;
var newArguments = visitor.VisitAndConvert(_arguments, "VisitChildren");

return newArguments != _arguments
? new SqlFunctionExpression(FunctionName, Type, newArguments)
return newInstance != Instance || newArguments != _arguments
? new SqlFunctionExpression(newInstance, FunctionName, Schema, Type, newArguments)
: this;
}

Expand Down Expand Up @@ -161,6 +193,8 @@ public override bool Equals(object obj)
private bool Equals(SqlFunctionExpression other)
=> Type == other.Type
&& string.Equals(FunctionName, other.FunctionName)
&& string.Equals(Schema, other.Schema)
&& Instance.Equals(other.Instance)
&& _arguments.SequenceEqual(other._arguments);

/// <summary>
Expand All @@ -174,10 +208,20 @@ public override int GetHashCode()
unchecked
{
var hashCode = _arguments.Aggregate(0, (current, argument) => current + ((current * 397) ^ argument.GetHashCode()));
hashCode = (hashCode * 397) ^ (Instance?.GetHashCode() ?? 0);
hashCode = (hashCode * 397) ^ FunctionName.GetHashCode();
hashCode = (hashCode * 397) ^ (Schema?.GetHashCode() ?? 0);
hashCode = (hashCode * 397) ^ Type.GetHashCode();
return hashCode;
}
}

/// <summary>
/// Creates a <see cref="string" /> representation of the Expression.
/// </summary>
/// <returns>A <see cref="string" /> representation of the Expression.</returns>
public override string ToString()
=> (Instance != null ? Instance + "." : Schema != null ? Schema + "." : "") +
$"{FunctionName}({string.Join("", "", Arguments)}";
}
}
37 changes: 26 additions & 11 deletions src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,6 @@ protected virtual void GenerateFromSql(
}

break;

}

if (substitutions != null)
Expand Down Expand Up @@ -1335,10 +1334,29 @@ public virtual Expression VisitLike(LikeExpression likeExpression)
/// </returns>
public virtual Expression VisitSqlFunction(SqlFunctionExpression sqlFunctionExpression)
{
GenerateFunctionCall(
sqlFunctionExpression.FunctionName,
sqlFunctionExpression.Arguments,
sqlFunctionExpression.Schema);
var parentTypeMapping = _typeMapping;
_typeMapping = null;

if (sqlFunctionExpression.Instance != null)
{
Visit(sqlFunctionExpression.Instance);
_relationalCommandBuilder.Append(".");
}
else if (!string.IsNullOrWhiteSpace(sqlFunctionExpression.Schema))
{
_relationalCommandBuilder
.Append(SqlGenerator.DelimitIdentifier(sqlFunctionExpression.Schema))
.Append(".");
}

_relationalCommandBuilder.Append(sqlFunctionExpression.FunctionName);
_relationalCommandBuilder.Append("(");

_typeMapping = null;
GenerateList(sqlFunctionExpression.Arguments);

_relationalCommandBuilder.Append(")");
_typeMapping = parentTypeMapping;

return sqlFunctionExpression;
}
Expand All @@ -1349,8 +1367,10 @@ public virtual Expression VisitSqlFunction(SqlFunctionExpression sqlFunctionExpr
/// <param name="functionName">The function name</param>
/// <param name="arguments">The function arguments</param>
/// <param name="schema">The function schema</param>
[Obsolete("Override VisitSqlFunction method instead.")]
protected virtual void GenerateFunctionCall(
[NotNull] string functionName, [NotNull] IReadOnlyList<Expression> arguments,
[NotNull] string functionName,
[NotNull] IReadOnlyList<Expression> arguments,
[CanBeNull] string schema = null)
{
Check.NotEmpty(functionName, nameof(functionName));
Expand Down Expand Up @@ -1439,16 +1459,13 @@ protected override Expression VisitUnary(UnaryExpression expression)
if (expression.Operand is ExistsExpression)
{
_relationalCommandBuilder.Append("NOT ");

Visit(expression.Operand);

return expression;
}

_relationalCommandBuilder.Append("NOT (");

Visit(expression.Operand);

_relationalCommandBuilder.Append(")");

return expression;
Expand All @@ -1460,11 +1477,9 @@ protected override Expression VisitUnary(UnaryExpression expression)

case ExpressionType.Negate:
_relationalCommandBuilder.Append("-");

Visit(expression.Operand);

return expression;

}

return base.VisitUnary(expression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public override Expression VisitSqlFunction(SqlFunctionExpression sqlFunctionExp
if (sqlFunctionExpression.FunctionName == "COUNT"
&& sqlFunctionExpression.Type == typeof(long))
{
GenerateFunctionCall("COUNT_BIG", sqlFunctionExpression.Arguments);
Visit(new SqlFunctionExpression("COUNT_BIG", typeof(long), sqlFunctionExpression.Arguments));

return sqlFunctionExpression;
}
Expand All @@ -125,12 +125,10 @@ protected override void GenerateProjection(Expression projection)
}

private Expression ExplicitCastToBool(Expression expression)
{
return (expression as BinaryExpression)?.NodeType == ExpressionType.Coalesce
&& expression.Type.UnwrapNullableType() == typeof(bool)
=> (expression as BinaryExpression)?.NodeType == ExpressionType.Coalesce
&& expression.Type.UnwrapNullableType() == typeof(bool)
? new ExplicitCastExpression(expression, expression.Type)
: expression;
}

private class RowNumberPagingExpressionVisitor : ExpressionVisitorBase
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.EntityFrameworkCore.Query
{
public class UdfDbFunctionSqlServerTests : IClassFixture<UdfDbFunctionSqlServerTests.SqlServerUDFFixture>
{
public UdfDbFunctionSqlServerTests(SqlServerUDFFixture fixture)
public UdfDbFunctionSqlServerTests(SqlServerUDFFixture fixture, ITestOutputHelper testOutputHelper)
{
Fixture = fixture;
Fixture.TestSqlLoggerFactory.Clear();
//Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

private SqlServerUDFFixture Fixture { get; }
Expand Down Expand Up @@ -1524,7 +1526,7 @@ protected override void Seed(DbContext context)
returns int
as
begin
return (select count(id) from orders where customerId = @customerId);
return (select count(id) from orders where customerId = @customerId);
end");

context.Database.ExecuteSqlCommand(@"create function[dbo].[StarValue] (@starCount int, @value nvarchar(max))
Expand All @@ -1545,28 +1547,28 @@ returns nvarchar(max)
returns DateTime
as
begin
return '1998-01-01'
return '1998-01-01'
end");

context.Database.ExecuteSqlCommand(@"create function [dbo].[GetCustomerWithMostOrdersAfterDate] (@searchDate Date)
returns int
as
begin
return (select top 1 customerId
from orders
where orderDate > @searchDate
group by CustomerId
order by count(id) desc)
return (select top 1 customerId
from orders
where orderDate > @searchDate
group by CustomerId
order by count(id) desc)
end");

context.Database.ExecuteSqlCommand(@"create function [dbo].[IsTopCustomer] (@customerId int)
returns bit
as
begin
if(@customerId = 1)
return 1
return 0
if(@customerId = 1)
return 1
return 0
end");

var order11 = new Order { Name = "Order11", ItemCount = 4, OrderDate = new DateTime(2000, 1, 20) };
Expand Down

0 comments on commit 3c57461

Please sign in to comment.