Skip to content

Commit

Permalink
Microsoft.Data.Sqlite: Lazy all the things!
Browse files Browse the repository at this point in the history
Fixes #17271
  • Loading branch information
bricelam committed Mar 3, 2020
1 parent 24e7c3f commit fcc17e8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 38 deletions.
16 changes: 5 additions & 11 deletions src/Microsoft.Data.Sqlite.Core/SqliteCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ namespace Microsoft.Data.Sqlite
/// <seealso href="https://docs.microsoft.com/dotnet/standard/data/sqlite/async">Async Limitations</seealso>
public class SqliteCommand : DbCommand
{
private readonly Lazy<SqliteParameterCollection> _parameters = new Lazy<SqliteParameterCollection>(
() => new SqliteParameterCollection());
private SqliteParameterCollection _parameters;

private readonly List<sqlite3_stmt> _preparedStatements = new List<sqlite3_stmt>();
private SqliteConnection _connection;
Expand Down Expand Up @@ -165,7 +164,7 @@ protected override DbTransaction DbTransaction
/// <value>The collection of parameters used by the command.</value>
/// <seealso href="https://docs.microsoft.com/dotnet/standard/data/sqlite/parameters">Parameters</seealso>
public new virtual SqliteParameterCollection Parameters
=> _parameters.Value;
=> _parameters ??= new SqliteParameterCollection();

/// <summary>
/// Gets the collection of parameters used by the command.
Expand Down Expand Up @@ -326,12 +325,7 @@ private IEnumerable<sqlite3_stmt> GetStatements(Stopwatch timer)
? PrepareAndEnumerateStatements(timer)
: _preparedStatements)
{
var boundParams = 0;

if (_parameters.IsValueCreated)
{
boundParams = _parameters.Value.Bind(stmt);
}
var boundParams = _parameters?.Bind(stmt) ?? 0;

var expectedParams = sqlite3_bind_parameter_count(stmt);
if (expectedParams != boundParams)
Expand All @@ -341,8 +335,8 @@ private IEnumerable<sqlite3_stmt> GetStatements(Stopwatch timer)
{
var name = sqlite3_bind_parameter_name(stmt, i).utf8_to_string();

if (_parameters.IsValueCreated
&& !_parameters.Value.Cast<SqliteParameter>().Any(p => p.ParameterName == name))
if (_parameters != null
&& !_parameters.Cast<SqliteParameter>().Any(p => p.ParameterName == name))
{
unboundParams.Add(name);
}
Expand Down
51 changes: 31 additions & 20 deletions src/Microsoft.Data.Sqlite.Core/SqliteConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,14 @@ public partial class SqliteConnection : DbConnection

private readonly List<WeakReference<SqliteCommand>> _commands = new List<WeakReference<SqliteCommand>>();

private readonly Dictionary<string, (object state, strdelegate_collation collation)> _collations
= new Dictionary<string, (object, strdelegate_collation)>(StringComparer.OrdinalIgnoreCase);
private Dictionary<string, (object state, strdelegate_collation collation)> _collations;

private readonly Dictionary<(string name, int arity), (int flags, object state, delegate_function_scalar func)> _functions
= new Dictionary<(string, int), (int, object, delegate_function_scalar)>(FunctionsKeyComparer.Instance);
private Dictionary<(string name, int arity), (int flags, object state, delegate_function_scalar func)> _functions;

private readonly Dictionary<(string name, int arity), (int flags, object state, delegate_function_aggregate_step func_step,
delegate_function_aggregate_final func_final)> _aggregates
= new Dictionary<(string, int), (int, object, delegate_function_aggregate_step, delegate_function_aggregate_final)>(
FunctionsKeyComparer.Instance);
private Dictionary<(string name, int arity), (int flags, object state, delegate_function_aggregate_step func_step,
delegate_function_aggregate_final func_final)> _aggregates;

private readonly HashSet<(string file, string proc)> _extensions = new HashSet<(string, string)>();
private HashSet<(string file, string proc)> _extensions;

private string _connectionString;
private ConnectionState _state;
Expand Down Expand Up @@ -276,27 +272,37 @@ public override void Open()
this.ExecuteNonQuery("PRAGMA recursive_triggers = 1;");
}

foreach (var item in _collations)
if (_collations != null)
{
rc = sqlite3_create_collation(_db, item.Key, item.Value.state, item.Value.collation);
SqliteException.ThrowExceptionForRC(rc, _db);
foreach (var item in _collations)
{
rc = sqlite3_create_collation(_db, item.Key, item.Value.state, item.Value.collation);
SqliteException.ThrowExceptionForRC(rc, _db);
}
}

foreach (var item in _functions)
if (_functions != null)
{
rc = sqlite3_create_function(_db, item.Key.name, item.Key.arity, item.Value.state, item.Value.func);
SqliteException.ThrowExceptionForRC(rc, _db);
foreach (var item in _functions)
{
rc = sqlite3_create_function(_db, item.Key.name, item.Key.arity, item.Value.state, item.Value.func);
SqliteException.ThrowExceptionForRC(rc, _db);
}
}

foreach (var item in _aggregates)
if (_aggregates != null)
{
rc = sqlite3_create_function(
_db, item.Key.name, item.Key.arity, item.Value.state, item.Value.func_step, item.Value.func_final);
SqliteException.ThrowExceptionForRC(rc, _db);
foreach (var item in _aggregates)
{
rc = sqlite3_create_function(
_db, item.Key.name, item.Key.arity, item.Value.state, item.Value.func_step, item.Value.func_final);
SqliteException.ThrowExceptionForRC(rc, _db);
}
}

var extensionsEnabledForLoad = false;
if (_extensions.Count != 0)
if (_extensions != null
&& _extensions.Count != 0)
{
rc = sqlite3_enable_load_extension(_db, 1);
SqliteException.ThrowExceptionForRC(rc, _db);
Expand Down Expand Up @@ -449,6 +455,7 @@ public virtual void CreateCollation<T>(string name, T state, Func<T, string, str
SqliteException.ThrowExceptionForRC(rc, _db);
}

_collations ??= new Dictionary<string, (object, strdelegate_collation)>(StringComparer.OrdinalIgnoreCase);
_collations[name] = (state, collation);
}

Expand Down Expand Up @@ -546,6 +553,7 @@ public virtual void LoadExtension(string file, string proc = null)
}
}

_extensions ??= new HashSet<(string, string)>();
_extensions.Add((file, proc));
}

Expand Down Expand Up @@ -677,6 +685,7 @@ private void CreateFunctionCore<TState, TResult>(
SqliteException.ThrowExceptionForRC(rc, _db);
}

_functions ??= new Dictionary<(string, int), (int, object, delegate_function_scalar)>(FunctionsKeyComparer.Instance);
_functions[(name, arity)] = (flags, state, func);
}

Expand Down Expand Up @@ -771,6 +780,8 @@ private void CreateAggregateCore<TAccumulate, TResult>(
SqliteException.ThrowExceptionForRC(rc, _db);
}

_aggregates ??= new Dictionary<(string, int), (int, object, delegate_function_aggregate_step, delegate_function_aggregate_final)>(
FunctionsKeyComparer.Instance);
_aggregates[(name, arity)] = (flags, state, func_step, func_final);
}

Expand Down
17 changes: 10 additions & 7 deletions src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace Microsoft.Data.Sqlite
internal class SqliteDataRecord : SqliteValueReader, IDisposable
{
private readonly SqliteConnection _connection;
private readonly byte[][] _blobCache;
private readonly int?[] _typeCache;
private byte[][] _blobCache;
private int?[] _typeCache;
private bool _stepped;
private int? _rowidOrdinal;

Expand All @@ -25,8 +25,6 @@ public SqliteDataRecord(sqlite3_stmt stmt, bool hasRows, SqliteConnection connec
Handle = stmt;
HasRows = hasRows;
_connection = connection;
_blobCache = new byte[FieldCount][];
_typeCache = new int?[FieldCount];
}

public virtual object this[string name]
Expand Down Expand Up @@ -146,10 +144,11 @@ public virtual Type GetFieldType(int ordinal)
var sqliteType = GetSqliteType(ordinal);
if (sqliteType == SQLITE_NULL)
{
sqliteType = _typeCache[ordinal] ?? Sqlite3AffinityType(GetDataTypeName(ordinal));
sqliteType = _typeCache?[ordinal] ?? Sqlite3AffinityType(GetDataTypeName(ordinal));
}
else
{
_typeCache ??= new int?[FieldCount];
_typeCache[ordinal] = sqliteType;
}

Expand Down Expand Up @@ -317,7 +316,10 @@ public bool Read()
var rc = sqlite3_step(Handle);
SqliteException.ThrowExceptionForRC(rc, _connection.Handle);

Array.Clear(_blobCache, 0, _blobCache.Length);
if (_blobCache != null)
{
Array.Clear(_blobCache, 0, _blobCache.Length);
}

return rc != SQLITE_DONE;
}
Expand All @@ -334,10 +336,11 @@ private byte[] GetCachedBlob(int ordinal)
throw new ArgumentOutOfRangeException(nameof(ordinal), ordinal, message: null);
}

var blob = _blobCache[ordinal];
var blob = _blobCache?[ordinal];
if (blob == null)
{
blob = GetBlob(ordinal);
_blobCache ??= new byte[FieldCount][];
_blobCache[ordinal] = blob;
}

Expand Down

0 comments on commit fcc17e8

Please sign in to comment.