Skip to content

Commit

Permalink
Server List: Treat each protocol separately
Browse files Browse the repository at this point in the history
SmartCMServerList now returns a new ServerRecord which represents an endpoint and only one possible protocol that the original ServerRecord supported

Internally it maintains a different last-bad-connection time per protocol.

Fixes #417
  • Loading branch information
yaakov-h committed Jul 29, 2017
1 parent 5323403 commit 2027498
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 53 deletions.
20 changes: 20 additions & 0 deletions SteamKit2/SteamKit2/Networking/Steam3/ProtocolType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


using System;
using System.Collections.Generic;

namespace SteamKit2
{
Expand Down Expand Up @@ -39,5 +40,24 @@ static class ProtocolTypesExtensions
{
public static bool HasFlagsFast(this ProtocolTypes self, ProtocolTypes flags)
=> (self & flags) > 0;

internal static IEnumerable<ProtocolTypes> GetFlags(this ProtocolTypes self)
{
if (self.HasFlagsFast(ProtocolTypes.Tcp))
{
yield return ProtocolTypes.Tcp;
}

if (self.HasFlagsFast(ProtocolTypes.Udp))
{
yield return ProtocolTypes.Udp;
}

if (self.HasFlagsFast(ProtocolTypes.WebSocket))
{
yield return ProtocolTypes.WebSocket;
}

}
}
}
6 changes: 3 additions & 3 deletions SteamKit2/SteamKit2/Steam/CMClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ void NetMsgReceived( object sender, NetMsgEventArgs e )

void Connected( object sender, EventArgs e )
{
Servers.TryMark( connection.CurrentEndPoint, ServerQuality.Good );
Servers.TryMark( connection.CurrentEndPoint, connection.ProtocolTypes, ServerQuality.Good );

IsConnected = true;
OnClientConnected();
Expand All @@ -371,7 +371,7 @@ void Disconnected( object sender, DisconnectedEventArgs e )

if ( !e.UserInitiated )
{
Servers.TryMark( connection.CurrentEndPoint, ServerQuality.Bad );
Servers.TryMark( connection.CurrentEndPoint, connection.ProtocolTypes, ServerQuality.Bad );
}

SessionID = null;
Expand Down Expand Up @@ -514,7 +514,7 @@ void HandleLoggedOff( IPacketMsg packetMsg )

if ( logoffResult == EResult.TryAnotherCM || logoffResult == EResult.ServiceUnavailable )
{
Servers.TryMark( connection.CurrentEndPoint, ServerQuality.Bad );
Servers.TryMark( connection.CurrentEndPoint, connection.ProtocolTypes, ServerQuality.Bad );
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion SteamKit2/SteamKit2/Steam/Discovery/ServerRecord.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace SteamKit2.Discovery
/// </summary>
public class ServerRecord
{
ServerRecord(EndPoint endPoint, ProtocolTypes protocolTypes)
internal ServerRecord(EndPoint endPoint, ProtocolTypes protocolTypes)
{
EndPoint = endPoint ?? throw new ArgumentNullException(nameof(endPoint));
ProtocolTypes = protocolTypes;
Expand Down
91 changes: 61 additions & 30 deletions SteamKit2/SteamKit2/Steam/Discovery/SmartCMServerList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,28 @@ public enum ServerQuality
/// </summary>
public class SmartCMServerList
{
[DebuggerDisplay("ServerInfo ({EndPoint}, Bad: {LastBadConnectionDateTimeUtc.HasValue})")]
[DebuggerDisplay("ServerInfo ({Record.EndPoint})")]
class ServerInfo
{
public ServerInfo( ServerRecord record )
{
Record = record;
LastBadConnectionTimeMap = new Dictionary<ProtocolTypes, DateTime?>();

foreach ( var protocolType in record.ProtocolTypes.GetFlags() )
{
ResetLastBadConnectionTime( protocolType );
}
}

public ServerRecord Record { get; set; }
public DateTime? LastBadConnectionDateTimeUtc { get; set; }
public Dictionary<ProtocolTypes, DateTime?> LastBadConnectionTimeMap { get; set; }

public void ResetLastBadConnectionTime( ProtocolTypes protocol )
=> LastBadConnectionTimeMap[ protocol ] = null;

public void SetLastBadConnectionTime( ProtocolTypes protocol, DateTime dateTime )
=> LastBadConnectionTimeMap[ protocol ] = dateTime;
}

/// <summary>
Expand Down Expand Up @@ -133,9 +150,13 @@ public void ResetOldScores()
{
foreach ( var serverInfo in servers )
{
if ( serverInfo.LastBadConnectionDateTimeUtc.HasValue && ( serverInfo.LastBadConnectionDateTimeUtc.Value + BadConnectionMemoryTimeSpan < DateTime.UtcNow ) )
foreach ( var protocolType in serverInfo.LastBadConnectionTimeMap.Keys )
{
serverInfo.LastBadConnectionDateTimeUtc = null;
var lastBadConnectionTime = serverInfo.LastBadConnectionTimeMap[ protocolType ];
if ( lastBadConnectionTime.HasValue && lastBadConnectionTime.Value + BadConnectionMemoryTimeSpan < DateTime.UtcNow )
{
serverInfo.LastBadConnectionTimeMap[ protocolType ] = null;
}
}
}
}
Expand All @@ -162,9 +183,9 @@ public void ReplaceList( IEnumerable<ServerRecord> endpointList )
}
}

void AddCore(ServerRecord endPoint )
void AddCore( ServerRecord endPoint )
{
var info = new ServerInfo { Record = endPoint };
var info = new ServerInfo( endPoint );

servers.Add( info );
}
Expand All @@ -178,12 +199,15 @@ public void ResetBadServers()
{
foreach ( var server in servers )
{
server.LastBadConnectionDateTimeUtc = null;
foreach ( var protocolType in server.LastBadConnectionTimeMap.Keys )
{
server.ResetLastBadConnectionTime( protocolType );
}
}
}
}

internal bool TryMark( EndPoint endPoint, ServerQuality quality )
internal bool TryMark( EndPoint endPoint, ProtocolTypes protocolTypes, ServerQuality quality )
{
lock ( listLock )
{
Expand All @@ -192,30 +216,32 @@ internal bool TryMark( EndPoint endPoint, ServerQuality quality )
{
return false;
}
MarkServerCore( serverInfo, quality );
MarkServerCore( serverInfo, protocolTypes, quality );
return true;
}
}

void MarkServerCore( ServerInfo serverInfo, ServerQuality quality )
void MarkServerCore( ServerInfo serverInfo, ProtocolTypes protocolTypes, ServerQuality quality )
{

switch ( quality )
foreach ( var protocol in protocolTypes.GetFlags() )
{
case ServerQuality.Good:
switch ( quality )
{
serverInfo.LastBadConnectionDateTimeUtc = null;
break;
}
case ServerQuality.Good:
{
serverInfo.ResetLastBadConnectionTime( protocol );
break;
}

case ServerQuality.Bad:
{
serverInfo.LastBadConnectionDateTimeUtc = DateTime.UtcNow;
break;
}
case ServerQuality.Bad:
{
serverInfo.SetLastBadConnectionTime( protocol, DateTime.UtcNow );
break;
}

default:
throw new ArgumentOutOfRangeException( "quality" );
default:
throw new ArgumentOutOfRangeException( "quality" );
}
}
}

Expand All @@ -232,20 +258,25 @@ private ServerRecord GetNextServerCandidateInternal( ProtocolTypes supportedProt
// isn't a problem.
ResetOldScores();

var serverInfo = servers
.Where( s => s.Record.ProtocolTypes.HasFlagsFast( supportedProtocolTypes ) )
.Select( (s, index) => new { Record = s.Record, IsBad = s.LastBadConnectionDateTimeUtc.HasValue, Index = index } )
.OrderBy( x => x.IsBad )
.ThenBy( x => x.Index )
.FirstOrDefault();
var query =
from o in servers.Select((server, index) => new { server, index })
let server = o.server
let index = o.index
where server.Record.ProtocolTypes.HasFlagsFast( supportedProtocolTypes )
from protocol in server.LastBadConnectionTimeMap.Keys
where supportedProtocolTypes.HasFlagsFast( protocol )
let lastBadConnectionTime = server.LastBadConnectionTimeMap[ protocol ]
orderby lastBadConnectionTime.HasValue, index
select new { Record = server.Record, IsBad = lastBadConnectionTime.HasValue, Index = index, Protocol = protocol };
var serverInfo = query.FirstOrDefault();

if ( serverInfo == null )
{
return null;
}

DebugWrite( $"Next server candidiate: {serverInfo.Record.EndPoint} ({serverInfo.Record.ProtocolTypes})" );
return serverInfo.Record;
return new ServerRecord( serverInfo.Record.EndPoint, serverInfo.Protocol );
}
}

Expand Down
2 changes: 0 additions & 2 deletions SteamKit2/SteamKit2/Steam/SteamClient/SteamConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ public SteamConfiguration()

/// <summary>
/// The supported protocol types to use when attempting to connect to Steam.
/// If <see cref="ProtocolTypes.Tcp"/> and <see cref="ProtocolTypes.Udp"/> are both specified, TCP will take precedence
/// and UDP will not be used.
/// </summary>
public ProtocolTypes ProtocolTypes { get; set; } = ProtocolTypes.Tcp;

Expand Down
75 changes: 58 additions & 17 deletions SteamKit2/Tests/SmartCMServerListFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public void GetNextServerCandidate_ReturnsServer_IfListHasServers()
serverList.ReplaceList( new List<ServerRecord>() { record } );

var nextRecord = serverList.GetNextServerCandidate( ProtocolTypes.Tcp );
Assert.Equal( record, nextRecord );
Assert.Equal( record.EndPoint, nextRecord.EndPoint );
Assert.Equal( ProtocolTypes.Tcp, nextRecord.ProtocolTypes );
}

[Fact]
Expand All @@ -73,33 +74,36 @@ public void GetNextServerCandidate_ReturnsServer_IfListHasServers_EvenIfAllServe

var record = ServerRecord.CreateSocketServer( new IPEndPoint( IPAddress.Loopback, 27015 ) );
serverList.ReplaceList( new List<ServerRecord>() { record } );
serverList.TryMark( record.EndPoint, ServerQuality.Bad );
serverList.TryMark( record.EndPoint, record.ProtocolTypes, ServerQuality.Bad );

var nextRecord = serverList.GetNextServerCandidate( ProtocolTypes.Tcp );
Assert.Equal( record, nextRecord );
Assert.Equal( record.EndPoint, nextRecord.EndPoint );
Assert.Equal( ProtocolTypes.Tcp, nextRecord.ProtocolTypes );
}

[Fact]
public void GetNextServerCandidate_IsBiasedTowardsServerOrdering()
{
serverList.GetAllEndPoints();

var goodRecord = ServerRecord.CreateSocketServer( new IPEndPoint( IPAddress.Loopback, 27015 ) );
var neutralRecord = ServerRecord.CreateSocketServer( new IPEndPoint( IPAddress.Loopback, 27016 ) );
var badRecord = ServerRecord.CreateSocketServer( new IPEndPoint( IPAddress.Loopback, 27017 ) );

serverList.ReplaceList( new List<ServerRecord>() { badRecord, neutralRecord, goodRecord } );

serverList.TryMark( badRecord.EndPoint, ServerQuality.Bad );
serverList.TryMark( goodRecord.EndPoint, ServerQuality.Good );
serverList.TryMark( badRecord.EndPoint, badRecord.ProtocolTypes, ServerQuality.Bad );
serverList.TryMark( goodRecord.EndPoint, goodRecord.ProtocolTypes, ServerQuality.Good );

var nextRecord = serverList.GetNextServerCandidate( ProtocolTypes.Tcp );
Assert.Equal( neutralRecord, nextRecord );
Assert.Equal( neutralRecord.EndPoint, nextRecord.EndPoint );
Assert.Equal( ProtocolTypes.Tcp, nextRecord.ProtocolTypes );

serverList.TryMark( badRecord.EndPoint, ServerQuality.Good);
serverList.TryMark( badRecord.EndPoint, badRecord.ProtocolTypes, ServerQuality.Good);

nextRecord = serverList.GetNextServerCandidate( ProtocolTypes.Tcp );
Assert.Equal( badRecord, nextRecord );
Assert.Equal( badRecord.EndPoint, nextRecord.EndPoint );
Assert.Equal( ProtocolTypes.Tcp, nextRecord.ProtocolTypes );
}

[Fact]
Expand All @@ -116,10 +120,12 @@ public void GetNextServerCandidate_OnlyReturnsMatchingServerOfType()
Assert.Null( endPoint );

endPoint = serverList.GetNextServerCandidate( ProtocolTypes.WebSocket );
Assert.Same( record, endPoint );
Assert.Equal( record.EndPoint, endPoint.EndPoint );
Assert.Equal( ProtocolTypes.WebSocket, endPoint.ProtocolTypes );

endPoint = serverList.GetNextServerCandidate( ProtocolTypes.All );
Assert.Same(record, endPoint);
Assert.Equal( record.EndPoint, endPoint.EndPoint );
Assert.Equal( ProtocolTypes.WebSocket, endPoint.ProtocolTypes );

record = ServerRecord.CreateSocketServer( new IPEndPoint( IPAddress.Loopback, 27015 ) );
serverList.ReplaceList( new List<ServerRecord>() { record } );
Expand All @@ -128,16 +134,20 @@ public void GetNextServerCandidate_OnlyReturnsMatchingServerOfType()
Assert.Null( endPoint );

endPoint = serverList.GetNextServerCandidate( ProtocolTypes.Tcp );
Assert.Same( record, endPoint );
Assert.Equal( record.EndPoint, endPoint.EndPoint );
Assert.Equal( ProtocolTypes.Tcp, endPoint.ProtocolTypes );

endPoint = serverList.GetNextServerCandidate( ProtocolTypes.Udp);
Assert.Same( record, endPoint );
Assert.Equal( record.EndPoint, endPoint.EndPoint );
Assert.Equal( ProtocolTypes.Udp, endPoint.ProtocolTypes );

endPoint = serverList.GetNextServerCandidate( ProtocolTypes.Tcp | ProtocolTypes.Udp );
Assert.Same( record, endPoint );
Assert.Equal( record.EndPoint, endPoint.EndPoint );
Assert.Equal( ProtocolTypes.Tcp, endPoint.ProtocolTypes );

endPoint = serverList.GetNextServerCandidate( ProtocolTypes.All );
Assert.Same(record, endPoint);
Assert.Equal( record.EndPoint, endPoint.EndPoint );
Assert.Equal( ProtocolTypes.Tcp, endPoint.ProtocolTypes );
}

[Fact]
Expand All @@ -146,7 +156,7 @@ public void TryMark_ReturnsTrue_IfServerInList()
var record = ServerRecord.CreateSocketServer( new IPEndPoint( IPAddress.Loopback, 27015 ));
serverList.ReplaceList( new List<ServerRecord>() { record } );

var marked = serverList.TryMark( record.EndPoint, ServerQuality.Good );
var marked = serverList.TryMark( record.EndPoint, record.ProtocolTypes, ServerQuality.Good );
Assert.True( marked );
}

Expand All @@ -156,8 +166,39 @@ public void TryMark_ReturnsFalse_IfServerNotInList()
var record = ServerRecord.CreateSocketServer( new IPEndPoint( IPAddress.Loopback, 27015 ) );
serverList.ReplaceList( new List<ServerRecord>() { record } );

var marked = serverList.TryMark( new IPEndPoint( IPAddress.Loopback, 27016 ), ServerQuality.Good );
var marked = serverList.TryMark( new IPEndPoint( IPAddress.Loopback, 27016 ), record.ProtocolTypes, ServerQuality.Good );
Assert.False( marked );
}

[Fact]
public void TreatsProtocolsForSameServerIndividiaully()
{
var record1 = ServerRecord.CreateServer( IPAddress.Loopback.ToString(), 27015, ProtocolTypes.Tcp | ProtocolTypes.Udp );
var record2 = ServerRecord.CreateServer( IPAddress.Loopback.ToString(), 27016, ProtocolTypes.Tcp | ProtocolTypes.Udp );

serverList.ReplaceList( new[] { record1, record2 } );

var nextTcp = serverList.GetNextServerCandidate( ProtocolTypes.Tcp );
var nextUdp = serverList.GetNextServerCandidate( ProtocolTypes.Udp );

Assert.Equal( record1.EndPoint, nextTcp.EndPoint );
Assert.Equal( record1.EndPoint, nextUdp.EndPoint );

serverList.TryMark( record1.EndPoint, ProtocolTypes.Tcp, ServerQuality.Bad );

nextTcp = serverList.GetNextServerCandidate( ProtocolTypes.Tcp );
nextUdp = serverList.GetNextServerCandidate( ProtocolTypes.Udp );

Assert.Equal( record2.EndPoint, nextTcp.EndPoint );
Assert.Equal( record1.EndPoint, nextUdp.EndPoint );

serverList.TryMark( record1.EndPoint, ProtocolTypes.Udp, ServerQuality.Bad );

nextTcp = serverList.GetNextServerCandidate( ProtocolTypes.Tcp );
nextUdp = serverList.GetNextServerCandidate( ProtocolTypes.Udp );

Assert.Equal( record2.EndPoint, nextTcp.EndPoint );
Assert.Equal( record2.EndPoint, nextUdp.EndPoint );
}
}
}

0 comments on commit 2027498

Please sign in to comment.