Skip to content

Commit

Permalink
Add Memory Overrides to Streams (#47125)
Browse files Browse the repository at this point in the history
* System.IO.Compression.DeflateStream.CopyToStream

-Implemented memory-based WriteAsync in DeflateStream.CopyToStream class.

This required implementing a memory-based overload of System.IO.Inflater.SetInput(). Previously, Inflater used a GCHandle to pin the array that was passed into SetInput. I converted it to use a MemoryHandle, and changed the array-based overload of SetInput to delegate to the new Memory-based overload.

* Implement suggested changes

* Memorify RequestStream

- Memory overrides for System.Net.RequestStream
- Memory overrides for System.Net.NetworkStreamWrapper

* Spanified ChunkedMemoryStream

WriteAsync is implemented in terms of Write, so I went ahead and implemented Write(ReadOnlySpan<byte>). 
For some reason, AsSpan() isn't available in this file.

* Apply suggested changes

* Apply suggestions from code review

Co-authored-by: Adam Sitnik <[email protected]>
  • Loading branch information
NewellClark and adamsitnik authored Mar 5, 2021
1 parent b76f17e commit 62f6c08
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 24 deletions.
28 changes: 22 additions & 6 deletions src/libraries/Common/src/System/IO/ChunkedMemoryStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using System;

namespace System.IO
{
Expand Down Expand Up @@ -34,24 +35,28 @@ public byte[] ToArray()

public override void Write(byte[] buffer, int offset, int count)
{
while (count > 0)
Write(new ReadOnlySpan<byte>(buffer, offset, count));
}

public override void Write(ReadOnlySpan<byte> buffer)
{
while (!buffer.IsEmpty)
{
if (_currentChunk != null)
{
int remaining = _currentChunk._buffer.Length - _currentChunk._freeOffset;
if (remaining > 0)
{
int toCopy = Math.Min(remaining, count);
Buffer.BlockCopy(buffer, offset, _currentChunk._buffer, _currentChunk._freeOffset, toCopy);
count -= toCopy;
offset += toCopy;
int toCopy = Math.Min(remaining, buffer.Length);
buffer.Slice(0, toCopy).CopyTo(new Span<byte>(_currentChunk._buffer, _currentChunk._freeOffset, toCopy));
buffer = buffer.Slice(toCopy);
_totalLength += toCopy;
_currentChunk._freeOffset += toCopy;
continue;
}
}

AppendChunk(count);
AppendChunk(buffer.Length);
}
}

Expand All @@ -66,6 +71,17 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
return Task.CompletedTask;
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled(cancellationToken);
}

Write(buffer.Span);
return ValueTask.CompletedTask;
}

private void AppendChunk(long count)
{
int nextChunkLength = _currentChunk != null ? _currentChunk._buffer.Length * 2 : InitialChunkDefaultSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,25 +907,38 @@ public void CopyFromSourceToDestination()
}
}

public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
// Validate inputs
Debug.Assert(buffer != _arrayPoolBuffer);
_deflateStream.EnsureNotDisposed();
if (count <= 0)
{
return;
return Task.CompletedTask;
}
else if (count > buffer.Length - offset)
{
// The buffer stream is either malicious or poorly implemented and returned a number of
// bytes larger than the buffer supplied to it.
throw new InvalidDataException(SR.GenericInvalidData);
return Task.FromException(new InvalidDataException(SR.GenericInvalidData));
}

Debug.Assert(_deflateStream._inflater != null);
// Feed the data from base stream into the decompression engine.
_deflateStream._inflater.SetInput(buffer, offset, count);
return WriteAsyncCore(buffer.AsMemory(offset, count), cancellationToken).AsTask();
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
_deflateStream.EnsureNotDisposed();

return WriteAsyncCore(buffer, cancellationToken);
}

private async ValueTask WriteAsyncCore(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
Debug.Assert(_deflateStream._inflater is not null);

// Feed the data from base stream into decompression engine.
_deflateStream._inflater.SetInput(buffer);

// While there's more decompressed data available, forward it to the buffer stream.
while (!_deflateStream._inflater.Finished())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
Expand All @@ -20,7 +21,7 @@ internal sealed class Inflater : IDisposable
private bool _isDisposed; // Prevents multiple disposals
private readonly int _windowBits; // The WindowBits parameter passed to Inflater construction
private ZLibNative.ZLibStreamHandle _zlibStream; // The handle to the primary underlying zlib stream
private GCHandle _inputBufferHandle; // The handle to the buffer that provides input to _zlibStream
private MemoryHandle _inputBufferHandle; // The handle to the buffer that provides input to _zlibStream
private readonly long _uncompressedSize;
private long _currentInflatedCount;

Expand Down Expand Up @@ -110,7 +111,7 @@ public unsafe int InflateVerified(byte* bufPtr, int length)
finally
{
// Before returning, make sure to release input buffer if necessary:
if (0 == _zlibStream.AvailIn && _inputBufferHandle.IsAllocated)
if (0 == _zlibStream.AvailIn && IsInputBufferHandleAllocated)
{
DeallocateInputBufferHandle();
}
Expand All @@ -121,7 +122,7 @@ private unsafe void ReadOutput(byte* bufPtr, int length, out int bytesRead)
{
if (ReadInflateOutput(bufPtr, length, ZLibNative.FlushCode.NoFlush, out bytesRead) == ZLibNative.ErrorCode.StreamEnd)
{
if (!NeedsInput() && IsGzipStream() && _inputBufferHandle.IsAllocated)
if (!NeedsInput() && IsGzipStream() && IsInputBufferHandleAllocated)
{
_finished = ResetStreamForLeftoverInput();
}
Expand All @@ -142,7 +143,7 @@ private unsafe bool ResetStreamForLeftoverInput()
{
Debug.Assert(!NeedsInput());
Debug.Assert(IsGzipStream());
Debug.Assert(_inputBufferHandle.IsAllocated);
Debug.Assert(IsInputBufferHandleAllocated);

lock (SyncLock)
{
Expand Down Expand Up @@ -180,16 +181,24 @@ public void SetInput(byte[] inputBuffer, int startIndex, int count)
Debug.Assert(NeedsInput(), "We have something left in previous input!");
Debug.Assert(inputBuffer != null);
Debug.Assert(startIndex >= 0 && count >= 0 && count + startIndex <= inputBuffer.Length);
Debug.Assert(!_inputBufferHandle.IsAllocated);
Debug.Assert(!IsInputBufferHandleAllocated);

if (0 == count)
SetInput(inputBuffer.AsMemory(startIndex, count));
}

public unsafe void SetInput(ReadOnlyMemory<byte> inputBuffer)
{
Debug.Assert(NeedsInput(), "We have something left in previous input!");
Debug.Assert(!IsInputBufferHandleAllocated);

if (inputBuffer.IsEmpty)
return;

lock (SyncLock)
{
_inputBufferHandle = GCHandle.Alloc(inputBuffer, GCHandleType.Pinned);
_zlibStream.NextIn = _inputBufferHandle.AddrOfPinnedObject() + startIndex;
_zlibStream.AvailIn = (uint)count;
_inputBufferHandle = inputBuffer.Pin();
_zlibStream.NextIn = (IntPtr)_inputBufferHandle.Pointer;
_zlibStream.AvailIn = (uint)inputBuffer.Length;
_finished = false;
}
}
Expand All @@ -201,7 +210,7 @@ private void Dispose(bool disposing)
if (disposing)
_zlibStream.Dispose();

if (_inputBufferHandle.IsAllocated)
if (IsInputBufferHandleAllocated)
DeallocateInputBufferHandle();

_isDisposed = true;
Expand Down Expand Up @@ -313,14 +322,16 @@ private ZLibNative.ErrorCode Inflate(ZLibNative.FlushCode flushCode)
/// </summary>
private void DeallocateInputBufferHandle()
{
Debug.Assert(_inputBufferHandle.IsAllocated);
Debug.Assert(IsInputBufferHandleAllocated);

lock (SyncLock)
{
_zlibStream.AvailIn = 0;
_zlibStream.NextIn = ZLibNative.ZNullPtr;
_inputBufferHandle.Free();
_inputBufferHandle.Dispose();
}
}

private unsafe bool IsInputBufferHandleAllocated => _inputBufferHandle.Pointer != default;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
return _networkStream.ReadAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
return _networkStream.ReadAsync(buffer, cancellationToken);
}

public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback? callback, object? state)
{
return _networkStream.BeginWrite(buffer, offset, size, callback, state);
Expand All @@ -204,6 +209,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
return _networkStream.WriteAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
return _networkStream.WriteAsync(buffer, cancellationToken);
}

public override void Flush()
{
_networkStream.Flush();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
return _buffer.WriteAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
return _buffer.WriteAsync(buffer, cancellationToken);
}

public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState)
{
ValidateBufferArguments(buffer, offset, count);
Expand Down

0 comments on commit 62f6c08

Please sign in to comment.