Skip to content

Commit

Permalink
Add support for table valued functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pmiddleton committed Mar 14, 2018
1 parent 034fb68 commit 8f8c35b
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 66 deletions.
5 changes: 3 additions & 2 deletions src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1491,8 +1491,9 @@ var joinExpression
= correlated
? QueryCompilationContext.IsLateralJoinOuterSupported
&& innerShapedQuery?.Method.MethodIsClosedFormOf(LinqOperatorProvider.DefaultIfEmpty) == true
&& innerSelectExpression.Tables.First() is SelectExpression s
&& s.Tables.First() is TableValuedSqlFunctionExpression
&& ((innerSelectExpression.Tables.First() is SelectExpression s
&& s.Tables.First() is TableValuedSqlFunctionExpression)
|| innerSelectExpression.Tables.First() is TableValuedSqlFunctionExpression)
? outerSelectExpression.AddCrossJoinLateralOuter(
innerSelectExpression.Tables.First(),
innerSelectExpression.Projection)
Expand Down
14 changes: 9 additions & 5 deletions src/EFCore/DbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1421,9 +1421,11 @@ public virtual Task<TEntity> FindAsync<TEntity>([CanBeNull] object[] keyValues,
/// <typeparam name="U">todo</typeparam>
/// <param name="dbFuncCall">todo</param>
/// <returns>todo</returns>
protected virtual T ExecuteScalarMethod<U, T>(Expression<Func<U, T>> dbFuncCall)
protected virtual T ExecuteScalarMethod<U, T>([NotNull] Expression<Func<U, T>> dbFuncCall)
where U : DbContext
{
Check.NotNull(dbFuncCall, nameof(dbFuncCall));

//todo - verify dbFuncCall contains a method call expression
var dbFuncFac = InternalServiceProvider.GetRequiredService<IDbFunctionSourceFactory>();
var resultsQuery = DbContextDependencies.QueryProvider.Execute(dbFuncFac.GenerateDbFunctionSource(dbFuncCall.Body as MethodCallExpression, Model)) as IEnumerable<T>;
Expand All @@ -1443,9 +1445,11 @@ protected virtual T ExecuteScalarMethod<U, T>(Expression<Func<U, T>> dbFuncCall)
/// <typeparam name="T">todo</typeparam>
/// <param name="dbFuncCall">todo</param>
/// <returns>todo</returns>
protected IQueryable<T> ExecuteTableValuedFunction<U, T>(Expression<Func<U, IQueryable<T>>> dbFuncCall)
protected virtual IQueryable<T> ExecuteTableValuedFunction<U, T>([NotNull] Expression<Func<U, IQueryable<T>>> dbFuncCall)
where U : DbContext
{
Check.NotNull(dbFuncCall, nameof(dbFuncCall));

var dbFuncFac = InternalServiceProvider.GetRequiredService<IDbFunctionSourceFactory>();

//todo - verify dbFuncCall contains a method call expression
Expand All @@ -1454,14 +1458,14 @@ protected IQueryable<T> ExecuteTableValuedFunction<U, T>(Expression<Func<U, IQue
return DbContextDependencies.QueryProvider.CreateQuery<T>(resultsQuery);
}

/// <summary>
/* /// <summary>
/// todo
/// </summary>
/// <typeparam name="T">todo</typeparam>
/// <param name="callingMethod">todo</param>
/// <param name="methodParams">todo</param>
/// <returns>todo</returns>
protected IQueryable<T> ExecuteTableValuedFunction<T>(MethodInfo callingMethod, params object[] methodParams)
protected IQueryable<T> ExecuteTableValuedFunction<T>([NotNull] MethodInfo callingMethod, params object[] methodParams)
{
var c = Expression.Call(Expression.Constant(this),
callingMethod,
Expand Down Expand Up @@ -1489,7 +1493,7 @@ protected IQueryable<T> ExecuteTableValuedFunction<T>(MethodInfo callingMethod,
Expression.Constant(this),
callingMethod,
paramExps));*/
}
// }

#region Hidden System.Object members

Expand Down
3 changes: 2 additions & 1 deletion src/EFCore/Internal/DbFunctionSourceFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Text;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata;

namespace Microsoft.EntityFrameworkCore.Internal
Expand All @@ -16,7 +17,7 @@ public class DbFunctionSourceFactory : IDbFunctionSourceFactory
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public Expression GenerateDbFunctionSource(MethodCallExpression methodCall, IModel model)
public virtual Expression GenerateDbFunctionSource(MethodCallExpression methodCall, IModel model)
{
throw new NotImplementedException();
}
Expand Down
3 changes: 2 additions & 1 deletion src/EFCore/Internal/IDbFunctionSourceFactory.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata;

namespace Microsoft.EntityFrameworkCore.Internal
Expand All @@ -13,6 +14,6 @@ public interface IDbFunctionSourceFactory
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
Expression GenerateDbFunctionSource(MethodCallExpression methodCall, IModel model);
Expression GenerateDbFunctionSource([NotNull] MethodCallExpression methodCall, [NotNull] IModel model);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1811,7 +1811,75 @@ CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [r]
}

[Fact]
public void Table_Function_CrossApply_Correlated_Select_Result()
public void TV_Function_Select_In_Anonymous()
{
using (var context = CreateContext())
{
var cust = (from c in context.Customers
orderby c.Id
select new
{
// c.Id,
Prods = context.GetTopTwoSellingProducts().ToList()
}).ToList();
}
}

[Fact]
public void Apauldevtest()
{
using (var context = CreateContext())
{
var cust = (from c in context.Customers
select new
{
OrderCount = c.Orders.ToList()
}).ToList();

AssertSql(
@"SELECT TOP(2) [c].[LastName], [dbo].[CustomerOrderCount]([c].[Id]) AS [OrderCount]
FROM [Customers] AS [c]
WHERE [c].[Id] = 1");
}
}

[Fact]
public void TV_Function_Correlated_Select_In_Anonymous()
{
using (var context = CreateContext())
{
var cust = (from c in context.Customers
orderby c.Id
select new
{
c.Id,
c.LastName,
Orders = context.GetCustomerOrderCountByYear(c.Id).ToList()
}).ToList();

Assert.Equal(4, cust.Count);
Assert.Equal(2, cust[0].Orders[0].Count);
Assert.Equal(1, cust[0].Orders[0].Count);
Assert.Equal(2, cust[0].Orders[0].Count);
Assert.Equal(1, cust[0].Orders[0].Count);
Assert.Equal("One", cust[0].LastName);
Assert.Equal("Two", cust[1].LastName);
Assert.Equal("Three", cust[2].LastName);
Assert.Equal("Four", cust[3].LastName);
Assert.Equal(1, cust[0].Id);
Assert.Equal(1, cust[1].Id);
Assert.Equal(2, cust[2].Id);
Assert.Equal(3, cust[3].Id);

AssertSql(@"SELECT [c].[Id], [c].[LastName], [r].[Year], [r].[Count]
FROM [Customers] AS [c]
CROSS APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [r]
ORDER BY [c].[Id], [r].[Year]");
}
}

[Fact]
public void TV_Function_CrossApply_Correlated_Select_Result()
{
using (var context = CreateContext())
{
Expand Down Expand Up @@ -2018,15 +2086,6 @@ from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty()
orderby c.Id, r.Year
select r).ToList();

/*
select new
{
c.Id,
c.LastName,
r.Year,
r.Count
}).ToList();*/

Assert.Equal(5, orders.Count);

Assert.Equal(2, orders[0].Count);
Expand All @@ -2052,6 +2111,120 @@ OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [g]
}
}

[Fact]
public void TV_Function_OuterApply_Correlated_Select_DbSet()
{
using (var context = CreateContext())
{
var custs = (from c in context.Customers
from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty()
orderby c.Id, r.Year
select c).ToList();

Assert.Equal(5, custs.Count);

Assert.Equal(1, custs[0].Id);
Assert.Equal(1, custs[1].Id);
Assert.Equal(2, custs[2].Id);
Assert.Equal(3, custs[3].Id);
Assert.Equal(4, custs[4].Id);
Assert.Equal("One", custs[0].LastName);
Assert.Equal("One", custs[1].LastName);
Assert.Equal("Two", custs[2].LastName);
Assert.Equal("Three", custs[3].LastName);
Assert.Equal("Four", custs[4].LastName);

AssertSql(@"SELECT [c].[Id], [c].[FirstName], [c].[LastName]
FROM [Customers] AS [c]
OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [g]
ORDER BY [c].[Id], [g].[Year]");
}
}

[Fact]
public void TV_Function_OuterApply_Correlated_Select_Anonymous()
{
using (var context = CreateContext())
{
var orders = (from c in context.Customers
from r in context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty()
orderby c.Id, r.Year
select new
{
c.Id,
c.LastName,
r.Year,
r.Count
}).ToList();

Assert.Equal(5, orders.Count);

Assert.Equal(1, orders[0].Id);
Assert.Equal(1, orders[1].Id);
Assert.Equal(2, orders[2].Id);
Assert.Equal(3, orders[3].Id);
Assert.Equal(4, orders[4].Id);
Assert.Equal("One", orders[0].LastName);
Assert.Equal("One", orders[1].LastName);
Assert.Equal("Two", orders[2].LastName);
Assert.Equal("Three", orders[3].LastName);
Assert.Equal("Four", orders[4].LastName);
Assert.Equal(2, orders[0].Count);
Assert.Equal(1, orders[1].Count);
Assert.Equal(2, orders[2].Count);
Assert.Equal(1, orders[3].Count);
Assert.Null(orders[4].Count);
Assert.Equal(2000, orders[0].Year);
Assert.Equal(2001, orders[1].Year);
Assert.Equal(2000, orders[2].Year);
Assert.Equal(2001, orders[3].Year);

AssertSql(@"SELECT [c].[Id], [c].[LastName], [g].[Year], [g].[Count]
FROM [Customers] AS [c]
OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [g]
ORDER BY [c].[Id], [g].[Year]");
}
}

[Fact]
public void TV_Function_OuterApply_Correlated_Select_Anonymous_Has_Outer()
{
using (var context = CreateContext())
{
var cust = (from c in context.Customers
orderby c.Id
select new
{
c.Id,
c.LastName,
OrderCount = context.GetCustomerOrderCountByYear(c.Id).DefaultIfEmpty().Select(r => r.Count).ToList()


}).ToList();

Assert.Equal(4, cust.Count);

Assert.Equal(1, cust[0].Id);
Assert.Equal(2, cust[1].Id);
Assert.Equal(3, cust[2].Id);
Assert.Equal(4, cust[3].Id);
Assert.Equal("One", cust[0].LastName);
Assert.Equal("Two", cust[1].LastName);
Assert.Equal("Three", cust[2].LastName);
Assert.Equal("Four", cust[3].LastName);
Assert.Equal(2, cust[0].OrderCount[0]);
Assert.Equal(2, cust[1].OrderCount[0]);
Assert.Equal(1, cust[2].OrderCount[0]);
Assert.Null(cust[3].OrderCount);


AssertSql(@"SELECT [c].[Id], [c].[LastName], [g].[Year], [g].[Count]
FROM [Customers] AS [c]
OUTER APPLY [dbo].[GetCustomerOrderCountByYear]([c].[Id]) AS [g]
ORDER BY [c].[Id], [g].[Year]");
}
}

[Fact]
public void TV_Function_Nested()
{
Expand Down Expand Up @@ -2112,53 +2285,8 @@ join r in context.GetTopThreeSellingProductsForYear(() => context.GetBestYearEve
}
}
//TODO - test that throw exceptions when parameter type mismatch between c# definition and sql function (wrong names, types (when not convertable) etc)
[Fact]
public void TV_Function_OuterApply_Correlated_Select_Anonymous()
{
using (var context = CreateContext())
{
var orders = (from c in context.Customers
from o in context.GetLatestNOrdersForCustomer(2, c.CustomerID).DefaultIfEmpty()
select new
{
c.CustomerID,
o.OrderId,
o.OrderDate
}).ToList();
Assert.Equal(179, orders.Count);
}
}
[Fact]
public void CrossJoin()
{
using (var context = CreateContext())
{
var foo = (from c in context.Customers
//from p in context.Products
//select new { c, p }).ToList();
select c).ToList();
}
}
[Fact]
public void TV_Function_OuterApply_Correlated_Select_Result()
{
//TODO - currently fails because EF tries to change track the result item "o" which is null due to the outer apply
//resolve once we figure out what kind of Type TVF return types are
using (var context = CreateContext())
{
var orders = (from c in context.Customers
where c.CustomerID == "FISSA" || c.CustomerID == "BOLID"
from o in context.GetLatestNOrdersForCustomer(2, c.CustomerID).DefaultIfEmpty()
select new { c, o }).ToList();
Assert.Equal(3, orders.Count);
}
}
/* [Fact]
public void LeftOuterJoin()
Expand Down

0 comments on commit 8f8c35b

Please sign in to comment.