diff --git a/src/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBuffer.cs b/src/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBuffer.cs index 02547dcafbd7..f96532d57b59 100644 --- a/src/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBuffer.cs +++ b/src/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBuffer.cs @@ -39,7 +39,7 @@ internal class WebSocketBuffer : IDisposable private readonly ArraySegment _propertyBuffer; private readonly int _sendBufferSize; private volatile int _payloadOffset; - private volatile WebSocketReceiveResult _bufferedPayloadReceiveResult; + private volatile PayloadReceiveResult _bufferedPayloadReceiveResult; private long _pinnedSendBufferStartAddress; private long _pinnedSendBufferEndAddress; private ArraySegment _pinnedSendBuffer; @@ -305,7 +305,7 @@ internal void BufferPayload(ArraySegment payload, Debug.Assert(_payloadOffset == 0, "'m_PayloadOffset' MUST be '0' at this point."); Debug.Assert(_bufferedPayloadReceiveResult == null || _bufferedPayloadReceiveResult.Count == 0, - "'m_BufferedPayloadReceiveResult.Count' MUST be '0' at this point."); + "'_bufferedPayloadReceiveResult.Count' MUST be '0' at this point."); Buffer.BlockCopy(payload.Array, payload.Offset + unconsumedDataOffset, @@ -314,7 +314,7 @@ internal void BufferPayload(ArraySegment payload, bytesBuffered); _bufferedPayloadReceiveResult = - new WebSocketReceiveResult(bytesBuffered, messageType, endOfMessage); + new PayloadReceiveResult(bytesBuffered, messageType, endOfMessage); this.ValidateBufferedPayload(); } @@ -326,12 +326,12 @@ internal bool ReceiveFromBufferedPayload(ArraySegment buffer, out WebSocke int bytesTransferred = Math.Min(buffer.Count, _bufferedPayloadReceiveResult.Count); + _bufferedPayloadReceiveResult.Count -= bytesTransferred; + receiveResult = new WebSocketReceiveResult( bytesTransferred, _bufferedPayloadReceiveResult.MessageType, - bytesTransferred == 0 && _bufferedPayloadReceiveResult.EndOfMessage, - _bufferedPayloadReceiveResult.CloseStatus, - _bufferedPayloadReceiveResult.CloseStatusDescription); + _bufferedPayloadReceiveResult.Count == 0 && _bufferedPayloadReceiveResult.EndOfMessage); Buffer.BlockCopy(_payloadBuffer.Array, _payloadBuffer.Offset + _payloadOffset, @@ -558,9 +558,9 @@ private void ThrowIfDisposed() private void ValidateBufferedPayload() { Debug.Assert(_bufferedPayloadReceiveResult != null, - "'m_BufferedPayloadReceiveResult' MUST NOT be NULL."); + "'_bufferedPayloadReceiveResult' MUST NOT be NULL."); Debug.Assert(_bufferedPayloadReceiveResult.Count >= 0, - "'m_BufferedPayloadReceiveResult.Count' MUST NOT be negative."); + "'_bufferedPayloadReceiveResult.Count' MUST NOT be negative."); Debug.Assert(_payloadOffset >= 0, "'m_PayloadOffset' MUST NOT be smaller than 0."); Debug.Assert(_payloadOffset <= _payloadBuffer.Count, "'m_PayloadOffset' MUST NOT be bigger than 'm_PayloadBuffer.Count'."); @@ -685,5 +685,24 @@ private static class SendBufferState public const int None = 0; public const int SendPayloadSpecified = 1; } + + private class PayloadReceiveResult + { + public int Count { get; set; } + public bool EndOfMessage { get; } + public WebSocketMessageType MessageType { get; } + + public PayloadReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage) + { + if (count < 0) + { + throw new ArgumentOutOfRangeException(nameof(count)); + } + + Count = count; + EndOfMessage = endOfMessage; + MessageType = messageType; + } + } } -} \ No newline at end of file +} diff --git a/src/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs b/src/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs index 5ea8149957c4..a0697b15cbe0 100644 --- a/src/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs +++ b/src/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs @@ -112,6 +112,40 @@ public async Task ReceiveAsync_ReadWholeBuffer_Success(WebSocketMessageType mess Assert.Equal(Text, Encoding.ASCII.GetString(receivedBytes)); } + [ConditionalTheory(nameof(IsNotWindows7))] + [InlineData(300)] + [InlineData(500)] + [InlineData(1000)] + [InlineData(1300)] + public async Task ReceiveAsync_DetectEndOfMessage_Success(int bufferSize) + { + const int StringLength = 1000; + string sendString = new string('A', StringLength); + byte[] sentBytes = Encoding.ASCII.GetBytes(sendString); + + HttpListenerWebSocketContext context = await GetWebSocketContext(); + await ClientConnectTask; + + await Client.SendAsync(new ArraySegment(sentBytes), WebSocketMessageType.Text, true, new CancellationToken()); + + byte[] receivedBytes = new byte[bufferSize]; + List compoundBuffer = new List(); + + WebSocketReceiveResult result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, false); + while (!result.EndOfMessage) + { + result = await (context.WebSocket).ReceiveAsync(new ArraySegment(receivedBytes), new CancellationToken()); + + byte[] readBytes = new byte[result.Count]; + Array.Copy(receivedBytes, readBytes, result.Count); + compoundBuffer.AddRange(readBytes); + } + + Assert.True(result.EndOfMessage); + string msg = Encoding.UTF8.GetString(compoundBuffer.ToArray()); + Assert.Equal(sendString, msg); + } + [ConditionalFact(nameof(IsNotWindows7))] public async Task ReceiveAsync_NoInnerBuffer_ThrowsArgumentNullException() {