Skip to content

Commit

Permalink
feat(repeater): implement WS request runner (#91)
Browse files Browse the repository at this point in the history
closes #85
  • Loading branch information
derevnjuk authored Dec 8, 2022
1 parent 7a43088 commit 224d154
Show file tree
Hide file tree
Showing 19 changed files with 973 additions and 54 deletions.
24 changes: 16 additions & 8 deletions src/SecTester.Repeater/Extensions/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Linq;
using System.Net;
using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
Expand All @@ -13,17 +14,23 @@ public static class ServiceCollectionExtensions
{
public static IServiceCollection AddSecTesterRepeater(this IServiceCollection collection)
{
return AddSecTesterRepeater(collection, () => new RequestRunnerOptions());
return AddSecTesterRepeater(collection, new RequestRunnerOptions());
}

public static IServiceCollection AddSecTesterRepeater(this IServiceCollection collection, Func<RequestRunnerOptions> configure)
public static IServiceCollection AddSecTesterRepeater(this IServiceCollection collection, RequestRunnerOptions options)
{
return collection
.AddScoped<RepeaterFactory, DefaultRepeaterFactory>()
.AddScoped<RequestExecutingEventHandler>()
.AddScoped(_ => configure())
.AddScoped(_ => options)
.AddScoped<Repeaters, DefaultRepeaters>()
.AddScoped<TimerProvider, SystemTimerProvider>()
.AddScoped<WebSocketFactory, DefaultWebSocketFactory>()
.AddScoped<RequestRunner, HttpRequestRunner>()
.AddScoped<RequestRunner, WsRequestRunner>()
.AddScoped<RequestRunnerResolver>(sp =>
protocol => sp.GetServices<RequestRunner>().FirstOrDefault(x => x.Protocol == protocol)
)
.AddHttpClientForHttpRequestRunner();
}

Expand Down Expand Up @@ -60,12 +67,13 @@ private static void ConfigureHttpClient(IServiceProvider sp, HttpClient client)
client.DefaultRequestHeaders.Add(keyValuePair.Key, keyValuePair.Value);
}

if (!config.ReuseConnection)
if (config.ReuseConnection)
{
return;
client.DefaultRequestHeaders.Add("Connection", "keep-alive");
client.DefaultRequestHeaders.Add("Keep-Alive", config.Timeout.ToString());
}

client.DefaultRequestHeaders.Add("Connection", "keep-alive");
client.DefaultRequestHeaders.Add("Keep-Alive", config.Timeout.ToString());
}
}



35 changes: 35 additions & 0 deletions src/SecTester.Repeater/Runners/DefaultWebSocketFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System;
using System.Net;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;

namespace SecTester.Repeater.Runners;

public class DefaultWebSocketFactory : WebSocketFactory
{
private readonly RequestRunnerOptions _options;

public DefaultWebSocketFactory(RequestRunnerOptions options)
{
_options = options ?? throw new ArgumentNullException(nameof(options));
}

public async Task<WebSocket> CreateWebSocket(Uri uri, CancellationToken cancellationToken = default)
{
var proxy = _options.ProxyUrl is not null ? new WebProxy(_options.ProxyUrl) : null;
// TODO: disable certs validation. For details see https://github.com/dotnet/runtime/issues/18696
var client = new ClientWebSocket
{
Options =
{
Proxy = proxy, KeepAliveInterval = _options.Timeout
}
};

await client.ConnectAsync(uri, cancellationToken).ConfigureAwait(false);

return client;
}
}

4 changes: 2 additions & 2 deletions src/SecTester.Repeater/Runners/HttpRequestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private static List<KeyValuePair<string, IEnumerable<string>>> AggregateHeaders(
return headers;
}

private async Task<TruncatedBody?> TruncateResponseBody(HttpResponseMessage response)
private async Task<ResponseBody?> TruncateResponseBody(HttpResponseMessage response)
{
if (response.StatusCode == HttpStatusCode.NoContent || response.RequestMessage.Method == HttpMethod.Head || response.Content == null)
{
Expand All @@ -97,7 +97,7 @@ private static List<KeyValuePair<string, IEnumerable<string>>> AggregateHeaders(

var body = await ParseResponseBody(response, allowed).ConfigureAwait(false);

return new TruncatedBody(body, contentType?.CharSet);
return new ResponseBody(body, contentType?.CharSet);
}

private async Task<byte[]> ParseResponseBody(HttpResponseMessage response, bool allowed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

namespace SecTester.Repeater.Runners;

internal sealed class TruncatedBody
internal class ResponseBody
{
public TruncatedBody(byte[] body, string? charSet = default)
public ResponseBody(byte[] body, string? charSet = default)
{
Body = body;
Encoding = string.IsNullOrEmpty(charSet) ? Encoding.Default : Encoding.GetEncoding(charSet);
}

private Encoding Encoding { get; }
private byte[] Body { get; }
protected byte[] Body { get; }
public int Length => Buffer.ByteLength(Body);

public override string ToString()
Expand Down
12 changes: 12 additions & 0 deletions src/SecTester.Repeater/Runners/WebSocketFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using System;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;

namespace SecTester.Repeater.Runners;

public interface WebSocketFactory
{
public Task<WebSocket> CreateWebSocket(Uri uri, CancellationToken cancellationToken = default);
}

17 changes: 17 additions & 0 deletions src/SecTester.Repeater/Runners/WebSocketResponseBody.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System.Net.WebSockets;

namespace SecTester.Repeater.Runners;

internal sealed class WebSocketResponseBody : ResponseBody
{
public WebSocketResponseBody(byte[] body, WebSocketCloseStatus? statusCode = default, string? statusDescription = default) : base(body)
{
StatusCode = statusCode;
StatusDescription = statusDescription;
}

public WebSocketCloseStatus? StatusCode { get; }
public string? StatusDescription { get; }
}


167 changes: 167 additions & 0 deletions src/SecTester.Repeater/Runners/WsRequestRunner.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.WebSockets;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using SecTester.Core.Extensions;
using SecTester.Repeater.Bus;

namespace SecTester.Repeater.Runners;

internal sealed class WsRequestRunner : RequestRunner
{
private const WebSocketCloseStatus DefaultStatusCode = WebSocketCloseStatus.NormalClosure;
private const int MaxBufferSize = 1024 * 4;
private readonly SemaphoreSlim _lock = new(1, 1);

private readonly RequestRunnerOptions _options;
private readonly WebSocketFactory _webSocketFactory;

public WsRequestRunner(RequestRunnerOptions options, WebSocketFactory webSocketFactory)
{
_options = options ?? throw new ArgumentNullException(nameof(options));
_webSocketFactory = webSocketFactory ?? throw new ArgumentNullException(nameof(webSocketFactory));
}

public Protocol Protocol => Protocol.Ws;

public async Task<Response> Run(Request request)
{
using var cts = new CancellationTokenSource(_options.Timeout);

WebSocket? client = null;

try
{
client = await _webSocketFactory.CreateWebSocket(request.Url, cts.Token).ConfigureAwait(false);

var result = await SendAndRetrieve(client, request, cts.Token).ConfigureAwait(false);

return CreateRequestExecutingResult(client, result);
}
catch (Exception err)
{
return CreateRequestExecutingResult(err);
}
finally
{
if (client != null)
{
await CloseSocket(client, cts.Token).ConfigureAwait(false);
}
}
}

private async Task<WebSocketResponseBody> SendAndRetrieve(WebSocket client, Request request, CancellationToken cancellationToken)
{
var message = BuildMessage(request);

await Send(client, message, cancellationToken).ConfigureAwait(false);

return await Consume(request, client, cancellationToken).ConfigureAwait(false);
}

private async Task Send(WebSocket client, ArraySegment<byte> message, CancellationToken cancellationToken)
{
using var _ = await _lock.LockAsync(cancellationToken).ConfigureAwait(false);
await client.SendAsync(message, WebSocketMessageType.Text, true, cancellationToken).ConfigureAwait(false);
}

private static async Task CloseSocket(WebSocket client, CancellationToken cancellationToken)
{
try
{
if (!client.CloseStatus.HasValue)
{
await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", cancellationToken).ConfigureAwait(false);
}

client.Dispose();
}
catch
{
// noop
}
}

private static async IAsyncEnumerable<WebSocketResponseBody> ConsumeMessage(WebSocket client,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
using var stream = new MemoryStream();
var buffer = new ArraySegment<byte>(new byte[MaxBufferSize]);

while (!client.CloseStatus.HasValue)
{
var result = await client.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false);

if (buffer.Array != null)
{
await stream.WriteAsync(buffer.Array, buffer.Offset, result.Count, cancellationToken).ConfigureAwait(false);
}

if (!result.CloseStatus.HasValue && !result.EndOfMessage)
{
continue;
}

stream.Seek(0, SeekOrigin.Begin);
yield return new WebSocketResponseBody(stream.ToArray(), result.CloseStatus, result.CloseStatusDescription);
}
}

private static ValueTask<WebSocketResponseBody> Consume(Request request, WebSocket client,
CancellationToken cancellationToken)
{
return ConsumeMessage(client, cancellationToken)
.FirstAsync(r => request.CorrelationIdRegex is null || request.CorrelationIdRegex.IsMatch(r.ToString()), cancellationToken);
}

private static RequestExecutingResult CreateRequestExecutingResult(WebSocket client, WebSocketResponseBody result)
{
var closeStatus = result.StatusCode ?? client.CloseStatus ?? DefaultStatusCode;
var statusDescription = result.StatusDescription ?? client.CloseStatusDescription;

return new RequestExecutingResult
{
Protocol = Protocol.Ws,
Message = statusDescription,
StatusCode = (int)closeStatus,
Body = result.ToString()
};
}

private static RequestExecutingResult CreateRequestExecutingResult(Exception exception)
{
var errorCode = GetErrorCode(exception);

return new RequestExecutingResult
{
Protocol = Protocol.Ws,
Message = exception.Message.TrimEnd(),
ErrorCode = errorCode
};
}

private static string? GetErrorCode(Exception err)
{
// TODO: use native errno codes instead
return err switch
{
WebSocketException exception => Enum.GetName(typeof(WebSocketError), exception.WebSocketErrorCode),
_ => null
};
}

private static ArraySegment<byte> BuildMessage(Request message)
{
var buffer = Encoding.Default.GetBytes(message.Body ?? "");
return new ArraySegment<byte>(buffer);
}
}



7 changes: 7 additions & 0 deletions src/SecTester.Repeater/SecTester.Repeater.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
<ProjectReference Include="..\SecTester.Core\SecTester.Core.csproj" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
</ItemGroup>

<ItemGroup>
<Folder Include="Api" />
<Folder Include="Bus" />
<Folder Include="Extensions" />
<Folder Include="Internal" />
<Folder Include="Runners" />
</ItemGroup>

<ItemGroup>
Expand Down
21 changes: 10 additions & 11 deletions src/SecTester.Repeater/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
"Microsoft.NETCore.Platforms": "1.1.0"
}
},
"System.Linq.Async": {
"type": "Direct",
"requested": "[6.0.1, )",
"resolved": "6.0.1",
"contentHash": "0YhHcaroWpQ9UCot3Pizah7ryAzQhNvobLMSxeDIGmnXfkQn8u5owvpOH0K6EVB+z9L7u6Cc4W17Br/+jyttEQ==",
"dependencies": {
"Microsoft.Bcl.AsyncInterfaces": "6.0.0"
}
},
"Macross.Json.Extensions": {
"type": "Transitive",
"resolved": "3.0.0",
Expand Down Expand Up @@ -291,7 +300,7 @@
"Microsoft.Extensions.DependencyInjection.Abstractions": "[6.0.0, )",
"Microsoft.Extensions.Http": "[6.0.0, )",
"RabbitMQ.Client": "[6.4.0, )",
"SecTester.Core": "[0.18.0, )",
"SecTester.Core": "[0.26.0, )",
"System.Text.Json": "[6.0.0, )",
"System.Threading.RateLimiting": "[7.0.0, )"
}
Expand All @@ -303,16 +312,6 @@
"Microsoft.Extensions.Logging": "[6.0.0, )",
"Microsoft.Extensions.Logging.Console": "[6.0.0, )"
}
},
"sectester.scan": {
"type": "Project",
"dependencies": {
"Macross.Json.Extensions": "[3.0.0, )",
"Microsoft.Extensions.DependencyInjection.Abstractions": "[6.0.0, )",
"SecTester.Bus": "[0.18.0, )",
"SecTester.Core": "[0.18.0, )",
"System.Text.Json": "[6.0.0, )"
}
}
}
}
Expand Down
Loading

0 comments on commit 224d154

Please sign in to comment.