Skip to content

Commit

Permalink
Async queries
Browse files Browse the repository at this point in the history
  • Loading branch information
smitpatel committed Apr 27, 2019
1 parent 2b99804 commit 1dba152
Show file tree
Hide file tree
Showing 17 changed files with 327 additions and 429 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace Microsoft.EntityFrameworkCore.InMemory.Query.Pipeline
{
public class InMemoryShapedQueryCompilingExpressionVisitor : ShapedQueryCompilingExpressionVisitor
{
public InMemoryShapedQueryCompilingExpressionVisitor(IEntityMaterializerSource entityMaterializerSource, bool trackQueryResults)
: base(entityMaterializerSource, trackQueryResults)
public InMemoryShapedQueryCompilingExpressionVisitor(IEntityMaterializerSource entityMaterializerSource, bool trackQueryResults, bool async)
: base(entityMaterializerSource, trackQueryResults, async)
{
}

Expand Down Expand Up @@ -56,11 +56,6 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s
(InMemoryQueryExpression)shapedQueryExpression.QueryExpression)
.Visit(shaperLambda.Body);

newBody = ReplacingExpressionVisitor.Replace(
MoveNextMarker,
Expression.Assign(hasNextParameter, Expression.Call(enumeratorParameter, _enumeratorMoveNextMethodInfo)),
newBody);

newBody = ReplacingExpressionVisitor.Replace(
InMemoryQueryExpression.ValueBufferParameter,
Expression.MakeMemberAccess(enumeratorParameter, _enumeratorCurrent),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ public InMemoryShapedQueryCompilingExpressionVisitorFactory(IEntityMaterializerS

public ShapedQueryCompilingExpressionVisitor Create(QueryCompilationContext2 queryCompilationContext)
{
return new InMemoryShapedQueryCompilingExpressionVisitor(_entityMaterializerSource, queryCompilationContext.TrackQueryResults);
return new InMemoryShapedQueryCompilingExpressionVisitor(
_entityMaterializerSource,
queryCompilationContext.TrackQueryResults,
queryCompilationContext.Async);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public ShapedQueryCompilingExpressionVisitor Create(QueryCompilationContext2 que
return new RelationalShapedQueryCompilingExpressionVisitor(
_entityMaterializerSource,
_querySqlGeneratorFactory,
queryCompilationContext.TrackQueryResults);
queryCompilationContext.TrackQueryResults,
queryCompilationContext.Async);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
Expand All @@ -25,8 +27,9 @@ public class RelationalShapedQueryCompilingExpressionVisitor : ShapedQueryCompil
public RelationalShapedQueryCompilingExpressionVisitor(
IEntityMaterializerSource entityMaterializerSource,
IQuerySqlGeneratorFactory2 querySqlGeneratorFactory,
bool trackQueryResults)
: base(entityMaterializerSource, trackQueryResults)
bool trackQueryResults,
bool async)
: base(entityMaterializerSource, trackQueryResults, async)
{
_querySqlGeneratorFactory = querySqlGeneratorFactory;
}
Expand All @@ -37,29 +40,23 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s

var selectExpression = (SelectExpression)shapedQueryExpression.QueryExpression;

var hasNextParameter = Expression.Parameter(typeof(bool).MakeByRefType(), "hasNext");

var newBody = new RelationalProjectionBindingRemovingExpressionVisitor(selectExpression)
.Visit(shaperLambda.Body);

newBody = ReplacingExpressionVisitor.Replace(
MoveNextMarker,
Expression.Assign(hasNextParameter, Expression.Call(
RelationalProjectionBindingRemovingExpressionVisitor.DataReaderParameter, _dbDataReaderReadMethodInfo)),
newBody);
shaperLambda = Expression.Lambda(
newBody,
QueryCompilationContext2.QueryContextParameter,
RelationalProjectionBindingRemovingExpressionVisitor.DataReaderParameter);

shaperLambda = (LambdaExpression)_createLambdaMethodInfo.MakeGenericMethod(newBody.Type)
.Invoke(
null,
new object[]
{
newBody,
new [] {
QueryCompilationContext2.QueryContextParameter,
RelationalProjectionBindingRemovingExpressionVisitor.DataReaderParameter,
hasNextParameter
}
});
if (Async)
{
return Expression.New(
typeof(AsyncQueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType.GetGenericArguments().Single()).GetConstructors()[0],
Expression.Convert(QueryCompilationContext2.QueryContextParameter, typeof(RelationalQueryContext)),
Expression.Constant(_querySqlGeneratorFactory.Create()),
Expression.Constant(selectExpression),
Expression.Constant(shaperLambda.Compile()));
}

return Expression.New(
typeof(QueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0],
Expand All @@ -69,34 +66,108 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s
Expression.Constant(shaperLambda.Compile()));
}

private static readonly MethodInfo _dbDataReaderReadMethodInfo
= typeof(DbDataReader).GetTypeInfo()
.GetRuntimeMethod(nameof(DbDataReader.Read), new Type[] { });
private class AsyncQueryingEnumerable<T> : IAsyncEnumerable<T>
{
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, Task<T>> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;

private delegate T Shaper<T>(QueryContext queryContext, DbDataReader dataReader, out bool hasNext);
public AsyncQueryingEnumerable(RelationalQueryContext relationalQueryContext,
QuerySqlGenerator querySqlGenerator,
SelectExpression selectExpression,
Func<QueryContext, DbDataReader, Task<T>> shaper)
{
_relationalQueryContext = relationalQueryContext;
_querySqlGenerator = querySqlGenerator;
_selectExpression = selectExpression;
_shaper = shaper;
}

private static readonly MethodInfo _createLambdaMethodInfo
= typeof(RelationalShapedQueryCompilingExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(CreateLambda));
public IAsyncEnumerator<T> GetEnumerator()
{
return new AsyncEnumerator(this);
}

private static LambdaExpression CreateLambda<T>(Expression body, ParameterExpression[] parameters)
{
return Expression.Lambda<Shaper<T>>(
body,
parameters);
private sealed class AsyncEnumerator : IAsyncEnumerator<T>
{
private RelationalDataReader _dataReader;
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, Task<T>> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;

public AsyncEnumerator(AsyncQueryingEnumerable<T> queryingEnumerable)
{
_relationalQueryContext = queryingEnumerable._relationalQueryContext;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_querySqlGenerator = queryingEnumerable._querySqlGenerator;
}

public T Current { get; private set; }

public void Dispose()
{
_dataReader?.Dispose();
_dataReader = null;
_relationalQueryContext.Connection.Close();
}

public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
if (_dataReader == null)
{
await _relationalQueryContext.Connection.OpenAsync(cancellationToken);

try
{
var relationalCommand = _querySqlGenerator
.GetCommand(
_selectExpression,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_dataReader
= await relationalCommand.ExecuteReaderAsync(
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger,
cancellationToken);
}
catch
{
// If failure happens creating the data reader, then it won't be available to
// handle closing the connection, so do it explicitly here to preserve ref counting.
_relationalQueryContext.Connection.Close();

throw;
}
}

var hasNext = await _dataReader.ReadAsync(cancellationToken);

Current
= hasNext
? await _shaper(_relationalQueryContext, _dataReader.DbDataReader)
: default;

return hasNext;
}
}
}

private class QueryingEnumerable<T> : IEnumerable<T>
{
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Shaper<T> _shaper;
private readonly Func<QueryContext, DbDataReader, T> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;

public QueryingEnumerable(RelationalQueryContext relationalQueryContext,
QuerySqlGenerator querySqlGenerator,
SelectExpression selectExpression,
Shaper<T> shaper)
Func<QueryContext, DbDataReader, T> shaper)
{
_relationalQueryContext = relationalQueryContext;
_querySqlGenerator = querySqlGenerator;
Expand All @@ -110,10 +181,9 @@ public QueryingEnumerable(RelationalQueryContext relationalQueryContext,
private sealed class Enumerator : IEnumerator<T>
{
private RelationalDataReader _dataReader;
private bool _hasNext;
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Shaper<T> _shaper;
private readonly Func<QueryContext, DbDataReader, T> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;

public Enumerator(QueryingEnumerable<T> queryingEnumerable)
Expand Down Expand Up @@ -154,8 +224,6 @@ public bool MoveNext()
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_hasNext = _dataReader.Read();
}
catch
{
Expand All @@ -167,11 +235,11 @@ public bool MoveNext()
}
}

var hasNext = _hasNext;
var hasNext = _dataReader.Read();

Current
= hasNext
? _shaper(_relationalQueryContext, _dataReader.DbDataReader, out _hasNext)
? _shaper(_relationalQueryContext, _dataReader.DbDataReader)
: default;

return hasNext;
Expand Down
Loading

0 comments on commit 1dba152

Please sign in to comment.