Skip to content

Commit

Permalink
[dotnet] Refactor WebSocket communication for BiDi (#12614)
Browse files Browse the repository at this point in the history
* [dotnet] Refactor WebSocket communication for BiDi

Replaces the existing WebSocket communication mechanism with one more
robust. Rather than immediately relying on event handlers to react to
events and command reponses, it writes the incoming data to a queue
which is read from a different thread. This eliminates the issue where
the user might have multiple simultaneous sends or receives to the
WebSocket while their event handler is running. It also dispatches
incoming events on different threads for the same reason. This should
eliminate at least some of the issues surrounding socket communication
with bidirectional features, whether implemented using CDP or the
WebDriver BiDi protocol.

* Address review comments

* Removing use of System.Threading.Channels

* nitpick: fix XML doc comment

* Simplify WebSocket message queue processing code

* Omit added test from Firefox

* revert add of now unused nuget packages
  • Loading branch information
jimevans authored Aug 29, 2023
1 parent cbda4dd commit 739d177
Show file tree
Hide file tree
Showing 7 changed files with 422 additions and 152 deletions.
198 changes: 57 additions & 141 deletions dotnet/src/webdriver/DevTools/DevToolsSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
using System;
using System.Collections.Concurrent;
using System.Globalization;
using System.IO;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Newtonsoft.Json;
Expand Down Expand Up @@ -50,15 +47,14 @@ public class DevToolsSession : IDevToolsSession
private bool isDisposed = false;
private string attachedTargetId;

private ClientWebSocket sessionSocket;
private WebSocketConnection connection;
private ConcurrentDictionary<long, DevToolsCommandData> pendingCommands = new ConcurrentDictionary<long, DevToolsCommandData>();
private readonly BlockingCollection<string> messageQueue = new BlockingCollection<string>();
private readonly Task messageQueueMonitorTask;
private long currentCommandId = 0;

private DevToolsDomains domains;

private CancellationTokenSource receiveCancellationToken;
private Task receiveTask;

/// <summary>
/// Initializes a new instance of the DevToolsSession class, using the specified WebSocket endpoint.
/// </summary>
Expand All @@ -76,6 +72,8 @@ public DevToolsSession(string endpointAddress)
{
this.websocketAddress = endpointAddress;
}
this.messageQueueMonitorTask = Task.Run(() => this.MonitorMessageQueue());
this.messageQueueMonitorTask.ConfigureAwait(false);
}

/// <summary>
Expand Down Expand Up @@ -213,15 +211,13 @@ public T GetVersionSpecificDomains<T>() where T : DevToolsSessionDomains

var message = new DevToolsCommandData(Interlocked.Increment(ref this.currentCommandId), this.ActiveSessionId, commandName, commandParameters);

if (this.sessionSocket != null && this.sessionSocket.State == WebSocketState.Open)
if (this.connection != null && this.connection.IsActive)
{
LogTrace("Sending {0} {1}: {2}", message.CommandId, message.CommandName, commandParameters.ToString());

var contents = JsonConvert.SerializeObject(message);
var contentBuffer = Encoding.UTF8.GetBytes(contents);

string contents = JsonConvert.SerializeObject(message);
this.pendingCommands.TryAdd(message.CommandId, message);
await this.sessionSocket.SendAsync(new ArraySegment<byte>(contentBuffer), WebSocketMessageType.Text, true, cancellationToken);
await this.connection.SendData(contents);

var responseWasReceived = await Task.Run(() => message.SyncEvent.Wait(millisecondsTimeout.Value, cancellationToken));

Expand All @@ -230,8 +226,7 @@ public T GetVersionSpecificDomains<T>() where T : DevToolsSessionDomains
throw new InvalidOperationException($"A command response was not received: {commandName}");
}

DevToolsCommandData modified;
if (this.pendingCommands.TryRemove(message.CommandId, out modified))
if (this.pendingCommands.TryRemove(message.CommandId, out DevToolsCommandData modified))
{
if (modified.IsError)
{
Expand All @@ -256,10 +251,7 @@ public T GetVersionSpecificDomains<T>() where T : DevToolsSessionDomains
}
else
{
if (this.sessionSocket != null)
{
LogTrace("WebSocket is not connected (current state is {0}); not sending {1}", this.sessionSocket.State, message.CommandName);
}
LogTrace("WebSocket is not connected; not sending {0}", message.CommandName);
}

return null;
Expand Down Expand Up @@ -330,11 +322,7 @@ protected void Dispose(bool disposing)
{
this.Domains.Target.TargetDetached -= this.OnTargetDetached;
this.pendingCommands.Clear();
this.TerminateSocketConnection();

// Note: Canceling the receive task will dispose of
// the underlying ClientWebSocket instance.
this.CancelReceiveTask();
this.TerminateSocketConnection().GetAwaiter().GetResult();
}

this.isDisposed = true;
Expand Down Expand Up @@ -377,28 +365,6 @@ private async Task<int> InitializeProtocol(int requestedProtocolVersion)
return protocolVersion;
}

private async Task InitializeSocketConnection()
{
LogTrace("Creating WebSocket");
this.sessionSocket = new ClientWebSocket();
this.sessionSocket.Options.KeepAliveInterval = TimeSpan.Zero;

try
{
var timeoutTokenSource = new CancellationTokenSource(this.openConnectionWaitTimeSpan);
await this.sessionSocket.ConnectAsync(new Uri(this.websocketAddress), timeoutTokenSource.Token);
while (this.sessionSocket.State != WebSocketState.Open && !timeoutTokenSource.Token.IsCancellationRequested) ;
}
catch (OperationCanceledException e)
{
throw new WebDriverException(string.Format(CultureInfo.InvariantCulture, "Could not establish WebSocket connection within {0} seconds.", this.openConnectionWaitTimeSpan.TotalSeconds), e);
}

LogTrace("WebSocket created; starting message listener");
this.receiveCancellationToken = new CancellationTokenSource();
this.receiveTask = Task.Run(() => ReceiveMessage().ConfigureAwait(false));
}

private async Task InitializeSession()
{
LogTrace("Creating session");
Expand Down Expand Up @@ -445,116 +411,56 @@ private void OnTargetDetached(object sender, TargetDetachedEventArgs e)
}
}

private void TerminateSocketConnection()
private async Task InitializeSocketConnection()
{
if (this.sessionSocket != null && this.sessionSocket.State == WebSocketState.Open)
{
var closeConnectionTokenSource = new CancellationTokenSource(this.closeConnectionWaitTimeSpan);
try
{
// Since Chromium-based DevTools does not respond to the close
// request with a correctly echoed WebSocket close packet, but
// rather just terminates the socket connection, so we have to
// catch the exception thrown when the socket is terminated
// unexpectedly. Also, because we are using async, waiting for
// the task to complete might throw a TaskCanceledException,
// which we should also catch. Additiionally, there are times
// when mulitple failure modes can be seen, which will throw an
// AggregateException, consolidating several exceptions into one,
// and this too must be caught. Finally, the call to CloseAsync
// will hang even though the connection is already severed.
// Wait for the task to complete for a short time (since we're
// restricted to localhost, the default of 2 seconds should be
// plenty; if not, change the initialization of the timout),
// and if the task is still running, then we assume the connection
// is properly closed.
LogTrace("Sending socket close request");
Task closeTask = Task.Run(async () => await this.sessionSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, string.Empty, closeConnectionTokenSource.Token));
closeTask.Wait();
}
catch (WebSocketException)
{
}
catch (TaskCanceledException)
{
}
catch (AggregateException)
{
}
}
LogTrace("Creating WebSocket");
this.connection = new WebSocketConnection(this.openConnectionWaitTimeSpan, this.closeConnectionWaitTimeSpan);
connection.DataReceived += OnConnectionDataReceived;
await connection.Start(this.websocketAddress);
LogTrace("WebSocket created");
}

private void CancelReceiveTask()
private async Task TerminateSocketConnection()
{
if (this.receiveTask != null)
LogTrace("Closing WebSocket");
if (this.connection != null && this.connection.IsActive)
{
// Wait for the recieve task to be completely exited (for
// whatever reason) before attempting to dispose it. Also
// note that canceling the receive task will dispose of the
// underlying WebSocket.
this.receiveCancellationToken.Cancel();
this.receiveTask.Wait();
this.receiveTask.Dispose();
this.receiveTask = null;
await this.connection.Stop();
await this.ShutdownMessageQueue();
}
LogTrace("WebSocket closed");
}

private async Task ReceiveMessage()
private async Task ShutdownMessageQueue()
{
var cancellationToken = this.receiveCancellationToken.Token;
try
{
var buffer = WebSocket.CreateClientBuffer(1024, 1024);
while (this.sessionSocket.State != WebSocketState.Closed && !cancellationToken.IsCancellationRequested)
{
WebSocketReceiveResult result = await this.sessionSocket.ReceiveAsync(buffer, cancellationToken);
if (!cancellationToken.IsCancellationRequested)
{
if (result.MessageType == WebSocketMessageType.Close && this.sessionSocket.State == WebSocketState.CloseReceived)
{
LogTrace("Got WebSocket close message from browser");
await this.sessionSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken);
}
}

if (this.sessionSocket.State == WebSocketState.Open && result.MessageType != WebSocketMessageType.Close)
{
using (var stream = new MemoryStream())
{
stream.Write(buffer.Array, 0, result.Count);
while (!result.EndOfMessage)
{
result = await this.sessionSocket.ReceiveAsync(buffer, cancellationToken);
stream.Write(buffer.Array, 0, result.Count);
}

stream.Seek(0, SeekOrigin.Begin);
using (var reader = new StreamReader(stream, Encoding.UTF8))
{
string message = reader.ReadToEnd();

// fire and forget
// TODO: we need implement some kind of queue
Task.Run(() => ProcessIncomingMessage(message));
}
}
}
}
}
catch (OperationCanceledException)
// THe WebSockect connection is always closed before this method
// is called, so there will eventually be no more data written
// into the message queue, meaning this loop should be guaranteed
// to complete.
while (this.connection.IsActive)
{
await Task.Delay(TimeSpan.FromMilliseconds(10));
}
catch (WebSocketException)
{
}
finally

this.messageQueue.CompleteAdding();
await this.messageQueueMonitorTask;
}

private void MonitorMessageQueue()
{
// GetConsumingEnumerable blocks until if BlockingCollection.IsCompleted
// is false (i.e., is still able to be written to), and there are no items
// in the collection. Once any items are added to the collection, the method
// unblocks and we can process any items in the collection at that moment.
// Once IsCompleted is true, the method unblocks with no items in returned
// in the IEnumerable, meaning the foreach loop will terminate gracefully.
foreach (string message in this.messageQueue.GetConsumingEnumerable())
{
this.sessionSocket.Dispose();
this.sessionSocket = null;
this.ProcessMessage(message);
}
}

private void ProcessIncomingMessage(string message)
private void ProcessMessage(string message)
{
var messageObject = JObject.Parse(message);

Expand Down Expand Up @@ -594,7 +500,12 @@ private void ProcessIncomingMessage(string message)

LogTrace("Recieved Event {0}: {1}", method, eventData.ToString());

OnDevToolsEventReceived(new DevToolsEventReceivedEventArgs(methodParts[0], methodParts[1], eventData));
// Dispatch the event on a new thread so that any event handlers
// responding to the event will not block this thread from processing
// DevTools commands that may be sent in the body of the attached
// event handler. If thread pool starvation seems to become a problem,
// we can switch to a channel-based queue.
Task.Run(() => OnDevToolsEventReceived(new DevToolsEventReceivedEventArgs(methodParts[0], methodParts[1], eventData)));

return;
}
Expand All @@ -610,6 +521,11 @@ private void OnDevToolsEventReceived(DevToolsEventReceivedEventArgs e)
}
}

private void OnConnectionDataReceived(object sender, WebSocketConnectionDataReceivedEventArgs e)
{
this.messageQueue.Add(e.Data);
}

private void LogTrace(string message, params object[] args)
{
if (LogMessage != null)
Expand Down
Loading

0 comments on commit 739d177

Please sign in to comment.