From 6a58596c7e8091225a1695b44a9a3a1a3e1a1b88 Mon Sep 17 00:00:00 2001 From: Paul Middleton Date: Mon, 5 Mar 2018 00:03:27 -0600 Subject: [PATCH] Add support for table valued functions --- ...ntityFrameworkRelationalServicesBuilder.cs | 4 + .../RelationalModelValidator.cs | 3 +- src/EFCore.Relational/Metadata/IDbFunction.cs | 5 + .../Metadata/Internal/DbFunction.cs | 23 + .../Properties/RelationalStrings.Designer.cs | 24 + .../Properties/RelationalStrings.resx | 9 + ...ationalEntityQueryableExpressionVisitor.cs | 83 +- ...yQueryableExpressionVisitorDependencies.cs | 38 +- .../SqlTranslatingExpressionVisitor.cs | 14 + .../CrossJoinLateralOuterExpression.cs | 86 ++ .../Expressions/DbFunctionSourceExpression.cs | 162 ++++ .../Query/Expressions/SelectExpression.cs | 28 + .../TableValuedSqlFunctionExpression.cs | 97 +++ .../RelationalDbFunctionTransformer.cs | 56 ++ ...ationalIExpressionTranformationProvider.cs | 41 + .../RelationalResultOperatorHandler.cs | 29 +- .../RelationalDbFunctionSourceFactory.cs | 24 + .../RelationalQueryCompilationContext.cs | 5 + .../Query/RelationalQueryModelVisitor.cs | 28 +- .../Query/Sql/DefaultQuerySqlGenerator.cs | 27 + .../Query/Sql/ISqlExpressionVisitor.cs | 18 + .../SqlServerQueryCompilationContext.cs | 6 + .../Internal/SqlServerQuerySqlGenerator.cs | 35 + src/EFCore/DbContext.cs | 79 ++ .../EntityFrameworkServicesBuilder.cs | 7 +- .../Internal/DbFunctionSourceFactory.cs | 24 + .../Internal/IDbFunctionSourceFactory.cs | 18 + .../Query/Internal/QueryModelGenerator.cs | 8 +- .../Metadata/DbFunctionMetadataTests.cs | 17 + .../Query/UdfDbFunctionSqlServerTests.cs | 735 +++++++++++++++++- 30 files changed, 1683 insertions(+), 50 deletions(-) create mode 100644 src/EFCore.Relational/Query/Expressions/CrossJoinLateralOuterExpression.cs create mode 100644 src/EFCore.Relational/Query/Expressions/DbFunctionSourceExpression.cs create mode 100644 src/EFCore.Relational/Query/Expressions/TableValuedSqlFunctionExpression.cs create mode 100644 src/EFCore.Relational/Query/Internal/RelationalDbFunctionTransformer.cs create mode 100644 src/EFCore.Relational/Query/Internal/RelationalIExpressionTranformationProvider.cs create mode 100644 src/EFCore.Relational/Query/RelationalDbFunctionSourceFactory.cs create mode 100644 src/EFCore/Internal/DbFunctionSourceFactory.cs create mode 100644 src/EFCore/Internal/IDbFunctionSourceFactory.cs diff --git a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs index 2201625b91c..9936e4aaa4f 100644 --- a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs +++ b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Infrastructure.Internal; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal; using Microsoft.EntityFrameworkCore.Migrations; using Microsoft.EntityFrameworkCore.Migrations.Internal; @@ -21,6 +22,7 @@ using Microsoft.EntityFrameworkCore.Update.Internal; using Microsoft.EntityFrameworkCore.ValueGeneration; using Microsoft.Extensions.DependencyInjection; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; using Remotion.Linq.Parsing.ExpressionVisitors.TreeEvaluation; namespace Microsoft.EntityFrameworkCore.Infrastructure @@ -172,6 +174,8 @@ public override EntityFrameworkServicesBuilder TryAddCoreServices() TryAdd(); TryAdd(); TryAdd(); + TryAdd(); + TryAdd(); TryAdd(p => { diff --git a/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs b/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs index be1a970acd6..092ce18b03e 100644 --- a/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs +++ b/src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs @@ -80,7 +80,8 @@ protected virtual void ValidateDbFunctions([NotNull] IModel model) if (dbFunction.Translation == null) { - if (RelationalDependencies.TypeMappingSource.FindMapping(methodInfo.ReturnType) == null) + if ((dbFunction.IsIQueryable && model.FindEntityType(dbFunction.MethodInfo.ReturnType.GetGenericArguments()[0]) == null) + && RelationalDependencies.TypeMappingSource.FindMapping(methodInfo.ReturnType) == null) { throw new InvalidOperationException( RelationalStrings.DbFunctionInvalidReturnType( diff --git a/src/EFCore.Relational/Metadata/IDbFunction.cs b/src/EFCore.Relational/Metadata/IDbFunction.cs index 4daecfd6fee..0fa643b157e 100644 --- a/src/EFCore.Relational/Metadata/IDbFunction.cs +++ b/src/EFCore.Relational/Metadata/IDbFunction.cs @@ -28,6 +28,11 @@ public interface IDbFunction /// MethodInfo MethodInfo { get; } + /// + /// Does this method return IQueryable + /// + bool IsIQueryable { get; } + /// /// A translation callback for performing custom translation of the method call into a SQL expression fragment. /// diff --git a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs index ba0af8057d2..444649f9895 100644 --- a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs +++ b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs @@ -82,6 +82,23 @@ private DbFunction( RelationalStrings.DbFunctionInvalidReturnType(methodInfo.DisplayName(), methodInfo.ReturnType.ShortDisplayName())); } + if (methodInfo.ReturnType.IsGenericType + && methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(IQueryable<>)) + { + if (methodInfo.IsStatic) + { + throw new ArgumentException( + RelationalStrings.DbFunctionQueryableNotStatic(methodInfo.DisplayName())); + } + + IsIQueryable = true; + + if (model.FindEntityType(methodInfo.ReturnType.GetGenericArguments()[0]) == null) + { + model.AddQueryType(methodInfo.ReturnType.GetGenericArguments()[0]); + } + } + MethodInfo = methodInfo; _model = model; @@ -186,6 +203,12 @@ private void UpdateNameConfigurationSource(ConfigurationSource configurationSour /// public virtual Func, Expression> Translation { get; set; } + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public virtual bool IsIQueryable { get; set; } + /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. diff --git a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs index 40f6cc09878..396d2d13c9d 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs +++ b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs @@ -893,6 +893,30 @@ public static string DbFunctionInvalidInstanceType([CanBeNull] object function, GetString("DbFunctionInvalidInstanceType", nameof(function), nameof(type)), function, type); + /// + /// The DbFunction '{function}' must return IQueryable. + /// + public static string DbFunctionTableValuedFunctionMustReturnIQueryable([CanBeNull] object function) + => string.Format( + GetString("DbFunctionTableValuedFunctionMustReturnIQueryable", nameof(function)), + function); + + /// + /// The DbFunction '{function}' is not registered with the model. + /// + public static string DbFunctionNotFound([CanBeNull] object function) + => string.Format( + GetString("DbFunctionNotFound", nameof(function)), + function); + + /// + /// IQueryable DbFunctions must be instance methods. '{function}' is static. + /// + public static string DbFunctionQueryableNotStatic([CanBeNull] object function) + => string.Format( + GetString("DbFunctionQueryableNotStatic", nameof(function)), + function); + /// /// An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection. /// diff --git a/src/EFCore.Relational/Properties/RelationalStrings.resx b/src/EFCore.Relational/Properties/RelationalStrings.resx index a5466be1f9b..272633c97a0 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.resx +++ b/src/EFCore.Relational/Properties/RelationalStrings.resx @@ -431,6 +431,15 @@ The DbFunction '{function}' defined on type '{type}' must be either a static method or an instance method defined on a DbContext subclass. Instance methods on other types are not supported. + + The DbFunction '{function}' must return IQueryable. + + + The DbFunction '{function}' is not registered with the model. + + + IQueryable DbFunctions must be instance methods. '{function}' is static. + An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection. diff --git a/src/EFCore.Relational/Query/ExpressionVisitors/RelationalEntityQueryableExpressionVisitor.cs b/src/EFCore.Relational/Query/ExpressionVisitors/RelationalEntityQueryableExpressionVisitor.cs index 48bcb1666b4..b2648832c88 100644 --- a/src/EFCore.Relational/Query/ExpressionVisitors/RelationalEntityQueryableExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/ExpressionVisitors/RelationalEntityQueryableExpressionVisitor.cs @@ -33,6 +33,7 @@ public class RelationalEntityQueryableExpressionVisitor : EntityQueryableExpress private readonly IMaterializerFactory _materializerFactory; private readonly IShaperCommandContextFactory _shaperCommandContextFactory; private readonly IQuerySource _querySource; + private readonly ISqlTranslatingExpressionVisitorFactory _sqlTranslatingExpressionVisitorFactory; /// /// Creates a new instance of . @@ -53,6 +54,7 @@ public RelationalEntityQueryableExpressionVisitor( _materializerFactory = dependencies.MaterializerFactory; _shaperCommandContextFactory = dependencies.ShaperCommandContextFactory; _querySource = querySource; + _sqlTranslatingExpressionVisitorFactory = dependencies.SqlTranslatingExpressionVisitorFactory; } private new RelationalQueryModelVisitor QueryModelVisitor => (RelationalQueryModelVisitor)base.QueryModelVisitor; @@ -119,8 +121,15 @@ protected override Expression VisitMethodCall(MethodCallExpression node) { Check.NotNull(node, nameof(node)); - QueryModelVisitor - .BindMethodCallExpression( + var dbFunc = _model.Relational().FindDbFunction(node.Method); + + if (dbFunc != null && dbFunc.IsIQueryable) + { + return VisitDbFunctionSourceExpression(new DbFunctionSourceExpression(node, _model)); + } + else + { + QueryModelVisitor.BindMethodCallExpression( node, (property, querySource, selectExpression) => selectExpression.AddToProjection( @@ -128,7 +137,24 @@ protected override Expression VisitMethodCall(MethodCallExpression node) querySource), bindSubQueries: true); - return base.VisitMethodCall(node); + return base.VisitMethodCall(node); + } + } + + /// + /// Visits Extension nodes. + /// + /// The node being visited. + /// An expression to use in place of the node. + protected override Expression VisitExtension(Expression node) + { + switch (node) + { + case DbFunctionSourceExpression dbNode: + return VisitDbFunctionSourceExpression(dbNode); + default: + return base.VisitExtension(node); + } } /// @@ -238,6 +264,57 @@ var useQueryComposition Expression.Constant(shaper)); } + /// + /// todo + /// + /// todo + /// todo + protected Expression VisitDbFunctionSourceExpression([NotNull] DbFunctionSourceExpression dbFunctionSourceExpression) + { + var relationalQueryCompilationContext = QueryModelVisitor.QueryCompilationContext; + var selectExpression = _selectExpressionFactory.Create(relationalQueryCompilationContext); + + QueryModelVisitor.AddQuery(_querySource, selectExpression); + + var sqlTranslatingExpressionVisitor = _sqlTranslatingExpressionVisitorFactory.Create(QueryModelVisitor); + + var sqlFuncExpression = (SqlFunctionExpression)sqlTranslatingExpressionVisitor.Visit(dbFunctionSourceExpression); + + Func querySqlGeneratorFunc = selectExpression.CreateDefaultQuerySqlGenerator; + + Shaper shaper; + + if (dbFunctionSourceExpression.IsIQueryable) + { + var tableAlias + = relationalQueryCompilationContext.CreateUniqueTableAlias( + _querySource.HasGeneratedItemName() + ? dbFunctionSourceExpression.Name[0].ToString().ToLowerInvariant() + : _querySource.ItemName); + + selectExpression.AddTable(new TableValuedSqlFunctionExpression(sqlFuncExpression, _querySource, tableAlias)); + + var entityType = _model.FindEntityType(dbFunctionSourceExpression.ReturnType); + + shaper = CreateShaper(dbFunctionSourceExpression.ReturnType, entityType, selectExpression); + } + else + { + selectExpression.AddToProjection(sqlFuncExpression); + + shaper = new ValueBufferShaper(_querySource); + } + + return Expression.Call( + QueryModelVisitor.QueryCompilationContext.QueryMethodProvider // TODO: Don't use ShapedQuery when projecting + .ShapedQueryMethod + .MakeGenericMethod(shaper.Type), + EntityQueryModelVisitor.QueryContextParameter, + Expression.Constant(_shaperCommandContextFactory.Create(querySqlGeneratorFunc)), + Expression.Constant(shaper)); + } + + private Shaper CreateShaper(Type elementType, IEntityType entityType, SelectExpression selectExpression) { Shaper shaper; diff --git a/src/EFCore.Relational/Query/ExpressionVisitors/RelationalEntityQueryableExpressionVisitorDependencies.cs b/src/EFCore.Relational/Query/ExpressionVisitors/RelationalEntityQueryableExpressionVisitorDependencies.cs index 937d86d0448..6ce91143418 100644 --- a/src/EFCore.Relational/Query/ExpressionVisitors/RelationalEntityQueryableExpressionVisitorDependencies.cs +++ b/src/EFCore.Relational/Query/ExpressionVisitors/RelationalEntityQueryableExpressionVisitorDependencies.cs @@ -47,21 +47,25 @@ public sealed class RelationalEntityQueryableExpressionVisitorDependencies /// The select expression factory. /// The materializer factory. /// The shaper command context factory. + /// TODO. public RelationalEntityQueryableExpressionVisitorDependencies( [NotNull] IModel model, [NotNull] ISelectExpressionFactory selectExpressionFactory, [NotNull] IMaterializerFactory materializerFactory, - [NotNull] IShaperCommandContextFactory shaperCommandContextFactory) + [NotNull] IShaperCommandContextFactory shaperCommandContextFactory, + [NotNull] ISqlTranslatingExpressionVisitorFactory sqlTranslatingExpressionVisitorFactory) { Check.NotNull(model, nameof(model)); Check.NotNull(selectExpressionFactory, nameof(selectExpressionFactory)); Check.NotNull(materializerFactory, nameof(materializerFactory)); Check.NotNull(shaperCommandContextFactory, nameof(shaperCommandContextFactory)); + Check.NotNull(sqlTranslatingExpressionVisitorFactory, nameof(sqlTranslatingExpressionVisitorFactory)); Model = model; SelectExpressionFactory = selectExpressionFactory; MaterializerFactory = materializerFactory; ShaperCommandContextFactory = shaperCommandContextFactory; + SqlTranslatingExpressionVisitorFactory = sqlTranslatingExpressionVisitorFactory; } /// @@ -84,6 +88,13 @@ public RelationalEntityQueryableExpressionVisitorDependencies( /// public IShaperCommandContextFactory ShaperCommandContextFactory { get; } + /// + /// todo + /// + public ISqlTranslatingExpressionVisitorFactory SqlTranslatingExpressionVisitorFactory { get; } + + + /// /// Clones this dependency parameter object with one service replaced. /// @@ -94,7 +105,8 @@ public RelationalEntityQueryableExpressionVisitorDependencies With([NotNull] IMo model, SelectExpressionFactory, MaterializerFactory, - ShaperCommandContextFactory); + ShaperCommandContextFactory, + SqlTranslatingExpressionVisitorFactory); /// /// Clones this dependency parameter object with one service replaced. @@ -106,7 +118,8 @@ public RelationalEntityQueryableExpressionVisitorDependencies With([NotNull] ISe Model, selectExpressionFactory, MaterializerFactory, - ShaperCommandContextFactory); + ShaperCommandContextFactory, + SqlTranslatingExpressionVisitorFactory); /// /// Clones this dependency parameter object with one service replaced. @@ -118,7 +131,8 @@ public RelationalEntityQueryableExpressionVisitorDependencies With([NotNull] IMa Model, SelectExpressionFactory, materializerFactory, - ShaperCommandContextFactory); + ShaperCommandContextFactory, + SqlTranslatingExpressionVisitorFactory); /// /// Clones this dependency parameter object with one service replaced. @@ -130,6 +144,20 @@ public RelationalEntityQueryableExpressionVisitorDependencies With([NotNull] ISh Model, SelectExpressionFactory, MaterializerFactory, - shaperCommandContextFactory); + shaperCommandContextFactory, + SqlTranslatingExpressionVisitorFactory); + + /// + /// Clones this dependency parameter object with one service replaced. + /// + /// A replacement for the current dependency of this type. + /// A new parameter object with the given service replaced. + public RelationalEntityQueryableExpressionVisitorDependencies With([NotNull] ISqlTranslatingExpressionVisitorFactory sqlTranslatingExpressionVisitorFactory) + => new RelationalEntityQueryableExpressionVisitorDependencies( + Model, + SelectExpressionFactory, + MaterializerFactory, + ShaperCommandContextFactory, + sqlTranslatingExpressionVisitorFactory); } } diff --git a/src/EFCore.Relational/Query/ExpressionVisitors/SqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/ExpressionVisitors/SqlTranslatingExpressionVisitor.cs index d80764199eb..a45d57ded06 100644 --- a/src/EFCore.Relational/Query/ExpressionVisitors/SqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/ExpressionVisitors/SqlTranslatingExpressionVisitor.cs @@ -1104,6 +1104,20 @@ var equalityExpression ? new NullCompensatedExpression(newOperand) : nullCompensatedExpression; } + case DbFunctionSourceExpression dbFunctionExpression: + { + var newArguments = Visit(dbFunctionExpression.Arguments); + + if (newArguments.Any(a => a == null)) + { + return null; + } + + //TODO - can you custom translate here? + return //dbFunctionExpression.Translate(newArguments) + //?? + new SqlFunctionExpression(dbFunctionExpression.Name, dbFunctionExpression.UnwrappedType, dbFunctionExpression.Schema, newArguments); + } case DiscriminatorPredicateExpression discriminatorPredicateExpression: return new DiscriminatorPredicateExpression( base.VisitExtension(expression), discriminatorPredicateExpression.QuerySource); diff --git a/src/EFCore.Relational/Query/Expressions/CrossJoinLateralOuterExpression.cs b/src/EFCore.Relational/Query/Expressions/CrossJoinLateralOuterExpression.cs new file mode 100644 index 00000000000..fb111934a3b --- /dev/null +++ b/src/EFCore.Relational/Query/Expressions/CrossJoinLateralOuterExpression.cs @@ -0,0 +1,86 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq.Expressions; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Query.Sql; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Query.Expressions +{ + /// + /// Represents a SQL CROSS JOIN LATERAL OUTER expression. + /// + public class CrossJoinLateralOuterExpression : JoinExpressionBase + { + /// + /// Creates a new instance of CrossJoinLateralExpression. + /// + /// The target table expression. + public CrossJoinLateralOuterExpression([NotNull] TableExpressionBase tableExpression) + : base(Check.NotNull(tableExpression, nameof(tableExpression))) + { + } + + /// + /// Dispatches to the specific visit method for this node type. + /// + protected override Expression Accept(ExpressionVisitor visitor) + { + Check.NotNull(visitor, nameof(visitor)); + + return visitor is ISqlExpressionVisitor specificVisitor + ? specificVisitor.VisitCrossJoinLateralOuter(this) + : base.Accept(visitor); + } + + /// + /// Tests if this object is considered equal to another. + /// + /// The object to compare with the current object. + /// + /// true if the objects are considered equal, false if they are not. + /// + public override bool Equals(object obj) + { + if (obj is null) + { + return false; + } + + if (ReferenceEquals(this, obj)) + { + return true; + } + + return obj.GetType() == GetType() && Equals((CrossJoinLateralExpression)obj); + } + + private bool Equals(CrossJoinLateralExpression other) + => string.Equals(Alias, other.Alias) + && Equals(QuerySource, other.QuerySource); + + /// + /// Returns a hash code for this object. + /// + /// + /// A hash code for this object. + /// + public override int GetHashCode() + { + unchecked + { + var hashCode = Alias?.GetHashCode() ?? 0; + hashCode = (hashCode * 397) ^ (QuerySource?.GetHashCode() ?? 0); + + return hashCode; + } + } + + /// + /// Creates a representation of the Expression. + /// + /// A representation of the Expression. + public override string ToString() => "CROSS JOIN LATERAL OUTER " + TableExpression; + } +} diff --git a/src/EFCore.Relational/Query/Expressions/DbFunctionSourceExpression.cs b/src/EFCore.Relational/Query/Expressions/DbFunctionSourceExpression.cs new file mode 100644 index 00000000000..2cbc342fe89 --- /dev/null +++ b/src/EFCore.Relational/Query/Expressions/DbFunctionSourceExpression.cs @@ -0,0 +1,162 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors.Internal; + +namespace Microsoft.EntityFrameworkCore.Query.Expressions +{ + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public class DbFunctionSourceExpression : Expression + { + private readonly IDbFunction _dbFunction; + + /// + /// todo + /// + public override ExpressionType NodeType => ExpressionType.Extension; + + /// + /// todo + /// + public override Type Type { get; } + + /// + /// todo + /// + public virtual Type ReturnType { get; } + + /// + /// todo + /// + public virtual string Schema => _dbFunction.Schema; + + /// + /// todo + /// + public virtual Type UnwrappedType => Type.IsGenericType ? Type.GetGenericArguments()[0] : Type; + + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public virtual string Name => _dbFunction.FunctionName; + + /// + /// todo + /// + public virtual bool IsIQueryable => _dbFunction.IsIQueryable; + + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public virtual ReadOnlyCollection Arguments { get; [param: NotNull] set; } + + /// + /// todo + /// + /// todo + /// todo + public DbFunctionSourceExpression([NotNull] MethodCallExpression expression, [NotNull] IModel model) + { + _dbFunction = FindDbFunction(expression, model); + Arguments = expression.Arguments; + + if (expression.Method.ReturnType.IsGenericType) + { + //todo - add unit test + if (expression.Method.ReturnType.GetGenericTypeDefinition() != typeof(IQueryable<>)) + { + throw new InvalidOperationException( + RelationalStrings.DbFunctionTableValuedFunctionMustReturnIQueryable(_dbFunction.FunctionName)); + } + + //todo - should i be using the dbfunction return type here? If not do I have to verify the expression return type? + Type = expression.Method.ReturnType; + ReturnType = expression.Method.ReturnType.GetGenericArguments()[0]; + } + else + { + Type = typeof(IEnumerable<>).MakeGenericType(expression.Method.ReturnType); + ReturnType = expression.Method.ReturnType; + } + } + + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public DbFunctionSourceExpression([NotNull] DbFunctionSourceExpression oldFuncExpression, [NotNull] ReadOnlyCollection newArguments) + { + Arguments = new ReadOnlyCollection(newArguments); + _dbFunction = oldFuncExpression._dbFunction; + ReturnType = oldFuncExpression.ReturnType; + Type = oldFuncExpression.Type; + } + + /// + /// todo + /// + /// todo + /// todo + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var newArguments = visitor.Visit(Arguments); + + if (visitor is ParameterExtractingExpressionVisitor) + { + newArguments = new ReadOnlyCollection(newArguments.Select(a => a is LambdaExpression l ? l.Body : a).ToList()); + } + + return newArguments != Arguments + ? new DbFunctionSourceExpression(this, newArguments) + : this; + } + + private IDbFunction FindDbFunction(MethodCallExpression exp, IModel model) + { + var method = exp.Method.DeclaringType.GetMethod( + exp.Method.Name, + exp.Method.GetParameters() + .Select(p => UnwrapParamterType(p.ParameterType)) + .ToArray()); + + var dbFunction = model.Relational().FindDbFunction(method); + + //todo - add unit test + if (dbFunction == null) + { + throw new InvalidOperationException( + RelationalStrings.DbFunctionNotFound(method.Name)); + } + + return dbFunction; + + Type UnwrapParamterType(Type paramType) + { + if (paramType.IsGenericType + && paramType.GetGenericTypeDefinition() == typeof(Expression<>)) + { + var expressionType = paramType.GetGenericArguments()[0]; + + if (expressionType.IsGenericType + && expressionType.GetGenericTypeDefinition() == typeof(Func<>)) + { + return expressionType.GetGenericArguments().Last(); + } + } + + return paramType; + } + } + } +} diff --git a/src/EFCore.Relational/Query/Expressions/SelectExpression.cs b/src/EFCore.Relational/Query/Expressions/SelectExpression.cs index 6fd4defc2dd..ebb246d2772 100644 --- a/src/EFCore.Relational/Query/Expressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/Expressions/SelectExpression.cs @@ -1007,6 +1007,34 @@ public virtual JoinExpressionBase AddCrossJoinLateral( return crossJoinLateralExpression; } + /// + /// Adds a SQL CROSS JOIN LATERAL to this SelectExpression. + /// + /// The target table expression. + /// A sequence of expressions that should be added to the projection. + public virtual JoinExpressionBase AddCrossJoinLateralOuter( + [NotNull] TableExpressionBase tableExpression, + [NotNull] IEnumerable projection) + { + Check.NotNull(tableExpression, nameof(tableExpression)); + Check.NotNull(projection, nameof(projection)); + + //todo - this seems very wrong - where is the right place to do this? By the caller? By the sql walker? + //for TVF we need to unwrap the inner select clause + if (tableExpression is SelectExpression s && s.Tables.First() is TableValuedSqlFunctionExpression) + { + tableExpression = s.Tables.First(); + projection = s.Projection; + } + + var crossJoinLateralOuterExpression = new CrossJoinLateralOuterExpression(tableExpression); + + _tables.Add(crossJoinLateralOuterExpression); + _projection.AddRange(projection); + + return crossJoinLateralOuterExpression; + } + /// /// Adds a SQL INNER JOIN to this SelectExpression. /// diff --git a/src/EFCore.Relational/Query/Expressions/TableValuedSqlFunctionExpression.cs b/src/EFCore.Relational/Query/Expressions/TableValuedSqlFunctionExpression.cs new file mode 100644 index 00000000000..3fcf2de82da --- /dev/null +++ b/src/EFCore.Relational/Query/Expressions/TableValuedSqlFunctionExpression.cs @@ -0,0 +1,97 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Query.Sql; +using Microsoft.EntityFrameworkCore.Utilities; +using Remotion.Linq.Clauses; + +namespace Microsoft.EntityFrameworkCore.Query.Expressions +{ + /// + /// Represents a SQL Table Valued Fuction. + /// + public class TableValuedSqlFunctionExpression : TableExpressionBase + { + private SqlFunctionExpression _sqlFunctionExpression; + + /// + /// todo + /// + public virtual SqlFunctionExpression SqlFunctionExpression => _sqlFunctionExpression; + + /// + /// todo + /// + /// todo + /// todo + /// todo + public TableValuedSqlFunctionExpression([NotNull] SqlFunctionExpression sqlFunction, [NotNull] IQuerySource querySource, [CanBeNull] string alias) + : this(sqlFunction.FunctionName, sqlFunction.Type, sqlFunction.Schema, sqlFunction.Arguments, querySource, alias) + { + + } + + /// + /// todo + /// + /// todo + /// todo + /// todo + /// todo + /// todo + /// todo + public TableValuedSqlFunctionExpression([NotNull] string functionName, + [NotNull] Type returnType, + [CanBeNull] string schema, + [NotNull] IEnumerable arguments, + [NotNull] IQuerySource querySource, + [CanBeNull]string alias) + : base(querySource, alias) + { + //TODO - make sure return type is of type IQueryable + //TODO - Do I even need this class or can I just use the SqlFunctionExpression? Thus far not much is happening in here + _sqlFunctionExpression = new SqlFunctionExpression(functionName, returnType, schema, arguments); + } + + /// + /// todo + /// + /// todo + public override string ToString() + { + return _sqlFunctionExpression.ToString(); + } + + /// + /// todo + /// + /// todo + /// todo + protected override Expression Accept(ExpressionVisitor visitor) + { + Check.NotNull(visitor, nameof(visitor)); + + return visitor is ISqlExpressionVisitor specificVisitor + ? specificVisitor.VisitTableValuedSqlFunctionExpression(this) + : base.Accept(visitor); + } + + /// + /// todo + /// + /// todo + /// todo + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var newArguments = visitor.Visit(new ReadOnlyCollection(_sqlFunctionExpression.Arguments.ToList())); + + return newArguments != _sqlFunctionExpression.Arguments + ? new TableValuedSqlFunctionExpression(new SqlFunctionExpression(_sqlFunctionExpression.FunctionName, Type, _sqlFunctionExpression.Schema, newArguments), QuerySource, Alias) + : this; + } + } +} diff --git a/src/EFCore.Relational/Query/Internal/RelationalDbFunctionTransformer.cs b/src/EFCore.Relational/Query/Internal/RelationalDbFunctionTransformer.cs new file mode 100644 index 00000000000..ee94e6b6648 --- /dev/null +++ b/src/EFCore.Relational/Query/Internal/RelationalDbFunctionTransformer.cs @@ -0,0 +1,56 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query.Expressions; +using Microsoft.EntityFrameworkCore.Utilities; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public class RelationalDbFunctionTransformer : IExpressionTransformer + { + private readonly IModel _model; + + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public RelationalDbFunctionTransformer([NotNull] IModel model) + { + Check.NotNull(model, nameof(model)); + + _model = model; + } + + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public Expression Transform(MethodCallExpression expression) + { + if (_model.Relational().FindDbFunction(expression.Method)?.IsIQueryable == true) + { + return new DbFunctionSourceExpression(expression, _model); + } + + return expression; + } + + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public virtual ExpressionType[] SupportedExpressionTypes => new[] { ExpressionType.Call }; + } +} diff --git a/src/EFCore.Relational/Query/Internal/RelationalIExpressionTranformationProvider.cs b/src/EFCore.Relational/Query/Internal/RelationalIExpressionTranformationProvider.cs new file mode 100644 index 00000000000..4d276b658ed --- /dev/null +++ b/src/EFCore.Relational/Query/Internal/RelationalIExpressionTranformationProvider.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Linq.Expressions; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Utilities; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; + +namespace Microsoft.EntityFrameworkCore.Query.Internal +{ + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public class RelationalIExpressionTranformationProvider : IExpressionTranformationProvider + { + private readonly ExpressionTransformerRegistry _transformProvider; + + /// + /// todo + /// + /// todo + public RelationalIExpressionTranformationProvider([NotNull] IModel model) + { + Check.NotNull(model, nameof(model)); + _transformProvider = ExpressionTransformerRegistry.CreateDefault(); + _transformProvider.Register(new RelationalDbFunctionTransformer(model)); + } + + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public IEnumerable GetTransformations(Expression expression) + { + return _transformProvider.GetTransformations(expression); + } + } +} diff --git a/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs b/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs index a9ea489ea78..c516b4a6904 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs @@ -359,28 +359,31 @@ private static Expression HandleDefaultIfEmpty(HandlerContext handlerContext) return handlerContext.EvalOnClient(); } - var selectExpression = handlerContext.SelectExpression; + if (!(handlerContext.QueryModel.MainFromClause.FromExpression is DbFunctionSourceExpression && handlerContext.QueryModelVisitor.QueryCompilationContext.IsLateralJoinOuterSupported)) + { + var selectExpression = handlerContext.SelectExpression; - selectExpression.PushDownSubquery(); - selectExpression.ExplodeStarProjection(); + selectExpression.PushDownSubquery(); + selectExpression.ExplodeStarProjection(); - var subquery = selectExpression.Tables.Single(); + var subquery = selectExpression.Tables.Single(); - selectExpression.ClearTables(); + selectExpression.ClearTables(); - var emptySelectExpression = handlerContext.SelectExpressionFactory.Create(handlerContext.QueryModelVisitor.QueryCompilationContext, "empty"); - emptySelectExpression.AddToProjection(new AliasExpression("empty", Expression.Constant(null))); + var emptySelectExpression = handlerContext.SelectExpressionFactory.Create(handlerContext.QueryModelVisitor.QueryCompilationContext, "empty"); + emptySelectExpression.AddToProjection(new AliasExpression("empty", Expression.Constant(null))); - selectExpression.AddTable(emptySelectExpression); + selectExpression.AddTable(emptySelectExpression); - var leftOuterJoinExpression = new LeftOuterJoinExpression(subquery); - var constant1 = Expression.Constant(1); + var leftOuterJoinExpression = new LeftOuterJoinExpression(subquery); + var constant1 = Expression.Constant(1); - leftOuterJoinExpression.Predicate = Expression.Equal(constant1, constant1); + leftOuterJoinExpression.Predicate = Expression.Equal(constant1, constant1); - selectExpression.AddTable(leftOuterJoinExpression); + selectExpression.AddTable(leftOuterJoinExpression); - selectExpression.ProjectStarTable = subquery; + selectExpression.ProjectStarTable = subquery; + } handlerContext.QueryModelVisitor.Expression = new DefaultIfEmptyExpressionVisitor( diff --git a/src/EFCore.Relational/Query/RelationalDbFunctionSourceFactory.cs b/src/EFCore.Relational/Query/RelationalDbFunctionSourceFactory.cs new file mode 100644 index 00000000000..6e5f997c4b5 --- /dev/null +++ b/src/EFCore.Relational/Query/RelationalDbFunctionSourceFactory.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Text; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query.Expressions; + +namespace Microsoft.EntityFrameworkCore.Query +{ + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public class RelationalDbFunctionSourceFactory : IDbFunctionSourceFactory + { + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public Expression GenerateDbFunctionSource(MethodCallExpression methodCall, IModel model) + => new DbFunctionSourceExpression(methodCall, model); + } +} diff --git a/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs b/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs index eacae581283..434917806fd 100644 --- a/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs +++ b/src/EFCore.Relational/Query/RelationalQueryCompilationContext.cs @@ -78,6 +78,11 @@ var relationalQueryModelVisitor /// public virtual bool IsLateralJoinSupported => false; + /// + /// True if the current provider supports SQL OUTER LATERAL JOIN. + /// + public virtual bool IsLateralJoinOuterSupported => false; + /// /// Max length of the table alias supported by provider. /// diff --git a/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs index af84a977ab8..b4894065162 100644 --- a/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs @@ -1126,6 +1126,20 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que Check.NotNull(selectClause, nameof(selectClause)); Check.NotNull(queryModel, nameof(queryModel)); + if ((selectClause.Selector.TryGetReferencedQuerySource() as MainFromClause)?.FromExpression is DbFunctionSourceExpression d + && !d.IsIQueryable) + { + var readExp = BindReadValueMethod(selectClause.Selector.Type, CurrentParameter, 0); + + Expression = Expression.Call( + LinqOperatorProvider.Select + .MakeGenericMethod(CurrentParameter.Type, readExp.Type), + Expression, + Expression.Lambda(readExp, CurrentParameter)); + + return; + } + base.VisitSelectClause(selectClause, queryModel); if (Expression is MethodCallExpression methodCallExpression @@ -1473,11 +1487,19 @@ var innerShapedQuery outerSelectExpression.RemoveRangeFromProjection(previousProjectionCount); } + //todo this first check is too complex here var joinExpression = correlated - ? outerSelectExpression.AddCrossJoinLateral( - innerSelectExpression.Tables.First(), - innerSelectExpression.Projection) + ? QueryCompilationContext.IsLateralJoinOuterSupported + && innerShapedQuery?.Method.MethodIsClosedFormOf(LinqOperatorProvider.DefaultIfEmpty) == true + && innerSelectExpression.Tables.First() is SelectExpression s + && s.Tables.First() is TableValuedSqlFunctionExpression + ? outerSelectExpression.AddCrossJoinLateralOuter( + innerSelectExpression.Tables.First(), + innerSelectExpression.Projection) + : outerSelectExpression.AddCrossJoinLateral( + innerSelectExpression.Tables.First(), + innerSelectExpression.Projection) : outerSelectExpression.AddCrossJoin( innerSelectExpression.Tables.First(), innerSelectExpression.Projection); diff --git a/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs b/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs index c239a28e21b..7ae3083aadd 100644 --- a/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs @@ -846,6 +846,21 @@ public virtual Expression VisitCrossJoinLateral(CrossJoinLateralExpression cross return crossJoinLateralExpression; } + /// + /// Visit a CrossJoinLateralOuterExpression expression. + /// + /// The cross join lateral outer expression. + /// + /// An Expression. + /// + public virtual Expression VisitCrossJoinLateralOuter(CrossJoinLateralOuterExpression crossJoinLateralOuterExpression) + { + Check.NotNull(crossJoinLateralOuterExpression, nameof(crossJoinLateralOuterExpression)); + _relationalCommandBuilder.Append("CROSS JOIN LATERAL OUTER"); + Visit(crossJoinLateralOuterExpression.TableExpression); + return crossJoinLateralOuterExpression; + } + /// /// Visit a SqlFragmentExpression. /// @@ -1500,6 +1515,18 @@ protected virtual void GenerateFunctionCall( _typeMapping = parentTypeMapping; } + /// + /// Visits a todo. + /// + /// todo + /// + /// An Expression. + /// + public virtual Expression VisitTableValuedSqlFunctionExpression(TableValuedSqlFunctionExpression tableValuedSqlFunctionExpression) + { + return VisitSqlFunction(tableValuedSqlFunctionExpression.SqlFunctionExpression); + } + /// /// Visit a SQL ExplicitCastExpression. /// diff --git a/src/EFCore.Relational/Query/Sql/ISqlExpressionVisitor.cs b/src/EFCore.Relational/Query/Sql/ISqlExpressionVisitor.cs index 01932ab0074..1718e6731e4 100644 --- a/src/EFCore.Relational/Query/Sql/ISqlExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Sql/ISqlExpressionVisitor.cs @@ -93,6 +93,15 @@ public interface ISqlExpressionVisitor /// Expression VisitCrossJoinLateral([NotNull] CrossJoinLateralExpression crossJoinLateralExpression); + /// + /// Visit a CrossJoinLateralOuterExpression. + /// + /// The cross join lateral outer expression. + /// + /// An Expression. + /// + Expression VisitCrossJoinLateralOuter([NotNull] CrossJoinLateralOuterExpression crossJoinLateralOuterExpression); + /// /// Visit an InnerJoinExpression. /// @@ -182,5 +191,14 @@ public interface ISqlExpressionVisitor /// An Expression. /// Expression VisitColumnReference([NotNull] ColumnReferenceExpression columnReferenceExpression); + + /// + /// Visit a TableValuedSqlFunctionExpression. + /// + /// todo + /// + /// An Expression. + /// + Expression VisitTableValuedSqlFunctionExpression([NotNull] TableValuedSqlFunctionExpression tableValuedSqlFunctionExpression); } } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs index 1d7d6c43331..0480f0c36bf 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs @@ -33,5 +33,11 @@ public SqlServerQueryCompilationContext( /// directly from your code. This API may change or be removed in future releases. /// public override bool IsLateralJoinSupported => true; + + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public override bool IsLateralJoinOuterSupported => true; } } diff --git a/src/EFCore.SqlServer/Query/Sql/Internal/SqlServerQuerySqlGenerator.cs b/src/EFCore.SqlServer/Query/Sql/Internal/SqlServerQuerySqlGenerator.cs index 0032cfaee02..215f6924402 100644 --- a/src/EFCore.SqlServer/Query/Sql/Internal/SqlServerQuerySqlGenerator.cs +++ b/src/EFCore.SqlServer/Query/Sql/Internal/SqlServerQuerySqlGenerator.cs @@ -73,6 +73,21 @@ public override Expression VisitCrossJoinLateral(CrossJoinLateralExpression cros return crossJoinLateralExpression; } + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public override Expression VisitCrossJoinLateralOuter(CrossJoinLateralOuterExpression crossJoinLateralOuterExpression) + { + Check.NotNull(crossJoinLateralOuterExpression, nameof(crossJoinLateralOuterExpression)); + + Sql.Append("OUTER APPLY "); + + Visit(crossJoinLateralOuterExpression.TableExpression); + + return crossJoinLateralOuterExpression; + } + /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. @@ -127,6 +142,26 @@ public override Expression VisitSqlFunction(SqlFunctionExpression sqlFunctionExp return base.VisitSqlFunction(sqlFunctionExpression); } + /// + /// Visits a TableValuedSqlFunctionExpression. + /// + /// The SQL function expression. + /// + /// An Expression. + /// + public override Expression VisitTableValuedSqlFunctionExpression(TableValuedSqlFunctionExpression tableValuedSqlFunctionExpression) + { + base.VisitTableValuedSqlFunctionExpression(tableValuedSqlFunctionExpression); + + if (tableValuedSqlFunctionExpression.Alias != null) + { + Sql.Append(" AS ") + .Append(SqlGenerator.DelimitIdentifier(tableValuedSqlFunctionExpression.Alias)); + } + + return tableValuedSqlFunctionExpression; + } + /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. diff --git a/src/EFCore/DbContext.cs b/src/EFCore/DbContext.cs index 62113f36323..e2d0a8f69c5 100644 --- a/src/EFCore/DbContext.cs +++ b/src/EFCore/DbContext.cs @@ -3,9 +3,11 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.ComponentModel; using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; using System.Reflection; using System.Threading; using System.Threading.Tasks; @@ -1412,6 +1414,83 @@ public virtual Task FindAsync([NotNull] object[] keyValues, Ca /// IServiceProvider IInfrastructure.Instance => InternalServiceProvider; + /// + /// todo + /// + /// todo + /// todo + /// todo + /// todo + protected virtual T ExecuteScalarMethod(Expression> dbFuncCall) + where U : DbContext + { + //todo - verify dbFuncCall contains a method call expression + var dbFuncFac = InternalServiceProvider.GetRequiredService(); + var resultsQuery = DbContextDependencies.QueryProvider.Execute(dbFuncFac.GenerateDbFunctionSource(dbFuncCall.Body as MethodCallExpression, Model)) as IEnumerable; + + var results = resultsQuery.ToList(); + + return results[0]; + //how am I going to get the dbFunction from the model here - I can't access FindDbFunction because it is in relational. + //maybe I need to pass a reference to the model and find it later? If I move DbFunctionSourceExpression into relational I can access it, but then how do I create DbFunctionSourceExpression. + //need some kind of factory..... + } + + /// + /// todo + /// + /// todo + /// todo + /// todo + /// todo + protected IQueryable ExecuteTableValuedFunction(Expression>> dbFuncCall) + where U : DbContext + { + var dbFuncFac = InternalServiceProvider.GetRequiredService(); + + //todo - verify dbFuncCall contains a method call expression + var resultsQuery = dbFuncFac.GenerateDbFunctionSource(dbFuncCall.Body as MethodCallExpression, Model); + + return DbContextDependencies.QueryProvider.CreateQuery(resultsQuery); + } + + /// + /// todo + /// + /// todo + /// todo + /// todo + /// todo + protected IQueryable ExecuteTableValuedFunction(MethodInfo callingMethod, params object[] methodParams) + { + var c = Expression.Call(Expression.Constant(this), + callingMethod, + methodParams.Select(mp => Expression.Constant(mp))); + + var dbFuncFac = InternalServiceProvider.GetRequiredService(); + var resultsQuery = dbFuncFac.GenerateDbFunctionSource(c, Model); + + return DbContextDependencies.QueryProvider.CreateQuery(resultsQuery); + + /* this.DbContextDependencies.QueryProvider.CreateQuery() + return (IQueryable) DbContextDependencies.QuerySource.CreateQuery(this, callingMethod.ReturnType.GetGenericArguments()[0]); + */ + /* var paramExps = methodParams.Select(mp => + { + if ((mp as Expression)?.NodeType == ExpressionType.Lambda) + return Expression.Invoke(mp as Expression); + + return Expression.Constant(mp); + }); + + + return QueryProvider.CreateQuery( + Expression.Call( + Expression.Constant(this), + callingMethod, + paramExps));*/ + } + #region Hidden System.Object members /// diff --git a/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs b/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs index 4b470a19f9b..551cca3c970 100644 --- a/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs +++ b/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs @@ -24,6 +24,7 @@ using Microsoft.EntityFrameworkCore.ValueGeneration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Remotion.Linq.Parsing.ExpressionVisitors.Transformation; using Remotion.Linq.Parsing.ExpressionVisitors.TreeEvaluation; namespace Microsoft.EntityFrameworkCore.Infrastructure @@ -143,7 +144,9 @@ public static readonly IDictionary CoreServices { typeof(IPropertyListener), new ServiceCharacteristics(ServiceLifetime.Scoped, multipleRegistrations: true) }, { typeof(IResettableService), new ServiceCharacteristics(ServiceLifetime.Scoped, multipleRegistrations: true) }, { typeof(ISingletonOptions), new ServiceCharacteristics(ServiceLifetime.Singleton, multipleRegistrations: true) }, - { typeof(IEvaluatableExpressionFilter), new ServiceCharacteristics(ServiceLifetime.Scoped) } + { typeof(IEvaluatableExpressionFilter), new ServiceCharacteristics(ServiceLifetime.Scoped) }, + { typeof(IExpressionTranformationProvider), new ServiceCharacteristics(ServiceLifetime.Scoped) }, + { typeof(IDbFunctionSourceFactory), new ServiceCharacteristics(ServiceLifetime.Singleton) } }; /// @@ -274,6 +277,7 @@ public virtual EntityFrameworkServicesBuilder TryAddCoreServices() TryAdd(p => p.GetService()); TryAdd>(p => p.GetService); TryAdd(); + TryAdd(p => ExpressionTransformerRegistry.CreateDefault()); TryAdd(); TryAdd(); TryAdd(); @@ -282,6 +286,7 @@ public virtual EntityFrameworkServicesBuilder TryAddCoreServices() TryAdd(); TryAdd(); TryAdd(); + TryAdd(); ServiceCollectionMap .TryAddSingleton(new DiagnosticListener(DbLoggerCategory.Name)); diff --git a/src/EFCore/Internal/DbFunctionSourceFactory.cs b/src/EFCore/Internal/DbFunctionSourceFactory.cs new file mode 100644 index 00000000000..50bbcd1be36 --- /dev/null +++ b/src/EFCore/Internal/DbFunctionSourceFactory.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Text; +using Microsoft.EntityFrameworkCore.Metadata; + +namespace Microsoft.EntityFrameworkCore.Internal +{ + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public class DbFunctionSourceFactory : IDbFunctionSourceFactory + { + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public Expression GenerateDbFunctionSource(MethodCallExpression methodCall, IModel model) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/EFCore/Internal/IDbFunctionSourceFactory.cs b/src/EFCore/Internal/IDbFunctionSourceFactory.cs new file mode 100644 index 00000000000..1d1658fedc3 --- /dev/null +++ b/src/EFCore/Internal/IDbFunctionSourceFactory.cs @@ -0,0 +1,18 @@ +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Metadata; + +namespace Microsoft.EntityFrameworkCore.Internal +{ + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public interface IDbFunctionSourceFactory + { + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + Expression GenerateDbFunctionSource(MethodCallExpression methodCall, IModel model); + } +} diff --git a/src/EFCore/Query/Internal/QueryModelGenerator.cs b/src/EFCore/Query/Internal/QueryModelGenerator.cs index dc9ff653e71..2925861e60a 100644 --- a/src/EFCore/Query/Internal/QueryModelGenerator.cs +++ b/src/EFCore/Query/Internal/QueryModelGenerator.cs @@ -23,6 +23,7 @@ public class QueryModelGenerator : IQueryModelGenerator { private readonly INodeTypeProvider _nodeTypeProvider; private readonly IEvaluatableExpressionFilter _evaluatableExpressionFilter; + private readonly IExpressionTranformationProvider _expressionTranformationProvider; /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used @@ -30,13 +31,16 @@ public class QueryModelGenerator : IQueryModelGenerator /// public QueryModelGenerator( [NotNull] INodeTypeProviderFactory nodeTypeProviderFactory, - [NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter) + [NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter, + [NotNull] IExpressionTranformationProvider expressionTranformationProvider) { Check.NotNull(nodeTypeProviderFactory, nameof(nodeTypeProviderFactory)); Check.NotNull(evaluatableExpressionFilter, nameof(evaluatableExpressionFilter)); + Check.NotNull(expressionTranformationProvider, nameof(expressionTranformationProvider)); _nodeTypeProvider = nodeTypeProviderFactory.Create(); _evaluatableExpressionFilter = evaluatableExpressionFilter; + _expressionTranformationProvider = expressionTranformationProvider; } /// @@ -81,7 +85,7 @@ private QueryParser CreateQueryParser(INodeTypeProvider nodeTypeProvider) new IExpressionTreeProcessor[] { new PartialEvaluatingExpressionTreeProcessor(_evaluatableExpressionFilter), - new TransformingExpressionTreeProcessor(ExpressionTransformerRegistry.CreateDefault()) + new TransformingExpressionTreeProcessor(_expressionTranformationProvider) }))); } } diff --git a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs index 40b7c299473..ae3a75e3ef4 100644 --- a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs +++ b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Linq; using System.Linq.Expressions; using System.Reflection; using Microsoft.EntityFrameworkCore.Infrastructure; @@ -161,6 +162,7 @@ public static int DuplicateNameTest() public static MethodInfo MethodAmi = typeof(TestMethods).GetRuntimeMethod(nameof(TestMethods.MethodA), new[] { typeof(string), typeof(int) }); public static MethodInfo MethodBmi = typeof(TestMethods).GetRuntimeMethod(nameof(TestMethods.MethodB), new[] { typeof(string), typeof(int) }); public static MethodInfo MethodHmi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodH)); + public static MethodInfo MethodImi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodI)); public class TestMethods { @@ -195,6 +197,11 @@ public static int MethodH(T a, string b) { throw new Exception(); } + + public static IQueryable MethodI() + { + throw new Exception(); + } } [Fact] @@ -477,6 +484,16 @@ public virtual void Set_empty_function_name_throws() Assert.Equal(expectedMessage, Assert.Throws(() => modelBuilder.HasDbFunction(MethodAmi).HasName("")).Message); } + [Fact] + public virtual void Queryable_method_must_be_static() + { + var modelBuilder = GetModelBuilder(); + + var expectedMessage = RelationalStrings.DbFunctionQueryableNotStatic("TestMethods.MethodI"); + + Assert.Equal(expectedMessage, Assert.Throws(() => modelBuilder.HasDbFunction(MethodImi)).Message); + } + private ModelBuilder GetModelBuilder() { var conventionset = new ConventionSet(); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs index 06423b17486..b67f487b025 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs @@ -4,7 +4,9 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Query.Expressions; using Microsoft.EntityFrameworkCore.TestUtilities; using Microsoft.Extensions.DependencyInjection; @@ -39,9 +41,16 @@ public class Order { public int Id { get; set; } public string Name { get; set; } - public int ItemCount { get; set; } + public int QuantitySold { get; set; } public DateTime OrderDate { get; set; } public Customer Customer { get; set; } + public Product Product { get; set; } + } + + public class Product + { + public int Id { get; set; } + public string Name { get; set; } } protected class UDFSqlContext : DbContext @@ -50,11 +59,14 @@ protected class UDFSqlContext : DbContext public DbSet Customers { get; set; } public DbSet Orders { get; set; } + public DbSet Products { get; set; } #endregion #region Function Stubs + #region Static Functions + public enum ReportingPeriod { Winter = 0, @@ -100,6 +112,8 @@ public static int CustomerOrderCountWithClientStatic(int customerId) return 1; case 4: return 0; + case 5: + return 0; default: throw new Exception(); } @@ -162,6 +176,8 @@ public int CustomerOrderCountWithClientInstance(int customerId) return 1; case 4: return 0; + case 5: + return 0; default: throw new Exception(); } @@ -198,6 +214,59 @@ public static string IdentityString(string s) throw new NotImplementedException(); } + public string SCHEMA_NAME() + { + //TODO - how to remove the generic params here? + return ExecuteScalarMethod(db => db.SCHEMA_NAME()); + } + + public int AddValues(int a, int b) + { + return ExecuteScalarMethod(db => db.AddValues(a, b)); + } + + public int AddValues(Expression> a, int b) + { + return ExecuteScalarMethod(db => db.AddValues(a, b)); + } + + #endregion + + #region Table Functions + + public class OrderByYear + { + public int? CustomerId { get; set; } + public int? Count { get; set; } + public int? Year { get; set; } + } + + public IQueryable GetCustomerOrderCountByYear(int customerId) + { + // return ExecuteTableValuedFunction(typeof(SqlServerDbFunctionsNorthwindContext).GetTypeInfo().GetDeclaredMethod(nameof(FindReportsForManager))); + + // return ExecuteTableValuedFunction(GetType().GetMethod(nameof(GetCustomerOrderCountByYear)), customerId ); + return ExecuteTableValuedFunction(db => db.GetCustomerOrderCountByYear(customerId)); + } + + public IQueryable GetCustomerOrderCountByYear(Expression> customerId) + { + return ExecuteTableValuedFunction(db => db.GetCustomerOrderCountByYear(customerId)); + } + + public class TopSellingProduct + { + public int? ProductId { get; set; } + public int? AmountSold { get; set; } + } + + public IQueryable GetTopTwoSellingProducts() + { + return ExecuteTableValuedFunction(db => db.GetTopTwoSellingProducts()); + } + + #endregion + #endregion public UDFSqlContext(DbContextOptions options) @@ -232,10 +301,21 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(DollarValueInstance))).HasName("DollarValue"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(AddValues), new[] { typeof(int), typeof(int) })); + // modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(AddValues), new[] { typeof(Expression>), typeof(int) })); + + var methodInfo2 = typeof(UDFSqlContext).GetMethod(nameof(MyCustomLengthInstance)); modelBuilder.HasDbFunction(methodInfo2) .HasTranslation(args => new SqlFunctionExpression("len", methodInfo2.ReturnType, args)); + + //Bootstrap + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(SCHEMA_NAME))).HasName("SCHEMA_NAME").HasSchema(""); + + //Table + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerOrderCountByYear), new[] { typeof(int) })); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetTopTwoSellingProducts))); } } @@ -250,7 +330,7 @@ private void Scalar_Function_Extension_Method_Static() { var len = context.Customers.Count(c => UDFSqlContext.IsDateStatic(c.FirstName) == false); - Assert.Equal(3, len); + Assert.Equal(4, len); AssertSql( @"SELECT COUNT(*) @@ -308,7 +388,7 @@ public void Scalar_Function_Constant_Parameter_Static() var custs = context.Customers.Select(c => UDFSqlContext.CustomerOrderCountStatic(customerId)).ToList(); - Assert.Equal(3, custs.Count); + Assert.Equal(4, custs.Count); AssertSql( @"@__customerId_0='1' @@ -630,7 +710,7 @@ public void Scalar_Nested_Function_Unwind_Client_Eval_Where_Static() } [Fact] - public void Scalar_Nested__Function_Unwind_Client_Eval_OrderBy_Static() + public void Scalar_Nested_Function_Unwind_Client_Eval_OrderBy_Static() { using (var context = CreateContext()) { @@ -638,8 +718,8 @@ public void Scalar_Nested__Function_Unwind_Client_Eval_OrderBy_Static() orderby UDFSqlContext.AddOneStatic(c.Id) select c.Id).ToList(); - Assert.Equal(3, results.Count); - Assert.True(results.SequenceEqual(Enumerable.Range(1, 3))); + Assert.Equal(4, results.Count); + Assert.True(results.SequenceEqual(Enumerable.Range(1, 4))); AssertSql( @"SELECT [c].[Id] @@ -656,8 +736,8 @@ public void Scalar_Nested_Function_Unwind_Client_Eval_Select_Static() orderby c.Id select UDFSqlContext.AddOneStatic(c.Id)).ToList(); - Assert.Equal(3, results.Count); - Assert.True(results.SequenceEqual(Enumerable.Range(2, 3))); + Assert.Equal(4, results.Count); + Assert.True(results.SequenceEqual(Enumerable.Range(2, 4))); AssertSql( @"SELECT [c].[Id] @@ -914,7 +994,7 @@ private void Scalar_Function_Extension_Method_Instance() { var len = context.Customers.Count(c => context.IsDateInstance(c.FirstName) == false); - Assert.Equal(3, len); + Assert.Equal(4, len); AssertSql( @"SELECT COUNT(*) @@ -972,7 +1052,7 @@ public void Scalar_Function_Constant_Parameter_Instance() var custs = context.Customers.Select(c => context.CustomerOrderCountInstance(customerId)).ToList(); - Assert.Equal(3, custs.Count); + Assert.Equal(4, custs.Count); AssertSql( @"@__customerId_1='1' @@ -1294,7 +1374,7 @@ public void Scalar_Nested_Function_Unwind_Client_Eval_Where_Instance() } [Fact] - public void Scalar_Nested__Function_Unwind_Client_Eval_OrderBy_Instance() + public void Scalar_Nested_Function_Unwind_Client_Eval_OrderBy_Instance() { using (var context = CreateContext()) { @@ -1302,8 +1382,8 @@ public void Scalar_Nested__Function_Unwind_Client_Eval_OrderBy_Instance() orderby context.AddOneInstance(c.Id) select c.Id).ToList(); - Assert.Equal(3, results.Count); - Assert.True(results.SequenceEqual(Enumerable.Range(1, 3))); + Assert.Equal(4, results.Count); + Assert.True(results.SequenceEqual(Enumerable.Range(1, 4))); AssertSql( @"SELECT [c].[Id] @@ -1320,8 +1400,8 @@ public void Scalar_Nested_Function_Unwind_Client_Eval_Select_Instance() orderby c.Id select context.AddOneInstance(c.Id)).ToList(); - Assert.Equal(3, results.Count); - Assert.True(results.SequenceEqual(Enumerable.Range(2, 3))); + Assert.Equal(4, results.Count); + Assert.True(results.SequenceEqual(Enumerable.Range(2, 4))); AssertSql( @"SELECT [c].[Id] @@ -1527,6 +1607,564 @@ FROM [Customers] AS [c] #endregion + #region BootStrap + + [Fact] + public void BootstrapScalarNoParams() + { + using (var context = CreateContext()) + { + var schame = context.SCHEMA_NAME(); + + Assert.Equal("dbo", schame); + + AssertSql( + @"SELECT SCHEMA_NAME()"); + } + } + + [Fact] + public void BootstrapScalarParams() + { + using (var context = CreateContext()) + { + var value = context.AddValues(1, 2); + + Assert.Equal(3, value); + + AssertSql(@"@__a_0='1' +@__b_1='2' + +SELECT [dbo].[AddValues](@__a_0, @__b_1)"); + } + } + + [Fact] + public void BootstrapScalarFuncParams() + { + using (var context = CreateContext()) + { + // var a1 = context.AddValues(1, 2); + // var a = context.AddValues(() => 1, 2); + //var x = 5; + //var value = context.AddValues(() => context.AddValues(x, 2), 2); + + var value = context.AddValues(() => context.AddValues(1, 2), 2); + + Assert.Equal(5, value); + + AssertSql(@"@__b_1='2' + +SELECT [dbo].[AddValues]([dbo].[AddValues](1, 2), @__b_1)"); + } + } + + [Fact] + public void BootstrapScalarFuncParamsWithVariable() + { + using (var context = CreateContext()) + { + var x = 5; + var value = context.AddValues(() => context.AddValues(x, 2), 2); + + Assert.Equal(9, value); + + AssertSql(@"@__x_1='5' +@__b_2='2' + +SELECT [dbo].[AddValues]([dbo].[AddValues](@__x_1, 2), @__b_2)"); + } + } + + [Fact] + public void BootstrapScalarFuncParamsConstant() + { + using (var context = CreateContext()) + { + var value = context.AddValues(() => 1, 2); + + Assert.Equal(3, value); + + AssertSql(@"@__b_0='2' + +SELECT [dbo].[AddValues](1, @__b_0)"); + } + } + #endregion + + #endregion + + #region Table Valued Tests + + [Fact] + public void TV_Function_Stand_Alone() + { + using (var context = CreateContext()) + { + var products = (from t in context.GetTopTwoSellingProducts() + orderby t.ProductId + select t).ToList(); + + Assert.Equal(2, products.Count); + Assert.Equal(1, products[0].ProductId); + Assert.Equal(27, products[0].AmountSold); + Assert.Equal(2, products[1].ProductId); + Assert.Equal(50, products[1].AmountSold); + + AssertSql(@"SELECT [t].[AmountSold], [t].[ProductId] +FROM [dbo].[GetTopTwoSellingProducts]() AS [t] +ORDER BY [t].[ProductId]"); + } + } + + [Fact] + public void TV_Function_Stand_Alone_Parameter() + { + using (var context = CreateContext()) + { + var orders = (from c in context.GetCustomerOrderCountByYear(1) + orderby c.Count descending + select c).ToList(); + + Assert.Equal(2, orders.Count); + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(1, orders[1].Count); + Assert.Equal(2001, orders[1].Year); + + AssertSql(@"@__customerId_0='1' + +SELECT [c].[Count], [c].[CustomerId], [c].[Year] +FROM [dbo].[GetCustomerOrderCountByYear](@__customerId_0) AS [c] +ORDER BY [c].[Count] DESC"); + } + } + + [Fact] + public void TV_Function_Stand_Alone_Nested() + { + using (var context = CreateContext()) + { + var orders = (from r in context.GetCustomerOrderCountByYear(() => context.AddValues(-2, 3)) + orderby r.Count descending + select r).ToList(); + + Assert.Equal(2, orders.Count); + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(1, orders[1].Count); + Assert.Equal(2001, orders[1].Year); + + AssertSql(@"SELECT [r].[Count], [r].[CustomerId], [r].[Year] +FROM [dbo].[GetCustomerOrderCountByYear]([dbo].[AddValues](-2, 3)) AS [r] +ORDER BY [r].[Count] DESC"); + } + } + + [Fact] + public void TV_Function_CrossApply_Correlated_Select_Anonymous() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(c.Id) + orderby c.Id, r.Year + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList(); + + Assert.Equal(4, orders.Count); + Assert.Equal(2, orders[0].Count); + Assert.Equal(1, orders[1].Count); + Assert.Equal(2, orders[2].Count); + Assert.Equal(1, orders[3].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(2001, orders[1].Year); + Assert.Equal(2000, orders[2].Year); + Assert.Equal(2001, orders[3].Year); + Assert.Equal(1, orders[0].Id); + Assert.Equal(1, orders[1].Id); + Assert.Equal(2, orders[2].Id); + Assert.Equal(3, orders[3].Id); + + AssertSql(@"SELECT [c].[Id], [c].[LastName], [r].[Year], [r].[Count] +FROM [Customers] AS [c] +CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [r] +ORDER BY [c].[Id], [r].[Year]"); + } + } + + [Fact] + public void Table_Function_CrossApply_Correlated_Select_Result() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(c.Id) + orderby r.Count descending, r.Year descending + select r).ToList(); + + Assert.Equal(4, orders.Count); + + Assert.Equal(4, orders.Count); + Assert.Equal(2, orders[0].Count); + Assert.Equal(2, orders[1].Count); + Assert.Equal(1, orders[2].Count); + Assert.Equal(1, orders[3].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(2000, orders[1].Year); + Assert.Equal(2001, orders[2].Year); + Assert.Equal(2001, orders[3].Year); + + AssertSql(@"SELECT [r].[Count], [r].[CustomerId], [r].[Year] +FROM [Customers] AS [c] +CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [r] +ORDER BY [r].[Count] DESC, [r].[Year] DESC"); + } + } + + [Fact] + public void TV_Function_CrossJoin_Not_Correlated() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(2) + where c.Id == 2 + orderby r.Count + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList(); + + Assert.Equal(1, orders.Count); + + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + + AssertSql(@"SELECT [c].[Id], [c].[LastName], [r].[Year], [r].[Count] +FROM [Customers] AS [c] +CROSS JOIN [dbo].[GetCustomerOrderCountByYear](2) AS [r] +WHERE [c].[Id] = 2 +ORDER BY [r].[Count]"); + } + } + + [Fact] + public void TV_Function_CrossJoin_Parameter() + { + using (var context = CreateContext()) + { + var custId = 2; + + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(custId) + where c.Id == custId + orderby r.Count + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList(); + + Assert.Equal(1, orders.Count); + + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + + AssertSql(@"@__custId_1='2' +@__custId_2='2' + +SELECT [c].[Id], [c].[LastName], [r].[Year], [r].[Count] +FROM [Customers] AS [c] +CROSS JOIN [dbo].[GetCustomerOrderCountByYear](@__custId_1) AS [r] +WHERE [c].[Id] = @__custId_2 +ORDER BY [r].[Count]"); + } + } + + [Fact] + public void TV_Function_Join() + { + //performing a join requires the method to have a body which calls ExecuteTableValuedFunction + using (var context = CreateContext()) + { + var products = (from p in context.Products + join r in context.GetTopTwoSellingProducts() on p.Id equals r.ProductId + select new + { + p.Id, + p.Name, + r.AmountSold + }).OrderBy(p => p.Id).ToList(); + + Assert.Equal(2, products.Count); + Assert.Equal(1, products[0].Id); + Assert.Equal("Product1", products[0].Name); + Assert.Equal(27, products[0].AmountSold); + Assert.Equal(2, products[1].Id); + Assert.Equal("Product2", products[1].Name); + Assert.Equal(50, products[1].AmountSold); + + AssertSql(@"SELECT [p].[Id], [p].[Name], [r].[AmountSold] +FROM [Products] AS [p] +INNER JOIN [dbo].[GetTopTwoSellingProducts]() AS [r] ON [p].[Id] = [r].[ProductId] +ORDER BY [p].[Id]"); + + } + } + + [Fact] + public void TV_Function_LeftJoin_Select_Anonymous() + { + //performing a join requires the method to have a body which calls ExecuteTableValuedFunction + using (var context = CreateContext()) + { + var products = (from p in context.Products + join r in context.GetTopTwoSellingProducts() on p.Id equals r.ProductId into joinTable + from j in joinTable.DefaultIfEmpty() + orderby p.Id descending + select new + { + p.Id, + p.Name, + j.AmountSold + }).ToList(); + + Assert.Equal(5, products.Count); + Assert.Equal(5, products[0].Id); + Assert.Equal(null, products[0].AmountSold); + Assert.Equal("Product5", products[0].Name); + Assert.Equal(4, products[1].Id); + Assert.Equal(null, products[1].AmountSold); + Assert.Equal("Product4", products[1].Name); + Assert.Equal(3, products[2].Id); + Assert.Equal(null, products[2].AmountSold); + Assert.Equal("Product3", products[2].Name); + Assert.Equal(2, products[3].Id); + Assert.Equal(50, products[3].AmountSold); + Assert.Equal("Product2", products[3].Name); + Assert.Equal(1, products[4].Id); + Assert.Equal(27, products[4].AmountSold); + Assert.Equal("Product1", products[4].Name); + + AssertSql(@"SELECT [p].[Id], [p].[Name], [r].[AmountSold] +FROM [Products] AS [p] +LEFT JOIN [dbo].[GetTopTwoSellingProducts]() AS [r] ON [p].[Id] = [r].[ProductId] +ORDER BY [p].[Id] DESC"); + } + } + + [Fact] + public void TV_Function_LeftJoin_Select_Result() + { + //performing a join requires the method to have a body which calls ExecuteTableValuedFunction + using (var context = CreateContext()) + { + var products = (from p in context.Products + join r in context.GetTopTwoSellingProducts() on p.Id equals r.ProductId into joinTable + from j in joinTable.DefaultIfEmpty() + orderby p.Id descending + select j).ToList(); + + Assert.Equal(5, products.Count); + Assert.Equal(null, products[0].ProductId); + Assert.Equal(null, products[0].AmountSold); + Assert.Equal(null, products[1].ProductId); + Assert.Equal(null, products[1].AmountSold); + Assert.Equal(null, products[2].ProductId); + Assert.Equal(null, products[2].AmountSold); + Assert.Equal(2, products[3].ProductId); + Assert.Equal(50, products[3].AmountSold); + Assert.Equal(1, products[4].ProductId); + Assert.Equal(27, products[4].AmountSold); + + AssertSql(@"SELECT [r].[AmountSold], [r].[ProductId] +FROM [Products] AS [p] +LEFT JOIN [dbo].[GetTopTwoSellingProducts]() AS [r] ON [p].[Id] = [r].[ProductId] +ORDER BY [p].[Id] DESC"); + + } + } + + [Fact] + public void TV_Function_OuterApply_Correlated_Select_TVF() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty() + orderby c.Id, r.Year + select r).ToList(); + + /* + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList();*/ + + Assert.Equal(5, orders.Count); + + Assert.Equal(2, orders[0].Count); + Assert.Equal(1, orders[1].Count); + Assert.Equal(2, orders[2].Count); + Assert.Equal(1, orders[3].Count); + Assert.Null(orders[4].Count); + Assert.Equal(2000, orders[0].Year); + Assert.Equal(2001, orders[1].Year); + Assert.Equal(2000, orders[2].Year); + Assert.Equal(2001, orders[3].Year); + Assert.Null(orders[4].Year); + Assert.Equal(1, orders[0].CustomerId); + Assert.Equal(1, orders[1].CustomerId); + Assert.Equal(2, orders[2].CustomerId); + Assert.Equal(3, orders[3].CustomerId); + Assert.Null(orders[4].CustomerId); + + AssertSql(@"SELECT [g].[Count], [g].[CustomerId], [g].[Year] +FROM [Customers] AS [c] +OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [g] +ORDER BY [c].[Id], [g].[Year]"); + } + } + + [Fact] + public void TV_Function_Nested() + { + //TODO - this is selecting too many employee columns in the query + using (var context = CreateContext()) + { + var custId = 2; + + var orders = (from c in context.Customers + from r in context.GetCustomerOrderCountByYear(context.AddValues(1, 1)) + where c.Id == custId + orderby r.Year + select new + { + c.Id, + c.LastName, + r.Year, + r.Count + }).ToList(); + + Assert.Equal(1, orders.Count); + + Assert.Equal(2, orders[0].Count); + Assert.Equal(2000, orders[0].Year); + + AssertSql(@"@__custId_2='2' + +SELECT [c].[Id], [c].[LastName], [r].[Year], [r].[Count] +FROM [Customers] AS [c] +CROSS JOIN [dbo].[GetCustomerOrderCountByYear]([dbo].[AddValues](1, 1)) AS [r] +WHERE [c].[Id] = @__custId_2 +ORDER BY [r].[Year]"); + } + } + + /* + [Fact] + public void TV_Function_Join_Nested() + { + using (var context = CreateContext()) + { + var products = (from p in context.Products + join r in context.GetTopThreeSellingProductsForYear(() => context.GetBestYearEver()) on p.ProductID equals r.ProductId + select new + { + p.ProductID, + p.ProductName, + r.AmountSold + }).ToList(); + + Assert.Equal(3, products.Count); + Assert.Equal(659, products[0].AmountSold); + Assert.Equal(546, products[1].AmountSold); + Assert.Equal(542, products[2].AmountSold); + Assert.Equal("Konbu", products[0].ProductName); + Assert.Equal("Guaraná Fantástica", products[1].ProductName); + Assert.Equal("Camembert Pierrot", products[2].ProductName); + } + } + + //TODO - test that throw exceptions when parameter type mismatch between c# definition and sql function (wrong names, types (when not convertable) etc) + + [Fact] + public void TV_Function_OuterApply_Correlated_Select_Anonymous() + { + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + from o in context.GetLatestNOrdersForCustomer(2, c.CustomerID).DefaultIfEmpty() + select new + { + c.CustomerID, + o.OrderId, + o.OrderDate + }).ToList(); + + Assert.Equal(179, orders.Count); + } + } + + [Fact] + public void CrossJoin() + { + using (var context = CreateContext()) + { + var foo = (from c in context.Customers + //from p in context.Products + //select new { c, p }).ToList(); + select c).ToList(); + } + } + + [Fact] + public void TV_Function_OuterApply_Correlated_Select_Result() + { + //TODO - currently fails because EF tries to change track the result item "o" which is null due to the outer apply + //resolve once we figure out what kind of Type TVF return types are + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + where c.CustomerID == "FISSA" || c.CustomerID == "BOLID" + from o in context.GetLatestNOrdersForCustomer(2, c.CustomerID).DefaultIfEmpty() + select new { c, o }).ToList(); + + Assert.Equal(3, orders.Count); + } + } + + /* [Fact] + public void LeftOuterJoin() + { + //TODO - currently fails because EF tries to change track the result item "o" which is null due to the outer apply + //resolve once we figure out what kind of Type TVF return types are + using (var context = CreateContext()) + { + var orders = (from c in context.Customers + where c.CustomerID == "FISSA" || c.CustomerID == "BOLID" + join o in context.Orders on c.CustomerID equals o.CustomerID into temp + from t in temp.DefaultIfEmpty() + select t).ToList(); + + Assert.Equal(832, orders.Count); + } + }*/ + #endregion public class SqlServerUDFFixture : SharedStoreFixtureBase @@ -1603,19 +2241,72 @@ returns nvarchar(max) return @customerName; end"); - var order11 = new Order { Name = "Order11", ItemCount = 4, OrderDate = new DateTime(2000, 1, 20) }; - var order12 = new Order { Name = "Order12", ItemCount = 8, OrderDate = new DateTime(2000, 2, 21) }; - var order13 = new Order { Name = "Order13", ItemCount = 15, OrderDate = new DateTime(2000, 3, 20) }; - var order21 = new Order { Name = "Order21", ItemCount = 16, OrderDate = new DateTime(2000, 4, 21) }; - var order22 = new Order { Name = "Order22", ItemCount = 23, OrderDate = new DateTime(2000, 5, 20) }; - var order31 = new Order { Name = "Order31", ItemCount = 42, OrderDate = new DateTime(2000, 6, 21) }; + context.Database.ExecuteSqlCommand(@"create function [dbo].GetCustomerOrderCountByYear(@customerId int) + returns @reports table + ( + CustomerId int not null, + Count int not null, + Year int not null + ) + as + begin + + insert into @reports + select @customerId, count(id), year(orderDate) + from orders + where customerId = @customerId + group by customerId, year(orderDate) + order by year(orderDate) + + return + end"); + + context.Database.ExecuteSqlCommand(@"create function [dbo].GetTopTwoSellingProducts() + returns @products table + ( + ProductId int not null, + AmountSold int + ) + as + begin + + insert into @products + select top 2 ProductID, sum(quantitySold) as totalSold + from orders + group by ProductID + order by totalSold desc + + return + end"); + + context.Database.ExecuteSqlCommand(@"create function [dbo].[AddValues] (@a int, @b int) + returns int + as + begin + return @a + @b; + end"); + + var product1 = new Product { Name = "Product1" }; + var product2 = new Product { Name = "Product2" }; + var product3 = new Product { Name = "Product3" }; + var product4 = new Product { Name = "Product4" }; + var product5 = new Product { Name = "Product5" }; + + var order11 = new Order { Name = "Order11", QuantitySold = 4, OrderDate = new DateTime(2000, 1, 20), Product = product1 }; + var order12 = new Order { Name = "Order12", QuantitySold = 8, OrderDate = new DateTime(2000, 2, 21), Product = product2 }; + var order13 = new Order { Name = "Order13", QuantitySold = 15, OrderDate = new DateTime(2001, 3, 20), Product = product3 }; + var order21 = new Order { Name = "Order21", QuantitySold = 16, OrderDate = new DateTime(2000, 4, 21), Product = product4 }; + var order22 = new Order { Name = "Order22", QuantitySold = 23, OrderDate = new DateTime(2000, 5, 20), Product = product1 }; + var order31 = new Order { Name = "Order31", QuantitySold = 42, OrderDate = new DateTime(2001, 6, 21), Product = product2 }; var customer1 = new Customer { FirstName = "Customer", LastName = "One", Orders = new List { order11, order12, order13 } }; var customer2 = new Customer { FirstName = "Customer", LastName = "Two", Orders = new List { order21, order22 } }; var customer3 = new Customer { FirstName = "Customer", LastName = "Three", Orders = new List { order31 } }; + var customer4 = new Customer { FirstName = "Customer", LastName = "Four" }; - ((UDFSqlContext)context).Customers.AddRange(customer1, customer2, customer3); + ((UDFSqlContext)context).Customers.AddRange(customer1, customer2, customer3, customer4); ((UDFSqlContext)context).Orders.AddRange(order11, order12, order13, order21, order22, order31); + ((UDFSqlContext)context).Products.AddRange(product1, product2, product3, product4, product5); context.SaveChanges(); } }