Skip to content

Commit

Permalink
Add HostAlignedByteCount to enforce alignment at compile time
Browse files Browse the repository at this point in the history
As part of the work to allow mmaps to be backed by other implementations, I
realized that we didn't have any way to track whether a particular usize is
host-page-aligned at compile time.

Add a `HostAlignedByteCount` which tracks that a particular usize is aligned to
the host page size. This also does not expose safe unchecked arithmetic
operations, to ensure that overflows always error out.

With `HostAlignedByteCount`, a lot of runtime checks can go away thanks to the
type-level assertion.

In the interest of keeping the diff relatively small, I haven't converted
everything over yet. More can be converted over as time permits.
  • Loading branch information
sunshowers committed Nov 19, 2024
1 parent 15b464b commit 7546b03
Show file tree
Hide file tree
Showing 11 changed files with 533 additions and 163 deletions.
276 changes: 276 additions & 0 deletions crates/wasmtime/src/runtime/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,216 @@ impl ModuleRuntimeInfo {
}
}

/// A number of bytes that's guaranteed to be aligned to the host page size.
///
/// This is used to manage page-aligned memory allocations.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct HostAlignedByteCount(
// Invariant: this is always a multiple of the host page size.
usize,
);

impl HostAlignedByteCount {
/// A zero byte count.
pub const ZERO: Self = Self(0);

/// Creates a new `HostAlignedByteCount` from an aligned byte count.
///
/// Returns an error if `bytes` is not page-aligned.
pub fn new(bytes: usize) -> Result<Self, ByteCountNotAligned> {
let host_page_size = host_page_size();
if bytes % host_page_size == 0 {
Ok(Self(bytes))
} else {
Err(ByteCountNotAligned(bytes))
}
}

/// Creates a new `HostAlignedByteCount` from an aligned byte count without
/// checking validity.
///
/// ## Safety
///
/// The caller must ensure that `bytes` is page-aligned.
pub unsafe fn new_unchecked(bytes: usize) -> Self {
Self(bytes)
}

/// Creates a new `HostAlignedByteCount`, rounding up to the nearest page.
///
/// Returns an error if `bytes + page_size - 1` overflows.
pub fn new_rounded_up(bytes: usize) -> Result<Self, ByteCountOutOfBounds> {
let page_size = host_page_size();
debug_assert!(page_size.is_power_of_two());
match bytes.checked_add(page_size - 1) {
Some(v) => Ok(Self(v & !(page_size - 1))),
None => Err(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::RoundUp(
bytes,
))),
}
}

/// Creates a new `HostAlignedByteCount` from a `u64`, rounding up to the nearest page.
///
/// Returns an error if the `u64` overflows `usize`, or if `bytes +
/// page_size - 1` overflows.
pub fn new_rounded_up_u64(bytes: u64) -> Result<Self, ByteCountOutOfBounds> {
let bytes = bytes
.try_into()
.map_err(|_| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::U64(bytes)))?;
Self::new_rounded_up(bytes)
}

/// Returns the host page size.
pub fn host_page_size() -> HostAlignedByteCount {
// The host page size is always a multiple of itself.
HostAlignedByteCount(host_page_size())
}

/// Returns true if the page count is zero.
#[inline]
pub fn is_zero(self) -> bool {
self == Self::ZERO
}

/// Returns the number of bytes as a `usize`.
#[inline]
pub fn byte_count(self) -> usize {
self.0
}

/// Add two aligned byte counts together.
///
/// Returns an error if the result overflows.
pub fn checked_add(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
// aligned + aligned = aligned
self.0
.checked_add(bytes.0)
.map(Self)
.ok_or(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Add(
self.0, bytes.0,
)))
}

/// Subtract two aligned byte counts.
///
/// Returns an error if the result underflows.
pub fn checked_sub(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
// aligned - aligned = aligned
self.0
.checked_sub(bytes.0)
.map(Self)
.ok_or(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Sub(
self.0, bytes.0,
)))
}

/// Multiply an aligned byte count by a scalar value.
///
/// Returns an error if the result overflows.
pub fn checked_mul(self, n: usize) -> Result<Self, ByteCountOutOfBounds> {
// aligned * scalar = aligned
self.0
.checked_mul(n)
.map(Self)
.ok_or(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Mul(
self.0, n,
)))
}

/// Unchecked multiplication by a scalar value.
///
/// ## Safety
///
/// The result must not overflow.
#[inline]
pub unsafe fn unchecked_mul(self, n: usize) -> Self {
Self(self.0 * n)
}
}

impl PartialEq<usize> for HostAlignedByteCount {
#[inline]
fn eq(&self, other: &usize) -> bool {
self.0 == *other
}
}

impl PartialEq<HostAlignedByteCount> for usize {
#[inline]
fn eq(&self, other: &HostAlignedByteCount) -> bool {
*self == other.0
}
}

impl fmt::Display for HostAlignedByteCount {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

impl fmt::LowerHex for HostAlignedByteCount {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::LowerHex::fmt(&self.0, f)
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ByteCountNotAligned(usize);

impl fmt::Display for ByteCountNotAligned {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "byte count not page-aligned: {}", self.0)
}
}

#[cfg(feature = "std")]
impl std::error::Error for ByteCountNotAligned {}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ByteCountOutOfBounds(ByteCountOutOfBoundsKind);

impl fmt::Display for ByteCountOutOfBounds {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}

#[cfg(feature = "std")]
impl std::error::Error for ByteCountOutOfBounds {}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ByteCountOutOfBoundsKind {
RoundUp(usize),
U64(u64),
Add(usize, usize),
Sub(usize, usize),
Mul(usize, usize),
}

impl fmt::Display for ByteCountOutOfBoundsKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ByteCountOutOfBoundsKind::RoundUp(bytes) => {
write!(f, "byte count overflow rounding up: {bytes}")
}
ByteCountOutOfBoundsKind::U64(bytes) => {
write!(f, "byte count overflow converting u64: {bytes}")
}
ByteCountOutOfBoundsKind::Add(a, b) => {
write!(f, "byte count overflow adding {a} and {b}")
}
ByteCountOutOfBoundsKind::Sub(a, b) => {
write!(f, "byte count underflow subtracting {b} from {a}")
}
ByteCountOutOfBoundsKind::Mul(a, b) => {
write!(f, "byte count overflow multiplying {a} by {b}")
}
}
}
}

/// Returns the host OS page size, in bytes.
pub fn host_page_size() -> usize {
static PAGE_SIZE: AtomicUsize = AtomicUsize::new(0);
Expand All @@ -352,13 +562,17 @@ pub fn host_page_size() -> usize {
}

/// Is `bytes` a multiple of the host page size?
///
/// (Deprecated: consider switching to `HostAlignedByteCount`.)
pub fn usize_is_multiple_of_host_page_size(bytes: usize) -> bool {
bytes % host_page_size() == 0
}

/// Round the given byte size up to a multiple of the host OS page size.
///
/// Returns an error if rounding up overflows.
///
/// (Deprecated: consider switching to `HostAlignedByteCount`.)
pub fn round_u64_up_to_host_pages(bytes: u64) -> Result<u64> {
let page_size = u64::try_from(crate::runtime::vm::host_page_size()).err2anyhow()?;
debug_assert!(page_size.is_power_of_two());
Expand All @@ -371,6 +585,8 @@ pub fn round_u64_up_to_host_pages(bytes: u64) -> Result<u64> {
}

/// Same as `round_u64_up_to_host_pages` but for `usize`s.
///
/// (Deprecated: consider switching to `HostAlignedByteCount`.)
pub fn round_usize_up_to_host_pages(bytes: usize) -> Result<usize> {
let bytes = u64::try_from(bytes).err2anyhow()?;
let rounded = round_u64_up_to_host_pages(bytes)?;
Expand Down Expand Up @@ -409,3 +625,63 @@ impl fmt::Display for WasmFault {
)
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn aligned_byte_count() {
let host_page_size = host_page_size();
HostAlignedByteCount::new(0).expect("0 is aligned");
HostAlignedByteCount::new(host_page_size).expect("host_page_size is aligned");
HostAlignedByteCount::new(host_page_size * 2).expect("host_page_size * 2 is aligned");
HostAlignedByteCount::new(host_page_size + 1)
.expect_err("host_page_size + 1 is not aligned");
HostAlignedByteCount::new(host_page_size / 2)
.expect_err("host_page_size / 2 is not aligned");

// Rounding up.
HostAlignedByteCount::new_rounded_up(usize::MAX).expect_err("usize::MAX overflows");
assert_eq!(
HostAlignedByteCount::new_rounded_up(usize::MAX - host_page_size)
.expect("(usize::MAX - 1 page) is in bounds"),
HostAlignedByteCount::new((usize::MAX - host_page_size) + 1)
.expect("usize::MAX is 2**N - 1"),
);

// Addition.
let half_max = HostAlignedByteCount::new((usize::MAX >> 1) + 1)
.expect("(usize::MAX >> 1) + 1 is aligned");
half_max
.checked_add(HostAlignedByteCount::host_page_size())
.expect("half max + page size is in bounds");
half_max
.checked_add(half_max)
.expect_err("half max + half max is out of bounds");

// Subtraction.
let half_max_minus_one = half_max
.checked_sub(HostAlignedByteCount::host_page_size())
.expect("(half_max - 1 page) is in bounds");
assert_eq!(
half_max.checked_sub(half_max),
Ok(HostAlignedByteCount::ZERO)
);
assert_eq!(
half_max.checked_sub(half_max_minus_one),
Ok(HostAlignedByteCount::host_page_size())
);
half_max_minus_one
.checked_sub(half_max)
.expect_err("(half_max - 1 page) - half_max is out of bounds");

// Multiplication.
half_max
.checked_mul(2)
.expect_err("half max * 2 is out of bounds");
half_max_minus_one
.checked_mul(2)
.expect("(half max - 1 page) * 2 is in bounds");
}
}
15 changes: 10 additions & 5 deletions crates/wasmtime/src/runtime/vm/cow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,9 +747,9 @@ impl Drop for MemoryImageSlot {
#[cfg(all(test, target_os = "linux", not(miri)))]
mod test {
use super::*;
use crate::runtime::vm::host_page_size;
use crate::runtime::vm::mmap::Mmap;
use crate::runtime::vm::sys::vm::decommit_pages;
use crate::{runtime::vm::host_page_size, vm::HostAlignedByteCount};
use std::sync::Arc;
use wasmtime_environ::{IndexType, Limits, Memory};

Expand Down Expand Up @@ -778,6 +778,11 @@ mod test {
}
}

fn mmap_4mib_inaccessible() -> Mmap {
let four_mib = HostAlignedByteCount::new(4 << 20).expect("4 MiB is page aligned");
Mmap::accessible_reserved(HostAlignedByteCount::ZERO, four_mib).unwrap()
}

#[test]
fn instantiate_no_image() {
let ty = dummy_memory();
Expand All @@ -786,7 +791,7 @@ mod test {
..Tunables::default_miri()
};
// 4 MiB mmap'd area, not accessible
let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap();
let mut mmap = mmap_4mib_inaccessible();
// Create a MemoryImageSlot on top of it
let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20);
memfd.no_clear_on_drop();
Expand Down Expand Up @@ -826,7 +831,7 @@ mod test {
..Tunables::default_miri()
};
// 4 MiB mmap'd area, not accessible
let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap();
let mut mmap = mmap_4mib_inaccessible();
// Create a MemoryImageSlot on top of it
let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20);
memfd.no_clear_on_drop();
Expand Down Expand Up @@ -898,7 +903,7 @@ mod test {
memory_reservation: 100 << 16,
..Tunables::default_miri()
};
let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap();
let mut mmap = mmap_4mib_inaccessible();
let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20);
memfd.no_clear_on_drop();

Expand Down Expand Up @@ -953,7 +958,7 @@ mod test {
..Tunables::default_miri()
};

let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap();
let mut mmap = mmap_4mib_inaccessible();
let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20);
memfd.no_clear_on_drop();
let image = Arc::new(create_memfd_with_data(page_size, &[1, 2, 3, 4]).unwrap());
Expand Down
Loading

0 comments on commit 7546b03

Please sign in to comment.