Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trs79 master #967

Merged
merged 9 commits into from
Jul 2, 2019
67 changes: 41 additions & 26 deletions src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1535,33 +1535,42 @@ public override ISqlFragment Visit(DbGroupByExpression e)
var member = members.Current;
var alias = QuoteIdentifier(member.Name);

Debug.Assert(aggregate.Arguments.Count == 1);
var translatedAggregateArgument = aggregate.Arguments[0].Accept(this);
var finalArgs = new List<object>();

object aggregateArgument;

if (needsInnerQuery)
for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++)
{
//In this case the argument to the aggratete is reference to the one projected out by the
// inner query
var wrappingAggregateArgument = new SqlBuilder();
wrappingAggregateArgument.Append(fromSymbol);
wrappingAggregateArgument.Append(".");
wrappingAggregateArgument.Append(alias);
aggregateArgument = wrappingAggregateArgument;
var argument = aggregate.Arguments[childIndex];
var translatedAggregateArgument = argument.Accept(this);

innerQuery.Select.Append(separator);
innerQuery.Select.AppendLine();
innerQuery.Select.Append(translatedAggregateArgument);
innerQuery.Select.Append(" AS ");
innerQuery.Select.Append(alias);
}
else
{
aggregateArgument = translatedAggregateArgument;
object aggregateArgument;

if (needsInnerQuery)
{
var argAlias = QuoteIdentifier(member.Name + "_" + childIndex);

//In this case the argument to the aggratete is reference to the one projected out by the
// inner query
var wrappingAggregateArgument = new SqlBuilder();
wrappingAggregateArgument.Append(fromSymbol);
wrappingAggregateArgument.Append(".");
wrappingAggregateArgument.Append(argAlias);
aggregateArgument = wrappingAggregateArgument;

innerQuery.Select.Append(separator);
innerQuery.Select.AppendLine();
innerQuery.Select.Append(translatedAggregateArgument);
innerQuery.Select.Append(" AS ");
innerQuery.Select.Append(argAlias);
}
else
{
aggregateArgument = translatedAggregateArgument;
}

finalArgs.Add(aggregateArgument);
}

ISqlFragment aggregateResult = VisitAggregate(aggregate, aggregateArgument);
ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs);

result.Select.Append(separator);
result.Select.AppendLine();
Expand Down Expand Up @@ -2756,8 +2765,8 @@ public override ISqlFragment Visit(DbVariableReferenceExpression e)
// Aggregates are not visited by the normal visitor walk.
// </summary>
// <param name="aggregate"> The aggregate go be translated </param>
// <param name="aggregateArgument"> The translated aggregate argument </param>
private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregateArgument)
// <param name="aggregateArguments"> The translated aggregate arguments </param>
private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList<object> aggregateArguments)
{
var aggregateResult = new SqlBuilder();
var functionAggregate = aggregate as DbFunctionAggregate;
Expand Down Expand Up @@ -2788,7 +2797,13 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregate
aggregateResult.Append("DISTINCT ");
}

aggregateResult.Append(aggregateArgument);
string separator = String.Empty;
foreach (var arg in aggregateArguments)
{
aggregateResult.Append(separator);
aggregateResult.Append(arg);
separator = ", ";
}

aggregateResult.Append(")");
return aggregateResult;
Expand Down Expand Up @@ -4334,7 +4349,7 @@ private static bool GroupByAggregatesNeedInnerQuery(IList<DbAggregate> aggregate
{
foreach (var aggregate in aggregates)
{
Debug.Assert(aggregate.Arguments.Count == 1);
Debug.Assert(aggregate.Arguments.Count >= 1);
if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName))
{
return true;
Expand Down
67 changes: 41 additions & 26 deletions src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1310,33 +1310,42 @@ public override ISqlFragment Visit(DbGroupByExpression e)
var member = members.Current;
var alias = QuoteIdentifier(member.Name);

Debug.Assert(aggregate.Arguments.Count == 1);
var translatedAggregateArgument = aggregate.Arguments[0].Accept(this);
var finalArgs = new List<object>();

object aggregateArgument;

if (needsInnerQuery)
for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++)
{
//In this case the argument to the aggratete is reference to the one projected out by the
// inner query
var wrappingAggregateArgument = new SqlBuilder();
wrappingAggregateArgument.Append(fromSymbol);
wrappingAggregateArgument.Append(".");
wrappingAggregateArgument.Append(alias);
aggregateArgument = wrappingAggregateArgument;
var argument = aggregate.Arguments[childIndex];
var translatedAggregateArgument = argument.Accept(this);

innerQuery.Select.Append(separator);
innerQuery.Select.AppendLine();
innerQuery.Select.Append(translatedAggregateArgument);
innerQuery.Select.Append(" AS ");
innerQuery.Select.Append(alias);
}
else
{
aggregateArgument = translatedAggregateArgument;
object aggregateArgument;

if (needsInnerQuery)
{
var argAlias = QuoteIdentifier(member.Name + "_" + childIndex);

//In this case the argument to the aggratete is reference to the one projected out by the
// inner query
var wrappingAggregateArgument = new SqlBuilder();
wrappingAggregateArgument.Append(fromSymbol);
wrappingAggregateArgument.Append(".");
wrappingAggregateArgument.Append(argAlias);
aggregateArgument = wrappingAggregateArgument;

innerQuery.Select.Append(separator);
innerQuery.Select.AppendLine();
innerQuery.Select.Append(translatedAggregateArgument);
innerQuery.Select.Append(" AS ");
innerQuery.Select.Append(argAlias);
}
else
{
aggregateArgument = translatedAggregateArgument;
}

finalArgs.Add(aggregateArgument);
}

ISqlFragment aggregateResult = VisitAggregate(aggregate, aggregateArgument);
ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs);

result.Select.Append(separator);
result.Select.AppendLine();
Expand Down Expand Up @@ -2103,8 +2112,8 @@ public override ISqlFragment Visit(DbVariableReferenceExpression e)
// Aggregates are not visited by the normal visitor walk.
// </summary>
// <param name="aggregate"> The aggreate go be translated </param>
// <param name="aggregateArgument"> The translated aggregate argument </param>
private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregateArgument)
// <param name="aggregateArguments"> The translated aggregate arguments </param>
private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList<object> aggregateArguments)
{
var aggregateFunction = new SqlBuilder();
var aggregateResult = new SqlBuilder();
Expand Down Expand Up @@ -2134,7 +2143,13 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregate
throw ADP1.NotSupported(EntityRes.GetString(EntityRes.DistinctAggregatesNotSupported));
}

aggregateResult.Append(aggregateArgument);
string separator = String.Empty;
foreach (var arg in aggregateArguments)
{
aggregateResult.Append(separator);
aggregateResult.Append(arg);
separator = ", ";
}

aggregateResult.Append(")");

Expand Down Expand Up @@ -4508,7 +4523,7 @@ private static bool GroupByAggregatesNeedInnerQuery(IList<DbAggregate> aggregate
{
foreach (var aggregate in aggregates)
{
Debug.Assert(aggregate.Arguments.Count == 1);
Debug.Assert(aggregate.Arguments.Count >= 1);
if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName))
{
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal DbAggregate(TypeUsage resultType, DbExpressionList arguments)
{
DebugCheck.NotNull(resultType);
DebugCheck.NotNull(arguments);
Debug.Assert(arguments.Count == 1, "DbAggregate requires a single argument");
Debug.Assert(arguments.Count >= 1, "DbAggregate requires at least one argument");

_type = resultType;
_args = arguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,17 @@ protected virtual DbFunctionAggregate VisitFunctionAggregate(DbFunctionAggregate
var newFunction = VisitFunction(aggregate.Function);
var newArguments = VisitExpressionList(aggregate.Arguments);

Debug.Assert(newArguments.Count == 1, "Function aggregate had more than one argument?");

if (!ReferenceEquals(aggregate.Function, newFunction)
||
!ReferenceEquals(aggregate.Arguments, newArguments))
{
if (aggregate.Distinct)
{
result = CqtBuilder.AggregateDistinct(newFunction, newArguments[0]);
result = CqtBuilder.AggregateDistinct(newFunction, newArguments);
}
else
{
result = CqtBuilder.Aggregate(newFunction, newArguments[0]);
result = CqtBuilder.Aggregate(newFunction, newArguments);
}
}
}
Expand All @@ -212,7 +210,6 @@ protected virtual DbGroupAggregate VisitGroupAggregate(DbGroupAggregate aggregat
if (aggregate != null)
{
var newArguments = VisitExpressionList(aggregate.Arguments);
Debug.Assert(newArguments.Count == 1, "Group aggregate had more than one argument?");

if (!ReferenceEquals(aggregate.Arguments, newArguments))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,55 @@ private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function,
return new DbFunctionAggregate(resultType, funcArgs, function, isDistinct);
}

/// <summary>
/// Creates a new <see cref="T:System.Data.Entity.Core.Common.CommandTrees.DbFunctionAggregate" />.
/// </summary>
/// <returns>A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value false.</returns>
/// <param name="function">The function that defines the aggregate operation.</param>
/// <param name="arguments">The argument over which the aggregate function should be calculated.</param>
/// <exception cref="T:System.ArgumentNullException">function or argument null.</exception>
/// <exception cref="T:System.ArgumentException">function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function.</exception>
public static DbFunctionAggregate Aggregate(this EdmFunction function, IEnumerable<DbExpression> arguments)
{
Check.NotNull(function, "function");
Check.NotNull(arguments, "argument");

if (arguments.Any() == false)
{
throw new ArgumentNullException("arguments");
}

return CreateFunctionAggregate(function, arguments, false);
}

/// <summary>
/// Creates a new <see cref="T:System.Data.Entity.Core.Common.CommandTrees.DbFunctionAggregate" /> that is applied in a distinct fashion.
/// </summary>
/// <returns>A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value true.</returns>
/// <param name="function">The function that defines the aggregate operation.</param>
/// <param name="arguments">The arguments over which the aggregate function should be calculated.</param>
/// <exception cref="T:System.ArgumentNullException">function or argument is null.</exception>
/// <exception cref="T:System.ArgumentException">function is not an aggregate function, or the result type of argument is not equal or promotable to the parameter type of function.</exception>
public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, IEnumerable<DbExpression> arguments)
{
Check.NotNull(function, "function");
Check.NotNull(arguments, "argument");

if (arguments.Any() == false)
{
throw new ArgumentNullException("arguments");
}

return CreateFunctionAggregate(function, arguments, true);
}

private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, IEnumerable<DbExpression> arguments, bool isDistinct)
{
var funcArgs = ArgumentValidation.ValidateFunctionAggregate(function, arguments);
var resultType = function.ReturnParameter.TypeUsage;
return new DbFunctionAggregate(resultType, funcArgs, function, isDistinct);
}

/// <summary>
/// Creates a new <see cref="DbGroupAggregate" /> over the specified argument
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ public override void Visit(DbGroupByExpression e)
if (ga != null)
{
_key.Append("GA(");
Debug.Assert(ga.Arguments.Count == 1, "Group aggregate must have one argument.");
Debug.Assert(ga.Arguments.Count >= 1, "Group aggregate must have at least one argument.");
ga.Arguments[0].Accept(this);
_key.Append(')');
}
Expand Down
10 changes: 5 additions & 5 deletions src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1373,10 +1373,10 @@ private static bool TryConvertAsFunctionAggregate(
"argument types resolved for the collection aggregate calls must match");
}

//
// Aggregate functions must have at least one argument, and the first argument must be of collection edmType
//
// Aggregate functions can have only one argument and of collection edmType
//
Debug.Assert((1 == functionType.Parameters.Count), "(1 == functionType.Parameters.Count)");
Debug.Assert((1 <= functionType.Parameters.Count), "(1 <= functionType.Parameters.Count)");
// we only support monadic aggregate functions
Debug.Assert(
TypeSemantics.IsCollectionType(functionType.Parameters[0].TypeUsage), "functionType.Parameters[0].Type is CollectionType");
Expand All @@ -1394,11 +1394,11 @@ private static bool TryConvertAsFunctionAggregate(
if (methodExpr.DistinctKind
== DistinctKind.Distinct)
{
functionAggregate = functionType.AggregateDistinct(args[0]);
functionAggregate = functionType.AggregateDistinct(args);
}
else
{
functionAggregate = functionType.Aggregate(args[0]);
functionAggregate = functionType.Aggregate(args);
}

//
Expand Down
16 changes: 12 additions & 4 deletions src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private void Process()
[SuppressMessage("Microsoft.Globalization", "CA1303:Do not pass literals as localized parameters",
MessageId = "System.Data.Entity.Core.Query.PlanCompiler.PlanCompiler.Assert(System.Boolean,System.String)")]
private void TryProcessCandidate(
KeyValuePair<Node, Node> candidate,
KeyValuePair<Node, List<Node>> candidate,
GroupAggregateVarInfo groupAggregateVarInfo)
{
IList<Node> functionAncestors;
Expand All @@ -100,15 +100,23 @@ private void TryProcessCandidate(
// Remap the template from referencing the groupAggregate var to reference the input to
// the group by into
//
var argumentNode = OpCopier.Copy(m_command, candidate.Value);
var dictionary = new Dictionary<Var, Var>(1);
dictionary.Add(groupAggregateVarInfo.GroupAggregateVar, inputVar);
var remapper = new VarRemapper(m_command, dictionary);
remapper.RemapSubtree(argumentNode);

var argNodes = new List<Node>(candidate.Value.Count);

foreach (var argumentNode in candidate.Value)
{
var argumentNodeCopy = OpCopier.Copy(m_command, argumentNode);
remapper.RemapSubtree(argumentNodeCopy);

argNodes.Add(argumentNodeCopy);
}

var newFunctionDefiningNode = m_command.CreateNode(
m_command.CreateAggregateOp(functionOp.Function, false),
argumentNode);
argNodes);

Var newFunctionVar;
var varDefNode = m_command.CreateVarDefNode(newFunctionDefiningNode, out newFunctionVar);
Expand Down
Loading