From 517d0523ca3f2931268a8d92e628e84de29e566c Mon Sep 17 00:00:00 2001 From: Ryan Kistner Date: Fri, 1 Dec 2017 00:41:59 -0700 Subject: [PATCH] Connection implementations shall issue the Disconnected callback before returning and must be in a destructing state. --- .../Networking/Steam3/TcpConnection.cs | 50 ++------ .../Networking/Steam3/UdpConnection.cs | 9 +- SteamKit2/SteamKit2/Steam/CMClient.cs | 120 ++++++++++-------- 3 files changed, 83 insertions(+), 96 deletions(-) diff --git a/SteamKit2/SteamKit2/Networking/Steam3/TcpConnection.cs b/SteamKit2/SteamKit2/Networking/Steam3/TcpConnection.cs index 672613836..e4f303e29 100644 --- a/SteamKit2/SteamKit2/Networking/Steam3/TcpConnection.cs +++ b/SteamKit2/SteamKit2/Networking/Steam3/TcpConnection.cs @@ -24,14 +24,11 @@ class TcpConnection : IConnection private BinaryWriter netWriter; private CancellationTokenSource cancellationToken; - private ManualResetEvent connectionFree; - private object netLock, connectLock; + private object netLock; public TcpConnection() { netLock = new object(); - connectLock = new object(); - connectionFree = new ManualResetEvent(true); } public event EventHandler NetMsgReceived; @@ -98,8 +95,6 @@ private void Release( bool userRequestedDisconnect ) } Disconnected?.Invoke( this, new DisconnectedEventArgs( userRequestedDisconnect ) ); - - connectionFree.Set(); } private void ConnectCompleted(bool success) @@ -184,46 +179,29 @@ private void TryConnect(object sender) /// Timeout in milliseconds public void Connect(EndPoint endPoint, int timeout) { - lock (connectLock) + lock ( netLock ) { - Disconnect(); - - connectionFree.Reset(); - - lock (netLock) - { - Debug.Assert(cancellationToken == null); - cancellationToken = new CancellationTokenSource(); + Debug.Assert( cancellationToken == null ); + cancellationToken = new CancellationTokenSource(); - socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - socket.ReceiveTimeout = timeout; - socket.SendTimeout = timeout; + socket = new Socket( AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp ); + socket.ReceiveTimeout = timeout; + socket.SendTimeout = timeout; - destination = endPoint; - DebugLog.WriteLine("TcpConnection", "Connecting to {0}...", destination); - TryConnect(timeout); - } + destination = endPoint; + DebugLog.WriteLine( "TcpConnection", "Connecting to {0}...", destination ); + TryConnect( timeout ); } + } public void Disconnect() { - lock (connectLock) + lock ( netLock ) { - lock (netLock) - { - if (cancellationToken != null) - { - cancellationToken.Cancel(); - } - else - { - // we already appear to be disconncted, nothing to wait for - return; - } - } + cancellationToken?.Cancel(); - connectionFree.WaitOne(); + Disconnected?.Invoke( this, new DisconnectedEventArgs( true ) ); } } diff --git a/SteamKit2/SteamKit2/Networking/Steam3/UdpConnection.cs b/SteamKit2/SteamKit2/Networking/Steam3/UdpConnection.cs index 664a8d68d..89c843642 100644 --- a/SteamKit2/SteamKit2/Networking/Steam3/UdpConnection.cs +++ b/SteamKit2/SteamKit2/Networking/Steam3/UdpConnection.cs @@ -56,7 +56,7 @@ private enum State private DateTime timeOut; private DateTime nextResend; - private uint sourceConnId = 512; + private static uint sourceConnId = 512; private uint remoteConnId; /// @@ -116,8 +116,6 @@ public UdpConnection() /// Timeout in milliseconds public void Connect(EndPoint endPoint, int timeout) { - Disconnect(); - outPackets = new List(); inPackets = new Dictionary(); @@ -159,11 +157,10 @@ public void Disconnect() SendSequenced(new UdpPacket(EUdpPacketType.Disconnect)); } - // Graceful shutdown allows for the connection to empty its queue of messages to send - netThread.Join(); - // Advance this the same way that steam does, when a socket gets reused. sourceConnId += 256; + + Disconnected?.Invoke( this, new DisconnectedEventArgs( true ) ); } /// diff --git a/SteamKit2/SteamKit2/Steam/CMClient.cs b/SteamKit2/SteamKit2/Steam/CMClient.cs index 6f446ce85..32674fbbd 100644 --- a/SteamKit2/SteamKit2/Steam/CMClient.cs +++ b/SteamKit2/SteamKit2/Steam/CMClient.cs @@ -38,7 +38,7 @@ public abstract class CMClient /// Returns the the local IP of this client. /// /// The local IP. - public IPAddress LocalIP => connection.GetLocalIP(); + public IPAddress LocalIP => connection?.GetLocalIP(); /// /// Gets the universe of this client. @@ -94,6 +94,8 @@ public abstract class CMClient internal bool ExpectDisconnection { get; set; } + // connection lock around the setup and tear down of the connection task + object connectionLock = new object(); CancellationTokenSource connectionCancellation; Task connectionSetupTask; IConnection connection; @@ -131,59 +133,62 @@ public CMClient( SteamConfiguration configuration ) /// The of the CM server to connect to. /// If null, SteamKit will randomly select a CM server from its internal list. /// - public void Connect( ServerRecord cmServer = null ) + public void Connect( ServerRecord cmServer = null ) { - this.Disconnect(); - Debug.Assert( connection == null ); + lock ( connectionLock ) + { + this.Disconnect(); + Debug.Assert( connection == null ); - var cancellation = new CancellationTokenSource(); - var token = cancellation.Token; - var oldCancellation = Interlocked.Exchange( ref connectionCancellation, cancellation ); - Debug.Assert( oldCancellation == null ); + Debug.Assert( connectionCancellation == null ); - ExpectDisconnection = false; + connectionCancellation = new CancellationTokenSource(); + var token = connectionCancellation.Token; - Task recordTask = null; + ExpectDisconnection = false; - if ( cmServer == null ) - { - recordTask = Servers.GetNextServerCandidateAsync( Configuration.ProtocolTypes ); - } - else - { - recordTask = Task.FromResult( cmServer ); - } + Task recordTask = null; - connectionSetupTask = recordTask.ContinueWith( t => - { - if ( token.IsCancellationRequested ) + if ( cmServer == null ) { - DebugLog.WriteLine( nameof(CMClient), "Connection cancelled before a server could be chosen." ); - OnClientDisconnected( userInitiated: true ); - return; + recordTask = Servers.GetNextServerCandidateAsync( Configuration.ProtocolTypes ); } - else if ( t.IsFaulted || t.IsCanceled ) + else { - DebugLog.WriteLine( nameof(CMClient), "Server record task threw exception: {0}", t.Exception ); - OnClientDisconnected( userInitiated: false ); - return; + recordTask = Task.FromResult( cmServer ); } - var record = t.Result; - - connection = CreateConnection( record.ProtocolTypes & Configuration.ProtocolTypes ); - connection.NetMsgReceived += NetMsgReceived; - connection.Connected += Connected; - connection.Disconnected += Disconnected; - connection.Connect( record.EndPoint, ( int )ConnectionTimeout.TotalMilliseconds ); - }, TaskContinuationOptions.ExecuteSynchronously).ContinueWith(t => - { - if ( t.IsFaulted ) + connectionSetupTask = recordTask.ContinueWith( t => { - DebugLog.WriteLine( nameof(CMClient), "Unhandled exception when attempting to connect to Steam: {0}", t.Exception ); - OnClientDisconnected( userInitiated: false ); - } - }, TaskContinuationOptions.ExecuteSynchronously); + if ( token.IsCancellationRequested ) + { + DebugLog.WriteLine( nameof( CMClient ), "Connection cancelled before a server could be chosen." ); + OnClientDisconnected( userInitiated: true ); + return; + } + else if ( t.IsFaulted || t.IsCanceled ) + { + DebugLog.WriteLine( nameof( CMClient ), "Server record task threw exception: {0}", t.Exception ); + OnClientDisconnected( userInitiated: false ); + return; + } + + var record = t.Result; + + connection = CreateConnection( record.ProtocolTypes & Configuration.ProtocolTypes ); + connection.NetMsgReceived += NetMsgReceived; + connection.Connected += Connected; + connection.Disconnected += Disconnected; + connection.Connect( record.EndPoint, ( int )ConnectionTimeout.TotalMilliseconds ); + }, TaskContinuationOptions.ExecuteSynchronously ).ContinueWith( t => + { + if ( t.IsFaulted ) + { + DebugLog.WriteLine( nameof( CMClient ), "Unhandled exception when attempting to connect to Steam: {0}", t.Exception ); + OnClientDisconnected( userInitiated: false ); + } + }, TaskContinuationOptions.ExecuteSynchronously ); + } } /// @@ -191,21 +196,28 @@ public void Connect( ServerRecord cmServer = null ) /// public void Disconnect() { - heartBeatFunc.Stop(); - - var cts = Interlocked.Exchange(ref connectionCancellation, null); - if (cts != null) + lock ( connectionLock ) { - cts.Cancel(); - cts.Dispose(); - } + heartBeatFunc.Stop(); - var task = Interlocked.Exchange(ref connectionSetupTask, null); - if ( task != null ) - { - task.GetAwaiter().GetResult(); + if ( connectionCancellation != null ) + { + connectionCancellation.Cancel(); + connectionCancellation.Dispose(); + connectionCancellation = null; + } + + if ( connectionSetupTask != null ) + { + // though it's ugly, we want to wait for the completion of this task and keep hold of the lock + connectionSetupTask.GetAwaiter().GetResult(); + connectionSetupTask = null; + } + + // Connection implementations are required to issue the Disconnected callback before Disconnect() returns + connection?.Disconnect(); + Debug.Assert( connection == null ); } - connection?.Disconnect(); } ///