Skip to content

Commit

Permalink
Query: Fix potential re-use of QuerySqlGen from multiple threads
Browse files Browse the repository at this point in the history
Due to internal state, it causes corruption of generated Sql command
The fix is to pass factory and generate SqlGen when we are enumerating.

Also fixed logging when connection opening failed in query
Fix build break
  • Loading branch information
smitpatel committed May 31, 2019
1 parent e9e44e8 commit b139c00
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ public bool MoveNext()
{
try
{

if (_enumerator == null)
{
_enumerator = _innerEnumerable.GetEnumerator();
Expand Down Expand Up @@ -261,7 +260,6 @@ public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
try
{

if (_enumerator == null)
{
_enumerator = _innerEnumerable.GetEnumerator();
Expand All @@ -274,7 +272,8 @@ public async Task<bool> MoveNext(CancellationToken cancellationToken)
: default;

return hasNext;
} catch (Exception exception)
}
catch (Exception exception)
{
_logger.QueryIterationFailed(_contextType, exception);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s
return Expression.New(
typeof(AsyncQueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType.GetGenericArguments().Single()).GetConstructors()[0],
Expression.Convert(QueryCompilationContext2.QueryContextParameter, typeof(RelationalQueryContext)),
Expression.Constant(_querySqlGeneratorFactory.Create()),
Expression.Constant(_querySqlGeneratorFactory),
Expression.Constant(selectExpression),
Expression.Constant(shaperLambda.Compile()),
Expression.Constant(_contextType),
Expand All @@ -74,7 +74,7 @@ protected override Expression VisitShapedQueryExpression(ShapedQueryExpression s
return Expression.New(
typeof(QueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0],
Expression.Convert(QueryCompilationContext2.QueryContextParameter, typeof(RelationalQueryContext)),
Expression.Constant(_querySqlGeneratorFactory.Create()),
Expression.Constant(_querySqlGeneratorFactory),
Expression.Constant(selectExpression),
Expression.Constant(shaperLambda.Compile()),
Expression.Constant(_contextType),
Expand Down Expand Up @@ -216,20 +216,20 @@ private class AsyncQueryingEnumerable<T> : IAsyncEnumerable<T>
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, Task<T>> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;
private readonly IQuerySqlGeneratorFactory2 _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

public AsyncQueryingEnumerable(
RelationalQueryContext relationalQueryContext,
QuerySqlGenerator querySqlGenerator,
IQuerySqlGeneratorFactory2 querySqlGeneratorFactory,
SelectExpression selectExpression,
Func<QueryContext, DbDataReader, Task<T>> shaper,
Type contextType,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
_relationalQueryContext = relationalQueryContext;
_querySqlGenerator = querySqlGenerator;
_querySqlGeneratorFactory = querySqlGeneratorFactory;
_selectExpression = selectExpression;
_shaper = shaper;
_contextType = contextType;
Expand All @@ -247,7 +247,7 @@ private sealed class AsyncEnumerator : IAsyncEnumerator<T>
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, Task<T>> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;
private readonly IQuerySqlGeneratorFactory2 _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

Expand All @@ -256,7 +256,7 @@ public AsyncEnumerator(AsyncQueryingEnumerable<T> queryingEnumerable)
_relationalQueryContext = queryingEnumerable._relationalQueryContext;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_querySqlGenerator = queryingEnumerable._querySqlGenerator;
_querySqlGeneratorFactory = queryingEnumerable._querySqlGeneratorFactory;
_contextType = queryingEnumerable._contextType;
_logger = queryingEnumerable._logger;
}
Expand All @@ -272,38 +272,37 @@ public void Dispose()

public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
if (_dataReader == null)
try
{
await _relationalQueryContext.Connection.OpenAsync(cancellationToken);

try
if (_dataReader == null)
{
var relationalCommand = _querySqlGenerator
.GetCommand(
_selectExpression,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_dataReader
= await relationalCommand.ExecuteReaderAsync(
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger,
cancellationToken);
await _relationalQueryContext.Connection.OpenAsync(cancellationToken);

try
{
var relationalCommand = _querySqlGeneratorFactory.Create()
.GetCommand(
_selectExpression,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_dataReader
= await relationalCommand.ExecuteReaderAsync(
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger,
cancellationToken);
}
catch (Exception)
{
// If failure happens creating the data reader, then it won't be available to
// handle closing the connection, so do it explicitly here to preserve ref counting.
_relationalQueryContext.Connection.Close();

throw;
}
}
catch (Exception exception)
{
_logger.QueryIterationFailed(_contextType, exception);
// If failure happens creating the data reader, then it won't be available to
// handle closing the connection, so do it explicitly here to preserve ref counting.
_relationalQueryContext.Connection.Close();

throw;
}
}

try
{
var hasNext = await _dataReader.ReadAsync(cancellationToken);

Current
Expand All @@ -315,7 +314,6 @@ public async Task<bool> MoveNext(CancellationToken cancellationToken)
}
catch (Exception exception)
{

_logger.QueryIterationFailed(_contextType, exception);

throw;
Expand All @@ -329,19 +327,19 @@ private class QueryingEnumerable<T> : IEnumerable<T>
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, T> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;
private readonly IQuerySqlGeneratorFactory2 _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

public QueryingEnumerable(RelationalQueryContext relationalQueryContext,
QuerySqlGenerator querySqlGenerator,
IQuerySqlGeneratorFactory2 querySqlGeneratorFactory,
SelectExpression selectExpression,
Func<QueryContext, DbDataReader, T> shaper,
Type contextType,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
_relationalQueryContext = relationalQueryContext;
_querySqlGenerator = querySqlGenerator;
_querySqlGeneratorFactory = querySqlGeneratorFactory;
_selectExpression = selectExpression;
_shaper = shaper;
_contextType = contextType;
Expand All @@ -357,7 +355,7 @@ private sealed class Enumerator : IEnumerator<T>
private readonly RelationalQueryContext _relationalQueryContext;
private readonly SelectExpression _selectExpression;
private readonly Func<QueryContext, DbDataReader, T> _shaper;
private readonly QuerySqlGenerator _querySqlGenerator;
private readonly IQuerySqlGeneratorFactory2 _querySqlGeneratorFactory;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

Expand All @@ -366,7 +364,7 @@ public Enumerator(QueryingEnumerable<T> queryingEnumerable)
_relationalQueryContext = queryingEnumerable._relationalQueryContext;
_shaper = queryingEnumerable._shaper;
_selectExpression = queryingEnumerable._selectExpression;
_querySqlGenerator = queryingEnumerable._querySqlGenerator;
_querySqlGeneratorFactory = queryingEnumerable._querySqlGeneratorFactory;
_contextType = queryingEnumerable._contextType;
_logger = queryingEnumerable._logger;
}
Expand All @@ -384,37 +382,36 @@ public void Dispose()

public bool MoveNext()
{
if (_dataReader == null)
try
{
_relationalQueryContext.Connection.Open();

try
if (_dataReader == null)
{
var relationalCommand = _querySqlGenerator
.GetCommand(
_selectExpression,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_dataReader
= relationalCommand.ExecuteReader(
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);
_relationalQueryContext.Connection.Open();

try
{
var relationalCommand = _querySqlGeneratorFactory.Create()
.GetCommand(
_selectExpression,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);

_dataReader
= relationalCommand.ExecuteReader(
_relationalQueryContext.Connection,
_relationalQueryContext.ParameterValues,
_relationalQueryContext.CommandLogger);
}
catch (Exception)
{
// If failure happens creating the data reader, then it won't be available to
// handle closing the connection, so do it explicitly here to preserve ref counting.
_relationalQueryContext.Connection.Close();

throw;
}
}
catch (Exception exception)
{
_logger.QueryIterationFailed(_contextType, exception);
// If failure happens creating the data reader, then it won't be available to
// handle closing the connection, so do it explicitly here to preserve ref counting.
_relationalQueryContext.Connection.Close();

throw;
}
}

try
{
var hasNext = _dataReader.Read();

Current
Expand All @@ -426,7 +423,6 @@ public bool MoveNext()
}
catch (Exception exception)
{

_logger.QueryIterationFailed(_contextType, exception);

throw;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ public override bool Equals(object obj)

private bool Equals(SqlConstantExpression sqlConstantExpression)
=> base.Equals(sqlConstantExpression)
&& Value?.Equals(sqlConstantExpression.Value) == true;
&& (Value == null
? sqlConstantExpression.Value == null
: Value.Equals(sqlConstantExpression.Value));

public override int GetHashCode()
{
Expand Down

0 comments on commit b139c00

Please sign in to comment.