Skip to content

Commit

Permalink
Work on user-defined functions
Browse files Browse the repository at this point in the history
Function quoting logic was changed (for now): all function names are
quoted just like any other identifier (the default EF Core behavior
is only to quote schema-less functions). This meant changing all
expression translators to generate lower-case names (lower() instead
of LOWER()), and also to add a hardcoded list of functions which
EF Core generates in upper-case but which shouldn't get quoted
(e.g. SUM()).

See:
* dotnet/efcore#8507
* dotnet/efcore#12044
* dotnet/efcore#12757
* dotnet/efcore#9558
* dotnet/efcore#9303
  • Loading branch information
roji committed Aug 3, 2018
1 parent f50d2b3 commit 8f1c4c1
Show file tree
Hide file tree
Showing 19 changed files with 989 additions and 84 deletions.
6 changes: 3 additions & 3 deletions src/EFCore.PG.NodaTime/NodaTimeMemberTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Expression TranslateDateTime(MemberExpression e)

if (n == "Date")
{
return new SqlFunctionExpression("DATE_TRUNC", e.Type, new[]
return new SqlFunctionExpression("date_trunc", e.Type, new[]
{
Expression.Constant("day"),
e.Expression
Expand All @@ -84,14 +84,14 @@ static Expression GetDatePartExpression(MemberExpression e, string partName, boo
// DATE_PART returns doubles, which we floor and cast into ints
// This also gets rid of sub-second components when retrieving seconds

var result = new SqlFunctionExpression("DATE_PART", typeof(double), new[]
var result = new SqlFunctionExpression("date_part", typeof(double), new[]
{
Expression.Constant(partName),
e.Expression
});

if (needsFloor)
result = new SqlFunctionExpression("FLOOR", typeof(double), new[] { result });
result = new SqlFunctionExpression("floor", typeof(double), new[] { result });

return new ExplicitCastExpression(result, typeof(int));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)

return new CustomBinaryExpression(
methodCallExpression.Object,
new PgFunctionExpression("MAKE_INTERVAL", typeof(TimeSpan), new Dictionary<string, Expression>
new PgFunctionExpression("make_interval", typeof(TimeSpan), new Dictionary<string, Expression>
{
{ datePart, amountToAdd }
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public virtual Expression Translate(MemberExpression e)
return GetDatePartExpression(e, "dow");

case nameof(DateTime.Date):
return new SqlFunctionExpression("DATE_TRUNC", e.Type, new[]
return new SqlFunctionExpression("date_trunc", e.Type, new[]
{
Expression.Constant("day"),
e.Expression
Expand All @@ -102,9 +102,9 @@ static Expression GetDatePartExpression(MemberExpression e, string partName)
// DATE_PART returns doubles, which we floor and cast into ints
// This also gets rid of sub-second components when retrieving seconds
new ExplicitCastExpression(
new SqlFunctionExpression("FLOOR", typeof(double), new[]
new SqlFunctionExpression("floor", typeof(double), new[]
{
new SqlFunctionExpression("DATE_PART", typeof(double), new[]
new SqlFunctionExpression("date_part", typeof(double), new[]
{
Expression.Constant(partName),
e.Expression
Expand All @@ -116,9 +116,9 @@ static Expression GetDatePartExpression(MemberExpression e, string partName)
Expression TranslateStatic(MemberExpression e)
{
if (e.Member.Equals(Now))
return new SqlFunctionExpression("NOW", e.Type);
return new SqlFunctionExpression("now", e.Type);
if (e.Member.Equals(UtcNow))
return new AtTimeZoneExpression(new SqlFunctionExpression("NOW", e.Type), "UTC");
return new AtTimeZoneExpression(new SqlFunctionExpression("now", e.Type), "UTC");
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,38 +36,38 @@ public class NpgsqlMathTranslator : IMethodCallTranslator
{
static readonly Dictionary<MethodInfo, string> _supportedMethodTranslations = new Dictionary<MethodInfo, string>
{
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(decimal) }), "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(double) }), "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(float) }), "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(int) }), "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(long) }), "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(sbyte) }), "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(short) }), "ABS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(decimal) }), "CEILING" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(double) }), "CEILING" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(decimal) }), "FLOOR" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(double) }), "FLOOR" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Pow), new[] { typeof(double), typeof(double) }), "POWER" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Exp), new[] { typeof(double) }), "EXP" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log10), new[] { typeof(double) }), "LOG" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double) }), "LN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(decimal) }), "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(double) }), "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(float) }), "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(int) }), "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(long) }), "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(sbyte) }), "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(short) }), "abs" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(decimal) }), "ceiling" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(double) }), "ceiling" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(decimal) }), "floor" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(double) }), "floor" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Pow), new[] { typeof(double), typeof(double) }), "power" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Exp), new[] { typeof(double) }), "exp" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log10), new[] { typeof(double) }), "log" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double) }), "ln" },
// Disabled because PG only has log(x,y) for numeric
//{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double), typeof(double) }), "LOG" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), new[] { typeof(double) }), "SQRT" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Acos), new[] { typeof(double) }), "ACOS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Asin), new[] { typeof(double) }), "ASIN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan), new[] { typeof(double) }), "ATAN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), new[] { typeof(double), typeof(double) }), "ATAN2" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Cos), new[] { typeof(double) }), "COS" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sin), new[] { typeof(double) }), "SIN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Tan), new[] { typeof(double) }), "TAN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(decimal) }), "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(double) }), "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(float) }), "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(int) }), "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(long) }), "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(sbyte) }), "SIGN" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(short) }), "SIGN" }
//{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double), typeof(double) }), "log" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), new[] { typeof(double) }), "sqrt" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Acos), new[] { typeof(double) }), "acos" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Asin), new[] { typeof(double) }), "asin" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan), new[] { typeof(double) }), "atan" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), new[] { typeof(double), typeof(double) }), "atan2" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Cos), new[] { typeof(double) }), "cos" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sin), new[] { typeof(double) }), "sin" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Tan), new[] { typeof(double) }), "tan" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(decimal) }), "sign" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(double) }), "sign" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(float) }), "sign" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(int) }), "sign" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(long) }), "sign" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(sbyte) }), "sign" },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(short) }), "sign" }
};

static readonly IEnumerable<MethodInfo> _truncateMethodInfos = new[]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
}

return Expression.GreaterThan(
new SqlFunctionExpression("STRPOS", typeof(int), new[]
new SqlFunctionExpression("strpos", typeof(int), new[]
{
methodCallExpression.Object,
argument0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
=> ReferenceEquals(methodCallExpression.Method, _methodInfo)
? Expression.Equal(
new SqlFunctionExpression(
"RIGHT",
"right",
// ReSharper disable once PossibleNullReferenceException
methodCallExpression.Object.Type,
new[]
{
methodCallExpression.Object,
new SqlFunctionExpression("LENGTH", typeof(int), new[] { methodCallExpression.Arguments[0] })
new SqlFunctionExpression("length", typeof(int), new[] { methodCallExpression.Arguments[0] })
}
),
methodCallExpression.Arguments[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public virtual Expression Translate([NotNull] MethodCallExpression methodCallExp

return Expression.Subtract(
new SqlFunctionExpression(
"STRPOS",
"strpos",
methodCallExpression.Type,
new[] { methodCallExpression.Object }.Concat(methodCallExpression.Arguments)),
Expression.Constant(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public virtual Expression Translate([NotNull] MemberExpression memberExpression)
=> (memberExpression.Expression != null)
&& (memberExpression.Expression.Type == typeof(string))
&& (memberExpression.Member.Name == nameof(string.Length))
? new SqlFunctionExpression("LENGTH", memberExpression.Type, new[] { memberExpression.Expression })
? new SqlFunctionExpression("length", memberExpression.Type, new[] { memberExpression.Expression })
: null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static readonly MethodInfo _methodInfo
public virtual Expression Translate([NotNull] MethodCallExpression methodCallExpression)
=> _methodInfo.Equals(methodCallExpression.Method)
? new SqlFunctionExpression(
"REPLACE",
"replace",
methodCallExpression.Type,
new[] { methodCallExpression.Object }.Concat(methodCallExpression.Arguments))
: null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ public virtual Expression Translate(MethodCallExpression e)
// but then add another test to filter out false positives.
var pattern = e.Arguments[0];

Expression leftExpr = new SqlFunctionExpression("LEFT", typeof(string), new[]
Expression leftExpr = new SqlFunctionExpression("left", typeof(string), new[]
{
e.Object,
new SqlFunctionExpression("LENGTH", typeof(int), new[] { pattern }),
new SqlFunctionExpression("length", typeof(int), new[] { pattern }),
});

// If StartsWith is being invoked on a citext, the LEFT() function above will return a reglar text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class NpgsqlStringSubstringTranslator : IMethodCallTranslator
public virtual Expression Translate(MethodCallExpression methodCallExpression)
=> methodCallExpression.Method.Equals(_methodInfo)
? new SqlFunctionExpression(
"SUBSTRING",
"substring",
methodCallExpression.Type,
new[]
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte
public class NpgsqlStringToLowerTranslator : ParameterlessInstanceMethodCallTranslator
{
public NpgsqlStringToLowerTranslator()
: base(typeof(string), "ToLower", "LOWER")
: base(typeof(string), "ToLower", "lower")
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte
public class NpgsqlStringToUpperTranslator : ParameterlessInstanceMethodCallTranslator
{
public NpgsqlStringToUpperTranslator()
: base(typeof(string), "ToUpper", "UPPER")
: base(typeof(string), "ToUpper", "upper")
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)

var trimChar = (char)constantTrimChars.Value;
return new SqlFunctionExpression(
"RTRIM",
"rtrim",
typeof(string),
new[]
{
Expand Down Expand Up @@ -91,7 +91,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
{
// Trim whitespace
return new SqlFunctionExpression(
"REGEXP_REPLACE",
"regexp_replace",
typeof(string),
new[]
{
Expand All @@ -102,7 +102,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
}

return new SqlFunctionExpression(
"RTRIM",
"rtrim",
typeof(string),
new[]
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
return null; // Don't translate if trim chars isn't a constant
var trimChar = (char)constantTrimChars.Value;
return new SqlFunctionExpression(
"LTRIM",
"ltrim",
typeof(string),
new[]
{
Expand All @@ -87,7 +87,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
{
// Trim whitespace
return new SqlFunctionExpression(
"REGEXP_REPLACE",
"regexp_replace",
typeof(string),
new[]
{
Expand All @@ -98,7 +98,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
}

return new SqlFunctionExpression(
"LTRIM",
"ltrim",
typeof(string),
new[]
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
{
// Note that PostgreSQL TRIM() does spaces only, not all whitespace, so we use a regex
return new SqlFunctionExpression(
"REGEXP_REPLACE",
"regexp_replace",
typeof(string),
new[]
{
Expand All @@ -44,7 +44,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
return null;

return new SqlFunctionExpression(
"BTRIM",
"btrim",
typeof(string),
new[]
{
Expand All @@ -60,7 +60,7 @@ public virtual Expression Translate(MethodCallExpression methodCallExpression)
return null;

return new SqlFunctionExpression(
"BTRIM",
"btrim",
typeof(string),
new[]
{
Expand Down
31 changes: 31 additions & 0 deletions src/EFCore.PG/Query/Sql/Internal/NpgsqlQuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

using System;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Text.RegularExpressions;
using JetBrains.Annotations;
Expand Down Expand Up @@ -108,6 +109,36 @@ public override Expression VisitSqlFunction(SqlFunctionExpression sqlFunctionExp
return expr;
}

protected override void GenerateSqlFunctionName(SqlFunctionExpression sqlFunctionExpression)
{
if (sqlFunctionExpression.Instance != null)
{
Visit(sqlFunctionExpression.Instance);

Sql.Append(".");
}

if (string.IsNullOrWhiteSpace(sqlFunctionExpression.Schema))
{
// Special exception for some built-in functions which should never get quoted
if (NonQuotableFunctions.Contains(sqlFunctionExpression.FunctionName))
{
Sql.Append(sqlFunctionExpression.FunctionName);
return;
}
}
else
{
Sql
.Append(SqlGenerator.DelimitIdentifier(sqlFunctionExpression.Schema))
.Append(".");
}

Sql.Append(SqlGenerator.DelimitIdentifier(sqlFunctionExpression.FunctionName));
}

static readonly string[] NonQuotableFunctions = { "COUNT", "AVG", "SUM", "MAX", "MIN" };

protected override Expression VisitBinary(BinaryExpression expression)
{
switch (expression.NodeType)
Expand Down
Loading

0 comments on commit 8f1c4c1

Please sign in to comment.