Skip to content

Commit

Permalink
Read in chunks for large messages.
Browse files Browse the repository at this point in the history
`BinaryDelimited` also moved to this model in #1382, it can help stop OOM
attaches when there really aren't enough bytes being sent.
  • Loading branch information
thomasvl committed Aug 24, 2023
1 parent 9604b52 commit 34c9781
Showing 1 changed file with 52 additions and 29 deletions.
81 changes: 52 additions & 29 deletions Sources/SwiftProtobuf/AsyncMessageSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ public struct AsyncMessageSequence<
Base: AsyncSequence,
M: Message
>: AsyncSequence where Base.Element == UInt8 {

/// The message type in this asynchronous sequence.
public typealias Element = M

private let base: Base
private let extensions: ExtensionMap?
private let partial: Bool
private let options: BinaryDecodingOptions

/// Reads size-delimited messages from the given sequence of bytes. Delimited
/// format allows a single file or stream to contain multiple messages. A delimited message
/// is a varint encoding the message size followed by a message of exactly that size.
Expand Down Expand Up @@ -95,7 +95,7 @@ public struct AsyncMessageSequence<
self.partial = partial
self.options = options
}

/// An asynchronous iterator that produces the messages of this asynchronous sequence
public struct AsyncIterator: AsyncIteratorProtocol {
@usableFromInline
Expand All @@ -106,7 +106,7 @@ public struct AsyncMessageSequence<
let partial: Bool
@usableFromInline
let options: BinaryDecodingOptions

init(
iterator: Base.AsyncIterator,
extensions: ExtensionMap?,
Expand All @@ -118,13 +118,13 @@ public struct AsyncMessageSequence<
self.partial = partial
self.options = options
}

/// Aysnchronously reads the next varint
@inlinable
mutating func nextVarInt() async throws -> UInt64? {
var messageSize: UInt64 = 0
var shift: UInt64 = 0

while let byte = try await iterator?.next() {
messageSize |= UInt64(byte & 0x7f) << shift
shift += UInt64(7)
Expand All @@ -143,7 +143,39 @@ public struct AsyncMessageSequence<
}
return nil // End of stream reached.
}


/// Helper to read the given number of bytes.
@usableFromInline
mutating func readBytes(_ size: Int) async throws -> [UInt8] {
// Even though the bytes are read in chunks, things can still hard fail if
// there isn't enough memory to append to have all the bytes at once for
// parsing; but this atleast catches some possible OOM attacks.
var bytesNeeded = size
var buffer = [UInt8]()
let kChunkSize = 16 * 1024 * 1024
var chunk = [UInt8](repeating: 0, count: Swift.min(bytesNeeded, kChunkSize))
while bytesNeeded > 0 {
var consumedBytes = 0
let maxLength = Swift.min(bytesNeeded, chunk.count)
while consumedBytes < maxLength {
guard let byte = try await iterator?.next() else {
// The iterator hit the end, but the chunk wasn't filled, so the full
// payload wasn't read.
throw BinaryDecodingError.truncated
}
chunk[consumedBytes] = byte
consumedBytes += 1
}
if consumedBytes < chunk.count {
buffer += chunk[0..<consumedBytes]
} else {
buffer += chunk
}
bytesNeeded -= maxLength
}
return buffer
}

/// Asynchronously advances to the next message and returns it, or ends the
/// sequence if there is no next message.
///
Expand All @@ -155,37 +187,28 @@ public struct AsyncMessageSequence<
iterator = nil
return nil
}
guard messageSize <= UInt64(0x7fffffff) else {
iterator = nil
throw BinaryDecodingError.tooLarge
}
if messageSize == 0 {
return try M(
serializedBytes: [],
extensions: extensions,
partial: partial,
options: options
)
} else if messageSize > 0x7fffffff {
iterator = nil
throw BinaryDecodingError.tooLarge
}

var buffer = [UInt8](repeating: 0, count: Int(messageSize))
var consumedBytes = 0

while let byte = try await iterator?.next() {
buffer[consumedBytes] = byte
consumedBytes += 1
if consumedBytes == messageSize {
return try M(
serializedBytes: buffer,
extensions: extensions,
partial: partial,
options: options
)
}
}
throw BinaryDecodingError.truncated // The buffer was not filled.
let buffer = try await readBytes(Int(messageSize))
return try M(
serializedBytes: buffer,
extensions: extensions,
partial: partial,
options: options
)
}
}

/// Creates the asynchronous iterator that produces elements of this
/// asynchronous sequence.
///
Expand Down

0 comments on commit 34c9781

Please sign in to comment.