diff --git a/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/Runners/AppRunner.cs b/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/Runners/AppRunner.cs index 279ab12f631..a61fb14cd27 100644 --- a/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/Runners/AppRunner.cs +++ b/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/Runners/AppRunner.cs @@ -58,7 +58,7 @@ public sealed class AppRunner : IAsyncDisposable public int ExitCode => _adapter.ExitCode; - public int ProcessId => _adapter.ProcessId; + public Task ProcessIdTask => _adapter.ProcessIdTask; /// /// Name of the scenario to run in the application. diff --git a/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/Runners/LoggingRunnerAdapter.cs b/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/Runners/LoggingRunnerAdapter.cs index bda63809b07..033900c2d1b 100644 --- a/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/Runners/LoggingRunnerAdapter.cs +++ b/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/Runners/LoggingRunnerAdapter.cs @@ -15,6 +15,8 @@ public sealed class LoggingRunnerAdapter : IAsyncDisposable { private readonly CancellationTokenSource _cancellation = new(); private readonly ITestOutputHelper _outputHelper; + private readonly TaskCompletionSource _processIdSource = + new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private readonly DotNetRunner _runner; private readonly List _standardErrorLines = new(); private readonly List _standardOutputLines = new(); @@ -30,8 +32,7 @@ public sealed class LoggingRunnerAdapter : IAsyncDisposable public int ExitCode => _exitCode.HasValue ? _exitCode.Value : throw new InvalidOperationException("Must call WaitForExitAsync before getting exit code."); - public int ProcessId => _processId.HasValue ? - _processId.Value : throw new InvalidOperationException("Process was not started."); + public Task ProcessIdTask => _processIdSource.Task; public event Action ReceivedStandardErrorLine; @@ -56,6 +57,8 @@ public async ValueTask DisposeAsync() _cancellation.Cancel(); + _processIdSource.TrySetCanceled(_cancellation.Token); + // Shutdown the runner _outputHelper.WriteLine("Stopping..."); _runner.ForceClose(); @@ -96,11 +99,15 @@ public async Task StartAsync(CancellationToken token) } _outputHelper.WriteLine("End Environment:"); - _outputHelper.WriteLine("Starting..."); - await _runner.StartAsync(token).ConfigureAwait(false); + using (var _ = token.Register(() => _processIdSource.TrySetCanceled(token))) + { + _outputHelper.WriteLine("Starting..."); + await _runner.StartAsync(token).ConfigureAwait(false); + } - _processId = _runner.ProcessId; _outputHelper.WriteLine("Process ID: {0}", _runner.ProcessId); + _processId = _runner.ProcessId; + _processIdSource.TrySetResult(_runner.ProcessId); _standardErrorTask = ReadLinesAsync(_runner.StandardError, _standardErrorLines, ReceivedStandardErrorLine, _cancellation.Token); _standardOutputTask = ReadLinesAsync(_runner.StandardOutput, _standardOutputLines, ReceivedStandardOutputLine, _cancellation.Token); diff --git a/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/TaskExtensions.cs b/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/TaskExtensions.cs index 9ba05e008e3..7be1b4c8cc9 100644 --- a/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/TaskExtensions.cs +++ b/src/Tests/Microsoft.Diagnostics.Monitoring.TestCommon/TaskExtensions.cs @@ -60,5 +60,12 @@ public static async Task WithCancellation(this Task task, CancellationToken toke localTokenSource.Cancel(); } } + + public static async Task WithCancellation(this Task task, CancellationToken token) + { + await WithCancellation((Task)task, token); + + return task.Result; + } } } diff --git a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/DumpTests.cs b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/DumpTests.cs index cfad516dd64..ed88aab6670 100644 --- a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/DumpTests.cs +++ b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/DumpTests.cs @@ -58,10 +58,12 @@ public Task DumpTest(DiagnosticPortConnectionMode mode, DumpType type) TestAppScenarios.AsyncWait.Name, appValidate: async (runner, client) => { - ProcessInfo processInfo = await client.GetProcessAsync(runner.ProcessId); + int processId = await runner.ProcessIdTask; + + ProcessInfo processInfo = await client.GetProcessAsync(processId); Assert.NotNull(processInfo); - using ResponseStreamHolder holder = await client.CaptureDumpAsync(runner.ProcessId, type); + using ResponseStreamHolder holder = await client.CaptureDumpAsync(processId, type); Assert.NotNull(holder); byte[] headerBuffer = new byte[64]; diff --git a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/EgressTests.cs b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/EgressTests.cs index c7baf9b40c7..27057c2a28c 100644 --- a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/EgressTests.cs +++ b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/EgressTests.cs @@ -55,7 +55,9 @@ await ScenarioRunner.SingleTarget( TestAppScenarios.AsyncWait.Name, appValidate: async (appRunner, apiClient) => { - OperationResponse response = await apiClient.EgressTraceAsync(appRunner.ProcessId, durationSeconds: 5, FileProviderName); + int processId = await appRunner.ProcessIdTask; + + OperationResponse response = await apiClient.EgressTraceAsync(processId, durationSeconds: 5, FileProviderName); Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); OperationStatusResponse operationResult = await apiClient.PollOperationToCompletion(response.OperationUri); @@ -81,7 +83,9 @@ await ScenarioRunner.SingleTarget( TestAppScenarios.AsyncWait.Name, appValidate: async (appRunner, apiClient) => { - OperationResponse response = await apiClient.EgressTraceAsync(appRunner.ProcessId, durationSeconds: -1, FileProviderName); + int processId = await appRunner.ProcessIdTask; + + OperationResponse response = await apiClient.EgressTraceAsync(processId, durationSeconds: -1, FileProviderName); Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); OperationStatusResponse operationResult = await apiClient.GetOperationStatus(response.OperationUri); @@ -113,8 +117,10 @@ await ScenarioRunner.SingleTarget( TestAppScenarios.AsyncWait.Name, appValidate: async (appRunner, apiClient) => { - OperationResponse response1 = await EgressTraceWithDelay(apiClient, appRunner.ProcessId); - OperationResponse response2 = await EgressTraceWithDelay(apiClient, appRunner.ProcessId, delay: false); + int processId = await appRunner.ProcessIdTask; + + OperationResponse response1 = await EgressTraceWithDelay(apiClient, processId); + OperationResponse response2 = await EgressTraceWithDelay(apiClient, processId, delay: false); await CancelEgressOperation(apiClient, response2); List result = await apiClient.GetOperations(); @@ -146,18 +152,20 @@ await ScenarioRunner.SingleTarget( TestAppScenarios.AsyncWait.Name, appValidate: async (appRunner, apiClient) => { - OperationResponse response1 = await EgressTraceWithDelay(apiClient, appRunner.ProcessId); - OperationResponse response2 = await EgressTraceWithDelay(apiClient, appRunner.ProcessId); - OperationResponse response3 = await EgressTraceWithDelay(apiClient, appRunner.ProcessId); + int processId = await appRunner.ProcessIdTask; - ValidationProblemDetailsException ex = await Assert.ThrowsAsync(() => EgressTraceWithDelay(apiClient, appRunner.ProcessId)); + OperationResponse response1 = await EgressTraceWithDelay(apiClient, processId); + OperationResponse response2 = await EgressTraceWithDelay(apiClient, processId); + OperationResponse response3 = await EgressTraceWithDelay(apiClient, processId); + + ValidationProblemDetailsException ex = await Assert.ThrowsAsync(() => EgressTraceWithDelay(apiClient, processId)); Assert.Equal(HttpStatusCode.TooManyRequests, ex.StatusCode); Assert.Equal((int)HttpStatusCode.TooManyRequests, ex.Details.Status.GetValueOrDefault()); await CancelEgressOperation(apiClient, response1); await CancelEgressOperation(apiClient, response2); - OperationResponse response4 = await EgressTraceWithDelay(apiClient, appRunner.ProcessId, delay: false); + OperationResponse response4 = await EgressTraceWithDelay(apiClient, processId, delay: false); await CancelEgressOperation(apiClient, response3); await CancelEgressOperation(apiClient, response4); @@ -180,25 +188,28 @@ await ScenarioRunner.SingleTarget( TestAppScenarios.AsyncWait.Name, appValidate: async (appRunner, apiClient) => { - OperationResponse response1 = await EgressTraceWithDelay(apiClient, appRunner.ProcessId); - OperationResponse response3 = await EgressTraceWithDelay(apiClient, appRunner.ProcessId); - using HttpResponseMessage traceDirect1 = await TraceWithDelay(apiClient, appRunner.ProcessId); + int processId = await appRunner.ProcessIdTask; + + OperationResponse response1 = await EgressTraceWithDelay(apiClient, processId); + OperationResponse response3 = await EgressTraceWithDelay(apiClient, processId); + using HttpResponseMessage traceDirect1 = await TraceWithDelay(apiClient, processId); Assert.Equal(HttpStatusCode.OK, traceDirect1.StatusCode); - ValidationProblemDetailsException ex = await Assert.ThrowsAsync(() => EgressTraceWithDelay(apiClient, appRunner.ProcessId, delay: false)); + ValidationProblemDetailsException ex = await Assert.ThrowsAsync( + () => EgressTraceWithDelay(apiClient, processId, delay: false)); Assert.Equal(HttpStatusCode.TooManyRequests, ex.StatusCode); - using HttpResponseMessage traceDirect = await TraceWithDelay(apiClient, appRunner.ProcessId, delay: false); + using HttpResponseMessage traceDirect = await TraceWithDelay(apiClient, processId, delay: false); Assert.Equal(HttpStatusCode.TooManyRequests, traceDirect.StatusCode); //Validate that the failure from a direct call (handled by middleware) //matches the failure produces by egress operations (handled by the Mvc ActionResult stack) - using HttpResponseMessage egressDirect = await EgressDirect(apiClient, appRunner.ProcessId); + using HttpResponseMessage egressDirect = await EgressDirect(apiClient, processId); Assert.Equal(HttpStatusCode.TooManyRequests, egressDirect.StatusCode); Assert.Equal(await egressDirect.Content.ReadAsStringAsync(), await traceDirect.Content.ReadAsStringAsync()); await CancelEgressOperation(apiClient, response1); - OperationResponse response4 = await EgressTraceWithDelay(apiClient, appRunner.ProcessId, delay: false); + OperationResponse response4 = await EgressTraceWithDelay(apiClient, processId, delay: false); await CancelEgressOperation(apiClient, response3); await CancelEgressOperation(apiClient, response4); @@ -224,19 +235,21 @@ await ScenarioRunner.SingleTarget( TestAppScenarios.AsyncWait.Name, appValidate: async (appRunner, appClient) => { - ProcessInfo processInfo = await appClient.GetProcessAsync(appRunner.ProcessId); + int processId = await appRunner.ProcessIdTask; + + ProcessInfo processInfo = await appClient.GetProcessAsync(processId); Assert.NotNull(processInfo); // Dump Error Check ValidationProblemDetailsException validationProblemDetailsExceptionDumps = await Assert.ThrowsAsync( - () => appClient.CaptureDumpAsync(appRunner.ProcessId, DumpType.Mini)); + () => appClient.CaptureDumpAsync(processId, DumpType.Mini)); Assert.Equal(HttpStatusCode.BadRequest, validationProblemDetailsExceptionDumps.StatusCode); Assert.Equal(StatusCodes.Status400BadRequest, validationProblemDetailsExceptionDumps.Details.Status); Assert.Equal(DisabledHTTPEgressErrorMessage, validationProblemDetailsExceptionDumps.Message); // Logs Error Check ValidationProblemDetailsException validationProblemDetailsExceptionLogs = await Assert.ThrowsAsync( - () => appClient.CaptureLogsAsync(appRunner.ProcessId, TestTimeouts.LogsDuration, LogLevel.None, LogFormat.NDJson)); + () => appClient.CaptureLogsAsync(processId, TestTimeouts.LogsDuration, LogLevel.None, LogFormat.NDJson)); Assert.Equal(HttpStatusCode.BadRequest, validationProblemDetailsExceptionLogs.StatusCode); Assert.Equal(StatusCodes.Status400BadRequest, validationProblemDetailsExceptionLogs.Details.Status); Assert.Equal(DisabledHTTPEgressErrorMessage, validationProblemDetailsExceptionLogs.Message); diff --git a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/LogsTests.cs b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/LogsTests.cs index 0ca4f409782..9244dd7fd7d 100644 --- a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/LogsTests.cs +++ b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/LogsTests.cs @@ -183,7 +183,7 @@ public Task LogsDefaultLevelNoneNotSupportedViaQueryTest(DiagnosticPortConnectio async () => { using ResponseStreamHolder _ = await client.CaptureLogsAsync( - runner.ProcessId, + await runner.ProcessIdTask, TestTimeouts.LogsDuration, LogLevel.None, logFormat); @@ -219,7 +219,7 @@ public Task LogsDefaultLevelNoneNotSupportedViaBodyTest(DiagnosticPortConnection async () => { using ResponseStreamHolder _ = await client.CaptureLogsAsync( - runner.ProcessId, + await runner.ProcessIdTask, TestTimeouts.LogsDuration, new LogsConfiguration() { LogLevel = LogLevel.None }, logFormat); @@ -406,11 +406,16 @@ private Task ValidateLogsAsync( _httpClientFactory, mode, TestAppScenarios.Logger.Name, - appValidate: (runner, client) => ValidateResponseStream( - runner, - client.CaptureLogsAsync(runner.ProcessId, TestTimeouts.LogsDuration, logLevel, logFormat), - callback, - logFormat)); + appValidate: async (runner, client) => + await ValidateResponseStream( + runner, + client.CaptureLogsAsync( + await runner.ProcessIdTask, + TestTimeouts.LogsDuration, + logLevel, + logFormat), + callback, + logFormat)); } private Task ValidateLogsAsync( @@ -424,11 +429,16 @@ private Task ValidateLogsAsync( _httpClientFactory, mode, TestAppScenarios.Logger.Name, - appValidate: (runner, client) => ValidateResponseStream( - runner, - client.CaptureLogsAsync(runner.ProcessId, TestTimeouts.LogsDuration, configuration, logFormat), - callback, - logFormat)); + appValidate: async (runner, client) => + await ValidateResponseStream( + runner, + client.CaptureLogsAsync( + await runner.ProcessIdTask, + TestTimeouts.LogsDuration, + configuration, + logFormat), + callback, + logFormat)); } private async Task ValidateResponseStream(AppRunner runner, Task holderTask, Func, Task> callback, LogFormat logFormat) diff --git a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/ProcessTests.cs b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/ProcessTests.cs index 5e6f1d76c25..9eeb480dc05 100644 --- a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/ProcessTests.cs +++ b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/ProcessTests.cs @@ -56,14 +56,16 @@ public Task SingleProcessIdentificationTest(DiagnosticPortConnectionMode mode) TestAppScenarios.AsyncWait.Name, appValidate: async (runner, client) => { + int processId = await runner.ProcessIdTask; + // GET /processes and filter to just the single process IEnumerable identifiers = await client.GetProcessesWithRetryAsync( _outputHelper, - new[] { runner.ProcessId }); + new[] { processId }); Assert.NotNull(identifiers); Assert.Single(identifiers); - await VerifyProcessAsync(client, identifiers, runner.ProcessId, expectedEnvVarValue); + await VerifyProcessAsync(client, identifiers, processId, expectedEnvVarValue); await runner.SendCommandAsync(TestAppScenarios.AsyncWait.Commands.Continue); }, @@ -129,7 +131,7 @@ await appRunners.ExecuteAsync(async () => IList unmatchedPids = new List(); foreach (AppRunner runner in appRunners) { - unmatchedPids.Add(runner.ProcessId); + unmatchedPids.Add(await runner.ProcessIdTask); } // Query for process identifiers @@ -202,7 +204,7 @@ await appRunners.ExecuteAsync(async () => { Assert.True(runner.Environment.TryGetValue(ExpectedEnvVarName, out string expectedEnvVarValue)); - await VerifyProcessAsync(apiClient, identifiers, runner.ProcessId, expectedEnvVarValue); + await VerifyProcessAsync(apiClient, identifiers, await runner.ProcessIdTask, expectedEnvVarValue); await runner.SendCommandAsync(TestAppScenarios.AsyncWait.Commands.Continue); } @@ -218,9 +220,15 @@ await appRunners.ExecuteAsync(async () => Assert.NotNull(identifiers); // Verify none of the apps are reported + List runnerProcessIds = new(appCount); for (int i = 0; i < appCount; i++) { - Assert.Null(identifiers.FirstOrDefault(p => p.Pid == appRunners[i].ProcessId)); + runnerProcessIds.Add(await appRunners[i].ProcessIdTask); + } + + foreach (ProcessIdentifier identifier in identifiers) + { + Assert.DoesNotContain(identifier.Pid, runnerProcessIds); } } diff --git a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/Runners/ScenarioRunner.cs b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/Runners/ScenarioRunner.cs index 51ee104ad7f..a211e3d61d1 100644 --- a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/Runners/ScenarioRunner.cs +++ b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.FunctionalTests/Runners/ScenarioRunner.cs @@ -62,7 +62,7 @@ await appRunner.ExecuteAsync(async () => if (null != postAppValidate) { - await postAppValidate(apiClient, appRunner.ProcessId); + await postAppValidate(apiClient, await appRunner.ProcessIdTask); } } } diff --git a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.UnitTests/EndpointInfoSourceTests.cs b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.UnitTests/EndpointInfoSourceTests.cs index 625890a1c36..d0f7e19fa03 100644 --- a/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.UnitTests/EndpointInfoSourceTests.cs +++ b/src/Tests/Microsoft.Diagnostics.Monitoring.Tool.UnitTests/EndpointInfoSourceTests.cs @@ -118,7 +118,8 @@ public async Task ServerSourceAddRemoveSingleConnectionTest(TargetFrameworkMonik AppRunner runner = CreateAppRunner(transportName, appTfm); - Task newEndpointInfoTask = callback.WaitForNewEndpointInfoAsync(runner, CommonTestTimeouts.StartProcess); + using CancellationTokenSource cancellation = new(CommonTestTimeouts.StartProcess); + Task newEndpointInfoTask = callback.WaitForNewEndpointInfoAsync(runner, cancellation.Token); await runner.ExecuteAsync(async () => { @@ -130,7 +131,7 @@ await runner.ExecuteAsync(async () => Assert.NotNull(endpointInfo.CommandLine); Assert.NotNull(endpointInfo.OperatingSystem); Assert.NotNull(endpointInfo.ProcessArchitecture); - VerifyConnection(runner, endpointInfo); + await VerifyConnectionAsync(runner, endpointInfo); await runner.SendCommandAsync(TestAppScenarios.AsyncWait.Commands.Continue); }); @@ -162,10 +163,11 @@ public async Task ServerSourceAddRemoveMultipleConnectionTest(TargetFrameworkMon Task[] newEndpointInfoTasks = new Task[appCount]; // Start all app instances + using CancellationTokenSource cancellation = new(CommonTestTimeouts.StartProcess); for (int i = 0; i < appCount; i++) { runners[i] = CreateAppRunner(transportName, appTfm, appId: i + 1); - newEndpointInfoTasks[i] = callback.WaitForNewEndpointInfoAsync(runners[i], CommonTestTimeouts.StartProcess); + newEndpointInfoTasks[i] = callback.WaitForNewEndpointInfoAsync(runners[i], cancellation.Token); } await runners.ExecuteAsync(async () => @@ -180,13 +182,15 @@ await runners.ExecuteAsync(async () => for (int i = 0; i < appCount; i++) { - IEndpointInfo endpointInfo = endpointInfos.FirstOrDefault(info => info.ProcessId == runners[i].ProcessId); + int processId = await runners[i].ProcessIdTask; + + IEndpointInfo endpointInfo = endpointInfos.FirstOrDefault(info => info.ProcessId == processId); Assert.NotNull(endpointInfo); Assert.NotNull(endpointInfo.CommandLine); Assert.NotNull(endpointInfo.OperatingSystem); Assert.NotNull(endpointInfo.ProcessArchitecture); - VerifyConnection(runners[i], endpointInfo); + await VerifyConnectionAsync(runners[i], endpointInfo); await runners[i].SendCommandAsync(TestAppScenarios.AsyncWait.Commands.Continue); } @@ -242,11 +246,11 @@ private async Task> GetEndpointInfoAsync(ServerEndpoi /// /// Verifies basic information on the connection and that it matches the target process from the runner. /// - private static void VerifyConnection(AppRunner runner, IEndpointInfo endpointInfo) + private static async Task VerifyConnectionAsync(AppRunner runner, IEndpointInfo endpointInfo) { Assert.NotNull(runner); Assert.NotNull(endpointInfo); - Assert.Equal(runner.ProcessId, endpointInfo.ProcessId); + Assert.Equal(await runner.ProcessIdTask, endpointInfo.ProcessId); Assert.NotEqual(Guid.Empty, endpointInfo.RuntimeInstanceCookie); Assert.NotNull(endpointInfo.Endpoint); } @@ -254,70 +258,94 @@ private static void VerifyConnection(AppRunner runner, IEndpointInfo endpointInf private sealed class ServerEndpointInfoCallback : IEndpointInfoSourceCallbacks { private readonly ITestOutputHelper _outputHelper; - private readonly List<(AppRunner Runner, TaskCompletionSource CompletionSource)> _addedEndpointInfoSources = new(); + /// + /// Use to protect the completion list from mutation while processing + /// callbacks from it. The processing is done in an async method with async + /// calls, which are not allowed in a lock, thus use SemaphoreSlim. + /// + private readonly SemaphoreSlim _completionEntriesSemaphore = new(1); + private readonly List _completionEntries = new(); public ServerEndpointInfoCallback(ITestOutputHelper outputHelper) { _outputHelper = outputHelper; } - public async Task WaitForNewEndpointInfoAsync(AppRunner runner, TimeSpan timeout) + public async Task WaitForNewEndpointInfoAsync(AppRunner runner, CancellationToken token) { - TaskCompletionSource addedEndpointInfoSource = new(TaskCreationOptions.RunContinuationsAsynchronously); - using CancellationTokenSource timeoutCancellation = new(); - var token = timeoutCancellation.Token; - using var _ = token.Register(() => addedEndpointInfoSource.TrySetCanceled(token)); + CompletionEntry entry = new(runner); + using var _ = token.Register(() => entry.CompletionSource.TrySetCanceled(token)); - lock (_addedEndpointInfoSources) + await _completionEntriesSemaphore.WaitAsync(token); + try { - _addedEndpointInfoSources.Add(new (runner, addedEndpointInfoSource)); + _completionEntries.Add(entry); _outputHelper.WriteLine($"[Wait] Register App{runner.AppId}"); } + finally + { + _completionEntriesSemaphore.Release(); + } _outputHelper.WriteLine($"[Wait] Wait for App{runner.AppId} notification"); - timeoutCancellation.CancelAfter(timeout); - IEndpointInfo endpointInfo = await addedEndpointInfoSource.Task; + IEndpointInfo endpointInfo = await entry.CompletionSource.Task; _outputHelper.WriteLine($"[Wait] Received App{runner.AppId} notification"); return endpointInfo; } - public Task OnBeforeResumeAsync(IEndpointInfo endpointInfo, CancellationToken cancellationToken) + public Task OnBeforeResumeAsync(IEndpointInfo endpointInfo, CancellationToken token) { return Task.CompletedTask; } - public void OnAddedEndpointInfo(IEndpointInfo info) + public async Task OnAddedEndpointInfoAsync(IEndpointInfo info, CancellationToken token) { _outputHelper.WriteLine($"[Source] Added: {ToOutputString(info)}"); - - lock (_addedEndpointInfoSources) + + await _completionEntriesSemaphore.WaitAsync(token); + try { _outputHelper.WriteLine($"[Source] Start notifications for process {info.ProcessId}"); - foreach (var sourceTuple in _addedEndpointInfoSources.ToList()) + // Create a mapping of the process ID tasks to the completion entries + IDictionary, CompletionEntry> map = new Dictionary, CompletionEntry>(_completionEntries.Count); + foreach (CompletionEntry entry in _completionEntries) + { + map.Add(entry.Runner.ProcessIdTask.WithCancellation(token), entry); + } + + while (map.Count > 0) { - AppRunner runner = sourceTuple.Runner; - _outputHelper.WriteLine($"[Source] Checking App{runner.AppId}"); - try + // Wait for any of the process ID tasks to complete. + Task completedTask = await Task.WhenAny(map.Keys); + + map.Remove(completedTask, out CompletionEntry entry); + + _outputHelper.WriteLine($"[Source] Checking App{entry.Runner.AppId}"); + + if (completedTask.IsCompletedSuccessfully) { - if (info.ProcessId == runner.ProcessId) + // If the process ID matches the one that was reported via the callback, + // then signal its completion source. + if (info.ProcessId == completedTask.Result) { - _outputHelper.WriteLine($"[Source] Notifying App{runner.AppId}"); - sourceTuple.CompletionSource.TrySetResult(info); - _addedEndpointInfoSources.Remove(sourceTuple); + _outputHelper.WriteLine($"[Source] Notifying App{entry.Runner.AppId}"); + entry.CompletionSource.TrySetResult(info); + + _completionEntries.Remove(entry); + break; } } - catch (InvalidOperationException) - { - // Thrown if app runner hasn't started process yet. - _outputHelper.WriteLine($"[Source] App{runner.AppId} has not start yet."); - } } _outputHelper.WriteLine($"[Source] Finished notifications for process {info.ProcessId}"); } + finally + { + _completionEntriesSemaphore.Release(); + } } public void OnRemovedEndpointInfo(IEndpointInfo info) @@ -329,6 +357,19 @@ private static string ToOutputString(IEndpointInfo info) { return FormattableString.Invariant($"PID={info.ProcessId}, Cookie={info.RuntimeInstanceCookie}"); } + + private sealed class CompletionEntry + { + public CompletionEntry(AppRunner runner) + { + Runner = runner; + CompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + + public AppRunner Runner { get; } + + public TaskCompletionSource CompletionSource { get; } + } } } } diff --git a/src/Tools/dotnet-monitor/EndpointInfo/IEndpointInfoSourceCallbacks.cs b/src/Tools/dotnet-monitor/EndpointInfo/IEndpointInfoSourceCallbacks.cs index c270f3048ef..1f5c7fd070c 100644 --- a/src/Tools/dotnet-monitor/EndpointInfo/IEndpointInfoSourceCallbacks.cs +++ b/src/Tools/dotnet-monitor/EndpointInfo/IEndpointInfoSourceCallbacks.cs @@ -15,7 +15,7 @@ internal interface IEndpointInfoSourceCallbacks { Task OnBeforeResumeAsync(IEndpointInfo endpointInfo, CancellationToken cancellationToken); - void OnAddedEndpointInfo(IEndpointInfo endpointInfo); + Task OnAddedEndpointInfoAsync(IEndpointInfo endpointInfo, CancellationToken cancellationToken); void OnRemovedEndpointInfo(IEndpointInfo endpointInfo); } diff --git a/src/Tools/dotnet-monitor/EndpointInfo/ServerEndpointInfoSource.cs b/src/Tools/dotnet-monitor/EndpointInfo/ServerEndpointInfoSource.cs index 1d598f9dedc..2ce057aefbb 100644 --- a/src/Tools/dotnet-monitor/EndpointInfo/ServerEndpointInfoSource.cs +++ b/src/Tools/dotnet-monitor/EndpointInfo/ServerEndpointInfoSource.cs @@ -243,7 +243,7 @@ private async Task ResumeAndQueueEndpointInfo(IpcEndpointInfo info, Cancellation foreach (IEndpointInfoSourceCallbacks callback in _callbacks) { - callback.OnAddedEndpointInfo(endpointInfo); + await callback.OnAddedEndpointInfoAsync(endpointInfo, token).ConfigureAwait(false); } } finally