Skip to content

Commit

Permalink
decode tests & bugfixes (#42256)
Browse files Browse the repository at this point in the history
* decode tests & bugfixes

* roundtrip tests

* more tests

* better errors | remove duplicate test
  • Loading branch information
jaschrep-msft authored Feb 29, 2024
1 parent 034c9cd commit c809883
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 47 deletions.
3 changes: 3 additions & 0 deletions sdk/storage/Azure.Storage.Common/src/Shared/Errors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public static ArgumentException CannotDeferTransactionalHashVerification()
public static ArgumentException CannotInitializeWriteStreamWithData()
=> new ArgumentException("Initialized buffer for StorageWriteStream must be empty.");

public static InvalidDataException InvalidStructuredMessage(string optionalMessage = default)
=> new InvalidDataException(("Invalid structured message data. " + optionalMessage ?? "").Trim());

internal static void VerifyStreamPosition(Stream stream, string streamName)
{
if (stream != null && stream.CanSeek && stream.Length > 0 && stream.Position >= stream.Length)
Expand Down
15 changes: 11 additions & 4 deletions sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@ public static class V1_0
public const byte MessageVersionByte = 1;

public const int StreamHeaderLength = 13;
public const int StreamHeaderVersionOffset = 0;
public const int StreamHeaderMessageLengthOffset = 1;
public const int StreamHeaderFlagsOffset = 9;
public const int StreamHeaderSegmentCountOffset = 11;

public const int SegmentHeaderLength = 10;
public const int SegmentHeaderNumOffset = 0;
public const int SegmentHeaderContentLengthOffset = 2;

#region Stream Header
public static void ReadStreamHeader(
Expand All @@ -35,13 +42,13 @@ public static void ReadStreamHeader(
out int totalSegments)
{
Errors.AssertBufferExactSize(buffer, 13, nameof(buffer));
if (buffer[0] != 1)
if (buffer[StreamHeaderVersionOffset] != 1)
{
throw new InvalidDataException("Unrecognized version of structured message.");
}
messageLength = (long)BinaryPrimitives.ReadUInt64LittleEndian(buffer.Slice(1, 8));
flags = (Flags)BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(9, 2));
totalSegments = BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(11, 2));
messageLength = (long)BinaryPrimitives.ReadUInt64LittleEndian(buffer.Slice(StreamHeaderMessageLengthOffset, 8));
flags = (Flags)BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(StreamHeaderFlagsOffset, 2));
totalSegments = BinaryPrimitives.ReadUInt16LittleEndian(buffer.Slice(StreamHeaderSegmentCountOffset, 2));
}

public static int WriteStreamHeader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ namespace Azure.Storage.Shared;
/// </remarks>
internal class StructuredMessageDecodingStream : Stream
{
private enum SMRegion
{
StreamHeader,
StreamFooter,
SegmentHeader,
SegmentFooter,
SegmentContent,
}

private readonly Stream _innerBufferedStream;

private byte[] _metadataBuffer = ArrayPool<byte>.Shared.Rent(Constants.KB);
Expand All @@ -48,9 +57,12 @@ internal class StructuredMessageDecodingStream : Stream
private int _streamFooterLength;
private int _segmentHeaderLength;
private int _segmentFooterLength;

private int _totalSegments;
private long _innerStreamLength;

private StructuredMessage.Flags _flags;
private bool _processedFooter = false;
private bool _disposed;

private StorageCrc64HashAlgorithm _totalContentCrc;
Expand All @@ -76,19 +88,6 @@ public override long Position
set => throw new NotSupportedException();
}

#region Position
private long _innerStreamLength;

private enum SMRegion
{
StreamHeader,
StreamFooter,
SegmentHeader,
SegmentFooter,
SegmentContent,
}
#endregion

public StructuredMessageDecodingStream(
Stream innerStream)
{
Expand Down Expand Up @@ -116,9 +115,15 @@ public override int Read(byte[] buf, int offset, int count)
do
{
read = _innerBufferedStream.Read(buf, offset, count);
_innerStreamConsumed += read;
decodedRead = Decode(new Span<byte>(buf, offset, read));
} while (decodedRead <= 0 && read > 0);

if (read <= 0)
{
AssertDecodeFinished();
}

return decodedRead;
}

Expand All @@ -129,9 +134,15 @@ public override async Task<int> ReadAsync(byte[] buf, int offset, int count, Can
do
{
read = await _innerBufferedStream.ReadAsync(buf, offset, count, cancellationToken).ConfigureAwait(false);
_innerStreamConsumed += read;
decodedRead = Decode(new Span<byte>(buf, offset, read));
} while (decodedRead <= 0 && read > 0);

if (read <= 0)
{
AssertDecodeFinished();
}

return decodedRead;
}

Expand All @@ -143,9 +154,15 @@ public override int Read(Span<byte> buf)
do
{
read = _innerBufferedStream.Read(buf);
_innerStreamConsumed += read;
decodedRead = Decode(buf.Slice(0, read));
} while (decodedRead <= 0 && read > 0);

if (read <= 0)
{
AssertDecodeFinished();
}

return decodedRead;
}

Expand All @@ -156,15 +173,31 @@ public override async ValueTask<int> ReadAsync(Memory<byte> buf, CancellationTok
do
{
read = await _innerBufferedStream.ReadAsync(buf).ConfigureAwait(false);
_innerStreamConsumed += read;
decodedRead = Decode(buf.Slice(0, read).Span);
} while (decodedRead <= 0 && read > 0);

if (read <= 0)
{
AssertDecodeFinished();
}

return decodedRead;
}
#endif

private SMRegion _currentRegion;
private int _currentSegmentNum;
private void AssertDecodeFinished()
{
if (_streamFooterLength > 0 && !_processedFooter)
{
throw Errors.InvalidStructuredMessage("Missing or incomplete trailer.");
}
_processedFooter = true;
}

private long _innerStreamConsumed = 0;
private SMRegion _currentRegion = SMRegion.StreamHeader;
private int _currentSegmentNum = 0;
private long _currentSegmentContentLength;
private long _currentSegmentContentRemaining;
private long CurrentRegionLength => _currentRegion switch
Expand Down Expand Up @@ -363,6 +396,17 @@ private int ProcessStreamFooter(ReadOnlySpan<byte> span)
}
}
}

if (_innerStreamConsumed != _innerStreamLength)
{
throw Errors.InvalidStructuredMessage("Unexpected message size.");
}
if (_currentSegmentNum != _totalSegments)
{
throw Errors.InvalidStructuredMessage("Missing expected message segments.");
}

_processedFooter = true;
return totalProcessed;
}

Expand All @@ -375,7 +419,7 @@ private int ProcessSegmentHeader(ReadOnlySpan<byte> span)
_currentSegmentContentRemaining = _currentSegmentContentLength;
if (newSegNum != _currentSegmentNum + 1)
{
throw new InvalidDataException("Unexpected segment number in structured message.");
throw Errors.InvalidStructuredMessage("Unexpected segment number in structured message.");
}
_currentSegmentNum = newSegNum;
_currentRegion = SMRegion.SegmentContent;
Expand Down
Loading

0 comments on commit c809883

Please sign in to comment.