From 912402d8cc7e7202028f33e331fe13ee496db896 Mon Sep 17 00:00:00 2001 From: Tim Stowell Date: Tue, 18 Oct 2016 12:39:07 -0600 Subject: [PATCH 1/9] Removed the limitations on using only one argument for custom aggregate functions. The first argument does still need to be a collection type, however. --- .../SqlGen/SqlGenerator.cs | 100 ++++++++++-------- .../SqlGen/SqlGenerator.cs | 74 +++++++------ .../CommandTrees/DefaultExpressionVisitor.cs | 6 +- .../ExpressionBuilder/DbExpressionBuilder.cs | 25 +++-- .../Core/Common/EntitySql/SemanticAnalyzer.cs | 8 +- .../Query/PlanCompiler/AggregatePushdown.cs | 21 ++-- .../Core/Query/PlanCompiler/CTreeGenerator.cs | 22 ++-- .../GroupAggregateRefComputingVisitor.cs | 55 ++++++---- .../PlanCompiler/GroupAggregateVarInfo.cs | 6 +- .../Query/PlanCompiler/PlanCompilerUtil.cs | 6 +- 10 files changed, 188 insertions(+), 135 deletions(-) diff --git a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs index c8155135ee..12b181e235 100644 --- a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs @@ -1532,46 +1532,52 @@ public override ISqlFragment Visit(DbGroupByExpression e) foreach (var aggregate in e.Aggregates) { - var member = members.Current; - var alias = QuoteIdentifier(member.Name); - - Debug.Assert(aggregate.Arguments.Count == 1); - var translatedAggregateArgument = aggregate.Arguments[0].Accept(this); - - object aggregateArgument; - - if (needsInnerQuery) - { - //In this case the argument to the aggratete is reference to the one projected out by the - // inner query - var wrappingAggregateArgument = new SqlBuilder(); - wrappingAggregateArgument.Append(fromSymbol); - wrappingAggregateArgument.Append("."); - wrappingAggregateArgument.Append(alias); - aggregateArgument = wrappingAggregateArgument; - - innerQuery.Select.Append(separator); - innerQuery.Select.AppendLine(); - innerQuery.Select.Append(translatedAggregateArgument); - innerQuery.Select.Append(" AS "); - innerQuery.Select.Append(alias); - } - else - { - aggregateArgument = translatedAggregateArgument; - } - - ISqlFragment aggregateResult = VisitAggregate(aggregate, aggregateArgument); - - result.Select.Append(separator); - result.Select.AppendLine(); - result.Select.Append(aggregateResult); - result.Select.Append(" AS "); - result.Select.Append(alias); - - separator = ", "; - members.MoveNext(); - } + var member = members.Current; + var alias = QuoteIdentifier(member.Name); + + var finalArgs = new List(); + + foreach (var argument in aggregate.Arguments) + { + var translatedAggregateArgument = argument.Accept(this); + + object aggregateArgument; + + if (needsInnerQuery) + { + //In this case the argument to the aggratete is reference to the one projected out by the + // inner query + var wrappingAggregateArgument = new SqlBuilder(); + wrappingAggregateArgument.Append(fromSymbol); + wrappingAggregateArgument.Append("."); + wrappingAggregateArgument.Append(alias); + aggregateArgument = wrappingAggregateArgument; + + innerQuery.Select.Append(separator); + innerQuery.Select.AppendLine(); + innerQuery.Select.Append(translatedAggregateArgument); + innerQuery.Select.Append(" AS "); + innerQuery.Select.Append(alias); + } + else + { + aggregateArgument = translatedAggregateArgument; + } + + finalArgs.Add(aggregateArgument); + } + + ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs); + + result.Select.Append(separator); + result.Select.AppendLine(); + result.Select.Append(aggregateResult); + result.Select.Append(" AS "); + result.Select.Append(alias); + + separator = ", "; + members.MoveNext(); + } } symbolTable.ExitScope(); @@ -2756,8 +2762,8 @@ public override ISqlFragment Visit(DbVariableReferenceExpression e) // Aggregates are not visited by the normal visitor walk. // // The aggregate go be translated - // The translated aggregate argument - private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregateArgument) + // The translated aggregate arguments + private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList aggregateArguments) { var aggregateResult = new SqlBuilder(); var functionAggregate = aggregate as DbFunctionAggregate; @@ -2788,9 +2794,15 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregate aggregateResult.Append("DISTINCT "); } - aggregateResult.Append(aggregateArgument); + string separator = String.Empty; + foreach (var arg in aggregateArguments) + { + aggregateResult.Append(separator); + aggregateResult.Append(arg); + separator = ", "; + } - aggregateResult.Append(")"); + aggregateResult.Append(")"); return aggregateResult; } diff --git a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs index 705a76a24a..c94b93c865 100644 --- a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs @@ -1310,33 +1310,39 @@ public override ISqlFragment Visit(DbGroupByExpression e) var member = members.Current; var alias = QuoteIdentifier(member.Name); - Debug.Assert(aggregate.Arguments.Count == 1); - var translatedAggregateArgument = aggregate.Arguments[0].Accept(this); - - object aggregateArgument; - - if (needsInnerQuery) - { - //In this case the argument to the aggratete is reference to the one projected out by the - // inner query - var wrappingAggregateArgument = new SqlBuilder(); - wrappingAggregateArgument.Append(fromSymbol); - wrappingAggregateArgument.Append("."); - wrappingAggregateArgument.Append(alias); - aggregateArgument = wrappingAggregateArgument; - - innerQuery.Select.Append(separator); - innerQuery.Select.AppendLine(); - innerQuery.Select.Append(translatedAggregateArgument); - innerQuery.Select.Append(" AS "); - innerQuery.Select.Append(alias); - } - else - { - aggregateArgument = translatedAggregateArgument; - } - - ISqlFragment aggregateResult = VisitAggregate(aggregate, aggregateArgument); + var finalArgs = new List(); + + foreach (var argument in aggregate.Arguments) + { + var translatedAggregateArgument = argument.Accept(this); + + object aggregateArgument; + + if (needsInnerQuery) + { + //In this case the argument to the aggratete is reference to the one projected out by the + // inner query + var wrappingAggregateArgument = new SqlBuilder(); + wrappingAggregateArgument.Append(fromSymbol); + wrappingAggregateArgument.Append("."); + wrappingAggregateArgument.Append(alias); + aggregateArgument = wrappingAggregateArgument; + + innerQuery.Select.Append(separator); + innerQuery.Select.AppendLine(); + innerQuery.Select.Append(translatedAggregateArgument); + innerQuery.Select.Append(" AS "); + innerQuery.Select.Append(alias); + } + else + { + aggregateArgument = translatedAggregateArgument; + } + + finalArgs.Add(aggregateArgument); + } + + ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs); result.Select.Append(separator); result.Select.AppendLine(); @@ -2103,8 +2109,8 @@ public override ISqlFragment Visit(DbVariableReferenceExpression e) // Aggregates are not visited by the normal visitor walk. // // The aggreate go be translated - // The translated aggregate argument - private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregateArgument) + // The translated aggregate arguments + private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList aggregateArguments) { var aggregateFunction = new SqlBuilder(); var aggregateResult = new SqlBuilder(); @@ -2134,9 +2140,15 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregate throw ADP1.NotSupported(EntityRes.GetString(EntityRes.DistinctAggregatesNotSupported)); } - aggregateResult.Append(aggregateArgument); + string separator = String.Empty; + foreach (var arg in aggregateArguments) + { + aggregateResult.Append(separator); + aggregateResult.Append(arg); + separator = ", "; + } - aggregateResult.Append(")"); + aggregateResult.Append(")"); if (fCast) { diff --git a/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs b/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs index 58546e2b9a..94b084bf69 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs @@ -184,19 +184,17 @@ protected virtual DbFunctionAggregate VisitFunctionAggregate(DbFunctionAggregate var newFunction = VisitFunction(aggregate.Function); var newArguments = VisitExpressionList(aggregate.Arguments); - Debug.Assert(newArguments.Count == 1, "Function aggregate had more than one argument?"); - if (!ReferenceEquals(aggregate.Function, newFunction) || !ReferenceEquals(aggregate.Arguments, newArguments)) { if (aggregate.Distinct) { - result = CqtBuilder.AggregateDistinct(newFunction, newArguments[0]); + result = CqtBuilder.AggregateDistinct(newFunction, newArguments); } else { - result = CqtBuilder.Aggregate(newFunction, newArguments[0]); + result = CqtBuilder.Aggregate(newFunction, newArguments); } } } diff --git a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs index f63c8c5945..7081cef379 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs @@ -208,12 +208,15 @@ public static DbGroupExpressionBinding GroupBindAs(this DbExpression input, stri /// The argument over which the aggregate function should be calculated. /// function or argument null. /// function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function. - public static DbFunctionAggregate Aggregate(this EdmFunction function, DbExpression argument) + public static DbFunctionAggregate Aggregate(this EdmFunction function, IList arguments) { Check.NotNull(function, "function"); - Check.NotNull(argument, "argument"); + if (arguments.Count == 0) + { + throw new ArgumentNullException("arguments"); + } - return CreateFunctionAggregate(function, argument, false); + return CreateFunctionAggregate(function, arguments, false); } /// @@ -221,20 +224,22 @@ public static DbFunctionAggregate Aggregate(this EdmFunction function, DbExpress /// /// A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value true. /// The function that defines the aggregate operation. - /// The argument over which the aggregate function should be calculated. + /// The arguments over which the aggregate function should be calculated. /// function or argument is null. - /// function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function. - public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, DbExpression argument) + /// function is not an aggregate function, or the result type of argument is not equal or promotable to the parameter type of function. + public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, IList arguments) { Check.NotNull(function, "function"); - Check.NotNull(argument, "argument"); + if (arguments.Count == 0) { + throw new ArgumentNullException("arguments"); + } - return CreateFunctionAggregate(function, argument, true); + return CreateFunctionAggregate(function, arguments, true); } - private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, DbExpression argument, bool isDistinct) + private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, IList arguments, bool isDistinct) { - var funcArgs = ArgumentValidation.ValidateFunctionAggregate(function, new[] { argument }); + var funcArgs = ArgumentValidation.ValidateFunctionAggregate(function, arguments); var resultType = function.ReturnParameter.TypeUsage; return new DbFunctionAggregate(resultType, funcArgs, function, isDistinct); } diff --git a/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs b/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs index 24c941e3bc..91bd1ca5fd 100644 --- a/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs +++ b/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs @@ -1373,10 +1373,6 @@ private static bool TryConvertAsFunctionAggregate( "argument types resolved for the collection aggregate calls must match"); } - // - // Aggregate functions can have only one argument and of collection edmType - // - Debug.Assert((1 == functionType.Parameters.Count), "(1 == functionType.Parameters.Count)"); // we only support monadic aggregate functions Debug.Assert( TypeSemantics.IsCollectionType(functionType.Parameters[0].TypeUsage), "functionType.Parameters[0].Type is CollectionType"); @@ -1394,11 +1390,11 @@ private static bool TryConvertAsFunctionAggregate( if (methodExpr.DistinctKind == DistinctKind.Distinct) { - functionAggregate = functionType.AggregateDistinct(args[0]); + functionAggregate = functionType.AggregateDistinct(args); } else { - functionAggregate = functionType.Aggregate(args[0]); + functionAggregate = functionType.Aggregate(args); } // diff --git a/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs b/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs index 9a39acfbf3..ebced17245 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs @@ -76,7 +76,7 @@ private void Process() [SuppressMessage("Microsoft.Globalization", "CA1303:Do not pass literals as localized parameters", MessageId = "System.Data.Entity.Core.Query.PlanCompiler.PlanCompiler.Assert(System.Boolean,System.String)")] private void TryProcessCandidate( - KeyValuePair candidate, + KeyValuePair> candidate, GroupAggregateVarInfo groupAggregateVarInfo) { IList functionAncestors; @@ -100,17 +100,24 @@ private void TryProcessCandidate( // Remap the template from referencing the groupAggregate var to reference the input to // the group by into // - var argumentNode = OpCopier.Copy(m_command, candidate.Value); var dictionary = new Dictionary(1); dictionary.Add(groupAggregateVarInfo.GroupAggregateVar, inputVar); var remapper = new VarRemapper(m_command, dictionary); - remapper.RemapSubtree(argumentNode); - var newFunctionDefiningNode = m_command.CreateNode( - m_command.CreateAggregateOp(functionOp.Function, false), - argumentNode); + var argNodes = new List(candidate.Value.Count); - Var newFunctionVar; + foreach (var argumentNode in candidate.Value) { + var argumentNodeCopy = OpCopier.Copy(m_command, argumentNode); + remapper.RemapSubtree(argumentNodeCopy); + + argNodes.Add(argumentNodeCopy); + } + + var newFunctionDefiningNode = m_command.CreateNode( + m_command.CreateAggregateOp(functionOp.Function, false), + argNodes); + + Var newFunctionVar; var varDefNode = m_command.CreateVarDefNode(newFunctionDefiningNode, out newFunctionVar); // Add the new aggregate to the list of aggregates diff --git a/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs b/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs index 1ee8b5bb12..979361a6b5 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs @@ -1886,24 +1886,32 @@ public override DbExpression Visit(GroupByOp op, Node n) PlanCompiler.Assert(aggRootNode.Op is VarDefListOp, "Invalid Aggregates VarDefListOp Node encountered in GroupByOp"); foreach (var aggVarDefNode in aggRootNode.Children) { - var aggVarDef = aggVarDefNode.Op as VarDefOp; + var aggVarDef = aggVarDefNode.Op as VarDefOp; PlanCompiler.Assert(aggVarDef != null, "Non-VarDefOp Node encountered as child of Aggregates VarDefListOp Node"); var aggVar = aggVarDef.Var; - PlanCompiler.Assert(aggVar is ComputedVar, "Non-ComputedVar encountered in Aggregate VarDefOp"); + PlanCompiler.Assert(aggVar is ComputedVar, "Non-ComputedVar encountered in Aggregate VarDefOp"); var aggOpNode = aggVarDefNode.Child0; - var aggDef = VisitNode(aggOpNode.Child0); - var funcAggOp = aggOpNode.Op as AggregateOp; - PlanCompiler.Assert(funcAggOp != null, "Non-Aggregate Node encountered as child of Aggregate VarDefOp Node"); + + // Loop through arguments + var args = new List(); + + foreach (var argumentNode in aggOpNode.Children) { + var aggDef = VisitNode(argumentNode); + args.Add(aggDef); + } + + var funcAggOp = aggOpNode.Op as AggregateOp; + PlanCompiler.Assert(funcAggOp != null, "Non-Aggregate Node encountered as child of Aggregate VarDefOp Node"); DbFunctionAggregate newFuncAgg; if (funcAggOp.IsDistinctAggregate) { - newFuncAgg = funcAggOp.AggFunc.AggregateDistinct(aggDef); + newFuncAgg = funcAggOp.AggFunc.AggregateDistinct(args); } else { - newFuncAgg = funcAggOp.AggFunc.Aggregate(aggDef); + newFuncAgg = funcAggOp.AggFunc.Aggregate(args); } PlanCompiler.Assert(outputAggVars.Contains(aggVar), "Defined aggregate Var not in Output Aggregate Vars list?"); diff --git a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs index d00b897765..ab49252c65 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs @@ -149,25 +149,42 @@ public override void Visit(UnnestOp op, Node n) MessageId = "System.Data.Entity.Core.Query.PlanCompiler.PlanCompiler.Assert(System.Boolean,System.String)")] public override void Visit(FunctionOp op, Node n) { - VisitDefault(n); - if (!PlanCompilerUtil.IsCollectionAggregateFunction(op, n)) - { - return; - } - PlanCompiler.Assert(n.Children.Count == 1, "Aggregate Function must have one argument"); - - GroupAggregateVarInfo referencedGroupAggregateVarInfo; - Node templateNode; - bool isUnnested; - if (GroupAggregateVarComputationTranslator.TryTranslateOverGroupAggregateVar( - n.Child0, false, _command, _groupAggregateVarInfoManager, out referencedGroupAggregateVarInfo, out templateNode, - out isUnnested) - && - (isUnnested || AggregatePushdownUtil.IsVarRefOverGivenVar(templateNode, referencedGroupAggregateVarInfo.GroupAggregateVar))) - { - referencedGroupAggregateVarInfo.CandidateAggregateNodes.Add(new KeyValuePair(n, templateNode)); - } - } + VisitDefault(n); + if (!PlanCompilerUtil.IsCollectionAggregateFunction(op, n)) + { + return; + } + GroupAggregateVarInfo referencedGroupAggregateVarInfo; + GroupAggregateVarInfo referencedGroupAggregateVarInfoTracker = null; + + Node templateNode; + bool isUnnested; + + var list = new List(); + + foreach (var argument in n.Children) + { + if (GroupAggregateVarComputationTranslator.TryTranslateOverGroupAggregateVar( + argument, false, _command, _groupAggregateVarInfoManager, out referencedGroupAggregateVarInfo, out templateNode, + out isUnnested) + && + (isUnnested || AggregatePushdownUtil.IsVarRefOverGivenVar(templateNode, referencedGroupAggregateVarInfo.GroupAggregateVar))) + { + + referencedGroupAggregateVarInfoTracker = referencedGroupAggregateVarInfo; + list.Add(templateNode); + } + else + { + list.Add(argument); + } + } + + if (referencedGroupAggregateVarInfoTracker != null) + { + referencedGroupAggregateVarInfoTracker.CandidateAggregateNodes.Add(new KeyValuePair>(n, list)); + } + } #endregion diff --git a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs index 89c6757ea6..d5428e9530 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs @@ -15,7 +15,7 @@ internal class GroupAggregateVarInfo #region Private Fields private readonly Node _definingGroupByNode; - private HashSet> _candidateAggregateNodes; + private HashSet>> _candidateAggregateNodes; private readonly Var _groupAggregateVar; #endregion @@ -43,13 +43,13 @@ internal GroupAggregateVarInfo(Node defingingGroupNode, Var groupAggregateVar) // A valid candidate has an argument that does not have any external references // except for the group aggregate corresponding to the DefiningGroupNode. // - internal HashSet> CandidateAggregateNodes + internal HashSet>> CandidateAggregateNodes { get { if (_candidateAggregateNodes == null) { - _candidateAggregateNodes = new HashSet>(); + _candidateAggregateNodes = new HashSet>>(); } return _candidateAggregateNodes; } diff --git a/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs b/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs index bbd2c91fe4..72eea48537 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs @@ -60,8 +60,7 @@ internal static bool IsRowTypeCaseOpWithNullability(CaseOp op, Node n, out bool // // Is this function a collection aggregate function. It is, if - // - it has exactly one child - // - that child is a collection type + // - the first child is a collection type // - and the function has been marked with the aggregate attribute // // the function op @@ -69,8 +68,7 @@ internal static bool IsRowTypeCaseOpWithNullability(CaseOp op, Node n, out bool // true, if this was a collection aggregate function internal static bool IsCollectionAggregateFunction(FunctionOp op, Node n) { - return ((n.Children.Count == 1) && - TypeSemantics.IsCollectionType(n.Child0.Op.Type) && + return (TypeSemantics.IsCollectionType(n.Child0.Op.Type) && TypeSemantics.IsAggregateFunction(op.Function)); } From 1381b66ccfed61b3cbba97be10a67ae006dbe957 Mon Sep 17 00:00:00 2001 From: Tim Stowell Date: Thu, 20 Oct 2016 14:22:56 -0600 Subject: [PATCH 2/9] Incorporating suggestions from code review and replacing tabs with spaces to fix indentation --- .../SqlGen/SqlGenerator.cs | 108 +++++++++--------- .../SqlGen/SqlGenerator.cs | 82 ++++++------- .../ExpressionBuilder/DbExpressionBuilder.cs | 23 ++-- .../Query/PlanCompiler/AggregatePushdown.cs | 21 ++-- .../Core/Query/PlanCompiler/CTreeGenerator.cs | 18 +-- .../GroupAggregateRefComputingVisitor.cs | 72 ++++++------ 6 files changed, 160 insertions(+), 164 deletions(-) diff --git a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs index 12b181e235..30b1dedaa1 100644 --- a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs @@ -1532,52 +1532,52 @@ public override ISqlFragment Visit(DbGroupByExpression e) foreach (var aggregate in e.Aggregates) { - var member = members.Current; - var alias = QuoteIdentifier(member.Name); - - var finalArgs = new List(); - - foreach (var argument in aggregate.Arguments) - { - var translatedAggregateArgument = argument.Accept(this); - - object aggregateArgument; - - if (needsInnerQuery) - { - //In this case the argument to the aggratete is reference to the one projected out by the - // inner query - var wrappingAggregateArgument = new SqlBuilder(); - wrappingAggregateArgument.Append(fromSymbol); - wrappingAggregateArgument.Append("."); - wrappingAggregateArgument.Append(alias); - aggregateArgument = wrappingAggregateArgument; - - innerQuery.Select.Append(separator); - innerQuery.Select.AppendLine(); - innerQuery.Select.Append(translatedAggregateArgument); - innerQuery.Select.Append(" AS "); - innerQuery.Select.Append(alias); - } - else - { - aggregateArgument = translatedAggregateArgument; - } - - finalArgs.Add(aggregateArgument); - } - - ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs); - - result.Select.Append(separator); - result.Select.AppendLine(); - result.Select.Append(aggregateResult); - result.Select.Append(" AS "); - result.Select.Append(alias); - - separator = ", "; - members.MoveNext(); - } + var member = members.Current; + var alias = QuoteIdentifier(member.Name); + + var finalArgs = new List(); + + foreach (var argument in aggregate.Arguments) + { + var translatedAggregateArgument = argument.Accept(this); + + object aggregateArgument; + + if (needsInnerQuery) + { + //In this case the argument to the aggratete is reference to the one projected out by the + // inner query + var wrappingAggregateArgument = new SqlBuilder(); + wrappingAggregateArgument.Append(fromSymbol); + wrappingAggregateArgument.Append("."); + wrappingAggregateArgument.Append(alias); + aggregateArgument = wrappingAggregateArgument; + + innerQuery.Select.Append(separator); + innerQuery.Select.AppendLine(); + innerQuery.Select.Append(translatedAggregateArgument); + innerQuery.Select.Append(" AS "); + innerQuery.Select.Append(alias); + } + else + { + aggregateArgument = translatedAggregateArgument; + } + + finalArgs.Add(aggregateArgument); + } + + ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs); + + result.Select.Append(separator); + result.Select.AppendLine(); + result.Select.Append(aggregateResult); + result.Select.Append(" AS "); + result.Select.Append(alias); + + separator = ", "; + members.MoveNext(); + } } symbolTable.ExitScope(); @@ -2794,15 +2794,15 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList ag aggregateResult.Append("DISTINCT "); } - string separator = String.Empty; - foreach (var arg in aggregateArguments) - { - aggregateResult.Append(separator); - aggregateResult.Append(arg); - separator = ", "; - } + string separator = String.Empty; + foreach (var arg in aggregateArguments) + { + aggregateResult.Append(separator); + aggregateResult.Append(arg); + separator = ", "; + } - aggregateResult.Append(")"); + aggregateResult.Append(")"); return aggregateResult; } diff --git a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs index c94b93c865..a8d0d62a55 100644 --- a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs @@ -1310,39 +1310,39 @@ public override ISqlFragment Visit(DbGroupByExpression e) var member = members.Current; var alias = QuoteIdentifier(member.Name); - var finalArgs = new List(); - - foreach (var argument in aggregate.Arguments) - { - var translatedAggregateArgument = argument.Accept(this); - - object aggregateArgument; - - if (needsInnerQuery) - { - //In this case the argument to the aggratete is reference to the one projected out by the - // inner query - var wrappingAggregateArgument = new SqlBuilder(); - wrappingAggregateArgument.Append(fromSymbol); - wrappingAggregateArgument.Append("."); - wrappingAggregateArgument.Append(alias); - aggregateArgument = wrappingAggregateArgument; - - innerQuery.Select.Append(separator); - innerQuery.Select.AppendLine(); - innerQuery.Select.Append(translatedAggregateArgument); - innerQuery.Select.Append(" AS "); - innerQuery.Select.Append(alias); - } - else - { - aggregateArgument = translatedAggregateArgument; - } - - finalArgs.Add(aggregateArgument); - } - - ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs); + var finalArgs = new List(); + + foreach (var argument in aggregate.Arguments) + { + var translatedAggregateArgument = argument.Accept(this); + + object aggregateArgument; + + if (needsInnerQuery) + { + //In this case the argument to the aggratete is reference to the one projected out by the + // inner query + var wrappingAggregateArgument = new SqlBuilder(); + wrappingAggregateArgument.Append(fromSymbol); + wrappingAggregateArgument.Append("."); + wrappingAggregateArgument.Append(alias); + aggregateArgument = wrappingAggregateArgument; + + innerQuery.Select.Append(separator); + innerQuery.Select.AppendLine(); + innerQuery.Select.Append(translatedAggregateArgument); + innerQuery.Select.Append(" AS "); + innerQuery.Select.Append(alias); + } + else + { + aggregateArgument = translatedAggregateArgument; + } + + finalArgs.Add(aggregateArgument); + } + + ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs); result.Select.Append(separator); result.Select.AppendLine(); @@ -2140,15 +2140,15 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList ag throw ADP1.NotSupported(EntityRes.GetString(EntityRes.DistinctAggregatesNotSupported)); } - string separator = String.Empty; - foreach (var arg in aggregateArguments) - { - aggregateResult.Append(separator); - aggregateResult.Append(arg); - separator = ", "; - } + string separator = String.Empty; + foreach (var arg in aggregateArguments) + { + aggregateResult.Append(separator); + aggregateResult.Append(arg); + separator = ", "; + } - aggregateResult.Append(")"); + aggregateResult.Append(")"); if (fCast) { diff --git a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs index 7081cef379..47c3b5b15a 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs @@ -208,15 +208,15 @@ public static DbGroupExpressionBinding GroupBindAs(this DbExpression input, stri /// The argument over which the aggregate function should be calculated. /// function or argument null. /// function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function. - public static DbFunctionAggregate Aggregate(this EdmFunction function, IList arguments) + public static DbFunctionAggregate Aggregate(this EdmFunction function, IEnumerable arguments) { Check.NotNull(function, "function"); - if (arguments.Count == 0) - { - throw new ArgumentNullException("arguments"); - } + if (arguments?.Any() == false) + { + throw new ArgumentNullException("arguments"); + } - return CreateFunctionAggregate(function, arguments, false); + return CreateFunctionAggregate(function, arguments, false); } /// @@ -227,17 +227,18 @@ public static DbFunctionAggregate Aggregate(this EdmFunction function, IListThe arguments over which the aggregate function should be calculated. /// function or argument is null. /// function is not an aggregate function, or the result type of argument is not equal or promotable to the parameter type of function. - public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, IList arguments) + public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, IEnumerable arguments) { Check.NotNull(function, "function"); - if (arguments.Count == 0) { - throw new ArgumentNullException("arguments"); - } + if (arguments?.Any() == false) + { + throw new ArgumentNullException("arguments"); + } return CreateFunctionAggregate(function, arguments, true); } - private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, IList arguments, bool isDistinct) + private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, IEnumerable arguments, bool isDistinct) { var funcArgs = ArgumentValidation.ValidateFunctionAggregate(function, arguments); var resultType = function.ReturnParameter.TypeUsage; diff --git a/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs b/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs index ebced17245..fdf89eb627 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs @@ -104,20 +104,21 @@ private void TryProcessCandidate( dictionary.Add(groupAggregateVarInfo.GroupAggregateVar, inputVar); var remapper = new VarRemapper(m_command, dictionary); - var argNodes = new List(candidate.Value.Count); + var argNodes = new List(candidate.Value.Count); - foreach (var argumentNode in candidate.Value) { - var argumentNodeCopy = OpCopier.Copy(m_command, argumentNode); - remapper.RemapSubtree(argumentNodeCopy); + foreach (var argumentNode in candidate.Value) + { + var argumentNodeCopy = OpCopier.Copy(m_command, argumentNode); + remapper.RemapSubtree(argumentNodeCopy); - argNodes.Add(argumentNodeCopy); - } + argNodes.Add(argumentNodeCopy); + } - var newFunctionDefiningNode = m_command.CreateNode( - m_command.CreateAggregateOp(functionOp.Function, false), - argNodes); + var newFunctionDefiningNode = m_command.CreateNode( + m_command.CreateAggregateOp(functionOp.Function, false), + argNodes); - Var newFunctionVar; + Var newFunctionVar; var varDefNode = m_command.CreateVarDefNode(newFunctionDefiningNode, out newFunctionVar); // Add the new aggregate to the list of aggregates diff --git a/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs b/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs index 979361a6b5..7b4dda39dd 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs @@ -12,6 +12,7 @@ namespace System.Data.Entity.Core.Query.PlanCompiler using System.Data.Entity.Resources; using System.Diagnostics.CodeAnalysis; using System.Globalization; + using Linq; using SortKey = System.Data.Entity.Core.Query.InternalTrees.SortKey; [SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling")] @@ -1886,24 +1887,17 @@ public override DbExpression Visit(GroupByOp op, Node n) PlanCompiler.Assert(aggRootNode.Op is VarDefListOp, "Invalid Aggregates VarDefListOp Node encountered in GroupByOp"); foreach (var aggVarDefNode in aggRootNode.Children) { - var aggVarDef = aggVarDefNode.Op as VarDefOp; + var aggVarDef = aggVarDefNode.Op as VarDefOp; PlanCompiler.Assert(aggVarDef != null, "Non-VarDefOp Node encountered as child of Aggregates VarDefListOp Node"); var aggVar = aggVarDef.Var; - PlanCompiler.Assert(aggVar is ComputedVar, "Non-ComputedVar encountered in Aggregate VarDefOp"); + PlanCompiler.Assert(aggVar is ComputedVar, "Non-ComputedVar encountered in Aggregate VarDefOp"); var aggOpNode = aggVarDefNode.Child0; + var args = aggOpNode.Children.Select(argumentNode => VisitNode(argumentNode)); - // Loop through arguments - var args = new List(); - - foreach (var argumentNode in aggOpNode.Children) { - var aggDef = VisitNode(argumentNode); - args.Add(aggDef); - } - - var funcAggOp = aggOpNode.Op as AggregateOp; - PlanCompiler.Assert(funcAggOp != null, "Non-Aggregate Node encountered as child of Aggregate VarDefOp Node"); + var funcAggOp = aggOpNode.Op as AggregateOp; + PlanCompiler.Assert(funcAggOp != null, "Non-Aggregate Node encountered as child of Aggregate VarDefOp Node"); DbFunctionAggregate newFuncAgg; if (funcAggOp.IsDistinctAggregate) { diff --git a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs index ab49252c65..5725c4b1c4 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs @@ -149,42 +149,42 @@ public override void Visit(UnnestOp op, Node n) MessageId = "System.Data.Entity.Core.Query.PlanCompiler.PlanCompiler.Assert(System.Boolean,System.String)")] public override void Visit(FunctionOp op, Node n) { - VisitDefault(n); - if (!PlanCompilerUtil.IsCollectionAggregateFunction(op, n)) - { - return; - } - GroupAggregateVarInfo referencedGroupAggregateVarInfo; - GroupAggregateVarInfo referencedGroupAggregateVarInfoTracker = null; - - Node templateNode; - bool isUnnested; - - var list = new List(); - - foreach (var argument in n.Children) - { - if (GroupAggregateVarComputationTranslator.TryTranslateOverGroupAggregateVar( - argument, false, _command, _groupAggregateVarInfoManager, out referencedGroupAggregateVarInfo, out templateNode, - out isUnnested) - && - (isUnnested || AggregatePushdownUtil.IsVarRefOverGivenVar(templateNode, referencedGroupAggregateVarInfo.GroupAggregateVar))) - { - - referencedGroupAggregateVarInfoTracker = referencedGroupAggregateVarInfo; - list.Add(templateNode); - } - else - { - list.Add(argument); - } - } - - if (referencedGroupAggregateVarInfoTracker != null) - { - referencedGroupAggregateVarInfoTracker.CandidateAggregateNodes.Add(new KeyValuePair>(n, list)); - } - } + VisitDefault(n); + if (!PlanCompilerUtil.IsCollectionAggregateFunction(op, n)) + { + return; + } + GroupAggregateVarInfo referencedGroupAggregateVarInfo; + GroupAggregateVarInfo referencedGroupAggregateVarInfoTracker = null; + + Node templateNode; + bool isUnnested; + + var list = new List(); + + foreach (var argument in n.Children) + { + if (GroupAggregateVarComputationTranslator.TryTranslateOverGroupAggregateVar( + argument, false, _command, _groupAggregateVarInfoManager, out referencedGroupAggregateVarInfo, out templateNode, + out isUnnested) + && + (isUnnested || AggregatePushdownUtil.IsVarRefOverGivenVar(templateNode, referencedGroupAggregateVarInfo.GroupAggregateVar))) + { + + referencedGroupAggregateVarInfoTracker = referencedGroupAggregateVarInfo; + list.Add(templateNode); + } + else + { + list.Add(argument); + } + } + + if (referencedGroupAggregateVarInfoTracker != null) + { + referencedGroupAggregateVarInfoTracker.CandidateAggregateNodes.Add(new KeyValuePair>(n, list)); + } + } #endregion From 4e03d6e833ea5b1e84d819b9cddd73fa75cd9f16 Mon Sep 17 00:00:00 2001 From: Tim Stowell Date: Thu, 27 Oct 2016 09:50:21 -0600 Subject: [PATCH 3/9] Fixed aliases for multiple aggregate arguments --- src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs | 9 ++++++--- .../SqlGen/SqlGenerator.cs | 10 ++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs index 30b1dedaa1..71b2c04a5c 100644 --- a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs @@ -1537,27 +1537,30 @@ public override ISqlFragment Visit(DbGroupByExpression e) var finalArgs = new List(); - foreach (var argument in aggregate.Arguments) + for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++) { + var argument = aggregate.Arguments[childIndex]; var translatedAggregateArgument = argument.Accept(this); object aggregateArgument; if (needsInnerQuery) { + var argAlias = QuoteIdentifier(member.Name + "_" + childIndex); + //In this case the argument to the aggratete is reference to the one projected out by the // inner query var wrappingAggregateArgument = new SqlBuilder(); wrappingAggregateArgument.Append(fromSymbol); wrappingAggregateArgument.Append("."); - wrappingAggregateArgument.Append(alias); + wrappingAggregateArgument.Append(argAlias); aggregateArgument = wrappingAggregateArgument; innerQuery.Select.Append(separator); innerQuery.Select.AppendLine(); innerQuery.Select.Append(translatedAggregateArgument); innerQuery.Select.Append(" AS "); - innerQuery.Select.Append(alias); + innerQuery.Select.Append(argAlias); } else { diff --git a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs index a8d0d62a55..054b8e94cb 100644 --- a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs @@ -1312,27 +1312,29 @@ public override ISqlFragment Visit(DbGroupByExpression e) var finalArgs = new List(); - foreach (var argument in aggregate.Arguments) - { + for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++) { + var argument = aggregate.Arguments[childIndex]; var translatedAggregateArgument = argument.Accept(this); object aggregateArgument; if (needsInnerQuery) { + var argAlias = QuoteIdentifier(member.Name + "_" + childIndex); + //In this case the argument to the aggratete is reference to the one projected out by the // inner query var wrappingAggregateArgument = new SqlBuilder(); wrappingAggregateArgument.Append(fromSymbol); wrappingAggregateArgument.Append("."); - wrappingAggregateArgument.Append(alias); + wrappingAggregateArgument.Append(argAlias); aggregateArgument = wrappingAggregateArgument; innerQuery.Select.Append(separator); innerQuery.Select.AppendLine(); innerQuery.Select.Append(translatedAggregateArgument); innerQuery.Select.Append(" AS "); - innerQuery.Select.Append(alias); + innerQuery.Select.Append(argAlias); } else { From 66a4d0ba03b660fed57d697428412ab708b55899 Mon Sep 17 00:00:00 2001 From: Tim Stowell Date: Fri, 1 Sep 2017 14:28:32 -0600 Subject: [PATCH 4/9] Code review changes: - Added overloads to not break public API - Updated summary text - Disable optimization for aggregate functions with multiple arguments - Simplified LINQ expression --- .../SqlGen/SqlGenerator.cs | 3 +- .../ExpressionBuilder/DbExpressionBuilder.cs | 36 ++++++++++++++++ .../Core/Query/PlanCompiler/CTreeGenerator.cs | 2 +- .../GroupAggregateRefComputingVisitor.cs | 41 +++++++------------ .../PlanCompiler/GroupAggregateVarInfo.cs | 4 +- 5 files changed, 55 insertions(+), 31 deletions(-) diff --git a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs index 054b8e94cb..36a5049e8e 100644 --- a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs @@ -1312,7 +1312,8 @@ public override ISqlFragment Visit(DbGroupByExpression e) var finalArgs = new List(); - for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++) { + for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++) + { var argument = aggregate.Arguments[childIndex]; var translatedAggregateArgument = argument.Accept(this); diff --git a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs index 47c3b5b15a..da6ad018d9 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs @@ -200,6 +200,42 @@ public static DbGroupExpressionBinding GroupBindAs(this DbExpression input, stri #region Aggregates and SortClauses are required only for Binding-based method support - replaced by OrderBy[Descending]/ThenBy[Descending] and Aggregate[Distinct] methods in new API + /// + /// Creates a new . + /// + /// A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value false. + /// The function that defines the aggregate operation. + /// The argument over which the aggregate function should be calculated. + /// function or argument null. + /// function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function. + public static DbFunctionAggregate Aggregate(this EdmFunction function, DbExpression argument) { + Check.NotNull(function, "function"); + Check.NotNull(argument, "argument"); + + return CreateFunctionAggregate(function, argument, false); + } + + /// + /// Creates a new that is applied in a distinct fashion. + /// + /// A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value true. + /// The function that defines the aggregate operation. + /// The argument over which the aggregate function should be calculated. + /// function or argument is null. + /// function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function. + public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, DbExpression argument) { + Check.NotNull(function, "function"); + Check.NotNull(argument, "argument"); + + return CreateFunctionAggregate(function, argument, true); + } + + private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, DbExpression argument, bool isDistinct) { + var funcArgs = ArgumentValidation.ValidateFunctionAggregate(function, new[] { argument }); + var resultType = function.ReturnParameter.TypeUsage; + return new DbFunctionAggregate(resultType, funcArgs, function, isDistinct); + } + /// /// Creates a new . /// diff --git a/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs b/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs index 7b4dda39dd..df2f23919c 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs @@ -1894,7 +1894,7 @@ public override DbExpression Visit(GroupByOp op, Node n) PlanCompiler.Assert(aggVar is ComputedVar, "Non-ComputedVar encountered in Aggregate VarDefOp"); var aggOpNode = aggVarDefNode.Child0; - var args = aggOpNode.Children.Select(argumentNode => VisitNode(argumentNode)); + var args = aggOpNode.Children.Select(VisitNode); var funcAggOp = aggOpNode.Op as AggregateOp; PlanCompiler.Assert(funcAggOp != null, "Non-Aggregate Node encountered as child of Aggregate VarDefOp Node"); diff --git a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs index 5725c4b1c4..302f84eade 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs @@ -150,40 +150,27 @@ public override void Visit(UnnestOp op, Node n) public override void Visit(FunctionOp op, Node n) { VisitDefault(n); - if (!PlanCompilerUtil.IsCollectionAggregateFunction(op, n)) - { + if (!PlanCompilerUtil.IsCollectionAggregateFunction(op, n)) { return; } - GroupAggregateVarInfo referencedGroupAggregateVarInfo; - GroupAggregateVarInfo referencedGroupAggregateVarInfoTracker = null; - - Node templateNode; - bool isUnnested; - var list = new List(); + if (n.Children.Count > 1) { + return; + } - foreach (var argument in n.Children) - { - if (GroupAggregateVarComputationTranslator.TryTranslateOverGroupAggregateVar( - argument, false, _command, _groupAggregateVarInfoManager, out referencedGroupAggregateVarInfo, out templateNode, - out isUnnested) - && - (isUnnested || AggregatePushdownUtil.IsVarRefOverGivenVar(templateNode, referencedGroupAggregateVarInfo.GroupAggregateVar))) - { + var argumentNode = n.Child0; - referencedGroupAggregateVarInfoTracker = referencedGroupAggregateVarInfo; - list.Add(templateNode); - } - else - { - list.Add(argument); - } + GroupAggregateVarInfo referencedGroupAggregateVarInfo; + Node templateNode; + bool isUnnested; + if (GroupAggregateVarComputationTranslator.TryTranslateOverGroupAggregateVar( + n.Child0, false, _command, _groupAggregateVarInfoManager, out referencedGroupAggregateVarInfo, out templateNode, + out isUnnested) + && + (isUnnested || AggregatePushdownUtil.IsVarRefOverGivenVar(templateNode, referencedGroupAggregateVarInfo.GroupAggregateVar))) { + referencedGroupAggregateVarInfo.CandidateAggregateNodes.Add(new KeyValuePair>(n, new List { templateNode })); } - if (referencedGroupAggregateVarInfoTracker != null) - { - referencedGroupAggregateVarInfoTracker.CandidateAggregateNodes.Add(new KeyValuePair>(n, list)); - } } #endregion diff --git a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs index d5428e9530..0fe7b84f68 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs @@ -39,8 +39,8 @@ internal GroupAggregateVarInfo(Node defingingGroupNode, Var groupAggregateVar) // // Each key value pair represents a candidate aggregate. // The key is the function aggregate subtree and the value is a 'template' of translation of the - // function aggregate's argument over the var representing the group aggregate. - // A valid candidate has an argument that does not have any external references + // function aggregate's arguments over the var representing the group aggregate. + // A valid candidate has arguments that do not have any external references // except for the group aggregate corresponding to the DefiningGroupNode. // internal HashSet>> CandidateAggregateNodes From d4b2c665b569c1ae22d82e9dbfa099315df049e6 Mon Sep 17 00:00:00 2001 From: Tim Stowell Date: Fri, 12 Jan 2018 10:53:33 -0700 Subject: [PATCH 5/9] Added unit test for aggregate function with more than one argument, and removed more code that was enforcing only one argument --- .../SqlGen/SqlGenerator.cs | 1 - .../SqlGen/SqlGenerator.cs | 1 - .../Core/Common/CommandTrees/DbAggregate.cs | 1 - .../CommandTrees/DefaultExpressionVisitor.cs | 1 - .../CommandTrees/Internal/ExpressionKeyGen.cs | 1 - .../Core/SchemaObjectModel/Function.cs | 15 ++---------- src/EntityFramework/Properties/Resources.cs | 9 ------- src/EntityFramework/Properties/Resources.resx | 3 --- .../Query/GroupAggregateTests.cs | 24 ++++++++++++++++++- .../CommandTrees/DbExpressionBuilderTests.cs | 4 ++-- 10 files changed, 27 insertions(+), 33 deletions(-) diff --git a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs index 71b2c04a5c..81f507c717 100644 --- a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs @@ -4349,7 +4349,6 @@ private static bool GroupByAggregatesNeedInnerQuery(IList aggregate { foreach (var aggregate in aggregates) { - Debug.Assert(aggregate.Arguments.Count == 1); if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName)) { return true; diff --git a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs index 36a5049e8e..63b02399ba 100644 --- a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs @@ -4523,7 +4523,6 @@ private static bool GroupByAggregatesNeedInnerQuery(IList aggregate { foreach (var aggregate in aggregates) { - Debug.Assert(aggregate.Arguments.Count == 1); if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName)) { return true; diff --git a/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs b/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs index d3f61fe936..1bfc520270 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs @@ -18,7 +18,6 @@ internal DbAggregate(TypeUsage resultType, DbExpressionList arguments) { DebugCheck.NotNull(resultType); DebugCheck.NotNull(arguments); - Debug.Assert(arguments.Count == 1, "DbAggregate requires a single argument"); _type = resultType; _args = arguments; diff --git a/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs b/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs index 94b084bf69..702f36a95a 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs @@ -210,7 +210,6 @@ protected virtual DbGroupAggregate VisitGroupAggregate(DbGroupAggregate aggregat if (aggregate != null) { var newArguments = VisitExpressionList(aggregate.Arguments); - Debug.Assert(newArguments.Count == 1, "Group aggregate had more than one argument?"); if (!ReferenceEquals(aggregate.Arguments, newArguments)) { diff --git a/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs b/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs index ade88d7ee2..792512235a 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs @@ -766,7 +766,6 @@ public override void Visit(DbGroupByExpression e) if (ga != null) { _key.Append("GA("); - Debug.Assert(ga.Arguments.Count == 1, "Group aggregate must have one argument."); ga.Arguments[0].Accept(this); _key.Append(')'); } diff --git a/src/EntityFramework/Core/SchemaObjectModel/Function.cs b/src/EntityFramework/Core/SchemaObjectModel/Function.cs index 7d2b560d7e..e72b2c9e72 100644 --- a/src/EntityFramework/Core/SchemaObjectModel/Function.cs +++ b/src/EntityFramework/Core/SchemaObjectModel/Function.cs @@ -416,21 +416,10 @@ internal override void Validate() { if (IsAggregate) { - // Make sure that the function has exactly one parameter and that takes - // a collection type - if (Parameters.Count != 1) - { - AddError( - ErrorCode.InvalidNumberOfParametersForAggregateFunction, - EdmSchemaErrorSeverity.Error, - this, - Strings.InvalidNumberOfParametersForAggregateFunction(FQName)); - } - else if (Parameters.GetElementAt(0).CollectionKind + // Make sure that the takes a collection type + if (Parameters.GetElementAt(0).CollectionKind == CollectionKind.None) { - // Since we have already checked that there should be exactly one parameter, it should be safe to get the - // first parameter for the function var param = Parameters.GetElementAt(0); AddError( diff --git a/src/EntityFramework/Properties/Resources.cs b/src/EntityFramework/Properties/Resources.cs index 4a0d663c28..3456eebbd2 100644 --- a/src/EntityFramework/Properties/Resources.cs +++ b/src/EntityFramework/Properties/Resources.cs @@ -4326,14 +4326,6 @@ internal static string Metadata_General_Error get { return EntityRes.GetString(EntityRes.Metadata_General_Error); } } - // - // A string like "Error in Function '{0}'. Aggregate Functions should take exactly one input parameter." - // - internal static string InvalidNumberOfParametersForAggregateFunction(object p0) - { - return EntityRes.GetString(EntityRes.InvalidNumberOfParametersForAggregateFunction, p0); - } - // // A string like "Type of parameter '{0}' in function '{1}' is not valid. The aggregate function parameter type must be of CollectionType." // @@ -16273,7 +16265,6 @@ internal sealed class EntityRes internal const string Validator_UnsupportedEnumUnderlyingType = "Validator_UnsupportedEnumUnderlyingType"; internal const string ExtraInfo = "ExtraInfo"; internal const string Metadata_General_Error = "Metadata_General_Error"; - internal const string InvalidNumberOfParametersForAggregateFunction = "InvalidNumberOfParametersForAggregateFunction"; internal const string InvalidParameterTypeForAggregateFunction = "InvalidParameterTypeForAggregateFunction"; internal const string InvalidSchemaEncountered = "InvalidSchemaEncountered"; internal const string SystemNamespaceEncountered = "SystemNamespaceEncountered"; diff --git a/src/EntityFramework/Properties/Resources.resx b/src/EntityFramework/Properties/Resources.resx index 95f25b07de..423ae139f0 100644 --- a/src/EntityFramework/Properties/Resources.resx +++ b/src/EntityFramework/Properties/Resources.resx @@ -1884,9 +1884,6 @@ Inconsistent metadata error - - Error in Function '{0}'. Aggregate Functions should take exactly one input parameter. - Type of parameter '{0}' in function '{1}' is not valid. The aggregate function parameter type must be of CollectionType. diff --git a/test/EntityFramework/FunctionalTests/Query/GroupAggregateTests.cs b/test/EntityFramework/FunctionalTests/Query/GroupAggregateTests.cs index 0a304600af..991bba069e 100644 --- a/test/EntityFramework/FunctionalTests/Query/GroupAggregateTests.cs +++ b/test/EntityFramework/FunctionalTests/Query/GroupAggregateTests.cs @@ -1294,7 +1294,29 @@ GROUP BY [Extent1].[ProductName] QueryTestHelpers.VerifyQuery(query, workspace, expectedSql); } - public class CodePlex2160 : FunctionalTestBase + [Fact] + public void Using_ssdl_defined_aggregate_function_multiple_parameters() { + var query = + @"select gkey, [ĎefauľtNamėspacĕ].Store.F_UniqueCount(C.Address.Region, C.Address.City) +FROM ProductContainer.Customers as C +Group By C.Address.Region as gkey"; + + var expectedSql = + @"SELECT +1 AS [C1], +[GroupBy1].[K1] AS [Region], +[GroupBy1].[A1] AS [C2] +FROM ( SELECT + [Extent1].[Region] AS [K1], + [dbo].[F_UniqueCount]([Extent1].[Region], [Extent1].[City]) AS [A1] + FROM [dbo].[Customers] AS [Extent1] + GROUP BY [Extent1].[Region] +) AS [GroupBy1]"; + + QueryTestHelpers.VerifyQuery(query, workspace, expectedSql); + } + + public class CodePlex2160 : FunctionalTestBase { public class Foo { diff --git a/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs b/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs index 2c7b47240c..1fb451e7a5 100644 --- a/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs +++ b/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs @@ -2182,14 +2182,14 @@ public void SetClause_can_create_for_property() public void Null_check_Aggregate() { Assert.Throws(() => DbExpressionBuilder.Aggregate(null, DbExpressionBuilder.True)); - Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), null)); + Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), (DbExpression)null)); } [Fact] public void Null_check_AggregateDistinct() { Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(null, DbExpressionBuilder.True)); - Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), null)); + Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), (DbExpression)null)); } [Fact] From ea3a5190a294a9cbaccc75804802acaab5741479 Mon Sep 17 00:00:00 2001 From: Tim Stowell Date: Fri, 12 Jan 2018 12:00:23 -0700 Subject: [PATCH 6/9] Added more null-check unit tests for aggregate function with more than one argument --- .../ExpressionBuilder/DbExpressionBuilder.cs | 8 ++++++-- .../Common/CommandTrees/DbExpressionBuilderTests.cs | 13 +++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs index da6ad018d9..79e992af79 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs @@ -247,7 +247,9 @@ private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, public static DbFunctionAggregate Aggregate(this EdmFunction function, IEnumerable arguments) { Check.NotNull(function, "function"); - if (arguments?.Any() == false) + Check.NotNull(arguments, "argument"); + + if (arguments.Any() == false) { throw new ArgumentNullException("arguments"); } @@ -266,7 +268,9 @@ public static DbFunctionAggregate Aggregate(this EdmFunction function, IEnumerab public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, IEnumerable arguments) { Check.NotNull(function, "function"); - if (arguments?.Any() == false) + Check.NotNull(arguments, "argument"); + + if (arguments.Any() == false) { throw new ArgumentNullException("arguments"); } diff --git a/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs b/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs index 1fb451e7a5..5a8860bb6e 100644 --- a/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs +++ b/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs @@ -2178,6 +2178,19 @@ public void SetClause_can_create_for_property() #region Null checks + [Fact] + public void Null_check_Aggregate_MultipleArguments() { + Assert.Throws(() => DbExpressionBuilder.Aggregate(null, new List { DbExpressionBuilder.True })); + Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), (IEnumerable)null)); + Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), new List())); + } + + [Fact] + public void Null_check_AggregateDistinct_MultipleArguments() { + Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(null, new List { DbExpressionBuilder.True })); + Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), (IEnumerable)null)); + Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), new List())); + } [Fact] public void Null_check_Aggregate() { From 7b29cf7be4ee91d060772a54aa1499443ef266f7 Mon Sep 17 00:00:00 2001 From: Tim Stowell Date: Tue, 11 Dec 2018 12:56:35 -0700 Subject: [PATCH 7/9] - Added back in some checks on aggregate parameter count I had previously removed (now checks for greater then or equal to 1, instead of just 1) - More formatting cleanup - Added missing aggregate function in Ssdl --- .../SqlGen/SqlGenerator.cs | 1 + .../SqlGen/SqlGenerator.cs | 1 + .../Core/Common/CommandTrees/DbAggregate.cs | 1 + .../ExpressionBuilder/DbExpressionBuilder.cs | 9 ++++++--- .../CommandTrees/Internal/ExpressionKeyGen.cs | 1 + .../Core/Common/EntitySql/SemanticAnalyzer.cs | 4 ++++ .../GroupAggregateRefComputingVisitor.cs | 10 ++++++---- .../Core/Query/PlanCompiler/PlanCompilerUtil.cs | 4 +++- .../Core/SchemaObjectModel/Function.cs | 15 +++++++++++++-- src/EntityFramework/Properties/Resources.cs | 9 +++++++++ src/EntityFramework/Properties/Resources.resx | 3 +++ .../LinqToEntities/GroupByOptimizationTests.cs | 12 ++++++------ .../Query/ModelDefinedFunctionTests.cs | 4 ++-- .../FunctionalTests/Query/ProductModel.cs | 4 ++++ .../CommandTrees/DbExpressionBuilderTests.cs | 6 ++++-- 15 files changed, 64 insertions(+), 20 deletions(-) diff --git a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs index 81f507c717..34775622cd 100644 --- a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs @@ -4349,6 +4349,7 @@ private static bool GroupByAggregatesNeedInnerQuery(IList aggregate { foreach (var aggregate in aggregates) { + Debug.Assert(aggregate.Arguments.Count >= 1); if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName)) { return true; diff --git a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs index 63b02399ba..c2b9b8cb57 100644 --- a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs @@ -4523,6 +4523,7 @@ private static bool GroupByAggregatesNeedInnerQuery(IList aggregate { foreach (var aggregate in aggregates) { + Debug.Assert(aggregate.Arguments.Count >= 1); if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName)) { return true; diff --git a/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs b/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs index 1bfc520270..a56dcc6d35 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs @@ -18,6 +18,7 @@ internal DbAggregate(TypeUsage resultType, DbExpressionList arguments) { DebugCheck.NotNull(resultType); DebugCheck.NotNull(arguments); + Debug.Assert(arguments.Count >= 1, "DbAggregate requires at least one argument"); _type = resultType; _args = arguments; diff --git a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs index 79e992af79..ccb4c619f4 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs @@ -208,7 +208,8 @@ public static DbGroupExpressionBinding GroupBindAs(this DbExpression input, stri /// The argument over which the aggregate function should be calculated. /// function or argument null. /// function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function. - public static DbFunctionAggregate Aggregate(this EdmFunction function, DbExpression argument) { + public static DbFunctionAggregate Aggregate(this EdmFunction function, DbExpression argument) + { Check.NotNull(function, "function"); Check.NotNull(argument, "argument"); @@ -223,14 +224,16 @@ public static DbFunctionAggregate Aggregate(this EdmFunction function, DbExpress /// The argument over which the aggregate function should be calculated. /// function or argument is null. /// function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function. - public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, DbExpression argument) { + public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, DbExpression argument) + { Check.NotNull(function, "function"); Check.NotNull(argument, "argument"); return CreateFunctionAggregate(function, argument, true); } - private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, DbExpression argument, bool isDistinct) { + private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, DbExpression argument, bool isDistinct) + { var funcArgs = ArgumentValidation.ValidateFunctionAggregate(function, new[] { argument }); var resultType = function.ReturnParameter.TypeUsage; return new DbFunctionAggregate(resultType, funcArgs, function, isDistinct); diff --git a/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs b/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs index 792512235a..0551284521 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs @@ -766,6 +766,7 @@ public override void Visit(DbGroupByExpression e) if (ga != null) { _key.Append("GA("); + Debug.Assert(ga.Arguments.Count >= 1, "Group aggregate must have at least one argument."); ga.Arguments[0].Accept(this); _key.Append(')'); } diff --git a/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs b/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs index 91bd1ca5fd..532fa33219 100644 --- a/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs +++ b/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs @@ -1373,6 +1373,10 @@ private static bool TryConvertAsFunctionAggregate( "argument types resolved for the collection aggregate calls must match"); } + // + // Aggregate functions must have at least one argument, and the first argument must be of collection edmType + // + Debug.Assert((1 <= functionType.Parameters.Count), "(1 <= functionType.Parameters.Count)"); // we only support monadic aggregate functions Debug.Assert( TypeSemantics.IsCollectionType(functionType.Parameters[0].TypeUsage), "functionType.Parameters[0].Type is CollectionType"); diff --git a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs index 302f84eade..d7f0368adc 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs @@ -150,11 +150,13 @@ public override void Visit(UnnestOp op, Node n) public override void Visit(FunctionOp op, Node n) { VisitDefault(n); - if (!PlanCompilerUtil.IsCollectionAggregateFunction(op, n)) { + if (!PlanCompilerUtil.IsCollectionAggregateFunction(op, n)) + { return; } - if (n.Children.Count > 1) { + if (n.Children.Count > 1) + { return; } @@ -167,10 +169,10 @@ public override void Visit(FunctionOp op, Node n) n.Child0, false, _command, _groupAggregateVarInfoManager, out referencedGroupAggregateVarInfo, out templateNode, out isUnnested) && - (isUnnested || AggregatePushdownUtil.IsVarRefOverGivenVar(templateNode, referencedGroupAggregateVarInfo.GroupAggregateVar))) { + (isUnnested || AggregatePushdownUtil.IsVarRefOverGivenVar(templateNode, referencedGroupAggregateVarInfo.GroupAggregateVar))) + { referencedGroupAggregateVarInfo.CandidateAggregateNodes.Add(new KeyValuePair>(n, new List { templateNode })); } - } #endregion diff --git a/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs b/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs index 72eea48537..34dd53eca3 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs @@ -60,6 +60,7 @@ internal static bool IsRowTypeCaseOpWithNullability(CaseOp op, Node n, out bool // // Is this function a collection aggregate function. It is, if + // - it has at least one child // - the first child is a collection type // - and the function has been marked with the aggregate attribute // @@ -68,7 +69,8 @@ internal static bool IsRowTypeCaseOpWithNullability(CaseOp op, Node n, out bool // true, if this was a collection aggregate function internal static bool IsCollectionAggregateFunction(FunctionOp op, Node n) { - return (TypeSemantics.IsCollectionType(n.Child0.Op.Type) && + return ((n.Children.Count >= 1) && + TypeSemantics.IsCollectionType(n.Child0.Op.Type) && TypeSemantics.IsAggregateFunction(op.Function)); } diff --git a/src/EntityFramework/Core/SchemaObjectModel/Function.cs b/src/EntityFramework/Core/SchemaObjectModel/Function.cs index e72b2c9e72..ee0dadd3e1 100644 --- a/src/EntityFramework/Core/SchemaObjectModel/Function.cs +++ b/src/EntityFramework/Core/SchemaObjectModel/Function.cs @@ -416,10 +416,21 @@ internal override void Validate() { if (IsAggregate) { - // Make sure that the takes a collection type - if (Parameters.GetElementAt(0).CollectionKind + // Make sure that the function has at least one parameter and the first parameter takes + // a collection type + if (Parameters.Count == 0) + { + AddError( + ErrorCode.InvalidNumberOfParametersForAggregateFunction, + EdmSchemaErrorSeverity.Error, + this, + Strings.InvalidNumberOfParametersForAggregateFunction(FQName)); + } + else if (Parameters.GetElementAt(0).CollectionKind == CollectionKind.None) { + // Since we have already checked that there should be at least one parameter, it should be safe to get the + // first parameter for the function var param = Parameters.GetElementAt(0); AddError( diff --git a/src/EntityFramework/Properties/Resources.cs b/src/EntityFramework/Properties/Resources.cs index 3456eebbd2..4a0d663c28 100644 --- a/src/EntityFramework/Properties/Resources.cs +++ b/src/EntityFramework/Properties/Resources.cs @@ -4326,6 +4326,14 @@ internal static string Metadata_General_Error get { return EntityRes.GetString(EntityRes.Metadata_General_Error); } } + // + // A string like "Error in Function '{0}'. Aggregate Functions should take exactly one input parameter." + // + internal static string InvalidNumberOfParametersForAggregateFunction(object p0) + { + return EntityRes.GetString(EntityRes.InvalidNumberOfParametersForAggregateFunction, p0); + } + // // A string like "Type of parameter '{0}' in function '{1}' is not valid. The aggregate function parameter type must be of CollectionType." // @@ -16265,6 +16273,7 @@ internal sealed class EntityRes internal const string Validator_UnsupportedEnumUnderlyingType = "Validator_UnsupportedEnumUnderlyingType"; internal const string ExtraInfo = "ExtraInfo"; internal const string Metadata_General_Error = "Metadata_General_Error"; + internal const string InvalidNumberOfParametersForAggregateFunction = "InvalidNumberOfParametersForAggregateFunction"; internal const string InvalidParameterTypeForAggregateFunction = "InvalidParameterTypeForAggregateFunction"; internal const string InvalidSchemaEncountered = "InvalidSchemaEncountered"; internal const string SystemNamespaceEncountered = "SystemNamespaceEncountered"; diff --git a/src/EntityFramework/Properties/Resources.resx b/src/EntityFramework/Properties/Resources.resx index 423ae139f0..ff53755a3f 100644 --- a/src/EntityFramework/Properties/Resources.resx +++ b/src/EntityFramework/Properties/Resources.resx @@ -1884,6 +1884,9 @@ Inconsistent metadata error + + Error in Function '{0}'. Aggregate Functions should take at least one input parameter. + Type of parameter '{0}' in function '{1}' is not valid. The aggregate function parameter type must be of CollectionType. diff --git a/test/EntityFramework/FunctionalTests/Query/LinqToEntities/GroupByOptimizationTests.cs b/test/EntityFramework/FunctionalTests/Query/LinqToEntities/GroupByOptimizationTests.cs index 70708a2a8c..efbe129d1a 100644 --- a/test/EntityFramework/FunctionalTests/Query/LinqToEntities/GroupByOptimizationTests.cs +++ b/test/EntityFramework/FunctionalTests/Query/LinqToEntities/GroupByOptimizationTests.cs @@ -136,12 +136,12 @@ public void GroupBy_is_optimized_when_projecting_anonymous_type_containing_group [GroupBy1].[A2] AS [C3] FROM ( SELECT [Extent1].[K1] AS [K1], - MAX([Extent1].[A1]) AS [A1], - MIN([Extent1].[A2]) AS [A2] + MAX([Extent1].[A1_0]) AS [A1], + MIN([Extent1].[A2_0]) AS [A2] FROM ( SELECT [Extent1].[FirstName] AS [K1], - [Extent1].[Id] AS [A1], - [Extent1].[Id] + 2 AS [A2] + [Extent1].[Id] AS [A1_0], + [Extent1].[Id] + 2 AS [A2_0] FROM [dbo].[ArubaOwners] AS [Extent1] ) AS [Extent1] GROUP BY [K1] @@ -221,10 +221,10 @@ public void GroupBy_is_optimized_when_projecting_function_aggregate_with_express [GroupBy1].[A1] AS [C1] FROM ( SELECT [Extent1].[K1] AS [K1], - MAX([Extent1].[A1]) AS [A1] + MAX([Extent1].[A1_0]) AS [A1] FROM ( SELECT [Extent1].[FirstName] AS [K1], - [Extent1].[Id] * 2 AS [A1] + [Extent1].[Id] * 2 AS [A1_0] FROM [dbo].[ArubaOwners] AS [Extent1] ) AS [Extent1] GROUP BY [K1] diff --git a/test/EntityFramework/FunctionalTests/Query/ModelDefinedFunctionTests.cs b/test/EntityFramework/FunctionalTests/Query/ModelDefinedFunctionTests.cs index 41c347c90d..9def1f1fff 100644 --- a/test/EntityFramework/FunctionalTests/Query/ModelDefinedFunctionTests.cs +++ b/test/EntityFramework/FunctionalTests/Query/ModelDefinedFunctionTests.cs @@ -107,9 +107,9 @@ public void Function_returning_collection_of_scalars() 1 AS [C1], [GroupBy1].[A1] AS [C2] FROM ( SELECT - MAX([Filter1].[A1]) AS [A1] + MAX([Filter1].[A1_0]) AS [A1] FROM ( SELECT - [Extent1].[ProductID] - 3 AS [A1] + [Extent1].[ProductID] - 3 AS [A1_0] FROM [dbo].[Products] AS [Extent1] WHERE [Extent1].[Discontinued] IN (0,1) ) AS [Filter1] diff --git a/test/EntityFramework/FunctionalTests/Query/ProductModel.cs b/test/EntityFramework/FunctionalTests/Query/ProductModel.cs index 425c1d227a..8fc1714386 100644 --- a/test/EntityFramework/FunctionalTests/Query/ProductModel.cs +++ b/test/EntityFramework/FunctionalTests/Query/ProductModel.cs @@ -45,6 +45,10 @@ public static class ProductModel + + + + "; public const string Msl = diff --git a/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs b/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs index 5a8860bb6e..87d8023a43 100644 --- a/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs +++ b/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs @@ -2179,14 +2179,16 @@ public void SetClause_can_create_for_property() #region Null checks [Fact] - public void Null_check_Aggregate_MultipleArguments() { + public void Null_check_Aggregate_MultipleArguments() + { Assert.Throws(() => DbExpressionBuilder.Aggregate(null, new List { DbExpressionBuilder.True })); Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), (IEnumerable)null)); Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), new List())); } [Fact] - public void Null_check_AggregateDistinct_MultipleArguments() { + public void Null_check_AggregateDistinct_MultipleArguments() + { Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(null, new List { DbExpressionBuilder.True })); Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), (IEnumerable)null)); Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), new List())); From 992c060a263c33e475ce5b65a37d27ac65fef2df Mon Sep 17 00:00:00 2001 From: Tim Stowell Date: Tue, 11 Dec 2018 16:20:01 -0700 Subject: [PATCH 8/9] Updating comment --- src/EntityFramework/Properties/Resources.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/EntityFramework/Properties/Resources.cs b/src/EntityFramework/Properties/Resources.cs index 4a0d663c28..efdcc5877b 100644 --- a/src/EntityFramework/Properties/Resources.cs +++ b/src/EntityFramework/Properties/Resources.cs @@ -4327,7 +4327,7 @@ internal static string Metadata_General_Error } // - // A string like "Error in Function '{0}'. Aggregate Functions should take exactly one input parameter." + // A string like "Error in Function '{0}'. Aggregate Functions should take at least one input parameter." // internal static string InvalidNumberOfParametersForAggregateFunction(object p0) { From d1eae8d69cf4def56a8164dee00cc2d1174a5d2b Mon Sep 17 00:00:00 2001 From: Tim Stowell Date: Thu, 27 Jun 2019 10:10:17 -0600 Subject: [PATCH 9/9] Fix parameter name in comments --- .../CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs index ccb4c619f4..68fd43b293 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs @@ -244,7 +244,7 @@ private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, /// /// A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value false. /// The function that defines the aggregate operation. - /// The argument over which the aggregate function should be calculated. + /// The argument over which the aggregate function should be calculated. /// function or argument null. /// function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function. public static DbFunctionAggregate Aggregate(this EdmFunction function, IEnumerable arguments) @@ -265,7 +265,7 @@ public static DbFunctionAggregate Aggregate(this EdmFunction function, IEnumerab /// /// A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value true. /// The function that defines the aggregate operation. - /// The arguments over which the aggregate function should be calculated. + /// The arguments over which the aggregate function should be calculated. /// function or argument is null. /// function is not an aggregate function, or the result type of argument is not equal or promotable to the parameter type of function. public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, IEnumerable arguments)