diff --git a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs index c8155135ee..34775622cd 100644 --- a/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs @@ -1535,33 +1535,42 @@ public override ISqlFragment Visit(DbGroupByExpression e) var member = members.Current; var alias = QuoteIdentifier(member.Name); - Debug.Assert(aggregate.Arguments.Count == 1); - var translatedAggregateArgument = aggregate.Arguments[0].Accept(this); + var finalArgs = new List(); - object aggregateArgument; - - if (needsInnerQuery) + for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++) { - //In this case the argument to the aggratete is reference to the one projected out by the - // inner query - var wrappingAggregateArgument = new SqlBuilder(); - wrappingAggregateArgument.Append(fromSymbol); - wrappingAggregateArgument.Append("."); - wrappingAggregateArgument.Append(alias); - aggregateArgument = wrappingAggregateArgument; + var argument = aggregate.Arguments[childIndex]; + var translatedAggregateArgument = argument.Accept(this); - innerQuery.Select.Append(separator); - innerQuery.Select.AppendLine(); - innerQuery.Select.Append(translatedAggregateArgument); - innerQuery.Select.Append(" AS "); - innerQuery.Select.Append(alias); - } - else - { - aggregateArgument = translatedAggregateArgument; + object aggregateArgument; + + if (needsInnerQuery) + { + var argAlias = QuoteIdentifier(member.Name + "_" + childIndex); + + //In this case the argument to the aggratete is reference to the one projected out by the + // inner query + var wrappingAggregateArgument = new SqlBuilder(); + wrappingAggregateArgument.Append(fromSymbol); + wrappingAggregateArgument.Append("."); + wrappingAggregateArgument.Append(argAlias); + aggregateArgument = wrappingAggregateArgument; + + innerQuery.Select.Append(separator); + innerQuery.Select.AppendLine(); + innerQuery.Select.Append(translatedAggregateArgument); + innerQuery.Select.Append(" AS "); + innerQuery.Select.Append(argAlias); + } + else + { + aggregateArgument = translatedAggregateArgument; + } + + finalArgs.Add(aggregateArgument); } - ISqlFragment aggregateResult = VisitAggregate(aggregate, aggregateArgument); + ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs); result.Select.Append(separator); result.Select.AppendLine(); @@ -2756,8 +2765,8 @@ public override ISqlFragment Visit(DbVariableReferenceExpression e) // Aggregates are not visited by the normal visitor walk. // // The aggregate go be translated - // The translated aggregate argument - private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregateArgument) + // The translated aggregate arguments + private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList aggregateArguments) { var aggregateResult = new SqlBuilder(); var functionAggregate = aggregate as DbFunctionAggregate; @@ -2788,7 +2797,13 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregate aggregateResult.Append("DISTINCT "); } - aggregateResult.Append(aggregateArgument); + string separator = String.Empty; + foreach (var arg in aggregateArguments) + { + aggregateResult.Append(separator); + aggregateResult.Append(arg); + separator = ", "; + } aggregateResult.Append(")"); return aggregateResult; @@ -4334,7 +4349,7 @@ private static bool GroupByAggregatesNeedInnerQuery(IList aggregate { foreach (var aggregate in aggregates) { - Debug.Assert(aggregate.Arguments.Count == 1); + Debug.Assert(aggregate.Arguments.Count >= 1); if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName)) { return true; diff --git a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs index 705a76a24a..c2b9b8cb57 100644 --- a/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs +++ b/src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs @@ -1310,33 +1310,42 @@ public override ISqlFragment Visit(DbGroupByExpression e) var member = members.Current; var alias = QuoteIdentifier(member.Name); - Debug.Assert(aggregate.Arguments.Count == 1); - var translatedAggregateArgument = aggregate.Arguments[0].Accept(this); + var finalArgs = new List(); - object aggregateArgument; - - if (needsInnerQuery) + for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++) { - //In this case the argument to the aggratete is reference to the one projected out by the - // inner query - var wrappingAggregateArgument = new SqlBuilder(); - wrappingAggregateArgument.Append(fromSymbol); - wrappingAggregateArgument.Append("."); - wrappingAggregateArgument.Append(alias); - aggregateArgument = wrappingAggregateArgument; + var argument = aggregate.Arguments[childIndex]; + var translatedAggregateArgument = argument.Accept(this); - innerQuery.Select.Append(separator); - innerQuery.Select.AppendLine(); - innerQuery.Select.Append(translatedAggregateArgument); - innerQuery.Select.Append(" AS "); - innerQuery.Select.Append(alias); - } - else - { - aggregateArgument = translatedAggregateArgument; + object aggregateArgument; + + if (needsInnerQuery) + { + var argAlias = QuoteIdentifier(member.Name + "_" + childIndex); + + //In this case the argument to the aggratete is reference to the one projected out by the + // inner query + var wrappingAggregateArgument = new SqlBuilder(); + wrappingAggregateArgument.Append(fromSymbol); + wrappingAggregateArgument.Append("."); + wrappingAggregateArgument.Append(argAlias); + aggregateArgument = wrappingAggregateArgument; + + innerQuery.Select.Append(separator); + innerQuery.Select.AppendLine(); + innerQuery.Select.Append(translatedAggregateArgument); + innerQuery.Select.Append(" AS "); + innerQuery.Select.Append(argAlias); + } + else + { + aggregateArgument = translatedAggregateArgument; + } + + finalArgs.Add(aggregateArgument); } - ISqlFragment aggregateResult = VisitAggregate(aggregate, aggregateArgument); + ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs); result.Select.Append(separator); result.Select.AppendLine(); @@ -2103,8 +2112,8 @@ public override ISqlFragment Visit(DbVariableReferenceExpression e) // Aggregates are not visited by the normal visitor walk. // // The aggreate go be translated - // The translated aggregate argument - private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregateArgument) + // The translated aggregate arguments + private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList aggregateArguments) { var aggregateFunction = new SqlBuilder(); var aggregateResult = new SqlBuilder(); @@ -2134,7 +2143,13 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregate throw ADP1.NotSupported(EntityRes.GetString(EntityRes.DistinctAggregatesNotSupported)); } - aggregateResult.Append(aggregateArgument); + string separator = String.Empty; + foreach (var arg in aggregateArguments) + { + aggregateResult.Append(separator); + aggregateResult.Append(arg); + separator = ", "; + } aggregateResult.Append(")"); @@ -4508,7 +4523,7 @@ private static bool GroupByAggregatesNeedInnerQuery(IList aggregate { foreach (var aggregate in aggregates) { - Debug.Assert(aggregate.Arguments.Count == 1); + Debug.Assert(aggregate.Arguments.Count >= 1); if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName)) { return true; diff --git a/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs b/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs index d3f61fe936..a56dcc6d35 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/DbAggregate.cs @@ -18,7 +18,7 @@ internal DbAggregate(TypeUsage resultType, DbExpressionList arguments) { DebugCheck.NotNull(resultType); DebugCheck.NotNull(arguments); - Debug.Assert(arguments.Count == 1, "DbAggregate requires a single argument"); + Debug.Assert(arguments.Count >= 1, "DbAggregate requires at least one argument"); _type = resultType; _args = arguments; diff --git a/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs b/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs index 58546e2b9a..702f36a95a 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/DefaultExpressionVisitor.cs @@ -184,19 +184,17 @@ protected virtual DbFunctionAggregate VisitFunctionAggregate(DbFunctionAggregate var newFunction = VisitFunction(aggregate.Function); var newArguments = VisitExpressionList(aggregate.Arguments); - Debug.Assert(newArguments.Count == 1, "Function aggregate had more than one argument?"); - if (!ReferenceEquals(aggregate.Function, newFunction) || !ReferenceEquals(aggregate.Arguments, newArguments)) { if (aggregate.Distinct) { - result = CqtBuilder.AggregateDistinct(newFunction, newArguments[0]); + result = CqtBuilder.AggregateDistinct(newFunction, newArguments); } else { - result = CqtBuilder.Aggregate(newFunction, newArguments[0]); + result = CqtBuilder.Aggregate(newFunction, newArguments); } } } @@ -212,7 +210,6 @@ protected virtual DbGroupAggregate VisitGroupAggregate(DbGroupAggregate aggregat if (aggregate != null) { var newArguments = VisitExpressionList(aggregate.Arguments); - Debug.Assert(newArguments.Count == 1, "Group aggregate had more than one argument?"); if (!ReferenceEquals(aggregate.Arguments, newArguments)) { diff --git a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs index f63c8c5945..68fd43b293 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/ExpressionBuilder/DbExpressionBuilder.cs @@ -239,6 +239,55 @@ private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, return new DbFunctionAggregate(resultType, funcArgs, function, isDistinct); } + /// + /// Creates a new . + /// + /// A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value false. + /// The function that defines the aggregate operation. + /// The argument over which the aggregate function should be calculated. + /// function or argument null. + /// function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function. + public static DbFunctionAggregate Aggregate(this EdmFunction function, IEnumerable arguments) + { + Check.NotNull(function, "function"); + Check.NotNull(arguments, "argument"); + + if (arguments.Any() == false) + { + throw new ArgumentNullException("arguments"); + } + + return CreateFunctionAggregate(function, arguments, false); + } + + /// + /// Creates a new that is applied in a distinct fashion. + /// + /// A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value true. + /// The function that defines the aggregate operation. + /// The arguments over which the aggregate function should be calculated. + /// function or argument is null. + /// function is not an aggregate function, or the result type of argument is not equal or promotable to the parameter type of function. + public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, IEnumerable arguments) + { + Check.NotNull(function, "function"); + Check.NotNull(arguments, "argument"); + + if (arguments.Any() == false) + { + throw new ArgumentNullException("arguments"); + } + + return CreateFunctionAggregate(function, arguments, true); + } + + private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, IEnumerable arguments, bool isDistinct) + { + var funcArgs = ArgumentValidation.ValidateFunctionAggregate(function, arguments); + var resultType = function.ReturnParameter.TypeUsage; + return new DbFunctionAggregate(resultType, funcArgs, function, isDistinct); + } + /// /// Creates a new over the specified argument /// diff --git a/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs b/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs index ade88d7ee2..0551284521 100644 --- a/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs +++ b/src/EntityFramework/Core/Common/CommandTrees/Internal/ExpressionKeyGen.cs @@ -766,7 +766,7 @@ public override void Visit(DbGroupByExpression e) if (ga != null) { _key.Append("GA("); - Debug.Assert(ga.Arguments.Count == 1, "Group aggregate must have one argument."); + Debug.Assert(ga.Arguments.Count >= 1, "Group aggregate must have at least one argument."); ga.Arguments[0].Accept(this); _key.Append(')'); } diff --git a/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs b/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs index 24c941e3bc..532fa33219 100644 --- a/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs +++ b/src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs @@ -1373,10 +1373,10 @@ private static bool TryConvertAsFunctionAggregate( "argument types resolved for the collection aggregate calls must match"); } + // + // Aggregate functions must have at least one argument, and the first argument must be of collection edmType // - // Aggregate functions can have only one argument and of collection edmType - // - Debug.Assert((1 == functionType.Parameters.Count), "(1 == functionType.Parameters.Count)"); + Debug.Assert((1 <= functionType.Parameters.Count), "(1 <= functionType.Parameters.Count)"); // we only support monadic aggregate functions Debug.Assert( TypeSemantics.IsCollectionType(functionType.Parameters[0].TypeUsage), "functionType.Parameters[0].Type is CollectionType"); @@ -1394,11 +1394,11 @@ private static bool TryConvertAsFunctionAggregate( if (methodExpr.DistinctKind == DistinctKind.Distinct) { - functionAggregate = functionType.AggregateDistinct(args[0]); + functionAggregate = functionType.AggregateDistinct(args); } else { - functionAggregate = functionType.Aggregate(args[0]); + functionAggregate = functionType.Aggregate(args); } // diff --git a/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs b/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs index 9a39acfbf3..fdf89eb627 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs @@ -76,7 +76,7 @@ private void Process() [SuppressMessage("Microsoft.Globalization", "CA1303:Do not pass literals as localized parameters", MessageId = "System.Data.Entity.Core.Query.PlanCompiler.PlanCompiler.Assert(System.Boolean,System.String)")] private void TryProcessCandidate( - KeyValuePair candidate, + KeyValuePair> candidate, GroupAggregateVarInfo groupAggregateVarInfo) { IList functionAncestors; @@ -100,15 +100,23 @@ private void TryProcessCandidate( // Remap the template from referencing the groupAggregate var to reference the input to // the group by into // - var argumentNode = OpCopier.Copy(m_command, candidate.Value); var dictionary = new Dictionary(1); dictionary.Add(groupAggregateVarInfo.GroupAggregateVar, inputVar); var remapper = new VarRemapper(m_command, dictionary); - remapper.RemapSubtree(argumentNode); + + var argNodes = new List(candidate.Value.Count); + + foreach (var argumentNode in candidate.Value) + { + var argumentNodeCopy = OpCopier.Copy(m_command, argumentNode); + remapper.RemapSubtree(argumentNodeCopy); + + argNodes.Add(argumentNodeCopy); + } var newFunctionDefiningNode = m_command.CreateNode( m_command.CreateAggregateOp(functionOp.Function, false), - argumentNode); + argNodes); Var newFunctionVar; var varDefNode = m_command.CreateVarDefNode(newFunctionDefiningNode, out newFunctionVar); diff --git a/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs b/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs index 1ee8b5bb12..df2f23919c 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/CTreeGenerator.cs @@ -12,6 +12,7 @@ namespace System.Data.Entity.Core.Query.PlanCompiler using System.Data.Entity.Resources; using System.Diagnostics.CodeAnalysis; using System.Globalization; + using Linq; using SortKey = System.Data.Entity.Core.Query.InternalTrees.SortKey; [SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling")] @@ -1893,17 +1894,18 @@ public override DbExpression Visit(GroupByOp op, Node n) PlanCompiler.Assert(aggVar is ComputedVar, "Non-ComputedVar encountered in Aggregate VarDefOp"); var aggOpNode = aggVarDefNode.Child0; - var aggDef = VisitNode(aggOpNode.Child0); + var args = aggOpNode.Children.Select(VisitNode); + var funcAggOp = aggOpNode.Op as AggregateOp; PlanCompiler.Assert(funcAggOp != null, "Non-Aggregate Node encountered as child of Aggregate VarDefOp Node"); DbFunctionAggregate newFuncAgg; if (funcAggOp.IsDistinctAggregate) { - newFuncAgg = funcAggOp.AggFunc.AggregateDistinct(aggDef); + newFuncAgg = funcAggOp.AggFunc.AggregateDistinct(args); } else { - newFuncAgg = funcAggOp.AggFunc.Aggregate(aggDef); + newFuncAgg = funcAggOp.AggFunc.Aggregate(args); } PlanCompiler.Assert(outputAggVars.Contains(aggVar), "Defined aggregate Var not in Output Aggregate Vars list?"); diff --git a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs index d00b897765..d7f0368adc 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateRefComputingVisitor.cs @@ -154,7 +154,13 @@ public override void Visit(FunctionOp op, Node n) { return; } - PlanCompiler.Assert(n.Children.Count == 1, "Aggregate Function must have one argument"); + + if (n.Children.Count > 1) + { + return; + } + + var argumentNode = n.Child0; GroupAggregateVarInfo referencedGroupAggregateVarInfo; Node templateNode; @@ -164,8 +170,8 @@ public override void Visit(FunctionOp op, Node n) out isUnnested) && (isUnnested || AggregatePushdownUtil.IsVarRefOverGivenVar(templateNode, referencedGroupAggregateVarInfo.GroupAggregateVar))) - { - referencedGroupAggregateVarInfo.CandidateAggregateNodes.Add(new KeyValuePair(n, templateNode)); + { + referencedGroupAggregateVarInfo.CandidateAggregateNodes.Add(new KeyValuePair>(n, new List { templateNode })); } } diff --git a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs index 89c6757ea6..0fe7b84f68 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/GroupAggregateVarInfo.cs @@ -15,7 +15,7 @@ internal class GroupAggregateVarInfo #region Private Fields private readonly Node _definingGroupByNode; - private HashSet> _candidateAggregateNodes; + private HashSet>> _candidateAggregateNodes; private readonly Var _groupAggregateVar; #endregion @@ -39,17 +39,17 @@ internal GroupAggregateVarInfo(Node defingingGroupNode, Var groupAggregateVar) // // Each key value pair represents a candidate aggregate. // The key is the function aggregate subtree and the value is a 'template' of translation of the - // function aggregate's argument over the var representing the group aggregate. - // A valid candidate has an argument that does not have any external references + // function aggregate's arguments over the var representing the group aggregate. + // A valid candidate has arguments that do not have any external references // except for the group aggregate corresponding to the DefiningGroupNode. // - internal HashSet> CandidateAggregateNodes + internal HashSet>> CandidateAggregateNodes { get { if (_candidateAggregateNodes == null) { - _candidateAggregateNodes = new HashSet>(); + _candidateAggregateNodes = new HashSet>>(); } return _candidateAggregateNodes; } diff --git a/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs b/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs index bbd2c91fe4..34dd53eca3 100644 --- a/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs +++ b/src/EntityFramework/Core/Query/PlanCompiler/PlanCompilerUtil.cs @@ -60,8 +60,8 @@ internal static bool IsRowTypeCaseOpWithNullability(CaseOp op, Node n, out bool // // Is this function a collection aggregate function. It is, if - // - it has exactly one child - // - that child is a collection type + // - it has at least one child + // - the first child is a collection type // - and the function has been marked with the aggregate attribute // // the function op @@ -69,7 +69,7 @@ internal static bool IsRowTypeCaseOpWithNullability(CaseOp op, Node n, out bool // true, if this was a collection aggregate function internal static bool IsCollectionAggregateFunction(FunctionOp op, Node n) { - return ((n.Children.Count == 1) && + return ((n.Children.Count >= 1) && TypeSemantics.IsCollectionType(n.Child0.Op.Type) && TypeSemantics.IsAggregateFunction(op.Function)); } diff --git a/src/EntityFramework/Core/SchemaObjectModel/Function.cs b/src/EntityFramework/Core/SchemaObjectModel/Function.cs index 7d2b560d7e..ee0dadd3e1 100644 --- a/src/EntityFramework/Core/SchemaObjectModel/Function.cs +++ b/src/EntityFramework/Core/SchemaObjectModel/Function.cs @@ -416,9 +416,9 @@ internal override void Validate() { if (IsAggregate) { - // Make sure that the function has exactly one parameter and that takes + // Make sure that the function has at least one parameter and the first parameter takes // a collection type - if (Parameters.Count != 1) + if (Parameters.Count == 0) { AddError( ErrorCode.InvalidNumberOfParametersForAggregateFunction, @@ -429,7 +429,7 @@ internal override void Validate() else if (Parameters.GetElementAt(0).CollectionKind == CollectionKind.None) { - // Since we have already checked that there should be exactly one parameter, it should be safe to get the + // Since we have already checked that there should be at least one parameter, it should be safe to get the // first parameter for the function var param = Parameters.GetElementAt(0); diff --git a/src/EntityFramework/Properties/Resources.cs b/src/EntityFramework/Properties/Resources.cs index 4a0d663c28..efdcc5877b 100644 --- a/src/EntityFramework/Properties/Resources.cs +++ b/src/EntityFramework/Properties/Resources.cs @@ -4327,7 +4327,7 @@ internal static string Metadata_General_Error } // - // A string like "Error in Function '{0}'. Aggregate Functions should take exactly one input parameter." + // A string like "Error in Function '{0}'. Aggregate Functions should take at least one input parameter." // internal static string InvalidNumberOfParametersForAggregateFunction(object p0) { diff --git a/src/EntityFramework/Properties/Resources.resx b/src/EntityFramework/Properties/Resources.resx index 95f25b07de..ff53755a3f 100644 --- a/src/EntityFramework/Properties/Resources.resx +++ b/src/EntityFramework/Properties/Resources.resx @@ -1885,7 +1885,7 @@ Inconsistent metadata error - Error in Function '{0}'. Aggregate Functions should take exactly one input parameter. + Error in Function '{0}'. Aggregate Functions should take at least one input parameter. Type of parameter '{0}' in function '{1}' is not valid. The aggregate function parameter type must be of CollectionType. diff --git a/test/EntityFramework/FunctionalTests/Query/GroupAggregateTests.cs b/test/EntityFramework/FunctionalTests/Query/GroupAggregateTests.cs index 0a304600af..991bba069e 100644 --- a/test/EntityFramework/FunctionalTests/Query/GroupAggregateTests.cs +++ b/test/EntityFramework/FunctionalTests/Query/GroupAggregateTests.cs @@ -1294,7 +1294,29 @@ GROUP BY [Extent1].[ProductName] QueryTestHelpers.VerifyQuery(query, workspace, expectedSql); } - public class CodePlex2160 : FunctionalTestBase + [Fact] + public void Using_ssdl_defined_aggregate_function_multiple_parameters() { + var query = + @"select gkey, [ĎefauľtNamėspacĕ].Store.F_UniqueCount(C.Address.Region, C.Address.City) +FROM ProductContainer.Customers as C +Group By C.Address.Region as gkey"; + + var expectedSql = + @"SELECT +1 AS [C1], +[GroupBy1].[K1] AS [Region], +[GroupBy1].[A1] AS [C2] +FROM ( SELECT + [Extent1].[Region] AS [K1], + [dbo].[F_UniqueCount]([Extent1].[Region], [Extent1].[City]) AS [A1] + FROM [dbo].[Customers] AS [Extent1] + GROUP BY [Extent1].[Region] +) AS [GroupBy1]"; + + QueryTestHelpers.VerifyQuery(query, workspace, expectedSql); + } + + public class CodePlex2160 : FunctionalTestBase { public class Foo { diff --git a/test/EntityFramework/FunctionalTests/Query/LinqToEntities/GroupByOptimizationTests.cs b/test/EntityFramework/FunctionalTests/Query/LinqToEntities/GroupByOptimizationTests.cs index 70708a2a8c..efbe129d1a 100644 --- a/test/EntityFramework/FunctionalTests/Query/LinqToEntities/GroupByOptimizationTests.cs +++ b/test/EntityFramework/FunctionalTests/Query/LinqToEntities/GroupByOptimizationTests.cs @@ -136,12 +136,12 @@ public void GroupBy_is_optimized_when_projecting_anonymous_type_containing_group [GroupBy1].[A2] AS [C3] FROM ( SELECT [Extent1].[K1] AS [K1], - MAX([Extent1].[A1]) AS [A1], - MIN([Extent1].[A2]) AS [A2] + MAX([Extent1].[A1_0]) AS [A1], + MIN([Extent1].[A2_0]) AS [A2] FROM ( SELECT [Extent1].[FirstName] AS [K1], - [Extent1].[Id] AS [A1], - [Extent1].[Id] + 2 AS [A2] + [Extent1].[Id] AS [A1_0], + [Extent1].[Id] + 2 AS [A2_0] FROM [dbo].[ArubaOwners] AS [Extent1] ) AS [Extent1] GROUP BY [K1] @@ -221,10 +221,10 @@ public void GroupBy_is_optimized_when_projecting_function_aggregate_with_express [GroupBy1].[A1] AS [C1] FROM ( SELECT [Extent1].[K1] AS [K1], - MAX([Extent1].[A1]) AS [A1] + MAX([Extent1].[A1_0]) AS [A1] FROM ( SELECT [Extent1].[FirstName] AS [K1], - [Extent1].[Id] * 2 AS [A1] + [Extent1].[Id] * 2 AS [A1_0] FROM [dbo].[ArubaOwners] AS [Extent1] ) AS [Extent1] GROUP BY [K1] diff --git a/test/EntityFramework/FunctionalTests/Query/ModelDefinedFunctionTests.cs b/test/EntityFramework/FunctionalTests/Query/ModelDefinedFunctionTests.cs index 41c347c90d..9def1f1fff 100644 --- a/test/EntityFramework/FunctionalTests/Query/ModelDefinedFunctionTests.cs +++ b/test/EntityFramework/FunctionalTests/Query/ModelDefinedFunctionTests.cs @@ -107,9 +107,9 @@ public void Function_returning_collection_of_scalars() 1 AS [C1], [GroupBy1].[A1] AS [C2] FROM ( SELECT - MAX([Filter1].[A1]) AS [A1] + MAX([Filter1].[A1_0]) AS [A1] FROM ( SELECT - [Extent1].[ProductID] - 3 AS [A1] + [Extent1].[ProductID] - 3 AS [A1_0] FROM [dbo].[Products] AS [Extent1] WHERE [Extent1].[Discontinued] IN (0,1) ) AS [Filter1] diff --git a/test/EntityFramework/FunctionalTests/Query/ProductModel.cs b/test/EntityFramework/FunctionalTests/Query/ProductModel.cs index 425c1d227a..8fc1714386 100644 --- a/test/EntityFramework/FunctionalTests/Query/ProductModel.cs +++ b/test/EntityFramework/FunctionalTests/Query/ProductModel.cs @@ -45,6 +45,10 @@ public static class ProductModel + + + + "; public const string Msl = diff --git a/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs b/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs index 2c7b47240c..87d8023a43 100644 --- a/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs +++ b/test/EntityFramework/UnitTests/Core/Common/CommandTrees/DbExpressionBuilderTests.cs @@ -2178,18 +2178,33 @@ public void SetClause_can_create_for_property() #region Null checks + [Fact] + public void Null_check_Aggregate_MultipleArguments() + { + Assert.Throws(() => DbExpressionBuilder.Aggregate(null, new List { DbExpressionBuilder.True })); + Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), (IEnumerable)null)); + Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), new List())); + } + + [Fact] + public void Null_check_AggregateDistinct_MultipleArguments() + { + Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(null, new List { DbExpressionBuilder.True })); + Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), (IEnumerable)null)); + Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), new List())); + } [Fact] public void Null_check_Aggregate() { Assert.Throws(() => DbExpressionBuilder.Aggregate(null, DbExpressionBuilder.True)); - Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), null)); + Assert.Throws(() => DbExpressionBuilder.Aggregate(new EdmFunction("F", "N", DataSpace.SSpace), (DbExpression)null)); } [Fact] public void Null_check_AggregateDistinct() { Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(null, DbExpressionBuilder.True)); - Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), null)); + Assert.Throws(() => DbExpressionBuilder.AggregateDistinct(new EdmFunction("F", "N", DataSpace.SSpace), (DbExpression)null)); } [Fact]