From 74a703b27c84d8be2be2dd13080951aa7f39cf73 Mon Sep 17 00:00:00 2001 From: Rain Date: Thu, 21 Nov 2024 16:23:19 -0800 Subject: [PATCH] Add HostAlignedByteCount to enforce alignment at compile time (#9620) * Add HostAlignedByteCount to enforce alignment at compile time As part of the work to allow mmaps to be backed by other implementations, I realized that we didn't have any way to track whether a particular usize is host-page-aligned at compile time. Add a `HostAlignedByteCount` which tracks that a particular usize is aligned to the host page size. This also does not expose safe unchecked arithmetic operations, to ensure that overflows always error out. With `HostAlignedByteCount`, a lot of runtime checks can go away thanks to the type-level assertion. In the interest of keeping the diff relatively small, I haven't converted everything over yet. More can be converted over as time permits. * Make zero-sized mprotects a no-op, add tests --- crates/wasmtime/src/runtime/vm.rs | 9 + crates/wasmtime/src/runtime/vm/byte_count.rs | 303 ++++++++++++++++++ crates/wasmtime/src/runtime/vm/cow.rs | 17 +- .../instance/allocator/pooling/memory_pool.rs | 27 +- .../instance/allocator/pooling/table_pool.rs | 27 +- .../allocator/pooling/unix_stack_pool.rs | 63 ++-- crates/wasmtime/src/runtime/vm/memory/mmap.rs | 93 +++--- crates/wasmtime/src/runtime/vm/mmap.rs | 110 +++++-- .../src/runtime/vm/sys/custom/mmap.rs | 28 +- .../wasmtime/src/runtime/vm/sys/miri/mmap.rs | 33 +- .../wasmtime/src/runtime/vm/sys/unix/mmap.rs | 24 +- .../src/runtime/vm/sys/windows/mmap.rs | 31 +- 12 files changed, 611 insertions(+), 154 deletions(-) create mode 100644 crates/wasmtime/src/runtime/vm/byte_count.rs diff --git a/crates/wasmtime/src/runtime/vm.rs b/crates/wasmtime/src/runtime/vm.rs index ef40d9f357fc..dca1fec1b16d 100644 --- a/crates/wasmtime/src/runtime/vm.rs +++ b/crates/wasmtime/src/runtime/vm.rs @@ -85,6 +85,8 @@ pub use send_sync_unsafe_cell::SendSyncUnsafeCell; mod module_id; pub use module_id::CompiledModuleId; +#[cfg(feature = "signals-based-traps")] +mod byte_count; #[cfg(feature = "signals-based-traps")] mod cow; #[cfg(not(feature = "signals-based-traps"))] @@ -94,6 +96,7 @@ mod mmap; cfg_if::cfg_if! { if #[cfg(feature = "signals-based-traps")] { + pub use crate::runtime::vm::byte_count::*; pub use crate::runtime::vm::mmap::Mmap; pub use self::cow::{MemoryImage, MemoryImageSlot, ModuleMemoryImages}; } else { @@ -365,6 +368,8 @@ pub fn host_page_size() -> usize { } /// Is `bytes` a multiple of the host page size? +/// +/// (Deprecated: consider switching to `HostAlignedByteCount`.) #[cfg(feature = "signals-based-traps")] pub fn usize_is_multiple_of_host_page_size(bytes: usize) -> bool { bytes % host_page_size() == 0 @@ -373,6 +378,8 @@ pub fn usize_is_multiple_of_host_page_size(bytes: usize) -> bool { /// Round the given byte size up to a multiple of the host OS page size. /// /// Returns an error if rounding up overflows. +/// +/// (Deprecated: consider switching to `HostAlignedByteCount`.) #[cfg(feature = "signals-based-traps")] pub fn round_u64_up_to_host_pages(bytes: u64) -> Result { let page_size = u64::try_from(crate::runtime::vm::host_page_size()).err2anyhow()?; @@ -386,6 +393,8 @@ pub fn round_u64_up_to_host_pages(bytes: u64) -> Result { } /// Same as `round_u64_up_to_host_pages` but for `usize`s. +/// +/// (Deprecated: consider switching to `HostAlignedByteCount`.) #[cfg(feature = "signals-based-traps")] pub fn round_usize_up_to_host_pages(bytes: usize) -> Result { let bytes = u64::try_from(bytes).err2anyhow()?; diff --git a/crates/wasmtime/src/runtime/vm/byte_count.rs b/crates/wasmtime/src/runtime/vm/byte_count.rs new file mode 100644 index 000000000000..33cf3716ab49 --- /dev/null +++ b/crates/wasmtime/src/runtime/vm/byte_count.rs @@ -0,0 +1,303 @@ +use core::fmt; + +use super::host_page_size; + +/// A number of bytes that's guaranteed to be aligned to the host page size. +/// +/// This is used to manage page-aligned memory allocations. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct HostAlignedByteCount( + // Invariant: this is always a multiple of the host page size. + usize, +); + +impl HostAlignedByteCount { + /// A zero byte count. + pub const ZERO: Self = Self(0); + + /// Creates a new `HostAlignedByteCount` from an aligned byte count. + /// + /// Returns an error if `bytes` is not page-aligned. + pub fn new(bytes: usize) -> Result { + let host_page_size = host_page_size(); + if bytes % host_page_size == 0 { + Ok(Self(bytes)) + } else { + Err(ByteCountNotAligned(bytes)) + } + } + + /// Creates a new `HostAlignedByteCount` from an aligned byte count without + /// checking validity. + /// + /// ## Safety + /// + /// The caller must ensure that `bytes` is page-aligned. + pub unsafe fn new_unchecked(bytes: usize) -> Self { + debug_assert!( + bytes % host_page_size() == 0, + "byte count {bytes} is not page-aligned (page size = {})", + host_page_size(), + ); + Self(bytes) + } + + /// Creates a new `HostAlignedByteCount`, rounding up to the nearest page. + /// + /// Returns an error if `bytes + page_size - 1` overflows. + pub fn new_rounded_up(bytes: usize) -> Result { + let page_size = host_page_size(); + debug_assert!(page_size.is_power_of_two()); + match bytes.checked_add(page_size - 1) { + Some(v) => Ok(Self(v & !(page_size - 1))), + None => Err(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::RoundUp)), + } + } + + /// Creates a new `HostAlignedByteCount` from a `u64`, rounding up to the nearest page. + /// + /// Returns an error if the `u64` overflows `usize`, or if `bytes + + /// page_size - 1` overflows. + pub fn new_rounded_up_u64(bytes: u64) -> Result { + let bytes = bytes + .try_into() + .map_err(|_| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::ConvertU64))?; + Self::new_rounded_up(bytes) + } + + /// Returns the host page size. + pub fn host_page_size() -> HostAlignedByteCount { + // The host page size is always a multiple of itself. + HostAlignedByteCount(host_page_size()) + } + + /// Returns true if the page count is zero. + #[inline] + pub fn is_zero(self) -> bool { + self == Self::ZERO + } + + /// Returns the number of bytes as a `usize`. + #[inline] + pub fn byte_count(self) -> usize { + self.0 + } + + /// Add two aligned byte counts together. + /// + /// Returns an error if the result overflows. + pub fn checked_add(self, bytes: HostAlignedByteCount) -> Result { + // aligned + aligned = aligned + self.0 + .checked_add(bytes.0) + .map(Self) + .ok_or(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Add)) + } + + /// Compute `self - bytes`. + /// + /// Returns an error if the result underflows. + pub fn checked_sub(self, bytes: HostAlignedByteCount) -> Result { + // aligned - aligned = aligned + self.0 + .checked_sub(bytes.0) + .map(Self) + .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Sub)) + } + + /// Multiply an aligned byte count by a scalar value. + /// + /// Returns an error if the result overflows. + pub fn checked_mul(self, scalar: usize) -> Result { + // aligned * scalar = aligned + self.0 + .checked_mul(scalar) + .map(Self) + .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Mul)) + } + + /// Unchecked multiplication by a scalar value. + /// + /// ## Safety + /// + /// The result must not overflow. + #[inline] + pub unsafe fn unchecked_mul(self, n: usize) -> Self { + Self(self.0 * n) + } +} + +impl PartialEq for HostAlignedByteCount { + #[inline] + fn eq(&self, other: &usize) -> bool { + self.0 == *other + } +} + +impl PartialEq for usize { + #[inline] + fn eq(&self, other: &HostAlignedByteCount) -> bool { + *self == other.0 + } +} + +struct LowerHexDisplay(T); + +impl fmt::Display for LowerHexDisplay { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Use the LowerHex impl as the Display impl, ensuring that there's + // always a 0x in the beginning (i.e. that the alternate formatter is + // used.) + if f.alternate() { + fmt::LowerHex::fmt(&self.0, f) + } else { + // Unfortunately, fill and alignment aren't respected this way, but + // it's quite hard to construct a new formatter with mostly the same + // options but the alternate flag set. + // https://github.com/rust-lang/rust/pull/118159 would make this + // easier. + write!(f, "{:#x}", self.0) + } + } +} + +impl fmt::Display for HostAlignedByteCount { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Use the LowerHex impl as the Display impl, ensuring that there's + // always a 0x in the beginning (i.e. that the alternate formatter is + // used.) + fmt::Display::fmt(&LowerHexDisplay(self.0), f) + } +} + +impl fmt::LowerHex for HostAlignedByteCount { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::LowerHex::fmt(&self.0, f) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct ByteCountNotAligned(usize); + +impl fmt::Display for ByteCountNotAligned { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "byte count not page-aligned: {}", + LowerHexDisplay(self.0) + ) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for ByteCountNotAligned {} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct ByteCountOutOfBounds(ByteCountOutOfBoundsKind); + +impl fmt::Display for ByteCountOutOfBounds { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for ByteCountOutOfBounds {} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ByteCountOutOfBoundsKind { + // We don't carry the arguments that errored out to avoid the error type + // becoming too big. + RoundUp, + ConvertU64, + Add, + Sub, + Mul, +} + +impl fmt::Display for ByteCountOutOfBoundsKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ByteCountOutOfBoundsKind::RoundUp => f.write_str("byte count overflow rounding up"), + ByteCountOutOfBoundsKind::ConvertU64 => { + f.write_str("byte count overflow converting u64") + } + ByteCountOutOfBoundsKind::Add => f.write_str("byte count overflow during addition"), + ByteCountOutOfBoundsKind::Sub => f.write_str("byte count underflow during subtraction"), + ByteCountOutOfBoundsKind::Mul => { + f.write_str("byte count overflow during multiplication") + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn byte_count_display() { + // Pages should hopefully be 64k or smaller. + let byte_count = HostAlignedByteCount::new(65536).unwrap(); + + assert_eq!(format!("{byte_count}"), "0x10000"); + assert_eq!(format!("{byte_count:x}"), "10000"); + assert_eq!(format!("{byte_count:#x}"), "0x10000"); + } + + #[test] + fn byte_count_ops() { + let host_page_size = host_page_size(); + HostAlignedByteCount::new(0).expect("0 is aligned"); + HostAlignedByteCount::new(host_page_size).expect("host_page_size is aligned"); + HostAlignedByteCount::new(host_page_size * 2).expect("host_page_size * 2 is aligned"); + HostAlignedByteCount::new(host_page_size + 1) + .expect_err("host_page_size + 1 is not aligned"); + HostAlignedByteCount::new(host_page_size / 2) + .expect_err("host_page_size / 2 is not aligned"); + + // Rounding up. + HostAlignedByteCount::new_rounded_up(usize::MAX).expect_err("usize::MAX overflows"); + assert_eq!( + HostAlignedByteCount::new_rounded_up(usize::MAX - host_page_size) + .expect("(usize::MAX - 1 page) is in bounds"), + HostAlignedByteCount::new((usize::MAX - host_page_size) + 1) + .expect("usize::MAX is 2**N - 1"), + ); + + // Addition. + let half_max = HostAlignedByteCount::new((usize::MAX >> 1) + 1) + .expect("(usize::MAX >> 1) + 1 is aligned"); + half_max + .checked_add(HostAlignedByteCount::host_page_size()) + .expect("half max + page size is in bounds"); + half_max + .checked_add(half_max) + .expect_err("half max + half max is out of bounds"); + + // Subtraction. + let half_max_minus_one = half_max + .checked_sub(HostAlignedByteCount::host_page_size()) + .expect("(half_max - 1 page) is in bounds"); + assert_eq!( + half_max.checked_sub(half_max), + Ok(HostAlignedByteCount::ZERO) + ); + assert_eq!( + half_max.checked_sub(half_max_minus_one), + Ok(HostAlignedByteCount::host_page_size()) + ); + half_max_minus_one + .checked_sub(half_max) + .expect_err("(half_max - 1 page) - half_max is out of bounds"); + + // Multiplication. + half_max + .checked_mul(2) + .expect_err("half max * 2 is out of bounds"); + half_max_minus_one + .checked_mul(2) + .expect("(half max - 1 page) * 2 is in bounds"); + } +} diff --git a/crates/wasmtime/src/runtime/vm/cow.rs b/crates/wasmtime/src/runtime/vm/cow.rs index 581716222991..0febb9d7e616 100644 --- a/crates/wasmtime/src/runtime/vm/cow.rs +++ b/crates/wasmtime/src/runtime/vm/cow.rs @@ -745,9 +745,9 @@ impl Drop for MemoryImageSlot { #[cfg(all(test, target_os = "linux", not(miri)))] mod test { use super::*; - use crate::runtime::vm::host_page_size; - use crate::runtime::vm::mmap::Mmap; + use crate::runtime::vm::mmap::{AlignedLength, Mmap}; use crate::runtime::vm::sys::vm::decommit_pages; + use crate::{runtime::vm::host_page_size, vm::HostAlignedByteCount}; use std::sync::Arc; use wasmtime_environ::{IndexType, Limits, Memory}; @@ -776,6 +776,11 @@ mod test { } } + fn mmap_4mib_inaccessible() -> Mmap { + let four_mib = HostAlignedByteCount::new(4 << 20).expect("4 MiB is page aligned"); + Mmap::accessible_reserved(HostAlignedByteCount::ZERO, four_mib).unwrap() + } + #[test] fn instantiate_no_image() { let ty = dummy_memory(); @@ -784,7 +789,7 @@ mod test { ..Tunables::default_miri() }; // 4 MiB mmap'd area, not accessible - let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap(); + let mut mmap = mmap_4mib_inaccessible(); // Create a MemoryImageSlot on top of it let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20); memfd.no_clear_on_drop(); @@ -824,7 +829,7 @@ mod test { ..Tunables::default_miri() }; // 4 MiB mmap'd area, not accessible - let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap(); + let mut mmap = mmap_4mib_inaccessible(); // Create a MemoryImageSlot on top of it let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20); memfd.no_clear_on_drop(); @@ -896,7 +901,7 @@ mod test { memory_reservation: 100 << 16, ..Tunables::default_miri() }; - let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap(); + let mut mmap = mmap_4mib_inaccessible(); let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20); memfd.no_clear_on_drop(); @@ -951,7 +956,7 @@ mod test { ..Tunables::default_miri() }; - let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap(); + let mut mmap = mmap_4mib_inaccessible(); let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20); memfd.no_clear_on_drop(); let image = Arc::new(create_memfd_with_data(page_size, &[1, 2, 3, 4]).unwrap()); diff --git a/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/memory_pool.rs b/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/memory_pool.rs index be937d4b8d3e..fc6419110c4b 100644 --- a/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/memory_pool.rs +++ b/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/memory_pool.rs @@ -54,12 +54,15 @@ use super::{ index_allocator::{MemoryInModule, ModuleAffinityIndexAllocator, SlotId}, MemoryAllocationIndex, }; -use crate::runtime::vm::mpk::{self, ProtectionKey, ProtectionMask}; use crate::runtime::vm::{ mmap::AlignedLength, CompiledModuleId, InstanceAllocationRequest, InstanceLimits, Memory, MemoryImageSlot, Mmap, MpkEnabled, PoolingInstanceAllocatorConfig, }; use crate::{prelude::*, vm::round_usize_up_to_host_pages}; +use crate::{ + runtime::vm::mpk::{self, ProtectionKey, ProtectionMask}, + vm::HostAlignedByteCount, +}; use std::ffi::c_void; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Mutex; @@ -187,8 +190,9 @@ impl MemoryPool { "creating memory pool: {constraints:?} -> {layout:?} (total: {})", layout.total_slab_bytes()? ); - let mut mapping = Mmap::accessible_reserved(0, layout.total_slab_bytes()?) - .context("failed to create memory pool mapping")?; + let mut mapping = + Mmap::accessible_reserved(HostAlignedByteCount::ZERO, layout.total_slab_bytes()?) + .context("failed to create memory pool mapping")?; // Then, stripe the memory with the available protection keys. This is // unnecessary if there is only one stripe color. @@ -618,12 +622,19 @@ impl SlabLayout { /// │pre_slab_guard_bytes│slot 1│slot 2│...│slot n│post_slab_guard_bytes│ /// └────────────────────┴──────┴──────┴───┴──────┴─────────────────────┘ /// ``` - fn total_slab_bytes(&self) -> Result { - self.slot_bytes + fn total_slab_bytes(&self) -> Result { + let byte_count = self + .slot_bytes .checked_mul(self.num_slots) .and_then(|c| c.checked_add(self.pre_slab_guard_bytes)) .and_then(|c| c.checked_add(self.post_slab_guard_bytes)) - .ok_or_else(|| anyhow!("total size of memory reservation exceeds addressable memory")) + .ok_or_else(|| { + anyhow!("total size of memory reservation exceeds addressable memory") + })?; + + // TODO: pre_slab_guard_bytes and post_slab_guard_bytes should be + // HostAlignedByteCount instances. + HostAlignedByteCount::new(byte_count).err2anyhow() } /// Returns the number of Wasm bytes from the beginning of one slot to the @@ -916,10 +927,6 @@ mod tests { is_aligned(s.post_slab_guard_bytes), "post-slab guard region is not page-aligned: {c:?} => {s:?}" ); - assert!( - is_aligned(s.total_slab_bytes().unwrap()), - "slab is not page-aligned: {c:?} => {s:?}" - ); // Check that we use no more or less stripes than needed. assert!(s.num_stripes >= 1, "not enough stripes: {c:?} => {s:?}"); diff --git a/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/table_pool.rs b/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/table_pool.rs index a67a652ae806..bf6e0130689d 100644 --- a/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/table_pool.rs +++ b/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/table_pool.rs @@ -2,11 +2,11 @@ use super::{ index_allocator::{SimpleIndexAllocator, SlotId}, round_up_to_pow2, TableAllocationIndex, }; -use crate::prelude::*; use crate::runtime::vm::{ mmap::AlignedLength, InstanceAllocationRequest, Mmap, PoolingInstanceAllocatorConfig, SendSyncPtr, Table, }; +use crate::{prelude::*, vm::HostAlignedByteCount}; use crate::{runtime::vm::sys::vm::commit_pages, vm::round_usize_up_to_host_pages}; use std::mem; use std::ptr::NonNull; @@ -20,7 +20,7 @@ use wasmtime_environ::{Module, Tunables}; pub struct TablePool { index_allocator: SimpleIndexAllocator, mapping: Mmap, - table_size: usize, + table_size: HostAlignedByteCount, max_total_tables: usize, tables_per_instance: usize, page_size: usize, @@ -33,19 +33,20 @@ impl TablePool { pub fn new(config: &PoolingInstanceAllocatorConfig) -> Result { let page_size = crate::runtime::vm::host_page_size(); - let table_size = round_up_to_pow2( + let table_size = HostAlignedByteCount::new_rounded_up( mem::size_of::<*mut u8>() .checked_mul(config.limits.table_elements) .ok_or_else(|| anyhow!("table size exceeds addressable memory"))?, - page_size, - ); + ) + .err2anyhow()?; let max_total_tables = usize::try_from(config.limits.total_tables).unwrap(); let tables_per_instance = usize::try_from(config.limits.max_tables_per_module).unwrap(); let allocation_size = table_size .checked_mul(max_total_tables) - .ok_or_else(|| anyhow!("total size of tables exceeds addressable memory"))?; + .err2anyhow() + .context("total size of tables exceeds addressable memory")?; let mapping = Mmap::accessible_reserved(allocation_size, allocation_size) .context("failed to create table pool mapping")?; @@ -108,7 +109,14 @@ impl TablePool { unsafe { self.mapping .as_ptr() - .add(table_index.index() * self.table_size) + .add( + self.table_size + .checked_mul(table_index.index()) + .expect( + "checked in constructor that table_size * table_index doesn't overflow", + ) + .byte_count(), + ) .cast_mut() } } @@ -234,7 +242,10 @@ mod tests { for i in 0..7 { let index = TableAllocationIndex(i); let ptr = pool.get(index); - assert_eq!(ptr as usize - base, i as usize * pool.table_size); + assert_eq!( + ptr as usize - base, + pool.table_size.checked_mul(i as usize).unwrap() + ); } Ok(()) diff --git a/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/unix_stack_pool.rs b/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/unix_stack_pool.rs index 687c890c4922..e294ce9016a0 100644 --- a/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/unix_stack_pool.rs +++ b/crates/wasmtime/src/runtime/vm/instance/allocator/pooling/unix_stack_pool.rs @@ -1,13 +1,11 @@ #![cfg_attr(asan, allow(dead_code))] -use super::{ - index_allocator::{SimpleIndexAllocator, SlotId}, - round_up_to_pow2, -}; +use super::index_allocator::{SimpleIndexAllocator, SlotId}; use crate::prelude::*; use crate::runtime::vm::sys::vm::commit_pages; use crate::runtime::vm::{ - mmap::AlignedLength, round_usize_up_to_host_pages, Mmap, PoolingInstanceAllocatorConfig, + mmap::AlignedLength, round_usize_up_to_host_pages, HostAlignedByteCount, Mmap, + PoolingInstanceAllocatorConfig, }; /// Represents a pool of execution stacks (used for the async fiber implementation). @@ -23,9 +21,9 @@ use crate::runtime::vm::{ #[derive(Debug)] pub struct StackPool { mapping: Mmap, - stack_size: usize, + stack_size: HostAlignedByteCount, max_stacks: usize, - page_size: usize, + page_size: HostAlignedByteCount, index_allocator: SimpleIndexAllocator, async_stack_zeroing: bool, async_stack_keep_resident: usize, @@ -35,34 +33,43 @@ impl StackPool { pub fn new(config: &PoolingInstanceAllocatorConfig) -> Result { use rustix::mm::{mprotect, MprotectFlags}; - let page_size = crate::runtime::vm::host_page_size(); + let page_size = HostAlignedByteCount::host_page_size(); // Add a page to the stack size for the guard page when using fiber stacks let stack_size = if config.stack_size == 0 { - 0 + HostAlignedByteCount::ZERO } else { - round_up_to_pow2(config.stack_size, page_size) - .checked_add(page_size) - .ok_or_else(|| anyhow!("stack size exceeds addressable memory"))? + HostAlignedByteCount::new_rounded_up(config.stack_size) + .and_then(|size| size.checked_add(HostAlignedByteCount::host_page_size())) + .err2anyhow() + .context("stack size exceeds addressable memory")? }; let max_stacks = usize::try_from(config.limits.total_stacks).unwrap(); let allocation_size = stack_size .checked_mul(max_stacks) - .ok_or_else(|| anyhow!("total size of execution stacks exceeds addressable memory"))?; + .err2anyhow() + .context("total size of execution stacks exceeds addressable memory")?; let mapping = Mmap::accessible_reserved(allocation_size, allocation_size) .context("failed to create stack pool mapping")?; // Set up the stack guard pages. - if allocation_size > 0 { + if !allocation_size.is_zero() { unsafe { for i in 0..max_stacks { + // Safety: i < max_stacks and we've already checked that + // stack_size * max_stacks is valid. + let offset = stack_size.unchecked_mul(i); // Make the stack guard page inaccessible. - let bottom_of_stack = mapping.as_ptr().add(i * stack_size).cast_mut(); - mprotect(bottom_of_stack.cast(), page_size, MprotectFlags::empty()) - .context("failed to protect stack guard page")?; + let bottom_of_stack = mapping.as_ptr().add(offset.byte_count()).cast_mut(); + mprotect( + bottom_of_stack.cast(), + page_size.byte_count(), + MprotectFlags::empty(), + ) + .context("failed to protect stack guard page")?; } } } @@ -102,19 +109,19 @@ impl StackPool { unsafe { // Remove the guard page from the size - let size_without_guard = self.stack_size - self.page_size; + let size_without_guard = self.stack_size.byte_count() - self.page_size.byte_count(); let bottom_of_stack = self .mapping .as_ptr() - .add(index * self.stack_size) + .add(self.stack_size.unchecked_mul(index).byte_count()) .cast_mut(); commit_pages(bottom_of_stack, size_without_guard)?; let stack = wasmtime_fiber::FiberStack::from_raw_parts( bottom_of_stack, - self.page_size, + self.page_size.byte_count(), size_without_guard, )?; Ok(stack) @@ -154,11 +161,11 @@ impl StackPool { ); // Remove the guard page from the size - let stack_size = self.stack_size - self.page_size; + let stack_size = self.stack_size.byte_count() - self.page_size.byte_count(); let bottom_of_stack = top - stack_size; - let start_of_stack = bottom_of_stack - self.page_size; + let start_of_stack = bottom_of_stack - self.page_size.byte_count(); assert!(start_of_stack >= base && start_of_stack < (base + len)); - assert!((start_of_stack - base) % self.stack_size == 0); + assert!((start_of_stack - base) % self.stack_size.byte_count() == 0); // Manually zero the top of the stack to keep the pages resident in // memory and avoid future page faults. Use the system to deallocate @@ -202,13 +209,13 @@ impl StackPool { ); // Remove the guard page from the size - let stack_size = self.stack_size - self.page_size; + let stack_size = self.stack_size.byte_count() - self.page_size.byte_count(); let bottom_of_stack = top - stack_size; - let start_of_stack = bottom_of_stack - self.page_size; + let start_of_stack = bottom_of_stack - self.page_size.byte_count(); assert!(start_of_stack >= base && start_of_stack < (base + len)); - assert!((start_of_stack - base) % self.stack_size == 0); + assert!((start_of_stack - base) % self.stack_size.byte_count() == 0); - let index = (start_of_stack - base) / self.stack_size; + let index = (start_of_stack - base) / self.stack_size.byte_count(); assert!(index < self.max_stacks); let index = u32::try_from(index).unwrap(); @@ -248,7 +255,7 @@ mod tests { for i in 0..10 { let stack = pool.allocate().expect("allocation should succeed"); assert_eq!( - ((stack.top().unwrap() as usize - base) / pool.stack_size) - 1, + ((stack.top().unwrap() as usize - base) / pool.stack_size.byte_count()) - 1, i ); stacks.push(stack); diff --git a/crates/wasmtime/src/runtime/vm/memory/mmap.rs b/crates/wasmtime/src/runtime/vm/memory/mmap.rs index 0d0b7f16f1ab..66840f8dcf5c 100644 --- a/crates/wasmtime/src/runtime/vm/memory/mmap.rs +++ b/crates/wasmtime/src/runtime/vm/memory/mmap.rs @@ -3,8 +3,7 @@ use crate::prelude::*; use crate::runtime::vm::memory::RuntimeLinearMemory; -use crate::runtime::vm::mmap::{AlignedLength, Mmap}; -use crate::runtime::vm::{round_usize_up_to_host_pages, usize_is_multiple_of_host_page_size}; +use crate::runtime::vm::{mmap::AlignedLength, HostAlignedByteCount, Mmap}; use wasmtime_environ::Tunables; /// A linear memory instance. @@ -35,12 +34,12 @@ pub struct MmapMemory { // The amount of extra bytes to reserve whenever memory grows. This is // specified so that the cost of repeated growth is amortized. - extra_to_reserve_on_growth: usize, + extra_to_reserve_on_growth: HostAlignedByteCount, // Size in bytes of extra guard pages before the start and after the end to // optimize loads and stores with constant offsets. - pre_guard_size: usize, - offset_guard_size: usize, + pre_guard_size: HostAlignedByteCount, + offset_guard_size: HostAlignedByteCount, } impl MmapMemory { @@ -57,12 +56,14 @@ impl MmapMemory { // found (mostly an issue for hypothetical 32-bit hosts). // // Also be sure to round up to the host page size for this value. - let offset_guard_bytes = usize::try_from(tunables.memory_guard_size).unwrap(); - let offset_guard_bytes = round_usize_up_to_host_pages(offset_guard_bytes)?; + let offset_guard_bytes = + HostAlignedByteCount::new_rounded_up_u64(tunables.memory_guard_size) + .err2anyhow() + .context("tunable.memory_guard_size overflows")?; let pre_guard_bytes = if tunables.guard_before_linear_memory { offset_guard_bytes } else { - 0 + HostAlignedByteCount::ZERO }; // Calculate how much is going to be allocated for this linear memory in @@ -90,21 +91,24 @@ impl MmapMemory { // Convert `alloc_bytes` and `extra_to_reserve_on_growth` to // page-aligned `usize` values. - let alloc_bytes = usize::try_from(alloc_bytes).unwrap(); - let extra_to_reserve_on_growth = usize::try_from(extra_to_reserve_on_growth).unwrap(); - let alloc_bytes = round_usize_up_to_host_pages(alloc_bytes)?; - let extra_to_reserve_on_growth = round_usize_up_to_host_pages(extra_to_reserve_on_growth)?; + let alloc_bytes = HostAlignedByteCount::new_rounded_up_u64(alloc_bytes) + .err2anyhow() + .context("tunables.memory_reservation overflows")?; + let extra_to_reserve_on_growth = + HostAlignedByteCount::new_rounded_up_u64(extra_to_reserve_on_growth) + .err2anyhow() + .context("tunables.memory_reservation_for_growth overflows")?; let request_bytes = pre_guard_bytes .checked_add(alloc_bytes) .and_then(|i| i.checked_add(offset_guard_bytes)) - .ok_or_else(|| format_err!("cannot allocate {} with guard regions", minimum))?; - assert!(usize_is_multiple_of_host_page_size(request_bytes)); + .err2anyhow() + .with_context(|| format!("cannot allocate {minimum} with guard regions"))?; - let mut mmap = Mmap::accessible_reserved(0, request_bytes)?; + let mut mmap = Mmap::accessible_reserved(HostAlignedByteCount::ZERO, request_bytes)?; if minimum > 0 { - let accessible = round_usize_up_to_host_pages(minimum)?; + let accessible = HostAlignedByteCount::new_rounded_up(minimum).err2anyhow()?; mmap.make_accessible(pre_guard_bytes, accessible)?; } @@ -121,12 +125,21 @@ impl MmapMemory { /// Get the length of the accessible portion of the underlying `mmap`. This /// is the same region as `self.len` but rounded up to a multiple of the /// host page size. - fn accessible(&self) -> usize { - let accessible = - round_usize_up_to_host_pages(self.len).expect("accessible region always fits in usize"); - debug_assert!(accessible <= self.mmap.len() - self.offset_guard_size - self.pre_guard_size); + fn accessible(&self) -> HostAlignedByteCount { + let accessible = HostAlignedByteCount::new_rounded_up(self.len) + .expect("accessible region always fits in usize"); + debug_assert!(accessible <= self.current_capacity()); accessible } + + /// Get the amount to which this memory can grow. + fn current_capacity(&self) -> HostAlignedByteCount { + let mmap_len = self.mmap.len_aligned(); + mmap_len + .checked_sub(self.offset_guard_size) + .and_then(|i| i.checked_sub(self.pre_guard_size)) + .expect("guard regions fit in mmap.len") + } } impl RuntimeLinearMemory for MmapMemory { @@ -135,16 +148,13 @@ impl RuntimeLinearMemory for MmapMemory { } fn byte_capacity(&self) -> usize { - self.mmap.len() - self.offset_guard_size - self.pre_guard_size + self.current_capacity().byte_count() } fn grow_to(&mut self, new_size: usize) -> Result<()> { - assert!(usize_is_multiple_of_host_page_size(self.offset_guard_size)); - assert!(usize_is_multiple_of_host_page_size(self.pre_guard_size)); - assert!(usize_is_multiple_of_host_page_size(self.mmap.len())); - - let new_accessible = round_usize_up_to_host_pages(new_size)?; - if new_accessible > self.mmap.len() - self.offset_guard_size - self.pre_guard_size { + let new_accessible = HostAlignedByteCount::new_rounded_up(new_size).err2anyhow()?; + let current_capacity = self.current_capacity(); + if new_accessible > current_capacity { // If the new size of this heap exceeds the current size of the // allocation we have, then this must be a dynamic heap. Use // `new_size` to calculate a new size of an allocation, allocate it, @@ -154,17 +164,19 @@ impl RuntimeLinearMemory for MmapMemory { .checked_add(new_accessible) .and_then(|s| s.checked_add(self.extra_to_reserve_on_growth)) .and_then(|s| s.checked_add(self.offset_guard_size)) - .ok_or_else(|| format_err!("overflow calculating size of memory allocation"))?; - assert!(usize_is_multiple_of_host_page_size(request_bytes)); + .err2anyhow() + .context("overflow calculating size of memory allocation")?; - let mut new_mmap = Mmap::accessible_reserved(0, request_bytes)?; + let mut new_mmap = + Mmap::accessible_reserved(HostAlignedByteCount::ZERO, request_bytes)?; new_mmap.make_accessible(self.pre_guard_size, new_accessible)?; // This method has an exclusive reference to `self.mmap` and just // created `new_mmap` so it should be safe to acquire references // into both of them and copy between them. unsafe { - let range = self.pre_guard_size..self.pre_guard_size + self.len; + let range = + self.pre_guard_size.byte_count()..(self.pre_guard_size.byte_count() + self.len); let src = self.mmap.slice(range.clone()); let dst = new_mmap.slice_mut(range); dst.copy_from_slice(src); @@ -178,23 +190,20 @@ impl RuntimeLinearMemory for MmapMemory { // or "dynamic" heaps which have some space reserved after the // initial allocation to grow into before the heap is moved in // memory. - assert!(new_size > self.len); + assert!(new_size <= current_capacity.byte_count()); assert!(self.maximum.map_or(true, |max| new_size <= max)); - assert!(new_size <= self.mmap.len() - self.offset_guard_size - self.pre_guard_size); - - let new_accessible = round_usize_up_to_host_pages(new_size)?; - assert!( - new_accessible <= self.mmap.len() - self.offset_guard_size - self.pre_guard_size, - ); // If the Wasm memory's page size is smaller than the host's page // size, then we might not need to actually change permissions, // since we are forced to round our accessible range up to the // host's page size. - if new_accessible > self.accessible() { + if let Ok(difference) = new_accessible.checked_sub(self.accessible()) { self.mmap.make_accessible( - self.pre_guard_size + self.accessible(), - new_accessible - self.accessible(), + self.pre_guard_size + .checked_add(self.accessible()) + .err2anyhow() + .context("overflow calculating new accessible region")?, + difference, )?; } } @@ -209,6 +218,6 @@ impl RuntimeLinearMemory for MmapMemory { } fn base_ptr(&self) -> *mut u8 { - unsafe { self.mmap.as_mut_ptr().add(self.pre_guard_size) } + unsafe { self.mmap.as_mut_ptr().add(self.pre_guard_size.byte_count()) } } } diff --git a/crates/wasmtime/src/runtime/vm/mmap.rs b/crates/wasmtime/src/runtime/vm/mmap.rs index 657d89e6088d..f9c0d692d43f 100644 --- a/crates/wasmtime/src/runtime/vm/mmap.rs +++ b/crates/wasmtime/src/runtime/vm/mmap.rs @@ -1,8 +1,9 @@ //! Low-level abstraction for allocating and managing zero-filled pages //! of memory. +use super::HostAlignedByteCount; +use crate::prelude::*; use crate::runtime::vm::sys::mmap; -use crate::{prelude::*, vm::usize_is_multiple_of_host_page_size}; use core::ops::Range; #[cfg(feature = "std")] use std::{fs::File, sync::Arc}; @@ -62,7 +63,7 @@ impl Mmap { /// Create a new `Mmap` pointing to at least `size` bytes of page-aligned /// accessible memory. pub fn with_at_least(size: usize) -> Result { - let rounded_size = crate::runtime::vm::round_usize_up_to_host_pages(size)?; + let rounded_size = HostAlignedByteCount::new_rounded_up(size).err2anyhow()?; Self::accessible_reserved(rounded_size, rounded_size) } @@ -73,13 +74,14 @@ impl Mmap { /// # Panics /// /// This function will panic if `accessible_size` is greater than - /// `mapping_size` or if either of them are not page-aligned. - pub fn accessible_reserved(accessible_size: usize, mapping_size: usize) -> Result { + /// `mapping_size`. + pub fn accessible_reserved( + accessible_size: HostAlignedByteCount, + mapping_size: HostAlignedByteCount, + ) -> Result { assert!(accessible_size <= mapping_size); - assert!(usize_is_multiple_of_host_page_size(mapping_size)); - assert!(usize_is_multiple_of_host_page_size(accessible_size)); - if mapping_size == 0 { + if mapping_size.is_zero() { Ok(Mmap { sys: mmap::Mmap::new_empty(), data: AlignedLength {}, @@ -96,10 +98,12 @@ impl Mmap { .context(format!("mmap failed to reserve {mapping_size:#x} bytes"))?, data: AlignedLength {}, }; - if accessible_size > 0 { - result.make_accessible(0, accessible_size).context(format!( - "mmap failed to allocate {accessible_size:#x} bytes" - ))?; + if !accessible_size.is_zero() { + result + .make_accessible(HostAlignedByteCount::ZERO, accessible_size) + .context(format!( + "mmap failed to allocate {accessible_size:#x} bytes" + ))?; } Ok(result) } @@ -119,20 +123,39 @@ impl Mmap { } } + /// Returns the length of the memory mapping as an aligned byte count. + pub fn len_aligned(&self) -> HostAlignedByteCount { + // SAFETY: The type parameter indicates that self.sys.len() is aligned. + unsafe { HostAlignedByteCount::new_unchecked(self.sys.len()) } + } + /// Make the memory starting at `start` and extending for `len` bytes /// accessible. `start` and `len` must be native page-size multiples and /// describe a range within `self`'s reserved memory. /// /// # Panics /// - /// This function will panic if `start` or `len` is not page aligned or if - /// either are outside the bounds of this mapping. - pub fn make_accessible(&mut self, start: usize, len: usize) -> Result<()> { - let page_size = crate::runtime::vm::host_page_size(); - assert_eq!(start & (page_size - 1), 0); - assert_eq!(len & (page_size - 1), 0); - assert!(len <= self.len()); - assert!(start <= self.len() - len); + /// Panics if `start + len >= self.len()`. + pub fn make_accessible( + &mut self, + start: HostAlignedByteCount, + len: HostAlignedByteCount, + ) -> Result<()> { + if len.is_zero() { + // A zero-sized mprotect (or equivalent) is allowed on some + // platforms but not others (notably Windows). Treat it as a no-op + // everywhere. + return Ok(()); + } + + let end = start + .checked_add(len) + .expect("start + len must not overflow"); + assert!( + end <= self.len_aligned(), + "start + len ({end}) must be <= mmap region {}", + self.len_aligned() + ); self.sys.make_accessible(start, len) } @@ -213,6 +236,9 @@ impl Mmap { /// /// This is the byte length of this entire mapping which includes both /// addressable and non-addressable memory. + /// + /// If the length is statically known to be page-aligned via the + /// [`AlignedLength`] type parameter, use [`Self::len_aligned`]. #[inline] pub fn len(&self) -> usize { self.sys.len() @@ -242,6 +268,14 @@ impl Mmap { range.start % crate::runtime::vm::host_page_size() == 0, "changing of protections isn't page-aligned", ); + + if range.start == range.end { + // A zero-sized mprotect (or equivalent) is allowed on some + // platforms but not others (notably Windows). Treat it as a no-op + // everywhere. + return Ok(()); + } + self.sys .make_executable(range, enable_branch_protection) .context("failed to make memory executable") @@ -256,6 +290,14 @@ impl Mmap { range.start % crate::runtime::vm::host_page_size() == 0, "changing of protections isn't page-aligned", ); + + if range.start == range.end { + // A zero-sized mprotect (or equivalent) is allowed on some + // platforms but not others (notably Windows). Treat it as a no-op + // everywhere. + return Ok(()); + } + self.sys .make_readonly(range) .context("failed to make memory readonly") @@ -273,3 +315,33 @@ impl From> for Mmap { mmap.into_unaligned() } } + +#[cfg(test)] +mod tests { + use super::*; + + /// Test zero-length calls to mprotect (or the OS equivalent). + /// + /// These should be treated as no-ops on all platforms. This test ensures + /// that such calls at least don't error out. + #[test] + fn mprotect_zero_length() { + let page_size = HostAlignedByteCount::host_page_size(); + let pagex2 = page_size.checked_mul(2).unwrap(); + let pagex3 = page_size.checked_mul(3).unwrap(); + let pagex4 = page_size.checked_mul(4).unwrap(); + + let mut mem = Mmap::accessible_reserved(pagex2, pagex4).expect("allocated memory"); + + mem.make_accessible(pagex3, HostAlignedByteCount::ZERO) + .expect("make_accessible succeeded"); + + unsafe { + mem.make_executable(pagex3.byte_count()..pagex3.byte_count(), false) + .expect("make_executable succeeded"); + + mem.make_readonly(pagex3.byte_count()..pagex3.byte_count()) + .expect("make_readonly succeeded"); + }; + } +} diff --git a/crates/wasmtime/src/runtime/vm/sys/custom/mmap.rs b/crates/wasmtime/src/runtime/vm/sys/custom/mmap.rs index 469531c587c5..1bcd9b2bfd0f 100644 --- a/crates/wasmtime/src/runtime/vm/sys/custom/mmap.rs +++ b/crates/wasmtime/src/runtime/vm/sys/custom/mmap.rs @@ -1,7 +1,7 @@ use super::cvt; use crate::prelude::*; use crate::runtime::vm::sys::capi; -use crate::runtime::vm::SendSyncPtr; +use crate::runtime::vm::{HostAlignedByteCount, SendSyncPtr}; use core::ops::Range; use core::ptr::{self, NonNull}; #[cfg(feature = "std")] @@ -24,20 +24,24 @@ impl Mmap { } } - pub fn new(size: usize) -> Result { + pub fn new(size: HostAlignedByteCount) -> Result { let mut ptr = ptr::null_mut(); cvt(unsafe { - capi::wasmtime_mmap_new(size, capi::PROT_READ | capi::PROT_WRITE, &mut ptr) + capi::wasmtime_mmap_new( + size.byte_count(), + capi::PROT_READ | capi::PROT_WRITE, + &mut ptr, + ) })?; - let memory = ptr::slice_from_raw_parts_mut(ptr.cast(), size); + let memory = ptr::slice_from_raw_parts_mut(ptr.cast(), size.byte_count()); let memory = SendSyncPtr::new(NonNull::new(memory).unwrap()); Ok(Mmap { memory }) } - pub fn reserve(size: usize) -> Result { + pub fn reserve(size: HostAlignedByteCount) -> Result { let mut ptr = ptr::null_mut(); - cvt(unsafe { capi::wasmtime_mmap_new(size, 0, &mut ptr) })?; - let memory = ptr::slice_from_raw_parts_mut(ptr.cast(), size); + cvt(unsafe { capi::wasmtime_mmap_new(size.byte_count(), 0, &mut ptr) })?; + let memory = ptr::slice_from_raw_parts_mut(ptr.cast(), size.byte_count()); let memory = SendSyncPtr::new(NonNull::new(memory).unwrap()); Ok(Mmap { memory }) } @@ -47,12 +51,16 @@ impl Mmap { anyhow::bail!("not supported on this platform"); } - pub fn make_accessible(&mut self, start: usize, len: usize) -> Result<()> { + pub fn make_accessible( + &mut self, + start: HostAlignedByteCount, + len: HostAlignedByteCount, + ) -> Result<()> { let ptr = self.memory.as_ptr(); unsafe { cvt(capi::wasmtime_mprotect( - ptr.byte_add(start).cast(), - len, + ptr.byte_add(start.byte_count()).cast(), + len.byte_count(), capi::PROT_READ | capi::PROT_WRITE, ))?; } diff --git a/crates/wasmtime/src/runtime/vm/sys/miri/mmap.rs b/crates/wasmtime/src/runtime/vm/sys/miri/mmap.rs index cd474060f1e1..05bd596e9658 100644 --- a/crates/wasmtime/src/runtime/vm/sys/miri/mmap.rs +++ b/crates/wasmtime/src/runtime/vm/sys/miri/mmap.rs @@ -6,7 +6,7 @@ //! but it's enough to get various tests running relying on memories and such. use crate::prelude::*; -use crate::runtime::vm::SendSyncPtr; +use crate::runtime::vm::{HostAlignedByteCount, SendSyncPtr}; use std::alloc::{self, Layout}; use std::fs::File; use std::ops::Range; @@ -29,23 +29,23 @@ impl Mmap { } } - pub fn new(size: usize) -> Result { + pub fn new(size: HostAlignedByteCount) -> Result { let mut ret = Mmap::reserve(size)?; - ret.make_accessible(0, size)?; + ret.make_accessible(HostAlignedByteCount::ZERO, size)?; Ok(ret) } - pub fn reserve(size: usize) -> Result { - if size > 1 << 32 { + pub fn reserve(size: HostAlignedByteCount) -> Result { + if size.byte_count() > 1 << 32 { bail!("failed to allocate memory"); } - let layout = Layout::from_size_align(size, crate::runtime::vm::host_page_size()).unwrap(); + let layout = make_layout(size.byte_count()); let ptr = unsafe { alloc::alloc(layout) }; if ptr.is_null() { bail!("failed to allocate memory"); } - let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size); + let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size.byte_count()); let memory = SendSyncPtr::new(NonNull::new(memory).unwrap()); Ok(Mmap { memory }) } @@ -54,11 +54,19 @@ impl Mmap { bail!("not supported on miri"); } - pub fn make_accessible(&mut self, start: usize, len: usize) -> Result<()> { + pub fn make_accessible( + &mut self, + start: HostAlignedByteCount, + len: HostAlignedByteCount, + ) -> Result<()> { // The memory is technically always accessible but this marks it as // initialized for miri-level checking. unsafe { - std::ptr::write_bytes(self.as_mut_ptr().add(start), 0u8, len); + std::ptr::write_bytes( + self.as_mut_ptr().add(start.byte_count()), + 0u8, + len.byte_count(), + ); } Ok(()) } @@ -94,9 +102,12 @@ impl Drop for Mmap { return; } unsafe { - let layout = - Layout::from_size_align(self.len(), crate::runtime::vm::host_page_size()).unwrap(); + let layout = make_layout(self.len()); alloc::dealloc(self.as_mut_ptr(), layout); } } } + +fn make_layout(size: usize) -> Layout { + Layout::from_size_align(size, crate::runtime::vm::host_page_size()).unwrap() +} diff --git a/crates/wasmtime/src/runtime/vm/sys/unix/mmap.rs b/crates/wasmtime/src/runtime/vm/sys/unix/mmap.rs index e12819996077..c2a00e0f1abe 100644 --- a/crates/wasmtime/src/runtime/vm/sys/unix/mmap.rs +++ b/crates/wasmtime/src/runtime/vm/sys/unix/mmap.rs @@ -1,5 +1,5 @@ use crate::prelude::*; -use crate::runtime::vm::SendSyncPtr; +use crate::runtime::vm::{HostAlignedByteCount, SendSyncPtr}; use rustix::mm::{mprotect, MprotectFlags}; use std::ops::Range; use std::ptr::{self, NonNull}; @@ -41,26 +41,26 @@ impl Mmap { } } - pub fn new(size: usize) -> Result { + pub fn new(size: HostAlignedByteCount) -> Result { let ptr = unsafe { rustix::mm::mmap_anonymous( ptr::null_mut(), - size, + size.byte_count(), rustix::mm::ProtFlags::READ | rustix::mm::ProtFlags::WRITE, rustix::mm::MapFlags::PRIVATE | MMAP_NORESERVE_FLAG, ) .err2anyhow()? }; - let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size); + let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size.byte_count()); let memory = SendSyncPtr::new(NonNull::new(memory).unwrap()); Ok(Mmap { memory }) } - pub fn reserve(size: usize) -> Result { + pub fn reserve(size: HostAlignedByteCount) -> Result { let ptr = unsafe { rustix::mm::mmap_anonymous( ptr::null_mut(), - size, + size.byte_count(), rustix::mm::ProtFlags::empty(), // Astute readers might be wondering why a function called "reserve" passes in a // NORESERVE flag. That's because "reserve" in this context means one of two @@ -78,7 +78,7 @@ impl Mmap { .err2anyhow()? }; - let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size); + let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size.byte_count()); let memory = SendSyncPtr::new(NonNull::new(memory).unwrap()); Ok(Mmap { memory }) } @@ -109,12 +109,16 @@ impl Mmap { Ok(Mmap { memory }) } - pub fn make_accessible(&mut self, start: usize, len: usize) -> Result<()> { + pub fn make_accessible( + &mut self, + start: HostAlignedByteCount, + len: HostAlignedByteCount, + ) -> Result<()> { let ptr = self.memory.as_ptr(); unsafe { mprotect( - ptr.byte_add(start).cast(), - len, + ptr.byte_add(start.byte_count()).cast(), + len.byte_count(), MprotectFlags::READ | MprotectFlags::WRITE, ) .err2anyhow()?; diff --git a/crates/wasmtime/src/runtime/vm/sys/windows/mmap.rs b/crates/wasmtime/src/runtime/vm/sys/windows/mmap.rs index 1d9c2d899744..b45c3488c445 100644 --- a/crates/wasmtime/src/runtime/vm/sys/windows/mmap.rs +++ b/crates/wasmtime/src/runtime/vm/sys/windows/mmap.rs @@ -1,5 +1,5 @@ use crate::prelude::*; -use crate::runtime::vm::SendSyncPtr; +use crate::runtime::vm::{HostAlignedByteCount, SendSyncPtr}; use std::fs::{File, OpenOptions}; use std::io; use std::ops::Range; @@ -40,11 +40,11 @@ impl Mmap { } } - pub fn new(size: usize) -> Result { + pub fn new(size: HostAlignedByteCount) -> Result { let ptr = unsafe { VirtualAlloc( ptr::null_mut(), - size, + size.byte_count(), MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE, ) @@ -53,7 +53,7 @@ impl Mmap { bail!(io::Error::last_os_error()) } - let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size); + let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size.byte_count()); let memory = SendSyncPtr::new(NonNull::new(memory).unwrap()); Ok(Self { memory, @@ -61,12 +61,19 @@ impl Mmap { }) } - pub fn reserve(size: usize) -> Result { - let ptr = unsafe { VirtualAlloc(ptr::null_mut(), size, MEM_RESERVE, PAGE_NOACCESS) }; + pub fn reserve(size: HostAlignedByteCount) -> Result { + let ptr = unsafe { + VirtualAlloc( + ptr::null_mut(), + size.byte_count(), + MEM_RESERVE, + PAGE_NOACCESS, + ) + }; if ptr.is_null() { bail!(io::Error::last_os_error()) } - let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size); + let memory = std::ptr::slice_from_raw_parts_mut(ptr.cast(), size.byte_count()); let memory = SendSyncPtr::new(NonNull::new(memory).unwrap()); Ok(Self { memory, @@ -139,11 +146,15 @@ impl Mmap { } } - pub fn make_accessible(&mut self, start: usize, len: usize) -> Result<()> { + pub fn make_accessible( + &mut self, + start: HostAlignedByteCount, + len: HostAlignedByteCount, + ) -> Result<()> { if unsafe { VirtualAlloc( - self.as_ptr().add(start) as _, - len, + self.as_ptr().add(start.byte_count()) as _, + len.byte_count(), MEM_COMMIT, PAGE_READWRITE, )