Skip to content

Commit

Permalink
feat(repeater): support unsafe headers violated RFC
Browse files Browse the repository at this point in the history
closes #85
  • Loading branch information
derevnjuk committed Dec 6, 2022
1 parent 7528641 commit d966c73
Show file tree
Hide file tree
Showing 9 changed files with 519 additions and 74 deletions.
26 changes: 26 additions & 0 deletions src/SecTester.Repeater/Extensions/EnumerableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;

namespace SecTester.Repeater.Extensions;

internal static class EnumerableExtensions
{
public static void ForEach<T>(this IEnumerable<T> source, Action<T> action)
{
if (source == null)
{
throw new ArgumentNullException(nameof(source));
}

if (action == null)
{
throw new ArgumentNullException(nameof(action));
}

foreach (var item in source)
{
action(item);
}
}
}

22 changes: 22 additions & 0 deletions src/SecTester.Repeater/Extensions/ListExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;

namespace SecTester.Repeater.Extensions;

internal static class ListExtensions
{
public static int Replace<T>(this List<T> source, T newValue, Predicate<T> predicate)
{
if (source == null) throw new ArgumentNullException(nameof(source));
if (predicate == null) throw new ArgumentNullException(nameof(predicate));

var contentLenghtIdx = source.FindIndex(predicate);

if (contentLenghtIdx != -1)
{
source[contentLenghtIdx] = newValue;
}

return contentLenghtIdx;
}
}
99 changes: 60 additions & 39 deletions src/SecTester.Repeater/Runners/HttpRequestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,31 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Mime;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using SecTester.Repeater.Bus;
using SecTester.Repeater.Extensions;

namespace SecTester.Repeater.Runners;

internal class HttpRequestRunner : RequestRunner
internal sealed class HttpRequestRunner : RequestRunner
{
private const string DefaultMimeType = "text/plain";

private const string ContentLengthFieldName = "Content-Length";
private const string ContentTypeFieldName = "Content-Type";
private readonly HashSet<string> _contentHeaders = new(StringComparer.OrdinalIgnoreCase)
{
ContentLengthFieldName, ContentTypeFieldName
};

private readonly IHttpClientFactory _httpClientFactory;
private readonly RequestRunnerOptions _options;

public HttpRequestRunner(RequestRunnerOptions options, IHttpClientFactory httpClientFactory)
public HttpRequestRunner(RequestRunnerOptions options, IHttpClientFactory? httpClientFactory)
{
_options = options ?? throw new ArgumentNullException(nameof(options));
_httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
Expand All @@ -38,62 +46,58 @@ public async Task<Response> Run(Request request)
}
catch (Exception err)
{
return await CreateRequestExecutingResult(err).ConfigureAwait(false);
return new RequestExecutingResult
{
Message = err.Message,
// TODO: use native errno codes instead
ErrorCode = err is SocketException exception ? Enum.GetName(typeof(SocketError), exception.SocketErrorCode) : null
};
}
}

private static Task<RequestExecutingResult> CreateRequestExecutingResult(Exception response)
{
return Task.FromResult(new RequestExecutingResult
{
Message = response.Message,
// TODO: use native errno codes instead
ErrorCode = response is SocketException exception ? Enum.GetName(typeof(SocketError), exception.SocketErrorCode) : null
});
}

private async Task<RequestExecutingResult> CreateRequestExecutingResult(HttpResponseMessage response)
{
var body = await TruncateResponseBody(response).ConfigureAwait(false);
var headers = AggregateHeaders(response, body.Length);
var headers = AggregateHeaders(response);

if (body != null)
{
var contentLength = new KeyValuePair<string, IEnumerable<string>>(ContentLengthFieldName, new[]
{
$"{body.Length}"
});
headers.Replace(contentLength, x => x.Key.Equals(ContentLengthFieldName, StringComparison.OrdinalIgnoreCase));
}

return new RequestExecutingResult
{
Headers = headers,
StatusCode = (int)response.StatusCode,
Body = Encoding.UTF8.GetString(body)
Body = body?.ToString() ?? ""
};
}

private static IEnumerable<KeyValuePair<string, IEnumerable<string>>> AggregateHeaders(HttpResponseMessage response, int contentLength)
private static List<KeyValuePair<string, IEnumerable<string>>> AggregateHeaders(HttpResponseMessage response)
{

var headers = response.Headers.ToList();
headers.AddRange(response.Content.Headers);

var contentLenghtIdx = headers.FindIndex(x => x.Key.Equals(ContentLengthFieldName, StringComparison.OrdinalIgnoreCase));
if (contentLenghtIdx != -1)
{
headers[contentLenghtIdx] = new KeyValuePair<string, IEnumerable<string>>(ContentLengthFieldName, new[]
{
$"{contentLength}"
});
}

return headers;
}

private async Task<byte[]> TruncateResponseBody(HttpResponseMessage response)
private async Task<TruncatedBody?> TruncateResponseBody(HttpResponseMessage response)
{
if (response.StatusCode == HttpStatusCode.NoContent || response.RequestMessage.Method == HttpMethod.Head || response.Content == null)
{
return Array.Empty<byte>();
return null;
}

var type = response.Content.Headers.ContentType.MediaType ?? DefaultMimeType;
var allowed = _options.AllowedMimes.Any(mime => type.Contains(mime));
var contentType = response.Content.Headers.ContentType;
var mimeType = contentType?.MediaType ?? DefaultMimeType;
var allowed = _options.AllowedMimes.Any(mime => mimeType.Contains(mime));

return await ParseResponseBody(response, allowed).ConfigureAwait(false);
var body = await ParseResponseBody(response, allowed).ConfigureAwait(false);

return new TruncatedBody(body, contentType?.CharSet);
}

private async Task<byte[]> ParseResponseBody(HttpResponseMessage response, bool allowed)
Expand All @@ -111,28 +115,45 @@ private async Task<byte[]> ParseResponseBody(HttpResponseMessage response, bool
return body;
}

private static HttpRequestMessage CreateHttpRequestMessage(Request request)
private HttpRequestMessage CreateHttpRequestMessage(Request request)
{
var content = new StringContent(request.Body ?? "", Encoding.Default);
var content = request.Body != null ? CreateHttpContent(request) : null;
var options = new HttpRequestMessage
{
RequestUri = request.Url,
Method = request.Method,
Content = content
};

foreach (var keyValuePair in request.Headers)
request.Headers
.Where(x => !_contentHeaders.Contains(x.Key))
.ForEach(x => options.Headers.TryAddWithoutValidation(x.Key, x.Value));

return options;
}

private static StringContent? CreateHttpContent(Request request)
{
var values = request.Headers
.Where(header => header.Key.Equals(ContentTypeFieldName, StringComparison.OrdinalIgnoreCase))
.SelectMany(header => header.Value).ToArray();

if (!values.Any())
{
options.Headers.Add(keyValuePair.Key, keyValuePair.Value);
return null;
}

return options;
var mime = new ContentType(string.Join(", ", values));
var encoding = !string.IsNullOrEmpty(mime.CharSet) ? Encoding.GetEncoding(mime.CharSet) : Encoding.Default;

return new StringContent(request.Body, encoding, mime.MediaType);
}

private async Task<HttpResponseMessage> Request(HttpRequestMessage options, CancellationToken cancellationToken = default)
{
using var httpClient = _httpClientFactory.CreateClient(nameof(HttpRequestRunner));
return await httpClient.SendAsync(options,
cancellationToken).ConfigureAwait(false);
return await httpClient.SendAsync(options, cancellationToken).ConfigureAwait(false);
}
}


23 changes: 23 additions & 0 deletions src/SecTester.Repeater/Runners/TruncatedBody.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using System;
using System.Text;

namespace SecTester.Repeater.Runners;

internal sealed class TruncatedBody
{
public TruncatedBody(byte[] body, string? charSet = default)
{
Body = body;
Encoding = string.IsNullOrEmpty(charSet) ? Encoding.Default : Encoding.GetEncoding(charSet);
}

private Encoding Encoding { get; }
private byte[] Body { get; }
public int Length => Buffer.ByteLength(Body);

public override string ToString()
{
return Encoding.GetString(Body);
}
}

1 change: 0 additions & 1 deletion src/SecTester.Repeater/SecTester.Repeater.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
<ItemGroup>
<Folder Include="Api" />
<Folder Include="Internal" />
<Folder Include="Runners" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
namespace SecTester.Repeater.Tests.Extensions;

public class EnumerableExtensionsTests
{
private readonly Action<int> _action = Substitute.For<Action<int>>();

[Fact]
public void ForEach_SourceIsNotDefined_ThrowsError()
{
// act
var act = () => ((null as IEnumerable<int>)!).ForEach(_action);

// assert
act.Should().Throw<ArgumentNullException>().WithMessage("*source*");
}

[Fact]
public void ForEach_ActionIsNotDefined_ThrowsError()
{
// arrange
IEnumerable<int> list = new List<int>
{
1, 2, 3
};

// act
var act = () => list.ForEach(null!);

// assert
act.Should().Throw<ArgumentNullException>().WithMessage("*action*");
}

[Fact]
public void ForEach_IterateOverAllElements()
{
// arrange
IEnumerable<int> list = new List<int>
{
1, 2, 3
};

// act
list.ForEach(_action);

// assert
_action.Received(3)(Arg.Any<int>());
}

[Fact]
public void ForEach_ExecutesActionInCorrectOrder()
{
// arrange
var source = Enumerable.Range(1, 10);
var items = new List<int>();

// act
source.ForEach(x => items.Add(x));

// assert
items.Should().ContainInOrder(source);
}

[Fact]
public void ForEach_DoesNothingOnEmptyEnumerable()
{
// arrange
IEnumerable<int> list = new List<int>();

// act
list.ForEach(_action);

// assert
_action.DidNotReceive()(Arg.Any<int>());
}
}
Loading

0 comments on commit d966c73

Please sign in to comment.