Skip to content

Commit

Permalink
Query: Merge AsyncQueryingEnumerable into QueryingEnumerable
Browse files Browse the repository at this point in the history
It is just a wrapper to create enumerators. Enumerators can be created out of single Enumerable
  • Loading branch information
smitpatel committed Sep 30, 2019
1 parent c0dda49 commit 7da3faf
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 331 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Query;
using Newtonsoft.Json.Linq;
Expand All @@ -12,7 +14,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal
{
public partial class CosmosShapedQueryCompilingExpressionVisitor
{
private class QueryingEnumerable<T> : IEnumerable<T>
private class QueryingEnumerable<T> : IEnumerable<T>, IAsyncEnumerable<T>
{
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand All @@ -39,7 +41,8 @@ public QueryingEnumerable(
_contextType = contextType;
_logger = logger;
}

public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
=> new AsyncEnumerator(this, cancellationToken);
public IEnumerator<T> GetEnumerator() => new Enumerator(this);
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

Expand Down Expand Up @@ -116,6 +119,78 @@ public void Dispose()

public void Reset() => throw new NotImplementedException();
}

private sealed class AsyncEnumerator : IAsyncEnumerator<T>
{
private IAsyncEnumerator<JObject> _enumerator;
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, JObject, T> _shaper;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;
private readonly CancellationToken _cancellationToken;

public AsyncEnumerator(QueryingEnumerable<T> queryingEnumerable, CancellationToken cancellationToken)
{
_cosmosQueryContext = queryingEnumerable._cosmosQueryContext;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_sqlExpressionFactory = queryingEnumerable._sqlExpressionFactory;
_querySqlGeneratorFactory = queryingEnumerable._querySqlGeneratorFactory;
_contextType = queryingEnumerable._contextType;
_logger = queryingEnumerable._logger;
_cancellationToken = cancellationToken;
}

public T Current { get; private set; }

public async ValueTask<bool> MoveNextAsync()
{
try
{
using (_cosmosQueryContext.ConcurrencyDetector.EnterCriticalSection())
{
if (_enumerator == null)
{
var selectExpression = (SelectExpression)new InExpressionValuesExpandingExpressionVisitor(
_sqlExpressionFactory, _cosmosQueryContext.ParameterValues).Visit(_selectExpression);

_enumerator = _cosmosQueryContext.CosmosClient
.ExecuteSqlQueryAsync(
_selectExpression.Container,
_querySqlGeneratorFactory.Create().GetSqlQuery(
selectExpression, _cosmosQueryContext.ParameterValues))
.GetAsyncEnumerator(_cancellationToken);
}

var hasNext = await _enumerator.MoveNextAsync();

Current
= hasNext
? _shaper(_cosmosQueryContext, _enumerator.Current)
: default;

return hasNext;
}
}
catch (Exception exception)
{
_logger.QueryIterationFailed(_contextType, exception);

throw;
}
}

public ValueTask DisposeAsync()
{
_enumerator?.DisposeAsync();
_enumerator = null;

return default;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s
jObjectParameter);

return Expression.New(
(IsAsync
? typeof(AsyncQueryingEnumerable<>)
: typeof(QueryingEnumerable<>)).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0],
typeof(QueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0],
Expression.Convert(QueryCompilationContext.QueryContextParameter, typeof(CosmosQueryContext)),
Expression.Constant(_sqlExpressionFactory),
Expression.Constant(_querySqlGeneratorFactory),
Expand Down
Loading

0 comments on commit 7da3faf

Please sign in to comment.