Skip to content

Commit

Permalink
Override more Stream members on System.IO.Compression streams (#54518)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Jun 25, 2021
1 parent 419506b commit e9f101c
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -97,19 +97,22 @@ public override void SetLength(long value)
public override int Read(byte[] buffer, int offset, int count)
{
ValidateBufferArguments(buffer, offset, count);
return Read(new Span<byte>(buffer, offset, count));
}

public override int Read(Span<byte> buffer)
{
EnsureNotDisposed();

int bytesRead;
int currentOffset = offset;
int remainingCount = count;
int initialLength = buffer.Length;

int bytesRead;
while (true)
{
bytesRead = _inflater.Inflate(buffer, currentOffset, remainingCount);
currentOffset += bytesRead;
remainingCount -= bytesRead;
bytesRead = _inflater.Inflate(buffer);
buffer = buffer.Slice(bytesRead);

if (remainingCount == 0)
if (buffer.Length == 0)
{
break;
}
Expand All @@ -136,7 +139,13 @@ public override int Read(byte[] buffer, int offset, int count)
_inflater.SetInput(_buffer, 0, bytes);
}

return count - remainingCount;
return initialLength - buffer.Length;
}

public override int ReadByte()
{
byte b = default;
return Read(MemoryMarshal.CreateSpan(ref b, 1)) == 1 ? b : -1;
}

private void EnsureNotDisposed()
Expand Down Expand Up @@ -169,7 +178,7 @@ private ValueTask<int> ReadAsyncInternal(Memory<byte> buffer, CancellationToken
try
{
// Try to read decompressed data in output buffer
int bytesRead = _inflater.Inflate(buffer);
int bytesRead = _inflater.Inflate(buffer.Span);
if (bytesRead != 0)
{
// If decompression output buffer is not empty, return immediately.
Expand Down Expand Up @@ -224,7 +233,7 @@ private async ValueTask<int> ReadAsyncCore(ValueTask<int> readTask, Memory<byte>

// Feed the data from base stream into decompression engine
_inflater.SetInput(_buffer, 0, bytesRead);
bytesRead = _inflater.Inflate(buffer);
bytesRead = _inflater.Inflate(buffer.Span);

if (bytesRead == 0 && !_inflater.Finished())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void SetInput(byte[] inputBytes, int offset, int length) =>

public int AvailableOutput => _output.AvailableBytes;

public int Inflate(Memory<byte> bytes)
public int Inflate(Span<byte> bytes)
{
// copy bytes from output to outputbytes if we have available bytes
// if buffer is not filled up. keep decoding until no input are available
Expand Down Expand Up @@ -139,7 +139,7 @@ public int Inflate(Memory<byte> bytes)
return count;
}

public int Inflate(byte[] bytes, int offset, int length) => Inflate(bytes.AsMemory(offset, length));
public int Inflate(byte[] bytes, int offset, int length) => Inflate(bytes.AsSpan(offset, length));

//Each block of compressed data begins with 3 header bits
// containing the following data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public int CopyFrom(InputBuffer input, int length)
public int AvailableBytes => _bytesUsed;

/// <summary>Copy the decompressed bytes to output buffer.</summary>
public int CopyTo(Memory<byte> output)
public int CopyTo(Span<byte> output)
{
int copy_end;

Expand All @@ -140,19 +140,13 @@ public int CopyTo(Memory<byte> output)
{
// this means we need to copy two parts separately
// copy the taillen bytes from the end of the output window
_window.AsSpan(WindowSize - tailLen, tailLen).CopyTo(output.Span);
_window.AsSpan(WindowSize - tailLen, tailLen).CopyTo(output);
output = output.Slice(tailLen, copy_end);
}
_window.AsSpan(copy_end - output.Length, output.Length).CopyTo(output.Span);
_window.AsSpan(copy_end - output.Length, output.Length).CopyTo(output);
_bytesUsed -= copied;
Debug.Assert(_bytesUsed >= 0, "check this function and find why we copied more bytes than we have");
return copied;
}

/// <summary>Copy the decompressed bytes to output array.</summary>
public int CopyTo(byte[] output, int offset, int length)
{
return CopyTo(output.AsMemory(offset, length));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace System.IO.Compression
{
Expand Down Expand Up @@ -1222,6 +1225,38 @@ public override void Write(ReadOnlySpan<byte> source)
_position += source.Length;
}

public override void WriteByte(byte value) =>
Write(MemoryMarshal.CreateReadOnlySpan(ref value, 1));

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateBufferArguments(buffer, offset, count);
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
Debug.Assert(CanWrite);

return !buffer.IsEmpty ?
Core(buffer, cancellationToken) :
default;

async ValueTask Core(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
if (!_everWritten)
{
_everWritten = true;
// write local header, we are good to go
_usedZip64inLH = _entry.WriteLocalFileHeader(isEmptyFile: false);
}

await _crcSizeStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
_position += buffer.Length;
}
}

public override void Flush()
{
ThrowIfDisposed();
Expand All @@ -1230,6 +1265,14 @@ public override void Flush()
_crcSizeStream.Flush();
}

public override Task FlushAsync(CancellationToken cancellationToken)
{
ThrowIfDisposed();
Debug.Assert(CanWrite);

return _crcSizeStream.FlushAsync(cancellationToken);
}

protected override void Dispose(bool disposing)
{
if (disposing && !_isDisposed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

namespace System.IO.Compression
{
Expand Down Expand Up @@ -95,6 +98,38 @@ public override int Read(byte[] buffer, int offset, int count)
return _baseStream.Read(buffer, offset, count);
}

public override int Read(Span<byte> buffer)
{
ThrowIfDisposed();
ThrowIfCantRead();

return _baseStream.Read(buffer);
}

public override int ReadByte()
{
ThrowIfDisposed();
ThrowIfCantRead();

return _baseStream.ReadByte();
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ThrowIfDisposed();
ThrowIfCantRead();

return _baseStream.ReadAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
ThrowIfCantRead();

return _baseStream.ReadAsync(buffer, cancellationToken);
}

public override long Seek(long offset, SeekOrigin origin)
{
ThrowIfDisposed();
Expand Down Expand Up @@ -128,6 +163,30 @@ public override void Write(ReadOnlySpan<byte> source)
_baseStream.Write(source);
}

public override void WriteByte(byte value)
{
ThrowIfDisposed();
ThrowIfCantWrite();

_baseStream.WriteByte(value);
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ThrowIfDisposed();
ThrowIfCantWrite();

return _baseStream.WriteAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
ThrowIfCantWrite();

return _baseStream.WriteAsync(buffer, cancellationToken);
}

public override void Flush()
{
ThrowIfDisposed();
Expand All @@ -136,6 +195,14 @@ public override void Flush()
_baseStream.Flush();
}

public override Task FlushAsync(CancellationToken cancellationToken)
{
ThrowIfDisposed();
ThrowIfCantWrite();

return _baseStream.FlushAsync(cancellationToken);
}

protected override void Dispose(bool disposing)
{
if (disposing && !_isDisposed)
Expand Down Expand Up @@ -259,6 +326,43 @@ public override int Read(Span<byte> destination)
return ret;
}

public override int ReadByte()
{
byte b = default;
return Read(MemoryMarshal.CreateSpan(ref b, 1)) == 1 ? b : -1;
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateBufferArguments(buffer, offset, count);
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
ThrowIfCantRead();
return Core(buffer, cancellationToken);

async ValueTask<int> Core(Memory<byte> buffer, CancellationToken cancellationToken)
{
if (_superStream.Position != _positionInSuperStream)
{
_superStream.Seek(_positionInSuperStream, SeekOrigin.Begin);
}

if (_positionInSuperStream > _endInSuperStream - buffer.Length)
{
buffer = buffer.Slice(0, (int)(_endInSuperStream - _positionInSuperStream));
}

int ret = await _superStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);

_positionInSuperStream += ret;
return ret;
}
}

public override long Seek(long offset, SeekOrigin origin)
{
ThrowIfDisposed();
Expand Down Expand Up @@ -437,6 +541,39 @@ public override void Write(ReadOnlySpan<byte> source)
_position += source.Length;
}

public override void WriteByte(byte value) =>
Write(MemoryMarshal.CreateReadOnlySpan(ref value, 1));

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateBufferArguments(buffer, offset, count);
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
Debug.Assert(CanWrite);

return !buffer.IsEmpty ?
Core(buffer, cancellationToken) :
default;

async ValueTask Core(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
if (!_everWritten)
{
_initialPosition = _baseBaseStream.Position;
_everWritten = true;
}

_checksum = Crc32Helper.UpdateCrc32(_checksum, buffer.Span);

await _baseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
_position += buffer.Length;
}
}

public override void Flush()
{
ThrowIfDisposed();
Expand All @@ -447,6 +584,12 @@ public override void Flush()
_baseStream.Flush();
}

public override Task FlushAsync(CancellationToken cancellationToken)
{
ThrowIfDisposed();
return _baseStream.FlushAsync(cancellationToken);
}

protected override void Dispose(bool disposing)
{
if (disposing && !_isDisposed)
Expand Down

0 comments on commit e9f101c

Please sign in to comment.