From d966c7300072b6c010743592f43c016844910de8 Mon Sep 17 00:00:00 2001 From: Artem Derevnjuk Date: Wed, 7 Dec 2022 00:05:51 +0400 Subject: [PATCH] feat(repeater): support unsafe headers violated RFC closes #85 --- .../Extensions/EnumerableExtensions.cs | 26 +++ .../Extensions/ListExtensions.cs | 22 ++ .../Runners/HttpRequestRunner.cs | 99 ++++---- .../Runners/TruncatedBody.cs | 23 ++ .../SecTester.Repeater.csproj | 1 - .../Extensions/EnumerableExtensionsTests.cs | 75 ++++++ .../Extensions/ListExtensionsTests.cs | 131 +++++++++++ .../Runners/HttpRequestRunnerTests.cs | 214 +++++++++++++++--- test/SecTester.Repeater.Tests/Usings.cs | 2 +- 9 files changed, 519 insertions(+), 74 deletions(-) create mode 100644 src/SecTester.Repeater/Extensions/EnumerableExtensions.cs create mode 100644 src/SecTester.Repeater/Extensions/ListExtensions.cs create mode 100644 src/SecTester.Repeater/Runners/TruncatedBody.cs create mode 100644 test/SecTester.Repeater.Tests/Extensions/EnumerableExtensionsTests.cs create mode 100644 test/SecTester.Repeater.Tests/Extensions/ListExtensionsTests.cs diff --git a/src/SecTester.Repeater/Extensions/EnumerableExtensions.cs b/src/SecTester.Repeater/Extensions/EnumerableExtensions.cs new file mode 100644 index 0000000..39e7548 --- /dev/null +++ b/src/SecTester.Repeater/Extensions/EnumerableExtensions.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; + +namespace SecTester.Repeater.Extensions; + +internal static class EnumerableExtensions +{ + public static void ForEach(this IEnumerable source, Action action) + { + if (source == null) + { + throw new ArgumentNullException(nameof(source)); + } + + if (action == null) + { + throw new ArgumentNullException(nameof(action)); + } + + foreach (var item in source) + { + action(item); + } + } +} + diff --git a/src/SecTester.Repeater/Extensions/ListExtensions.cs b/src/SecTester.Repeater/Extensions/ListExtensions.cs new file mode 100644 index 0000000..fc267d3 --- /dev/null +++ b/src/SecTester.Repeater/Extensions/ListExtensions.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; + +namespace SecTester.Repeater.Extensions; + +internal static class ListExtensions +{ + public static int Replace(this List source, T newValue, Predicate predicate) + { + if (source == null) throw new ArgumentNullException(nameof(source)); + if (predicate == null) throw new ArgumentNullException(nameof(predicate)); + + var contentLenghtIdx = source.FindIndex(predicate); + + if (contentLenghtIdx != -1) + { + source[contentLenghtIdx] = newValue; + } + + return contentLenghtIdx; + } +} diff --git a/src/SecTester.Repeater/Runners/HttpRequestRunner.cs b/src/SecTester.Repeater/Runners/HttpRequestRunner.cs index 950fd16..b1c8b44 100644 --- a/src/SecTester.Repeater/Runners/HttpRequestRunner.cs +++ b/src/SecTester.Repeater/Runners/HttpRequestRunner.cs @@ -3,23 +3,31 @@ using System.Linq; using System.Net; using System.Net.Http; +using System.Net.Mime; using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; using SecTester.Repeater.Bus; +using SecTester.Repeater.Extensions; namespace SecTester.Repeater.Runners; -internal class HttpRequestRunner : RequestRunner +internal sealed class HttpRequestRunner : RequestRunner { private const string DefaultMimeType = "text/plain"; private const string ContentLengthFieldName = "Content-Length"; + private const string ContentTypeFieldName = "Content-Type"; + private readonly HashSet _contentHeaders = new(StringComparer.OrdinalIgnoreCase) + { + ContentLengthFieldName, ContentTypeFieldName + }; + private readonly IHttpClientFactory _httpClientFactory; private readonly RequestRunnerOptions _options; - public HttpRequestRunner(RequestRunnerOptions options, IHttpClientFactory httpClientFactory) + public HttpRequestRunner(RequestRunnerOptions options, IHttpClientFactory? httpClientFactory) { _options = options ?? throw new ArgumentNullException(nameof(options)); _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory)); @@ -38,62 +46,58 @@ public async Task Run(Request request) } catch (Exception err) { - return await CreateRequestExecutingResult(err).ConfigureAwait(false); + return new RequestExecutingResult + { + Message = err.Message, + // TODO: use native errno codes instead + ErrorCode = err is SocketException exception ? Enum.GetName(typeof(SocketError), exception.SocketErrorCode) : null + }; } } - private static Task CreateRequestExecutingResult(Exception response) - { - return Task.FromResult(new RequestExecutingResult - { - Message = response.Message, - // TODO: use native errno codes instead - ErrorCode = response is SocketException exception ? Enum.GetName(typeof(SocketError), exception.SocketErrorCode) : null - }); - } - private async Task CreateRequestExecutingResult(HttpResponseMessage response) { var body = await TruncateResponseBody(response).ConfigureAwait(false); - var headers = AggregateHeaders(response, body.Length); + var headers = AggregateHeaders(response); + + if (body != null) + { + var contentLength = new KeyValuePair>(ContentLengthFieldName, new[] + { + $"{body.Length}" + }); + headers.Replace(contentLength, x => x.Key.Equals(ContentLengthFieldName, StringComparison.OrdinalIgnoreCase)); + } return new RequestExecutingResult { Headers = headers, StatusCode = (int)response.StatusCode, - Body = Encoding.UTF8.GetString(body) + Body = body?.ToString() ?? "" }; } - private static IEnumerable>> AggregateHeaders(HttpResponseMessage response, int contentLength) + private static List>> AggregateHeaders(HttpResponseMessage response) { - var headers = response.Headers.ToList(); headers.AddRange(response.Content.Headers); - - var contentLenghtIdx = headers.FindIndex(x => x.Key.Equals(ContentLengthFieldName, StringComparison.OrdinalIgnoreCase)); - if (contentLenghtIdx != -1) - { - headers[contentLenghtIdx] = new KeyValuePair>(ContentLengthFieldName, new[] - { - $"{contentLength}" - }); - } - return headers; } - private async Task TruncateResponseBody(HttpResponseMessage response) + private async Task TruncateResponseBody(HttpResponseMessage response) { if (response.StatusCode == HttpStatusCode.NoContent || response.RequestMessage.Method == HttpMethod.Head || response.Content == null) { - return Array.Empty(); + return null; } - var type = response.Content.Headers.ContentType.MediaType ?? DefaultMimeType; - var allowed = _options.AllowedMimes.Any(mime => type.Contains(mime)); + var contentType = response.Content.Headers.ContentType; + var mimeType = contentType?.MediaType ?? DefaultMimeType; + var allowed = _options.AllowedMimes.Any(mime => mimeType.Contains(mime)); - return await ParseResponseBody(response, allowed).ConfigureAwait(false); + var body = await ParseResponseBody(response, allowed).ConfigureAwait(false); + + return new TruncatedBody(body, contentType?.CharSet); } private async Task ParseResponseBody(HttpResponseMessage response, bool allowed) @@ -111,9 +115,9 @@ private async Task ParseResponseBody(HttpResponseMessage response, bool return body; } - private static HttpRequestMessage CreateHttpRequestMessage(Request request) + private HttpRequestMessage CreateHttpRequestMessage(Request request) { - var content = new StringContent(request.Body ?? "", Encoding.Default); + var content = request.Body != null ? CreateHttpContent(request) : null; var options = new HttpRequestMessage { RequestUri = request.Url, @@ -121,18 +125,35 @@ private static HttpRequestMessage CreateHttpRequestMessage(Request request) Content = content }; - foreach (var keyValuePair in request.Headers) + request.Headers + .Where(x => !_contentHeaders.Contains(x.Key)) + .ForEach(x => options.Headers.TryAddWithoutValidation(x.Key, x.Value)); + + return options; + } + + private static StringContent? CreateHttpContent(Request request) + { + var values = request.Headers + .Where(header => header.Key.Equals(ContentTypeFieldName, StringComparison.OrdinalIgnoreCase)) + .SelectMany(header => header.Value).ToArray(); + + if (!values.Any()) { - options.Headers.Add(keyValuePair.Key, keyValuePair.Value); + return null; } - return options; + var mime = new ContentType(string.Join(", ", values)); + var encoding = !string.IsNullOrEmpty(mime.CharSet) ? Encoding.GetEncoding(mime.CharSet) : Encoding.Default; + + return new StringContent(request.Body, encoding, mime.MediaType); } private async Task Request(HttpRequestMessage options, CancellationToken cancellationToken = default) { using var httpClient = _httpClientFactory.CreateClient(nameof(HttpRequestRunner)); - return await httpClient.SendAsync(options, - cancellationToken).ConfigureAwait(false); + return await httpClient.SendAsync(options, cancellationToken).ConfigureAwait(false); } } + + diff --git a/src/SecTester.Repeater/Runners/TruncatedBody.cs b/src/SecTester.Repeater/Runners/TruncatedBody.cs new file mode 100644 index 0000000..2cefcf4 --- /dev/null +++ b/src/SecTester.Repeater/Runners/TruncatedBody.cs @@ -0,0 +1,23 @@ +using System; +using System.Text; + +namespace SecTester.Repeater.Runners; + +internal sealed class TruncatedBody +{ + public TruncatedBody(byte[] body, string? charSet = default) + { + Body = body; + Encoding = string.IsNullOrEmpty(charSet) ? Encoding.Default : Encoding.GetEncoding(charSet); + } + + private Encoding Encoding { get; } + private byte[] Body { get; } + public int Length => Buffer.ByteLength(Body); + + public override string ToString() + { + return Encoding.GetString(Body); + } +} + diff --git a/src/SecTester.Repeater/SecTester.Repeater.csproj b/src/SecTester.Repeater/SecTester.Repeater.csproj index d223ec6..35960a4 100644 --- a/src/SecTester.Repeater/SecTester.Repeater.csproj +++ b/src/SecTester.Repeater/SecTester.Repeater.csproj @@ -12,7 +12,6 @@ - diff --git a/test/SecTester.Repeater.Tests/Extensions/EnumerableExtensionsTests.cs b/test/SecTester.Repeater.Tests/Extensions/EnumerableExtensionsTests.cs new file mode 100644 index 0000000..e08c4cd --- /dev/null +++ b/test/SecTester.Repeater.Tests/Extensions/EnumerableExtensionsTests.cs @@ -0,0 +1,75 @@ +namespace SecTester.Repeater.Tests.Extensions; + +public class EnumerableExtensionsTests +{ + private readonly Action _action = Substitute.For>(); + + [Fact] + public void ForEach_SourceIsNotDefined_ThrowsError() + { + // act + var act = () => ((null as IEnumerable)!).ForEach(_action); + + // assert + act.Should().Throw().WithMessage("*source*"); + } + + [Fact] + public void ForEach_ActionIsNotDefined_ThrowsError() + { + // arrange + IEnumerable list = new List + { + 1, 2, 3 + }; + + // act + var act = () => list.ForEach(null!); + + // assert + act.Should().Throw().WithMessage("*action*"); + } + + [Fact] + public void ForEach_IterateOverAllElements() + { + // arrange + IEnumerable list = new List + { + 1, 2, 3 + }; + + // act + list.ForEach(_action); + + // assert + _action.Received(3)(Arg.Any()); + } + + [Fact] + public void ForEach_ExecutesActionInCorrectOrder() + { + // arrange + var source = Enumerable.Range(1, 10); + var items = new List(); + + // act + source.ForEach(x => items.Add(x)); + + // assert + items.Should().ContainInOrder(source); + } + + [Fact] + public void ForEach_DoesNothingOnEmptyEnumerable() + { + // arrange + IEnumerable list = new List(); + + // act + list.ForEach(_action); + + // assert + _action.DidNotReceive()(Arg.Any()); + } +} diff --git a/test/SecTester.Repeater.Tests/Extensions/ListExtensionsTests.cs b/test/SecTester.Repeater.Tests/Extensions/ListExtensionsTests.cs new file mode 100644 index 0000000..1c36ed1 --- /dev/null +++ b/test/SecTester.Repeater.Tests/Extensions/ListExtensionsTests.cs @@ -0,0 +1,131 @@ +namespace SecTester.Repeater.Tests.Extensions; + +public class ListExtensionsTests +{ + [Fact] + public void Replace_SourceIsNotDefined_ThrowsError() + { + // act + var act = () => ((null as List)!).Replace(4, x => x == 2); + + // assert + act.Should().Throw().WithMessage("*source*"); + } + + [Fact] + public void Replace_PredicateIsNotDefined_ThrowsError() + { + // arrange + var list = new List + { + 1, 2, 3 + }; + + // act + var act = () => list.Replace(4, null!); + + // assert + act.Should().Throw().WithMessage("*predicate*"); + } + + [Fact] + public void Replace_NewValueIsNull_ReplacesItem() + { + // arrange + var list = new List + { + 1, 2, 3 + }; + + // act + list.Replace(null, x => x == 1); + + // assert + list.Should().BeEquivalentTo(new int?[] { null, 2, 3 }); + } + + [Fact] + public void Replace_ItemFound_ReplacesItemInList() + { + // arrange + var list = new List + { + 1, 2, 3 + }; + + // act + list.Replace(4, x => x == 2); + + // assert + list.Should().Contain(4); + } + + [Fact] + public void Replace_ItemFound_ReturnsIndexOfReplacedItem() + { + // arrange + var list = new List + { + 1, 2, 3 + }; + var expected = list.IndexOf(2); + + // act + var result = list.Replace(4, x => x == 2); + + // assert + result.Should().Be(expected); + } + + [Fact] + public void Replace_ItemNotFound_DoesNothing() + { + // arrange + var list = new List + { + 1, 2, 3 + }; + + // act + list.Replace(4, x => x == 5); + + // assert + list.Should().BeEquivalentTo(new[] + { + 1, 2, 3 + }); + } + + [Fact] + public void Replace_ItemNotFound_ReturnsNegativeIndex() + { + // arrange + var list = new List + { + 1, 2, 3 + }; + + // act + var result = list.Replace(4, x => x == 5); + + // assert + result.Should().Be(-1); + } + + [Fact] + public void Replace_ItemFound_ReplacesFirstMatchingItem() + { + // arrange + var list = new List + { + 1, 2, 1 + }; + var expected = new List { 4, 2, 1 }; + + // act + list.Replace(4, x => x == 1); + + // assert + list.Should().BeEquivalentTo(expected); + } +} diff --git a/test/SecTester.Repeater.Tests/Runners/HttpRequestRunnerTests.cs b/test/SecTester.Repeater.Tests/Runners/HttpRequestRunnerTests.cs index 6902c99..4e4ea8b 100644 --- a/test/SecTester.Repeater.Tests/Runners/HttpRequestRunnerTests.cs +++ b/test/SecTester.Repeater.Tests/Runners/HttpRequestRunnerTests.cs @@ -4,42 +4,47 @@ public class HttpRequestRunnerTests : IDisposable { private const string Url = "https://example.com"; private const string JsonContentType = "application/json"; + private const string HtmlContentType = "text/html"; + private const string HtmlContentTypeWithCharSet = $"{HtmlContentType}; charset=utf-16"; private const string CustomContentType = "application/x-custom"; - private const string Content = @"{""foo"":""bar""}"; + private const string CustomContentTypeWithUtf8CharSet = $"{CustomContentType}; charset=utf-8"; + private const string JsonContent = @"{""foo"":""bar""}"; + private const string HtmlBody = ""; private const string HeaderFieldValue = "test-header-value"; - private const string HeaderFieldName = "testHeader"; + private const string HeaderFieldName = "X-Test-Header"; private const string ContentLengthFieldName = "Content-Length"; private const string ContentTypeFieldName = "Content-Type"; + private const string HostFieldName = "Host"; + private const string InvalidHostHeaderValue = "\0example.com\n"; private readonly IHttpClientFactory _httpClientFactory = Substitute.For(); private readonly MockHttpMessageHandler _mockHttp = new(); - private readonly RequestRunnerOptions _options = new() - { - MaxContentLength = 1 - }; - - private readonly HttpRequestRunner _sut; - - public HttpRequestRunnerTests() - { - _sut = new HttpRequestRunner(_options, _httpClientFactory); - _httpClientFactory.CreateClient(Arg.Any()).Returns(_mockHttp.ToHttpClient()); - } - public void Dispose() { _httpClientFactory.ClearSubstitute(); _mockHttp.Clear(); + _mockHttp.Dispose(); GC.SuppressFinalize(this); } + private HttpRequestRunner CreateSut(RequestRunnerOptions? options = default) + { + _httpClientFactory.CreateClient(Arg.Any()).Returns(_mockHttp.ToHttpClient()); + return new HttpRequestRunner(options ?? new RequestRunnerOptions(), _httpClientFactory); + } + [Fact] - public async Task Run_PerformAnHttpRequest() + public async Task Run_ReturnsResult_WhenRequestIsSuccessful() { // arrange + var sut = CreateSut(); var headers = new[] { + new KeyValuePair>(ContentTypeFieldName, new[] + { + JsonContentType + }), new KeyValuePair>(HeaderFieldName, new[] { HeaderFieldValue @@ -47,20 +52,96 @@ public async Task Run_PerformAnHttpRequest() }; var request = new RequestExecutingEvent(new Uri(Url)) { + Method = HttpMethod.Patch, + Body = JsonContent, Headers = headers }; - _mockHttp.Expect(Url).WithHeaders(headers.Select(x => new KeyValuePair(x.Key, string.Join(";", x.Value)))) - .Respond(HttpStatusCode.OK, JsonContentType, Content); + _mockHttp.Expect(Url) + .WithContent(JsonContent) + .WithHeaders($"{HeaderFieldName}: {HeaderFieldValue}") + .With(message => message.Method.Equals(HttpMethod.Patch)) + .With(message => (bool)message.Content?.Headers.ContentType?.MediaType?.StartsWith(JsonContentType, StringComparison.OrdinalIgnoreCase)) + .Respond(HttpStatusCode.OK, JsonContentType, JsonContent); // act - var result = await _sut.Run(request); + var result = await sut.Run(request); // assert _mockHttp.VerifyNoOutstandingExpectation(); result.Should().BeEquivalentTo(new { StatusCode = (int)HttpStatusCode.OK, - Body = Content + Body = JsonContent + }); + } + + [Fact] + public async Task Run_ReturnsResultWithDecodedBody() + { + // arrange + var sut = CreateSut(); + var encoding = Encoding.GetEncoding("utf-16"); + var expectedByteLength = Buffer.ByteLength(encoding.GetBytes(HtmlBody)); + var request = new RequestExecutingEvent(new Uri(Url)); + var content = new StringContent(HtmlBody, encoding, HtmlContentType); + _mockHttp.Expect(Url).Respond(HttpStatusCode.OK, content); + + // act + var result = await sut.Run(request); + + // assert + _mockHttp.VerifyNoOutstandingExpectation(); + result.Should().BeEquivalentTo(new + { + Headers = new[] { new KeyValuePair(ContentTypeFieldName, new[] { HtmlContentTypeWithCharSet }), new KeyValuePair(ContentLengthFieldName, new[] { $"{expectedByteLength}" }) }, + Body = HtmlBody + }, options => options.ExcludingMissingMembers().IncludingNestedObjects()); + } + + [Fact] + public async Task Run_ReturnsResultWithError_WhenRequestTimesOut() + { + // arrange + var sut = CreateSut(new RequestRunnerOptions + { + Timeout = TimeSpan.Zero + }); + var request = new RequestExecutingEvent(new Uri(Url)); + _mockHttp.Expect(Url) + .Respond(async () => + { + await Task.Delay(5); + + return new HttpResponseMessage(HttpStatusCode.OK); + }); + + // act + var result = await sut.Run(request); + + // assert + _mockHttp.VerifyNoOutstandingExpectation(); + result.Should().BeEquivalentTo(new + { + Message = "The operation was canceled." + }); + } + + [Fact] + public async Task Run_MaxContentLengthIsLessThan0_SkipsTruncating() + { + // arrange + var sut = CreateSut(new RequestRunnerOptions { MaxContentLength = -1 }); + var request = new RequestExecutingEvent(new Uri(Url)); + var body = string.Concat(Enumerable.Repeat("x", 5)); + _mockHttp.Expect(Url).Respond(HttpStatusCode.OK, CustomContentType, body); + + // act + var result = await sut.Run(request); + + // assert + result.Should().BeEquivalentTo(new + { + Body = body }); } @@ -68,11 +149,12 @@ public async Task Run_PerformAnHttpRequest() public async Task Run_NoContentStatusReceived_SkipsTruncating() { // arrange + var sut = CreateSut(); var request = new RequestExecutingEvent(new Uri(Url)); _mockHttp.Expect(Url).Respond(HttpStatusCode.NoContent); // act - var result = await _sut.Run(request); + var result = await sut.Run(request); // assert result.Should().BeEquivalentTo(new @@ -85,14 +167,15 @@ public async Task Run_NoContentStatusReceived_SkipsTruncating() public async Task Run_HeadMethodUsed_SkipsTruncating() { // arrange + var sut = CreateSut(); var request = new RequestExecutingEvent(new Uri(Url)) { Method = HttpMethod.Head }; - _mockHttp.Expect(Url).Respond(HttpStatusCode.OK, JsonContentType, Content); + _mockHttp.Expect(Url).Respond(HttpStatusCode.OK, JsonContentType, JsonContent); // act - var result = await _sut.Run(request); + var result = await sut.Run(request); // assert result.Should().BeEquivalentTo(new @@ -106,17 +189,18 @@ public async Task Run_HeadMethodUsed_SkipsTruncating() public async Task Run_AllowedMimeReceived_SkipsTruncating() { // arrange + var sut = CreateSut(); var request = new RequestExecutingEvent(new Uri(Url)); - _mockHttp.Expect(Url).Respond(HttpStatusCode.OK, JsonContentType, Content); + _mockHttp.Expect(Url).Respond(HttpStatusCode.OK, JsonContentType, JsonContent); // act - var result = await _sut.Run(request); + var result = await sut.Run(request); // assert result.Should().BeEquivalentTo(new { StatusCode = (int)HttpStatusCode.OK, - Body = Content + Body = JsonContent }); } @@ -124,15 +208,20 @@ public async Task Run_AllowedMimeReceived_SkipsTruncating() public async Task Run_NotAllowedMimeReceived_TruncatesBody() { // arrange + var options = new RequestRunnerOptions + { + MaxContentLength = 1 + }; + var sut = CreateSut(options); var headers = new[] { new KeyValuePair>(ContentTypeFieldName, new[] { - $"{CustomContentType}; charset=utf-8" + CustomContentTypeWithUtf8CharSet }), new KeyValuePair>(ContentLengthFieldName, new[] { - $"{_options.MaxContentLength}" + $"{options.MaxContentLength}" }) }; var request = new RequestExecutingEvent(new Uri(Url)); @@ -140,7 +229,7 @@ public async Task Run_NotAllowedMimeReceived_TruncatesBody() _mockHttp.Expect(Url).Respond(HttpStatusCode.OK, CustomContentType, body); // act - var result = await _sut.Run(request); + var result = await sut.Run(request); // assert result.Should().BeEquivalentTo(new @@ -155,17 +244,18 @@ public async Task Run_NotAllowedMimeReceived_TruncatesBody() public async Task Run_HttpStatusException_ReturnsResponse() { // arrange + var sut = CreateSut(); var request = new RequestExecutingEvent(new Uri(Url)); - _mockHttp.Expect(Url).Respond(HttpStatusCode.ServiceUnavailable, JsonContentType, Content); + _mockHttp.Expect(Url).Respond(HttpStatusCode.ServiceUnavailable, JsonContentType, JsonContent); // act - var result = await _sut.Run(request); + var result = await sut.Run(request); // assert result.Should().BeEquivalentTo(new { StatusCode = (int)HttpStatusCode.ServiceUnavailable, - Body = Content + Body = JsonContent }); } @@ -173,11 +263,12 @@ public async Task Run_HttpStatusException_ReturnsResponse() public async Task Run_TcpException_ReturnsResponse() { // arrange + var sut = CreateSut(); var request = new RequestExecutingEvent(new Uri(Url)); _mockHttp.Expect(Url).Throw(new SocketException((int)SocketError.ConnectionRefused)); // act - var result = await _sut.Run(request); + var result = await sut.Run(request); // assert result.Should().BeEquivalentTo(new @@ -185,4 +276,61 @@ public async Task Run_TcpException_ReturnsResponse() ErrorCode = "ConnectionRefused" }, options => options.Using(ctx => ctx.Subject.Should().BeOfType()).When(info => info.Path.EndsWith("Message"))); } + + [Fact] + public async Task Run_BypassesStrictHttpValidation() + { + // arrange + var sut = CreateSut(); + var headers = new[] + { + new KeyValuePair>(HostFieldName, new[] + { + InvalidHostHeaderValue + }) + }; + var request = new RequestExecutingEvent(new Uri(Url)) + { + Headers = headers + }; + _mockHttp.Expect(Url) + .With(message => message.Headers.ToString().Contains(InvalidHostHeaderValue, StringComparison.OrdinalIgnoreCase)) + .Respond(HttpStatusCode.NoContent); + + // act + await sut.Run(request); + + // assert + _mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public async Task Run_AcceptsContentHeaders() + { + // arrange + var sut = CreateSut(); + var headers = new[] + { + new KeyValuePair>(ContentTypeFieldName, new[] + { + JsonContentType + }) + }; + var request = new RequestExecutingEvent(new Uri(Url)) + { + Method = HttpMethod.Post, + Headers = headers, + Body = JsonContent + }; + _mockHttp + .Expect(Url) + .With(message => (bool)message.Content?.Headers.ContentType?.MediaType?.Equals(JsonContentType, StringComparison.OrdinalIgnoreCase)) + .Respond(HttpStatusCode.NoContent); + + // act + await sut.Run(request); + + // assert + _mockHttp.VerifyNoOutstandingExpectation(); + } } diff --git a/test/SecTester.Repeater.Tests/Usings.cs b/test/SecTester.Repeater.Tests/Usings.cs index 38c8031..4f013b1 100644 --- a/test/SecTester.Repeater.Tests/Usings.cs +++ b/test/SecTester.Repeater.Tests/Usings.cs @@ -1,5 +1,6 @@ global using System.Net; global using System.Net.Sockets; +global using System.Text; global using System.Text.RegularExpressions; global using System.Timers; global using FluentAssertions; @@ -13,7 +14,6 @@ global using SecTester.Core.Bus; global using SecTester.Core.Exceptions; global using SecTester.Core.Utils; -global using SecTester.Core.Utils; global using SecTester.Repeater.Api; global using SecTester.Repeater.Bus; global using SecTester.Repeater.Extensions;