Skip to content

Commit

Permalink
Merge pull request #85 from swiftwasm/katei/buffer-generator-protocol
Browse files Browse the repository at this point in the history
Add RandomBufferGenerator protocol for WASI random_get
  • Loading branch information
kateinoigakukun authored Apr 26, 2024
2 parents da13542 + 99252ae commit 949ff74
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 3 deletions.
48 changes: 48 additions & 0 deletions Sources/WASI/RandomBufferGenerator.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import SwiftShims // For swift_stdlib_random

/// A type that provides random bytes.
///
/// This type is similar to `RandomNumberGenerator` in Swift standard library,
/// but it provides a way to fill a buffer with random bytes instead of a single
/// random number.
public protocol RandomBufferGenerator {

/// Fills the buffer with random bytes.
///
/// - Parameter buffer: The destination buffer to fill with random bytes.
mutating func fill(buffer: UnsafeMutableBufferPointer<UInt8>)
}

extension RandomBufferGenerator where Self: RandomNumberGenerator {
public mutating func fill(buffer: UnsafeMutableBufferPointer<UInt8>) {
// The buffer is filled with 8 bytes at once.
let count = buffer.count / 8
for i in 0..<count {
let random = self.next()
withUnsafeBytes(of: random) { randomBytes in
let startOffset = i * 8
let destination = UnsafeMutableBufferPointer(rebasing: buffer[startOffset..<(startOffset + 8)])
randomBytes.copyBytes(to: destination)
}
}

// If the buffer size is not a multiple of 8, fill the remaining bytes.
let remaining = buffer.count % 8
if remaining > 0 {
let random = self.next()
withUnsafeBytes(of: random) { randomBytes in
let startOffset = count * 8
let destination = UnsafeMutableBufferPointer(rebasing: buffer[startOffset..<(startOffset + remaining)])
randomBytes.copyBytes(to: destination)
}
}
}
}

extension SystemRandomNumberGenerator: RandomBufferGenerator {
public mutating func fill(buffer: UnsafeMutableBufferPointer<UInt8>) {
guard let baseAddress = buffer.baseAddress else { return }
// Directly call underlying C function of SystemRandomNumberGenerator
swift_stdlib_random(baseAddress, Int(buffer.count))
}
}
8 changes: 5 additions & 3 deletions Sources/WASI/WASI.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import Foundation
import SwiftShims // For swift_stdlib_random
import SystemExtras
import SystemPackage
import WasmTypes
Expand Down Expand Up @@ -1356,14 +1355,16 @@ public class WASIBridgeToHost: WASI {
private let args: [String]
private let environment: [String: String]
private var fdTable: FdTable
private var randomGenerator: RandomBufferGenerator

public init(
args: [String] = [],
environment: [String: String] = [:],
preopens: [String: String] = [:],
stdin: FileDescriptor = .standardInput,
stdout: FileDescriptor = .standardOutput,
stderr: FileDescriptor = .standardError
stderr: FileDescriptor = .standardError,
randomGenerator: RandomBufferGenerator = SystemRandomNumberGenerator()
) throws {
self.args = args
self.environment = environment
Expand All @@ -1386,6 +1387,7 @@ public class WASIBridgeToHost: WASI {
}
}
self.fdTable = fdTable
self.randomGenerator = randomGenerator
}

public var wasiHostModules: [String: WASIHostModule] { _hostModules }
Expand Down Expand Up @@ -1860,7 +1862,7 @@ public class WASIBridgeToHost: WASI {
func random_get(buffer: UnsafeGuestPointer<UInt8>, length: WASIAbi.Size) {
guard length > 0 else { return }
buffer.withHostPointer(count: Int(length)) {
swift_stdlib_random($0.baseAddress!, Int(length))
self.randomGenerator.fill(buffer: $0)
}
}
}
35 changes: 35 additions & 0 deletions Tests/WASITests/RandomBufferGeneratorTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import XCTest

@testable import WASI

final class RandomBufferGeneratorTests: XCTestCase {
struct DeterministicGenerator: RandomNumberGenerator, RandomBufferGenerator {
var items: [UInt64]

mutating func next() -> UInt64 {
items.removeFirst()
}
}
func testDefaultFill() {
var generator = DeterministicGenerator(items: [
0x0123456789abcdef, 0xfedcba9876543210, 0xdeadbeefbaddcafe
])
for (bufferSize, expectedBytes): (Int, [UInt8]) in [
(10, [0xef, 0xcd, 0xab, 0x89, 0x67, 0x45, 0x23, 0x01, 0x10, 0x32]),
(2, [0xfe, 0xca]),
(0, [])
] {
var buffer: [UInt8] = Array(repeating: 0, count: bufferSize)
buffer.withUnsafeMutableBufferPointer {
generator.fill(buffer: $0)
}
let expected: [UInt8]
#if _endian(little)
expected = expectedBytes
#else
expected = Array(expectedBytes.reversed())
#endif
XCTAssertEqual(buffer, expected)
}
}
}

0 comments on commit 949ff74

Please sign in to comment.