diff --git a/Sources/SwiftProtobuf/AsyncMessageSequence.swift b/Sources/SwiftProtobuf/AsyncMessageSequence.swift new file mode 100644 index 000000000..59a1ae03b --- /dev/null +++ b/Sources/SwiftProtobuf/AsyncMessageSequence.swift @@ -0,0 +1,208 @@ +// +// Sources/SwiftProtobuf/AsyncMessageSequence.swift - Async sequence over binary delimited protobuf +// +// Copyright (c) 2023 Apple Inc. and the project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See LICENSE.txt for license information: +// https://github.com/apple/swift-protobuf/blob/main/LICENSE.txt +// +// ----------------------------------------------------------------------------- +/// +/// An async sequence of messages decoded from a binary delimited protobuf stream. +/// +// ----------------------------------------------------------------------------- + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncSequence where Element == UInt8 { + /// Creates an asynchronous sequence of size-delimited messages from this sequence of bytes. + /// Delimited format allows a single file or stream to contain multiple messages. A delimited message + /// is a varint encoding the message size followed by a message of exactly that size. + /// + /// - Parameters: + /// - messageType: The type of message to read. + /// - extensions: An `ExtensionMap` used to look up and decode any extensions in + /// messages encoded by this sequence, or in messages nested within these messages. + /// - partial: If `false` (the default), after decoding a message, `Message.isInitialized` + /// will be checked to ensure all fields are present. If any are missing, + /// `BinaryDecodingError.missingRequiredFields` will be thrown. + /// - options: The BinaryDecodingOptions to use. + /// - Returns: An asynchronous sequence of messages read from the `AsyncSequence` of bytes. + /// - Throws: `BinaryDecodingError` if decoding fails, throws + /// `BinaryDelimited.Error` for some reading errors, + /// `BinaryDecodingError.truncated` if the stream ends before fully decoding a + /// message or a delimiter, + /// `BinaryDecodingError.malformedProtobuf`if a delimiter could not be read and + /// `BinaryDecodingError.tooLarge` if a size delimiter of 2GB or greater is found. + @inlinable + public func binaryProtobufDelimitedMessages( + of messageType: M.Type = M.self, + extensions: ExtensionMap? = nil, + partial: Bool = false, + options: BinaryDecodingOptions = BinaryDecodingOptions() + ) -> AsyncMessageSequence { + AsyncMessageSequence( + base: self, + extensions: extensions, + partial: partial, + options: options + ) + } +} + +/// An asynchronous sequence of messages decoded from an asynchronous sequence of bytes. +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +public struct AsyncMessageSequence< + Base: AsyncSequence, + M: Message +>: AsyncSequence where Base.Element == UInt8 { + + /// The message type in this asynchronous sequence. + public typealias Element = M + + private let base: Base + private let extensions: ExtensionMap? + private let partial: Bool + private let options: BinaryDecodingOptions + + /// Reads size-delimited messages from the given sequence of bytes. Delimited + /// format allows a single file or stream to contain multiple messages. A delimited message + /// is a varint encoding the message size followed by a message of exactly that size. + /// + /// - Parameters: + /// - baseSequence: The `AsyncSequence` to read messages from. + /// - extensions: An `ExtensionMap` used to look up and decode any extensions in + /// messages encoded by this sequence, or in messages nested within these messages. + /// - partial: If `false` (the default), after decoding a message, `Message.isInitialized` + /// will be checked to ensure all fields are present. If any are missing, + /// `BinaryDecodingError.missingRequiredFields` will be thrown. + /// - options: The BinaryDecodingOptions to use. + /// - Returns: An asynchronous sequence of messages read from the `AsyncSequence` of bytes. + /// - Throws: `BinaryDecodingError` if decoding fails, throws + /// `BinaryDelimited.Error` for some reading errors, + /// `BinaryDecodingError.truncated` if the stream ends before fully decoding a + /// message or a delimiter, + /// `BinaryDecodingError.malformedProtobuf`if a delimiter could not be read and + /// `BinaryDecodingError.tooLarge` if a size delimiter of 2GB or greater is found. + public init( + base: Base, + extensions: ExtensionMap? = nil, + partial: Bool = false, + options: BinaryDecodingOptions = BinaryDecodingOptions() + ) { + self.base = base + self.extensions = extensions + self.partial = partial + self.options = options + } + + /// An asynchronous iterator that produces the messages of this asynchronous sequence + public struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline + var iterator: Base.AsyncIterator? + @usableFromInline + let extensions: ExtensionMap? + @usableFromInline + let partial: Bool + @usableFromInline + let options: BinaryDecodingOptions + + init( + iterator: Base.AsyncIterator, + extensions: ExtensionMap?, + partial: Bool, + options: BinaryDecodingOptions + ) { + self.iterator = iterator + self.extensions = extensions + self.partial = partial + self.options = options + } + + /// Aysnchronously reads the next varint + @inlinable + mutating func nextVarInt() async throws -> UInt64? { + var messageSize: UInt64 = 0 + var shift: UInt64 = 0 + + while let byte = try await iterator?.next() { + messageSize |= UInt64(byte & 0x7f) << shift + shift += UInt64(7) + if shift > 35 { + iterator = nil + throw BinaryDecodingError.malformedProtobuf + } + if (byte & 0x80 == 0) { + return messageSize + } + } + if (shift > 0) { + // The stream has ended inside a varint. + iterator = nil + throw BinaryDecodingError.truncated + } + return nil // End of stream reached. + } + + /// Asynchronously advances to the next message and returns it, or ends the + /// sequence if there is no next message. + /// + /// - Returns: The next message, if it exists, or `nil` to signal the end of + /// the sequence. + @inlinable + public mutating func next() async throws -> M? { + guard let messageSize = try await nextVarInt() else { + iterator = nil + return nil + } + if messageSize == 0 { + return try M( + serializedBytes: [], + extensions: extensions, + partial: partial, + options: options + ) + } else if messageSize > 0x7fffffff { + iterator = nil + throw BinaryDecodingError.tooLarge + } + + var buffer = [UInt8](repeating: 0, count: Int(messageSize)) + var consumedBytes = 0 + + while let byte = try await iterator?.next() { + buffer[consumedBytes] = byte + consumedBytes += 1 + if consumedBytes == messageSize { + return try M( + serializedBytes: buffer, + extensions: extensions, + partial: partial, + options: options + ) + } + } + throw BinaryDecodingError.truncated // The buffer was not filled. + } + } + + /// Creates the asynchronous iterator that produces elements of this + /// asynchronous sequence. + /// + /// - Returns: An instance of the `AsyncIterator` type used to produce + /// messages in the asynchronous sequence. + public func makeAsyncIterator() -> AsyncMessageSequence.AsyncIterator { + AsyncIterator( + iterator: base.makeAsyncIterator(), + extensions: extensions, + partial: partial, + options: options + ) + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncMessageSequence: Sendable where Base: Sendable { } + +@available(*, unavailable) +extension AsyncMessageSequence.AsyncIterator: Sendable { } diff --git a/Tests/SwiftProtobufTests/Test_AsyncMessageSequence.swift b/Tests/SwiftProtobufTests/Test_AsyncMessageSequence.swift new file mode 100644 index 000000000..91a27ac47 --- /dev/null +++ b/Tests/SwiftProtobufTests/Test_AsyncMessageSequence.swift @@ -0,0 +1,223 @@ +// Tests/SwiftProtobufTests/Test_AsyncMessageSequence.swift - +// +// Copyright (c) 2023 Apple Inc. and the project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See LICENSE.txt for license information: +// https://github.com/apple/swift-protobuf/blob/main/LICENSE.txt +// +// ----------------------------------------------------------------------------- +/// +/// Tests the decoding of binary-delimited message streams, ensuring various invalid stream scenarios are +/// handled gracefully. +/// +// ----------------------------------------------------------------------------- + +import Foundation +import XCTest +import SwiftProtobuf + +final class Test_AsyncMessageSequence: XCTestCase { + + // Decode a valid binary delimited stream + func testValidSequence() async throws { + let expected: [Int32] = Array(1...5) + var messages = [SwiftProtoTesting_TestAllTypes]() + for messageNumber in expected { + let message = SwiftProtoTesting_TestAllTypes.with { + $0.optionalInt32 = messageNumber + } + messages.append(message) + } + let serialized = try serializedMessageData(messages: messages) + let asyncBytes = asyncByteStream(bytes: serialized) + + // Recreate the original array + let decoded = asyncBytes.binaryProtobufDelimitedMessages(of: SwiftProtoTesting_TestAllTypes.self) + let observed = try await decoded.reduce(into: [Int32]()) { array, element in + array.append(element.optionalInt32) + } + XCTAssertEqual(observed, expected, "The original and re-created arrays should be equal.") + } + + // Decode a message from a stream, discarding unknown fields + func testBinaryDecodingOptions() async throws { + let unknownFields: [UInt8] = [ + // Field 1, 150 + 0x08, 0x96, 0x01, + // Field 2, string "testing" + 0x12, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67 + ] + let message = try SwiftProtoTesting_TestEmptyMessage(serializedBytes: unknownFields) + let serialized = try serializedMessageData(messages: [message]) + var asyncBytes = asyncByteStream(bytes: serialized) + var decodingOptions = BinaryDecodingOptions() + let decodedWithUnknown = asyncBytes.binaryProtobufDelimitedMessages( + of: SwiftProtoTesting_TestEmptyMessage.self, + options: decodingOptions + ) + + // First ensure unknown fields are decoded + for try await message in decodedWithUnknown { + XCTAssertEqual(Array(message.unknownFields.data), unknownFields) + } + asyncBytes = asyncByteStream(bytes: serialized) + // Then re-run ensuring unknowh fields are discarded + decodingOptions.discardUnknownFields = true + let decodedWithUnknownDiscarded = asyncBytes.binaryProtobufDelimitedMessages( + of: SwiftProtoTesting_TestEmptyMessage.self, + options: decodingOptions + ) + var count = 0; + for try await message in decodedWithUnknownDiscarded { + XCTAssertTrue(message.unknownFields.data.isEmpty) + count += 1 + } + XCTAssertEqual(count, 1, "Expected one message with unknown fields discarded.") + } + + // Decode zero length messages + func testZeroLengthMessages() async throws { + var messages = [SwiftProtoTesting_TestAllTypes]() + for _ in 1...5 { + messages.append(SwiftProtoTesting_TestAllTypes()) + } + let serialized = try serializedMessageData(messages: messages) + let asyncBytes = asyncByteStream(bytes: serialized) + + var count = 0 + let decoded = AsyncMessageSequence, SwiftProtoTesting_TestAllTypes>(base: asyncBytes) + for try await message in decoded { + XCTAssertEqual(message, SwiftProtoTesting_TestAllTypes()) + count += 1 + } + XCTAssertEqual(count, 5, "Expected five messages with default fields.") + } + + // Stream with a single zero varint + func testStreamZeroVarintOnly() async throws { + let seq = asyncByteStream(bytes: [0]) + let decoded = seq.binaryProtobufDelimitedMessages(of: SwiftProtoTesting_TestAllTypes.self) + + var count = 0 + for try await message in decoded { + XCTAssertEqual(message, SwiftProtoTesting_TestAllTypes()) + count += 1 + } + XCTAssertEqual(count, 1) + } + + // Empty stream with zero bytes + func testEmptyStream() async throws { + let asyncBytes = asyncByteStream(bytes: []) + let messages = asyncBytes.binaryProtobufDelimitedMessages(of: SwiftProtoTesting_TestAllTypes.self) + for try await _ in messages { + XCTFail("Shouldn't have returned a value for an empty stream.") + } + } + + // A stream with legal non-zero varint but no message + func testNonZeroVarintNoMessage() async throws { + let asyncBytes = asyncByteStream(bytes: [0x96, 0x01]) + let decoded = asyncBytes.binaryProtobufDelimitedMessages(of: SwiftProtoTesting_TestAllTypes.self) + var truncatedThrown = false + do { + for try await _ in decoded { + XCTFail("Shouldn't have returned a value for an empty stream.") + } + } catch { + if error as! BinaryDecodingError == .truncated { + truncatedThrown = true + } + } + XCTAssertTrue(truncatedThrown, "Should throw a BinaryDecodingError.truncated") + } + + // Single varint describing a 2GB message + func testTooLarge() async throws { + let asyncBytes = asyncByteStream(bytes: [128, 128, 128, 128, 8]) + var tooLargeThrown = false + let decoded = asyncBytes.binaryProtobufDelimitedMessages(of: SwiftProtoTesting_TestAllTypes.self) + do { + for try await _ in decoded { + XCTFail("Shouldn't have returned a value for an invalid stream.") + } + } catch { + if error as! BinaryDecodingError == .tooLarge { + tooLargeThrown = true + } + } + XCTAssertTrue(tooLargeThrown, "Should throw a BinaryDecodingError.tooLarge") + } + + // Stream with truncated varint + func testTruncatedVarint() async throws { + let asyncBytes = asyncByteStream(bytes: [192]) + + let decoded = asyncBytes.binaryProtobufDelimitedMessages(of: SwiftProtoTesting_TestAllTypes.self) + var truncatedThrown = false + do { + for try await _ in decoded { + XCTFail("Shouldn't have returned a value for an empty stream.") + } + } catch { + if error as! BinaryDecodingError == .truncated { + truncatedThrown = true + } + } + XCTAssertTrue(truncatedThrown, "Should throw a BinaryDecodingError.truncated") + } + + // Stream with a valid varint and message, but the following varint is truncated + func testValidMessageThenTruncatedVarint() async throws { + var truncatedThrown = false + let msg = SwiftProtoTesting_TestAllTypes.with { + $0.optionalInt64 = 123456789 + } + let truncatedVarint: [UInt8] = [224, 216] + var serialized = try serializedMessageData(messages: [msg]) + serialized += truncatedVarint + let asyncBytes = asyncByteStream(bytes: serialized) + + do { + var count = 0 + let decoded = asyncBytes.binaryProtobufDelimitedMessages(of: SwiftProtoTesting_TestAllTypes.self) + for try await message in decoded { + XCTAssertEqual(message, SwiftProtoTesting_TestAllTypes.with { + $0.optionalInt64 = 123456789 + }) + count += 1 + if count > 1 { + XCTFail("Expected one message only.") + } + } + XCTAssertEqual(count, 1, "One message should be deserialized") + } catch { + if error as! BinaryDecodingError == .truncated { + truncatedThrown = true + } + } + XCTAssertTrue(truncatedThrown, "Should throw a BinaryDecodingError.truncated") + } + + fileprivate func asyncByteStream(bytes: [UInt8]) -> AsyncStream { + AsyncStream(UInt8.self) { continuation in + for byte in bytes { + continuation.yield(byte) + } + continuation.finish() + } + } + + fileprivate func serializedMessageData(messages: [Message]) throws -> [UInt8] { + let memoryOutputStream = OutputStream.toMemory() + memoryOutputStream.open() + for message in messages { + XCTAssertNoThrow(try BinaryDelimited.serialize(message: message, to: memoryOutputStream)) + } + memoryOutputStream.close() + let nsData = memoryOutputStream.property(forKey: .dataWrittenToMemoryStreamKey) as! NSData + let data = Data(referencing: nsData) + return [UInt8](data) + } +}