Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLite: implement sum and average aggregation for decimal #33721

Merged
merged 5 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ protected virtual SqlExpression VisitRegexp(
return regexpExpression.Update(match, pattern);
}

/// <inheritdoc/>
protected override SqlExpression VisitSqlFunction(
SqlFunctionExpression sqlFunctionExpression,
bool allowOptimizedExpansion,
out bool nullable)
{
var result = base.VisitSqlFunction(sqlFunctionExpression, allowOptimizedExpansion, out nullable);

if (result is SqlFunctionExpression resultFunctionExpression
&& resultFunctionExpression.IsBuiltIn
&& string.Equals(resultFunctionExpression.Name, "ef_sum", StringComparison.OrdinalIgnoreCase))
cincuranet marked this conversation as resolved.
Show resolved Hide resolved
{
nullable = false;

var sqlExpressionFactory = Dependencies.SqlExpressionFactory;
return sqlExpressionFactory.Coalesce(
result,
sqlExpressionFactory.Constant(0, resultFunctionExpression.TypeMapping),
resultFunctionExpression.TypeMapping);
}

return result;
}

#pragma warning disable EF1001
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
var averageArgumentType = GetProviderType(averageSqlExpression);
if (averageArgumentType == typeof(decimal))
{
throw new NotSupportedException(
SqliteStrings.AggregateOperationNotSupported(
nameof(Queryable.Average), averageArgumentType.ShortDisplayName()));
averageSqlExpression = CombineTerms(source, averageSqlExpression);
return _sqlExpressionFactory.Function(
"ef_avg",
[averageSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
averageSqlExpression.Type,
averageSqlExpression.TypeMapping);
}

break;
Expand Down Expand Up @@ -100,8 +105,14 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
var sumArgumentType = GetProviderType(sumSqlExpression);
if (sumArgumentType == typeof(decimal))
{
throw new NotSupportedException(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), sumArgumentType.ShortDisplayName()));
sumSqlExpression = CombineTerms(source, sumSqlExpression);
return _sqlExpressionFactory.Function(
"ef_sum",
[sumSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
sumSqlExpression.Type,
sumSqlExpression.TypeMapping);
}

break;
Expand All @@ -115,4 +126,21 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
=> expression.TypeMapping?.Converter?.ProviderClrType
?? expression.TypeMapping?.ClrType
?? expression.Type;

private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, SqlExpression sqlExpression)
{
if (enumerableExpression.Predicate != null)
{
sqlExpression = _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(enumerableExpression.Predicate, sqlExpression) },
elseResult: null);
}

if (enumerableExpression.IsDistinct)
{
sqlExpression = new DistinctExpression(sqlExpression);
}

return sqlExpression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ private void InitializeDbConnection(DbConnection connection)
name: "ef_negate",
(decimal? m) => -m,
isDeterministic: true);

sqliteConnection.CreateAggregate(
"ef_avg",
seed: (0m, 0ul),
((decimal sum, ulong count) acc, decimal? value) => value is null
? acc
: (acc.sum + value.Value, acc.count + 1),
((decimal sum, ulong count) acc) => acc.count == 0
? default(decimal?)
: acc.sum / acc.count,
isDeterministic: true);

sqliteConnection.CreateAggregate(
"ef_sum",
seed: null,
(decimal? sum, decimal? value) => value is null
? sum
: sum is null ? value : sum.Value + value.Value,
isDeterministic: true);
}
else
{
Expand Down
22 changes: 10 additions & 12 deletions test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ public virtual void Cant_query_Max_of_converted_types()
}

[ConditionalFact]
public virtual void Cant_query_Average_of_converted_types()
public virtual void Can_query_Average_of_converted_types()
{
using var context = CreateContext();
context.Add(
Expand All @@ -958,15 +958,14 @@ public virtual void Cant_query_Average_of_converted_types()
context.SaveChanges();

Assert.Equal(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Average), typeof(decimal).ShortDisplayName()),
Assert.Throws<NotSupportedException>(
() => context.Set<BuiltInNullableDataTypes>()
.Where(e => e.PartitionId == 202)
.Average(e => e.TestNullableDecimal)).Message);
1.000000000000002m,
context.Set<BuiltInNullableDataTypes>()
.Where(e => e.PartitionId == 202)
.Average(e => e.TestNullableDecimal));
}

[ConditionalFact]
public virtual void Cant_query_Sum_of_converted_types()
public virtual void Can_query_Sum_of_converted_types()
{
using var context = CreateContext();
context.Add(
Expand All @@ -988,11 +987,10 @@ public virtual void Cant_query_Sum_of_converted_types()
context.SaveChanges();

Assert.Equal(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), typeof(decimal).ShortDisplayName()),
Assert.Throws<NotSupportedException>(
() => context.Set<BuiltInDataTypes>()
.Where(e => e.PartitionId == 203)
.Sum(e => e.TestDecimal)).Message);
2.000000000000002m,
context.Set<BuiltInDataTypes>()
.Where(e => e.PartitionId == 203)
.Sum(e => e.TestDecimal));
}

[ConditionalFact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,68 @@ INSERT INTO ZeroKey VALUES (NULL)
""");

public override async Task Average_with_cast()
=> Assert.Equal(
SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
(await Assert.ThrowsAsync<NotSupportedException>(base.Average_with_cast)).Message);
{
await base.Average_with_cast();

AssertSql(
"""
SELECT "p"."Id", "p"."DecimalColumn", "p"."DoubleColumn", "p"."FloatColumn", "p"."IntColumn", "p"."LongColumn", "p"."NullableDecimalColumn", "p"."NullableDoubleColumn", "p"."NullableFloatColumn", "p"."NullableIntColumn", "p"."NullableLongColumn", "p"."Price"
FROM "Prices" AS "p"
""",
//
"""
SELECT ef_avg("p"."Price")
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG(CAST("p"."IntColumn" AS REAL))
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG(CAST("p"."NullableIntColumn" AS REAL))
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG(CAST("p"."LongColumn" AS REAL))
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG(CAST("p"."NullableLongColumn" AS REAL))
FROM "Prices" AS "p"
""",
//
"""
SELECT CAST(AVG("p"."FloatColumn") AS REAL)
FROM "Prices" AS "p"
""",
//
"""
SELECT CAST(AVG("p"."NullableFloatColumn") AS REAL)
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG("p"."DoubleColumn")
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG("p"."NullableDoubleColumn")
FROM "Prices" AS "p"
""",
//
"""
SELECT ef_avg("p"."DecimalColumn")
FROM "Prices" AS "p"
""",
//
"""
SELECT ef_avg("p"."NullableDecimalColumn")
FROM "Prices" AS "p"
""");
}
}
17 changes: 13 additions & 4 deletions test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ public Ef6GroupBySqliteTest(Ef6GroupBySqliteFixture fixture, ITestOutputHelper t
}

public override async Task Average_Grouped_from_LINQ_101(bool async)
=> Assert.Equal(
SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
(await Assert.ThrowsAsync<NotSupportedException>(
() => base.Average_Grouped_from_LINQ_101(async))).Message);
{
await base.Average_Grouped_from_LINQ_101(async);

AssertSql(
"""
SELECT "p"."Category", ef_avg("p"."UnitPrice") AS "AveragePrice"
FROM "ProductForLinq" AS "p"
GROUP BY "p"."Category"
""");
}

public override async Task Max_Grouped_from_LINQ_101(bool async)
=> Assert.Equal(
Expand Down Expand Up @@ -49,6 +55,9 @@ public override async Task Group_Join_from_LINQ_101(bool async)
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Group_Join_from_LINQ_101(async))).Message);

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

public class Ef6GroupBySqliteFixture : Ef6GroupByFixtureBase, ITestSqlLoggerFactory
{
public TestSqlLoggerFactory TestSqlLoggerFactory
Expand Down
Loading
Loading