diff --git a/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs b/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs index aac3f8091a913..788c31e92230f 100644 --- a/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs +++ b/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs @@ -69,6 +69,7 @@ internal partial class WinHttp public const uint WINHTTP_QUERY_STATUS_TEXT = 20; public const uint WINHTTP_QUERY_RAW_HEADERS = 21; public const uint WINHTTP_QUERY_RAW_HEADERS_CRLF = 22; + public const uint WINHTTP_QUERY_FLAG_TRAILERS = 0x02000000; public const uint WINHTTP_QUERY_CONTENT_ENCODING = 29; public const uint WINHTTP_QUERY_SET_COOKIE = 43; public const uint WINHTTP_QUERY_CUSTOM = 65535; diff --git a/src/libraries/Common/src/System/Net/SecurityProtocol.cs b/src/libraries/Common/src/System/Net/SecurityProtocol.cs index f6de449ef3b8d..a8ff9b3cbc698 100644 --- a/src/libraries/Common/src/System/Net/SecurityProtocol.cs +++ b/src/libraries/Common/src/System/Net/SecurityProtocol.cs @@ -8,7 +8,7 @@ namespace System.Net internal static class SecurityProtocol { public const SslProtocols DefaultSecurityProtocols = -#if !NETSTANDARD2_0 && !NETFRAMEWORK +#if !NETSTANDARD2_0 && !NETSTANDARD2_1 && !NETFRAMEWORK SslProtocols.Tls13 | #endif SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12; diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System.Net.Http.WinHttpHandler.csproj b/src/libraries/System.Net.Http.WinHttpHandler/src/System.Net.Http.WinHttpHandler.csproj index 8e62c79985b8c..0761a799f6410 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/src/System.Net.Http.WinHttpHandler.csproj +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System.Net.Http.WinHttpHandler.csproj @@ -1,7 +1,7 @@ - + true - netstandard2.0-windows;netstandard2.0;net461-windows + netstandard2.0-windows;netstandard2.0;netstandard2.1-windows;netstandard2.1;net461-windows true enable diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs index 1f9c88a019a91..5196fc868a973 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseParser.cs @@ -29,7 +29,7 @@ public static HttpResponseMessage CreateResponseMessage( // Create a single buffer to use for all subsequent WinHttpQueryHeaders string interop calls. // This buffer is the length needed for WINHTTP_QUERY_RAW_HEADERS_CRLF, which includes the status line // and all headers separated by CRLF, so it should be large enough for any individual status line or header queries. - int bufferLength = GetResponseHeaderCharBufferLength(requestHandle, Interop.WinHttp.WINHTTP_QUERY_RAW_HEADERS_CRLF); + int bufferLength = GetResponseHeaderCharBufferLength(requestHandle, Interop.WinHttp.WINHTTP_QUERY_RAW_HEADERS_CRLF, isTrailingHeaders: false); char[] buffer = ArrayPool.Shared.Rent(bufferLength); try { @@ -58,7 +58,7 @@ public static HttpResponseMessage CreateResponseMessage( string.Empty; // Create response stream and wrap it in a StreamContent object. - var responseStream = new WinHttpResponseStream(requestHandle, state); + var responseStream = new WinHttpResponseStream(requestHandle, state, response); state.RequestHandle = null; // ownership successfully transfered to WinHttpResponseStram. Stream decompressedStream = responseStream; @@ -93,7 +93,7 @@ public static HttpResponseMessage CreateResponseMessage( response.RequestMessage = request; // Parse raw response headers and place them into response message. - ParseResponseHeaders(requestHandle, response, buffer, stripEncodingHeaders); + ParseResponseHeaders(requestHandle, Interop.WinHttp.WINHTTP_QUERY_RAW_HEADERS_CRLF, response, buffer, stripEncodingHeaders, isTrailers: false); if (response.RequestMessage.Method != HttpMethod.Head) { @@ -223,7 +223,7 @@ private static unsafe int GetResponseHeader(SafeWinHttpHandle requestHandle, uin /// /// Returns the size of the char array buffer. /// - private static unsafe int GetResponseHeaderCharBufferLength(SafeWinHttpHandle requestHandle, uint infoLevel) + public static unsafe int GetResponseHeaderCharBufferLength(SafeWinHttpHandle requestHandle, uint infoLevel, bool isTrailingHeaders) { char* buffer = null; int bufferLength = 0; @@ -233,11 +233,21 @@ private static unsafe int GetResponseHeaderCharBufferLength(SafeWinHttpHandle re { int lastError = Marshal.GetLastWin32Error(); - Debug.Assert(lastError != Interop.WinHttp.ERROR_WINHTTP_HEADER_NOT_FOUND); + if (!isTrailingHeaders) + { + Debug.Assert(lastError != Interop.WinHttp.ERROR_WINHTTP_HEADER_NOT_FOUND); - if (lastError != Interop.WinHttp.ERROR_INSUFFICIENT_BUFFER) + if (lastError != Interop.WinHttp.ERROR_INSUFFICIENT_BUFFER) + { + throw WinHttpException.CreateExceptionUsingError(lastError, nameof(Interop.WinHttp.WinHttpQueryHeaders)); + } + } + else { - throw WinHttpException.CreateExceptionUsingError(lastError, nameof(Interop.WinHttp.WinHttpQueryHeaders)); + if (!(lastError == Interop.WinHttp.ERROR_INSUFFICIENT_BUFFER || lastError == Interop.WinHttp.ERROR_WINHTTP_HEADER_NOT_FOUND)) + { + throw WinHttpException.CreateExceptionUsingError(lastError, nameof(Interop.WinHttp.WinHttpQueryHeaders)); + } } } @@ -286,24 +296,39 @@ private static string GetReasonPhrase(HttpStatusCode statusCode, char[] buffer, new string(buffer, 0, bufferLength); } - private static void ParseResponseHeaders( + private class HttpResponseTrailers : HttpHeaders + { + } + + public static void ParseResponseHeaders( SafeWinHttpHandle requestHandle, + uint infoLevel, HttpResponseMessage response, char[] buffer, - bool stripEncodingHeaders) + bool stripEncodingHeaders, + bool isTrailers) { HttpResponseHeaders responseHeaders = response.Headers; HttpContentHeaders contentHeaders = response.Content.Headers; +#if NETSTANDARD2_1 + HttpResponseHeaders responseTrailers = response.TrailingHeaders; +#else + HttpResponseTrailers responseTrailers = new HttpResponseTrailers(); + response.RequestMessage.Properties["__ResponseTrailers"] = responseTrailers; +#endif int bufferLength = GetResponseHeader( requestHandle, - Interop.WinHttp.WINHTTP_QUERY_RAW_HEADERS_CRLF, + infoLevel, buffer); var reader = new WinHttpResponseHeaderReader(buffer, 0, bufferLength); - // Skip the first line which contains status code, etc. information that we already parsed. - reader.ReadLine(); + if (!isTrailers) + { + // Skip the first line which contains status code, etc. information that we already parsed. + reader.ReadLine(); + } // Parse the array of headers and split them between Content headers and Response headers. string headerName; @@ -311,22 +336,29 @@ private static void ParseResponseHeaders( while (reader.ReadHeader(out headerName, out headerValue)) { - if (!responseHeaders.TryAddWithoutValidation(headerName, headerValue)) + if (!isTrailers) { - if (stripEncodingHeaders) + if (!responseHeaders.TryAddWithoutValidation(headerName, headerValue)) { - // Remove Content-Length and Content-Encoding headers if we are - // decompressing the response stream in the handler (due to - // WINHTTP not supporting it in a particular downlevel platform). - // This matches the behavior of WINHTTP when it does decompression itself. - if (string.Equals(HttpKnownHeaderNames.ContentLength, headerName, StringComparison.OrdinalIgnoreCase) || - string.Equals(HttpKnownHeaderNames.ContentEncoding, headerName, StringComparison.OrdinalIgnoreCase)) + if (stripEncodingHeaders) { - continue; + // Remove Content-Length and Content-Encoding headers if we are + // decompressing the response stream in the handler (due to + // WINHTTP not supporting it in a particular downlevel platform). + // This matches the behavior of WINHTTP when it does decompression itself. + if (string.Equals(HttpKnownHeaderNames.ContentLength, headerName, StringComparison.OrdinalIgnoreCase) || + string.Equals(HttpKnownHeaderNames.ContentEncoding, headerName, StringComparison.OrdinalIgnoreCase)) + { + continue; + } } - } - contentHeaders.TryAddWithoutValidation(headerName, headerValue); + contentHeaders.TryAddWithoutValidation(headerName, headerValue); + } + } + else + { + responseTrailers.TryAddWithoutValidation(headerName, headerValue); } } } diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseStream.cs b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseStream.cs index 812a56e7232cf..e8b4b30d37519 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseStream.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpResponseStream.cs @@ -16,11 +16,14 @@ internal sealed class WinHttpResponseStream : Stream { private volatile bool _disposed; private readonly WinHttpRequestState _state; + private readonly HttpResponseMessage _responseMessage; private SafeWinHttpHandle _requestHandle; + private bool _readTrailingHeaders; - internal WinHttpResponseStream(SafeWinHttpHandle requestHandle, WinHttpRequestState state) + internal WinHttpResponseStream(SafeWinHttpHandle requestHandle, WinHttpRequestState state, HttpResponseMessage responseMessage) { _state = state; + _responseMessage = responseMessage; _requestHandle = requestHandle; } @@ -126,6 +129,7 @@ private async Task CopyToAsyncCore(Stream destination, byte[] buffer, Cancellati int bytesAvailable = await _state.LifecycleAwaitable; if (bytesAvailable == 0) { + ReadResponseTrailers(); break; } Debug.Assert(bytesAvailable > 0); @@ -142,12 +146,17 @@ private async Task CopyToAsyncCore(Stream destination, byte[] buffer, Cancellati int bytesRead = await _state.LifecycleAwaitable; if (bytesRead == 0) { + ReadResponseTrailers(); break; } Debug.Assert(bytesRead > 0); // Write that data out to the output stream +#if NETSTANDARD2_1 + await destination.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken).ConfigureAwait(false); +#else await destination.WriteAsync(buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false); +#endif } } finally @@ -240,7 +249,14 @@ private async Task ReadAsyncCore(byte[] buffer, int offset, int count, Canc } } - return await _state.LifecycleAwaitable; + int bytesRead = await _state.LifecycleAwaitable; + + if (bytesRead == 0) + { + ReadResponseTrailers(); + } + + return bytesRead; } finally { @@ -249,6 +265,37 @@ private async Task ReadAsyncCore(byte[] buffer, int offset, int count, Canc } } + private void ReadResponseTrailers() + { + // Only load response trailers if: + // 1. HTTP/2 or later + // 2. Response trailers not already loaded + if (_readTrailingHeaders || _responseMessage.Version < WinHttpHandler.HttpVersion20) + { + return; + } + + _readTrailingHeaders = true; + + var bufferLength = WinHttpResponseParser.GetResponseHeaderCharBufferLength( + _requestHandle, + Interop.WinHttp.WINHTTP_QUERY_RAW_HEADERS_CRLF | Interop.WinHttp.WINHTTP_QUERY_FLAG_TRAILERS, + isTrailingHeaders: true); + + if (bufferLength != 0) + { + char[] trailersBuffer = ArrayPool.Shared.Rent(bufferLength); + try + { + WinHttpResponseParser.ParseResponseHeaders(_requestHandle, Interop.WinHttp.WINHTTP_QUERY_RAW_HEADERS_CRLF | Interop.WinHttp.WINHTTP_QUERY_FLAG_TRAILERS, _responseMessage, trailersBuffer, stripEncodingHeaders: false, isTrailers: true); + } + finally + { + ArrayPool.Shared.Return(trailersBuffer); + } + } + } + public override int Read(byte[] buffer, int offset, int count) { return ReadAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult(); diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/System.Net.Http.WinHttpHandler.Functional.Tests.csproj b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/System.Net.Http.WinHttpHandler.Functional.Tests.csproj index 4bbd2f502e116..481eca84a0eae 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/System.Net.Http.WinHttpHandler.Functional.Tests.csproj +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/System.Net.Http.WinHttpHandler.Functional.Tests.csproj @@ -14,11 +14,11 @@ - + + Link="Common\System\IO\DelegateStream.cs" /> + Link="Common\System\Net\Http\ResponseStreamTest.cs" /> + diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/TrailingHeadersTest.cs b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/TrailingHeadersTest.cs new file mode 100644 index 0000000000000..4414a713956e3 --- /dev/null +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/TrailingHeadersTest.cs @@ -0,0 +1,215 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http.Functional.Tests; +using System.Net.Http.Headers; +using System.Net.Test.Common; +using System.Text; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.Http.WinHttpHandlerFunctional.Tests +{ + public class TrailingHeadersTest : HttpClientHandlerTestBase + { + public TrailingHeadersTest(ITestOutputHelper output) : base(output) + { } + + protected override Version UseVersion => new Version(2, 0); + + protected static byte[] DataBytes = Encoding.ASCII.GetBytes("data"); + + protected static readonly IList TrailingHeaders = new HttpHeaderData[] { + new HttpHeaderData("MyCoolTrailerHeader", "amazingtrailer"), + new HttpHeaderData("EmptyHeader", ""), + new HttpHeaderData("Accept-Encoding", "identity,gzip"), + new HttpHeaderData("Hello", "World") }; + + protected static Frame MakeDataFrame(int streamId, byte[] data, bool endStream = false) => + new DataFrame(data, (endStream ? FrameFlags.EndStream : FrameFlags.None), 0, streamId); + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public async Task Http2GetAsync_NoTrailingHeaders_EmptyCollection() + { + using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer()) + using (HttpClient client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + + Http2LoopbackConnection connection = await server.EstablishConnectionAsync(); + + int streamId = await connection.ReadRequestHeaderAsync(); + + // Response header. + await connection.SendDefaultResponseHeadersAsync(streamId); + + // Response data. + await connection.WriteFrameAsync(MakeDataFrame(streamId, DataBytes, endStream: true)); + + // Server doesn't send trailing header frame. + HttpResponseMessage response = await sendTask; + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var trailingHeaders = GetTrailingHeaders(response); + Assert.NotNull(trailingHeaders); + Assert.Equal(0, trailingHeaders.Count()); + } + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public async Task Http2GetAsync_MissingTrailer_TrailingHeadersAccepted() + { + using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer()) + using (HttpClient client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + + Http2LoopbackConnection connection = await server.EstablishConnectionAsync(); + + int streamId = await connection.ReadRequestHeaderAsync(); + + // Response header. + await connection.SendDefaultResponseHeadersAsync(streamId); + + // Response data, missing Trailers. + await connection.WriteFrameAsync(MakeDataFrame(streamId, DataBytes)); + + // Additional trailing header frame. + await connection.SendResponseHeadersAsync(streamId, isTrailingHeader: true, headers: TrailingHeaders, endStream: true); + + HttpResponseMessage response = await sendTask; + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var trailingHeaders = GetTrailingHeaders(response); + Assert.Equal(TrailingHeaders.Count, trailingHeaders.Count()); + Assert.Contains("amazingtrailer", trailingHeaders.GetValues("MyCoolTrailerHeader")); + Assert.Contains("World", trailingHeaders.GetValues("Hello")); + } + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public async Task Http2GetAsyncResponseHeadersReadOption_TrailingHeaders_Available() + { + using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer()) + using (HttpClient client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address, HttpCompletionOption.ResponseHeadersRead); + + Http2LoopbackConnection connection = await server.EstablishConnectionAsync(); + + int streamId = await connection.ReadRequestHeaderAsync(); + + // Response header. + await connection.SendDefaultResponseHeadersAsync(streamId); + + // Response data, missing Trailers. + await connection.WriteFrameAsync(MakeDataFrame(streamId, DataBytes)); + + HttpResponseMessage response = await sendTask; + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Pending read on the response content. + var trailingHeaders = GetTrailingHeaders(response); + Assert.True(trailingHeaders == null || trailingHeaders.Count() == 0); + + Stream stream = await response.Content.ReadAsStreamAsync(TestAsync); + Byte[] data = new Byte[100]; + await stream.ReadAsync(data, 0, data.Length); + + // Intermediate test - haven't reached stream EOF yet. + trailingHeaders = GetTrailingHeaders(response); + Assert.True(trailingHeaders == null || trailingHeaders.Count() == 0); + + // Finish data stream and write out trailing headers. + await connection.WriteFrameAsync(MakeDataFrame(streamId, DataBytes)); + await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader: true, headers: TrailingHeaders); + + // Read data until EOF is reached + while (stream.Read(data, 0, data.Length) != 0) ; + + trailingHeaders = GetTrailingHeaders(response); + Assert.Equal(TrailingHeaders.Count, trailingHeaders.Count()); + Assert.Contains("amazingtrailer", trailingHeaders.GetValues("MyCoolTrailerHeader")); + Assert.Contains("World", trailingHeaders.GetValues("Hello")); + + // Read when already zero. Trailers shouldn't be changed. + stream.Read(data, 0, data.Length); + + trailingHeaders = GetTrailingHeaders(response); + Assert.Equal(TrailingHeaders.Count, trailingHeaders.Count()); + } + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public async Task Http2GetAsync_TrailerHeaders_TrailingHeaderNoBody() + { + using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer()) + using (HttpClient client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + + Http2LoopbackConnection connection = await server.EstablishConnectionAsync(); + + int streamId = await connection.ReadRequestHeaderAsync(); + + // Response header. + await connection.SendDefaultResponseHeadersAsync(streamId); + await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader: true, headers: TrailingHeaders); + + HttpResponseMessage response = await sendTask; + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var trailingHeaders = GetTrailingHeaders(response); + Assert.Equal(TrailingHeaders.Count, trailingHeaders.Count()); + Assert.Contains("amazingtrailer", trailingHeaders.GetValues("MyCoolTrailerHeader")); + Assert.Contains("World", trailingHeaders.GetValues("Hello")); + } + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public async Task Http2GetAsync_TrailingHeaders_NoData_EmptyResponseObserved() + { + using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer()) + using (HttpClient client = CreateHttpClient()) + { + Task sendTask = client.GetAsync(server.Address); + + Http2LoopbackConnection connection = await server.EstablishConnectionAsync(); + + int streamId = await connection.ReadRequestHeaderAsync(); + + // Response header. + await connection.SendDefaultResponseHeadersAsync(streamId); + + // No data. + + // Response trailing headers + await connection.SendResponseHeadersAsync(streamId, isTrailingHeader: true, headers: TrailingHeaders); + + HttpResponseMessage response = await sendTask; + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal(Array.Empty(), await response.Content.ReadAsByteArrayAsync()); + + var trailingHeaders = GetTrailingHeaders(response); + Assert.Contains("amazingtrailer", trailingHeaders.GetValues("MyCoolTrailerHeader")); + Assert.Contains("World", trailingHeaders.GetValues("Hello")); + } + } + + private HttpHeaders GetTrailingHeaders(HttpResponseMessage responseMessage) + { +#if !NET48 + return responseMessage.TrailingHeaders; +#else +#pragma warning disable CS0618 // Type or member is obsolete + responseMessage.RequestMessage.Properties.TryGetValue("__ResponseTrailers", out object trailers); +#pragma warning restore CS0618 // Type or member is obsolete + return (HttpHeaders)trailers; +#endif + } + } +} diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/UnitTests/WinHttpResponseStreamTest.cs b/src/libraries/System.Net.Http.WinHttpHandler/tests/UnitTests/WinHttpResponseStreamTest.cs index cac00483a083d..90900a0517e07 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/UnitTests/WinHttpResponseStreamTest.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/UnitTests/WinHttpResponseStreamTest.cs @@ -376,7 +376,7 @@ internal Stream MakeResponseStream() handle.Context = state.ToIntPtr(); state.RequestHandle = handle; - return new WinHttpResponseStream(handle, state); + return new WinHttpResponseStream(handle, state, new HttpResponseMessage()); } } }