Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental mars IO task scheduler #1543

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,9 @@
<Reference Include="System.Collections.NonGeneric" />
<Reference Include="System.Memory" />
</ItemGroup>
<ItemGroup>
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITaskScheduler.cs" />
</ItemGroup>
<ItemGroup>
<EmbeddedResource Include="Resources\Microsoft.Data.SqlClient.SqlMetaData.xml">
<LogicalName>Microsoft.Data.SqlClient.SqlMetaData.xml</LogicalName>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
Expand All @@ -16,16 +17,20 @@ internal class SNIMarsConnection
{
private const string s_className = nameof(SNIMarsConnection);

private static QueuedTaskScheduler s_scheduler;

private readonly Guid _connectionId = Guid.NewGuid();
private readonly Dictionary<int, SNIMarsHandle> _sessions = new Dictionary<int, SNIMarsHandle>();
private readonly byte[] _headerBytes = new byte[SNISMUXHeader.HEADER_LENGTH];
private readonly SNISMUXHeader _currentHeader = new SNISMUXHeader();
private SNIHandle _lowerHandle;
private ushort _nextSessionId = 0;
private int _nextSessionId;
private int _currentHeaderByteCount = 0;
private int _dataBytesLeft = 0;
private SNIPacket _currentPacket;



/// <summary>
/// Connection ID
/// </summary>
Expand All @@ -45,6 +50,8 @@ public Guid ConnectionId
/// <param name="lowerHandle">Lower handle</param>
public SNIMarsConnection(SNIHandle lowerHandle)
{

_nextSessionId = -1;
_lowerHandle = lowerHandle;
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Created MARS Session Id {0}", args0: ConnectionId);
_lowerHandle.SetAsyncCallbacks(HandleReceiveComplete, HandleSendComplete);
Expand All @@ -54,7 +61,8 @@ public SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
{
lock (this)
{
ushort sessionId = _nextSessionId++;
ushort sessionId = unchecked((ushort)(Interlocked.Increment(ref _nextSessionId) % ushort.MaxValue));

SNIMarsHandle handle = new SNIMarsHandle(this, sessionId, callbackObject, async);
_sessions.Add(sessionId, handle);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "MARS Session Id {0}, SNI MARS Handle Id {1}, created new MARS Session {2}", args0: ConnectionId, args1: handle?.ConnectionId, args2: sessionId);
Expand All @@ -68,25 +76,42 @@ public SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
/// <returns></returns>
public uint StartReceive()
{
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent(s_className);
try
using (TrySNIEventScope.Create(nameof(SNIMarsConnection)))
{
if (LocalAppContextSwitches.UseExperimentalMARSThreading
//#if NETCOREAPP31_AND_ABOVE
// && ThreadPool.PendingWorkItemCount > 0
//#endif
)
{
LazyInitializer.EnsureInitialized(ref s_scheduler, () => new QueuedTaskScheduler(10, "MARSIOScheduler", useForegroundThreads: false, ThreadPriority.Normal));

// will start an async task on the scheduler and immediatley return so this await is safe
return s_scheduler.Factory.StartNew(StartAsyncReceiveLoopForConnection, this).GetAwaiter().GetResult();
}
else
{
return StartAsyncReceiveLoopForConnection(this);
}
}

static uint StartAsyncReceiveLoopForConnection(object state)
{
SNIMarsConnection connection = (SNIMarsConnection)state;
SNIPacket packet = null;

if (ReceiveAsync(ref packet) == TdsEnums.SNI_SUCCESS_IO_PENDING)
if (connection.ReceiveAsync(ref packet) == TdsEnums.SNI_SUCCESS_IO_PENDING)
{
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "MARS Session Id {0}, Success IO pending.", args0: ConnectionId);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, Success IO pending.", args0: connection.ConnectionId);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.ERR, "MARS Session Id {0}, Connection not usable.", args0: ConnectionId);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "MARS Session Id {0}, Connection not usable.", args0: connection.ConnectionId);
return SNICommon.ReportSNIError(SNIProviders.SMUX_PROV, 0, SNICommon.ConnNotUsableError, Strings.SNI_ERROR_19);
}
finally
{
SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
}
};
}



/// <summary>
/// Send a packet synchronously
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
state: callback,
CancellationToken.None,
TaskContinuationOptions.DenyChildAttach,
TaskScheduler.Default
TaskScheduler.Current // specifically continue on the current scheduler because we may override it for mars
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// Provides a TaskScheduler that provides control over priorities, fairness, and the underlying threads utilized.
/// </summary>
[DebuggerDisplay("Id={Id}, Queues={DebugQueueCount}, ScheduledTasks = {DebugTaskCount}")]
internal sealed class QueuedTaskScheduler : TaskScheduler, IDisposable
{
/// <summary>Cancellation token used for disposal.</summary>
private readonly CancellationTokenSource _disposeCancellation = new CancellationTokenSource();
/// <summary>
/// The maximum allowed concurrency level of this scheduler. If custom threads are
/// used, this represents the number of created threads.
/// </summary>
private readonly int _concurrencyLevel;
/// <summary>Whether we're processing tasks on the current thread.</summary>
private static readonly ThreadLocal<bool> s_taskProcessingThread = new ThreadLocal<bool>();

/// <summary>The threads used by the scheduler to process work.</summary>
private readonly Thread[] _threads;
/// <summary>The collection of tasks to be executed on our custom threads.</summary>
private readonly BlockingCollection<Task> _blockingTaskQueue;

private readonly TaskFactory _factory;

/// <summary>Initializes the scheduler.</summary>
/// <param name="threadCount">The number of threads to create and use for processing work items.</param>
public QueuedTaskScheduler(int threadCount) : this(threadCount, string.Empty, false, ThreadPriority.Normal, 0, null, null) { }

/// <summary>Initializes the scheduler.</summary>
/// <param name="threadCount">The number of threads to create and use for processing work items.</param>
/// <param name="threadName">The name to use for each of the created threads.</param>
/// <param name="useForegroundThreads">A Boolean value that indicates whether to use foreground threads instead of background.</param>
/// <param name="threadPriority">The priority to assign to each thread.</param>
/// <param name="threadMaxStackSize">The stack size to use for each thread.</param>
/// <param name="threadInit">An initialization routine to run on each thread.</param>
/// <param name="threadFinally">A finalization routine to run on each thread.</param>
public QueuedTaskScheduler(
int threadCount,
string threadName = "",
bool useForegroundThreads = false,
ThreadPriority threadPriority = ThreadPriority.Normal,
int threadMaxStackSize = 0,
Action threadInit = null,
Action threadFinally = null)
{
// Validates arguments (some validation is left up to the Thread type itself).
// If the thread count is 0, default to the number of logical processors.
if (threadCount < 0)
throw new ArgumentOutOfRangeException(nameof(threadCount));
else if (threadCount == 0)
_concurrencyLevel = Environment.ProcessorCount;
else
_concurrencyLevel = threadCount;

// Initialize the queue used for storing tasks
_blockingTaskQueue = new BlockingCollection<Task>();

// Create all of the threads
_threads = new Thread[threadCount];
for (int i = 0; i < threadCount; i++)
{
_threads[i] = new Thread(() => DispatchLoop(threadInit, threadFinally), threadMaxStackSize)
{
Priority = threadPriority,
IsBackground = !useForegroundThreads,
};
if (threadName != null)
_threads[i].Name = threadName + " (" + i + ")";
}

_factory = new TaskFactory(this);

// Start all of the threads
foreach (var thread in _threads)
thread.Start();
}

public TaskFactory Factory => _factory;

/// <summary>The dispatch loop run by all threads in this scheduler.</summary>
/// <param name="threadInit">An initialization routine to run when the thread begins.</param>
/// <param name="threadFinally">A finalization routine to run before the thread ends.</param>
private void DispatchLoop(Action threadInit, Action threadFinally)
{
s_taskProcessingThread.Value = true;
threadInit?.Invoke();
try
{
// If the scheduler is disposed, the cancellation token will be set and
// we'll receive an OperationCanceledException. That OCE should not crash the process.
try
{
// If a thread abort occurs, we'll try to reset it and continue running.
while (true)
{
try
{
// For each task queued to the scheduler, try to execute it.
foreach (var task in _blockingTaskQueue.GetConsumingEnumerable(_disposeCancellation.Token))
{
// If the task is not null, that means it was queued to this scheduler directly.
// Run it.
if (task != null)
{
bool tried = TryExecuteTask(task);
}
}
}
catch (ThreadAbortException)
{
// If we received a thread abort, and that thread abort was due to shutting down
// or unloading, let it pass through. Otherwise, reset the abort so we can
// continue processing work items.
if (!Environment.HasShutdownStarted && !AppDomain.CurrentDomain.IsFinalizingForUnload())
{
Thread.ResetAbort();
}
}
}
}
catch (OperationCanceledException) { }
}
finally
{
// Run a cleanup routine if there was one
threadFinally?.Invoke();
s_taskProcessingThread.Value = false;
}
}

/// <summary>Queues a task to the scheduler.</summary>
/// <param name="task">The task to be queued.</param>
protected override void QueueTask(Task task)
{
// If we've been disposed, no one should be queueing
if (_disposeCancellation.IsCancellationRequested)
{
throw new ObjectDisposedException(GetType().Name);
}
_blockingTaskQueue.Add(task);
}

/// <summary>Tries to execute a task synchronously on the current thread.</summary>
/// <param name="task">The task to execute.</param>
/// <param name="taskWasPreviouslyQueued">Whether the task was previously queued.</param>
/// <returns>true if the task was executed; otherwise, false.</returns>
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) =>
// If we're already running tasks on this threads, enable inlining
false; // s_taskProcessingThread.Value && TryExecuteTask(task);

/// <summary>Gets the tasks scheduled to this scheduler.</summary>
/// <returns>An enumerable of all tasks queued to this scheduler.</returns>
/// <remarks>This does not include the tasks on sub-schedulers. Those will be retrieved by the debugger separately.</remarks>
protected override IEnumerable<Task> GetScheduledTasks()
{
// Get all of the tasks, filtering out nulls, which are just placeholders
// for tasks in other sub-schedulers
return _blockingTaskQueue.Where(t => t != null).ToList();
}

/// <summary>Gets the maximum concurrency level to use when processing tasks.</summary>
public override int MaximumConcurrencyLevel => _concurrencyLevel;

/// <summary>Initiates shutdown of the scheduler.</summary>
public void Dispose() => _disposeCancellation.Cancel();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ internal static partial class LocalAppContextSwitches
internal const string LegacyRowVersionNullString = @"Switch.Microsoft.Data.SqlClient.LegacyRowVersionNullBehavior";
internal const string UseSystemDefaultSecureProtocolsString = @"Switch.Microsoft.Data.SqlClient.UseSystemDefaultSecureProtocols";
internal const string SuppressInsecureTLSWarningString = @"Switch.Microsoft.Data.SqlClient.SuppressInsecureTLSWarning";
internal const string UseExperimentalMARSThreadingString = @"Switch.Microsoft.Data.SqlClient.UseExperimentalMARSThreading";

private static bool s_makeReadAsyncBlocking;
private static bool? s_LegacyRowVersionNullBehavior;
private static bool? s_UseSystemDefaultSecureProtocols;
private static bool? s_SuppressInsecureTLSWarning;
private static bool? s_SuppressInsecureTLSWarning;
private static bool? s_useExperimentalMARSThreading;

#if !NETFRAMEWORK

#if NETCOREAPP31_AND_ABOVE
static LocalAppContextSwitches()
{
IAppContextSwitchOverridesSection appContextSwitch = AppConfigManager.FetchConfigurationSection<AppContextSwitchOverridesSection>(AppContextSwitchOverridesSection.Name);
Expand Down Expand Up @@ -95,5 +98,19 @@ public static bool UseSystemDefaultSecureProtocols
return s_UseSystemDefaultSecureProtocols.Value;
}
}

public static bool UseExperimentalMARSThreading
{
get
{
if (s_useExperimentalMARSThreading is null)
{
bool result;
result = AppContext.TryGetSwitch(UseExperimentalMARSThreadingString, out result) ? result : false;
s_useExperimentalMARSThreading = result;
}
return s_useExperimentalMARSThreading.Value;
}
}
}
}