diff --git a/crates/libs/registry/Cargo.toml b/crates/libs/registry/Cargo.toml index 77a672b6be..fd5b280a75 100644 --- a/crates/libs/registry/Cargo.toml +++ b/crates/libs/registry/Cargo.toml @@ -24,3 +24,7 @@ path = "../targets" [dependencies.windows-result] version = "0.1.1" path = "../result" + +[dependencies.windows-strings] +version = "0.1.0" +path = "../strings" diff --git a/crates/libs/registry/src/key.rs b/crates/libs/registry/src/key.rs index 978890f239..59eb84f211 100644 --- a/crates/libs/registry/src/key.rs +++ b/crates/libs/registry/src/key.rs @@ -98,6 +98,15 @@ impl Key { unsafe { self.set_value(name, REG_SZ, value.as_ptr() as _, value.len() * 2) } } + /// Sets the name and value in the registry key. + pub fn set_hstring>( + &self, + name: T, + value: &windows_strings::HSTRING, + ) -> Result<()> { + unsafe { self.set_value(name, REG_SZ, value.as_ptr() as _, value.len() * 2) } + } + /// Sets the name and value in the registry key. pub fn set_multi_string>(&self, name: T, value: &[T]) -> Result<()> { let mut packed = value.iter().fold(vec![0u16; 0], |mut packed, value| { @@ -278,6 +287,40 @@ impl Key { } } + /// Gets the value for the name in the registry key. + pub fn get_hstring>(&self, name: T) -> Result { + let name = pcwstr(name); + let mut ty = 0; + let mut len = 0; + + let result = unsafe { + RegQueryValueExW(self.0, name.as_ptr(), null(), &mut ty, null_mut(), &mut len) + }; + + win32_error(result)?; + + if !matches!(ty, REG_SZ | REG_EXPAND_SZ) { + return Err(invalid_data()); + } + + let mut value = HStringBuilder::new(len as usize / 2)?; + + let result = unsafe { + RegQueryValueExW( + self.0, + name.as_ptr(), + null(), + null_mut(), + value.as_mut_ptr() as _, + &mut len, + ) + }; + + win32_error(result)?; + value.trim_end(); + Ok(value.into()) + } + /// Gets the value for the name in the registry key. pub fn get_bytes>(&self, name: T) -> Result> { let name = pcwstr(name); diff --git a/crates/libs/registry/src/lib.rs b/crates/libs/registry/src/lib.rs index 5c647e224a..2e900f0154 100644 --- a/crates/libs/registry/src/lib.rs +++ b/crates/libs/registry/src/lib.rs @@ -30,6 +30,9 @@ pub use r#type::Type; pub use windows_result::Result; use windows_result::*; +pub use windows_strings::HSTRING; +use windows_strings::*; + /// The predefined `HKEY_CLASSES_ROOT` registry key. pub const CLASSES_ROOT: &Key = &Key(HKEY_CLASSES_ROOT); diff --git a/crates/libs/result/src/bindings.rs b/crates/libs/result/src/bindings.rs index 6836443f37..f83832ed9d 100644 --- a/crates/libs/result/src/bindings.rs +++ b/crates/libs/result/src/bindings.rs @@ -19,7 +19,6 @@ pub type BOOL = i32; pub type BSTR = *const u16; pub const ERROR_INVALID_DATA: WIN32_ERROR = 13u32; pub const ERROR_NO_UNICODE_TRANSLATION: WIN32_ERROR = 1113u32; -pub const E_INVALIDARG: HRESULT = 0x80070057_u32 as _; pub const E_UNEXPECTED: HRESULT = 0x8000FFFF_u32 as _; pub const FORMAT_MESSAGE_ALLOCATE_BUFFER: FORMAT_MESSAGE_OPTIONS = 256u32; pub const FORMAT_MESSAGE_FROM_HMODULE: FORMAT_MESSAGE_OPTIONS = 2048u32; diff --git a/crates/libs/strings/.natvis b/crates/libs/strings/.natvis index 1b6a5c7565..ae1dc31c1c 100644 --- a/crates/libs/strings/.natvis +++ b/crates/libs/strings/.natvis @@ -1,8 +1,8 @@ - - + + "" {header()->data,[header()->len]su} diff --git a/crates/libs/strings/src/heap.rs b/crates/libs/strings/src/heap.rs deleted file mode 100644 index ffcb305bfd..0000000000 --- a/crates/libs/strings/src/heap.rs +++ /dev/null @@ -1,40 +0,0 @@ -use super::*; -use core::ffi::c_void; - -/// Allocate memory of size `bytes` using `HeapAlloc`. -pub fn heap_alloc(bytes: usize) -> crate::Result<*mut c_void> { - #[cfg(windows)] - let ptr: *mut c_void = unsafe { bindings::HeapAlloc(bindings::GetProcessHeap(), 0, bytes) }; - - #[cfg(not(windows))] - let ptr: *mut c_void = unsafe { - extern "C" { - fn malloc(bytes: usize) -> *mut c_void; - } - - malloc(bytes) - }; - - if ptr.is_null() { - Err(Error::from_hresult(HRESULT(bindings::E_OUTOFMEMORY))) - } else { - Ok(ptr) - } -} - -/// Free memory allocated by `heap_alloc`. -pub unsafe fn heap_free(ptr: *mut c_void) { - #[cfg(windows)] - { - bindings::HeapFree(bindings::GetProcessHeap(), 0, ptr); - } - - #[cfg(not(windows))] - { - extern "C" { - fn free(ptr: *mut c_void); - } - - free(ptr); - } -} diff --git a/crates/libs/strings/src/hstring.rs b/crates/libs/strings/src/hstring.rs index 8de3c265f2..b6766b89f2 100644 --- a/crates/libs/strings/src/hstring.rs +++ b/crates/libs/strings/src/hstring.rs @@ -1,27 +1,27 @@ use super::*; -/// A WinRT string ([HSTRING](https://docs.microsoft.com/en-us/windows/win32/winrt/hstring)) -/// is reference-counted and immutable. +/// An ([HSTRING](https://docs.microsoft.com/en-us/windows/win32/winrt/hstring)) +/// is a reference-counted and immutable UTF-16 string type. #[repr(transparent)] -pub struct HSTRING(Option>); +pub struct HSTRING(pub(crate) *mut HStringHeader); impl HSTRING { /// Create an empty `HSTRING`. /// /// This function does not allocate memory. pub const fn new() -> Self { - Self(None) + Self(core::ptr::null_mut()) } /// Returns `true` if the string is empty. - pub const fn is_empty(&self) -> bool { + pub fn is_empty(&self) -> bool { // An empty HSTRING is represented by a null pointer. - self.0.is_none() + self.0.is_null() } /// Returns the length of the string. The length is measured in `u16`s (UTF-16 code units), not including the terminating null character. pub fn len(&self) -> usize { - if let Some(header) = self.get_header() { + if let Some(header) = self.as_header() { header.len as usize } else { 0 @@ -35,7 +35,7 @@ impl HSTRING { /// Returns a raw pointer to the `HSTRING` buffer. pub fn as_ptr(&self) -> *const u16 { - if let Some(header) = self.get_header() { + if let Some(header) = self.as_header() { header.data } else { const EMPTY: [u16; 1] = [0]; @@ -66,7 +66,7 @@ impl HSTRING { return Ok(Self::new()); } - let ptr = Header::alloc(len.try_into()?)?; + let ptr = HStringHeader::alloc(len.try_into()?)?; // Place each utf-16 character into the buffer and // increase len as we go along. @@ -79,11 +79,11 @@ impl HSTRING { // Write a 0 byte to the end of the buffer. (*ptr).data.offset((*ptr).len as isize).write(0); - Ok(Self(core::ptr::NonNull::new(ptr))) + Ok(Self(ptr)) } - fn get_header(&self) -> Option<&Header> { - self.0.map(|header| unsafe { header.as_ref() }) + fn as_header(&self) -> Option<&HStringHeader> { + unsafe { self.0.as_ref() } } } @@ -95,8 +95,8 @@ impl Default for HSTRING { impl Clone for HSTRING { fn clone(&self) -> Self { - if let Some(header) = self.get_header() { - Self(core::ptr::NonNull::new(header.duplicate().unwrap())) + if let Some(header) = self.as_header() { + Self(header.duplicate().unwrap()) } else { Self::new() } @@ -105,17 +105,12 @@ impl Clone for HSTRING { impl Drop for HSTRING { fn drop(&mut self) { - if self.is_empty() { - return; - } - - if let Some(header) = self.0.take() { - // REFERENCE_FLAG indicates a string backed by static or stack memory that is + if let Some(header) = self.as_header() { + // HSTRING_REFERENCE_FLAG indicates a string backed by static or stack memory that is // thus not reference-counted and does not need to be freed. unsafe { - let header = header.as_ref(); - if header.flags & REFERENCE_FLAG == 0 && header.count.release() == 0 { - heap_free(header as *const _ as *mut _); + if header.flags & HSTRING_REFERENCE_FLAG == 0 && header.count.release() == 0 { + HStringHeader::free(self.0); } } } @@ -407,54 +402,3 @@ impl From for std::ffi::OsString { Self::from(&hstring) } } - -const REFERENCE_FLAG: u32 = 1; - -#[repr(C)] -struct Header { - flags: u32, - len: u32, - _0: u32, - _1: u32, - data: *mut u16, - count: RefCount, - buffer_start: u16, -} - -impl Header { - fn alloc(len: u32) -> Result<*mut Header> { - debug_assert!(len != 0); - // Allocate enough space for header and two bytes per character. - // The space for the terminating null character is already accounted for inside of `Header`. - let alloc_size = core::mem::size_of::
() + 2 * len as usize; - - let header = heap_alloc(alloc_size)? as *mut Header; - - unsafe { - // Use `ptr::write` (since `header` is unintialized). `Header` is safe to be all zeros. - header.write(core::mem::MaybeUninit::
::zeroed().assume_init()); - (*header).len = len; - (*header).count = RefCount::new(1); - (*header).data = &mut (*header).buffer_start; - } - - Ok(header) - } - - fn duplicate(&self) -> Result<*mut Header> { - if self.flags & REFERENCE_FLAG == 0 { - // If this is not a "fast pass" string then simply increment the reference count. - self.count.add_ref(); - Ok(self as *const Header as *mut Header) - } else { - // Otherwise, allocate a new string and copy the value into the new string. - let copy = Header::alloc(self.len)?; - // SAFETY: since we are duplicating the string it is safe to copy all data from self to the initialized `copy`. - // We copy `len + 1` characters since `len` does not account for the terminating null character. - unsafe { - core::ptr::copy_nonoverlapping(self.data, (*copy).data, self.len as usize + 1); - } - Ok(copy) - } - } -} diff --git a/crates/libs/strings/src/hstring_builder.rs b/crates/libs/strings/src/hstring_builder.rs new file mode 100644 index 0000000000..446451f0df --- /dev/null +++ b/crates/libs/strings/src/hstring_builder.rs @@ -0,0 +1,83 @@ +use super::*; + +/// An [HSTRING] builder that supports preallocating the `HSTRING` to avoid extra allocations and copies. +/// +/// This is similar to the `WindowsPreallocateStringBuffer` function but implemented directly in Rust for efficiency. +/// It is implemented as a separate type since [HSTRING] values are immutable. +pub struct HStringBuilder(*mut HStringHeader); + +impl HStringBuilder { + /// Creates a preallocated `HSTRING` value. + pub fn new(len: usize) -> Result { + Ok(Self(HStringHeader::alloc(len.try_into()?)?)) + } + + /// Shortens the string by removing any trailing 0 characters. + pub fn trim_end(&mut self) { + if let Some(header) = self.as_header_mut() { + while header.len > 0 + && unsafe { header.data.offset(header.len as isize - 1).read() == 0 } + { + header.len -= 1; + } + + if header.len == 0 { + unsafe { + HStringHeader::free(self.0); + } + self.0 = core::ptr::null_mut(); + } + } + } + + fn as_header(&self) -> Option<&HStringHeader> { + unsafe { self.0.as_ref() } + } + + fn as_header_mut(&mut self) -> Option<&mut HStringHeader> { + unsafe { self.0.as_mut() } + } +} + +impl From for HSTRING { + fn from(value: HStringBuilder) -> Self { + if let Some(header) = value.as_header() { + unsafe { header.data.offset(header.len as isize).write(0) }; + let result = Self(value.0); + core::mem::forget(value); + result + } else { + Self::new() + } + } +} + +impl core::ops::Deref for HStringBuilder { + type Target = [u16]; + + fn deref(&self) -> &[u16] { + if let Some(header) = self.as_header() { + unsafe { core::slice::from_raw_parts(header.data, header.len as usize) } + } else { + &[] + } + } +} + +impl core::ops::DerefMut for HStringBuilder { + fn deref_mut(&mut self) -> &mut [u16] { + if let Some(header) = self.as_header() { + unsafe { core::slice::from_raw_parts_mut(header.data, header.len as usize) } + } else { + &mut [] + } + } +} + +impl Drop for HStringBuilder { + fn drop(&mut self) { + unsafe { + HStringHeader::free(self.0); + } + } +} diff --git a/crates/libs/strings/src/hstring_header.rs b/crates/libs/strings/src/hstring_header.rs new file mode 100644 index 0000000000..84c3851519 --- /dev/null +++ b/crates/libs/strings/src/hstring_header.rs @@ -0,0 +1,92 @@ +use super::*; + +pub const HSTRING_REFERENCE_FLAG: u32 = 1; + +#[repr(C)] +pub struct HStringHeader { + pub flags: u32, + pub len: u32, + pub _0: u32, + pub _1: u32, + pub data: *mut u16, + pub count: RefCount, + pub buffer_start: u16, +} + +impl HStringHeader { + pub fn alloc(len: u32) -> Result<*mut HStringHeader> { + if len == 0 { + return Ok(core::ptr::null_mut()); + } + + // Allocate enough space for header and two bytes per character. + // The space for the terminating null character is already accounted for inside of `HStringHeader`. + let bytes = core::mem::size_of::() + 2 * len as usize; + + #[cfg(windows)] + let header = unsafe { bindings::HeapAlloc(bindings::GetProcessHeap(), 0, bytes) } + as *mut HStringHeader; + + #[cfg(not(windows))] + let header = unsafe { + extern "C" { + fn malloc(bytes: usize) -> *mut core::ffi::c_void; + } + + malloc(bytes) as *mut HStringHeader + }; + + if header.is_null() { + return Err(Error::from_hresult(HRESULT(bindings::E_OUTOFMEMORY))); + } + + unsafe { + // Use `ptr::write` (since `header` is unintialized). `HStringHeader` is safe to be all zeros. + header.write(core::mem::MaybeUninit::::zeroed().assume_init()); + (*header).len = len; + (*header).count = RefCount::new(1); + (*header).data = &mut (*header).buffer_start; + } + + Ok(header) + } + + pub unsafe fn free(header: *mut HStringHeader) { + if header.is_null() { + return; + } + + let header = header as *mut _; + + #[cfg(windows)] + { + bindings::HeapFree(bindings::GetProcessHeap(), 0, header); + } + + #[cfg(not(windows))] + { + extern "C" { + fn free(ptr: *mut core::ffi::c_void); + } + + free(header); + } + } + + pub fn duplicate(&self) -> Result<*mut HStringHeader> { + if self.flags & HSTRING_REFERENCE_FLAG == 0 { + // If this is not a "fast pass" string then simply increment the reference count. + self.count.add_ref(); + Ok(self as *const HStringHeader as *mut HStringHeader) + } else { + // Otherwise, allocate a new string and copy the value into the new string. + let copy = HStringHeader::alloc(self.len)?; + // SAFETY: since we are duplicating the string it is safe to copy all data from self to the initialized `copy`. + // We copy `len + 1` characters since `len` does not account for the terminating null character. + unsafe { + core::ptr::copy_nonoverlapping(self.data, (*copy).data, self.len as usize + 1); + } + Ok(copy) + } + } +} diff --git a/crates/libs/strings/src/lib.rs b/crates/libs/strings/src/lib.rs index 9039e9b54f..88e5c89c83 100644 --- a/crates/libs/strings/src/lib.rs +++ b/crates/libs/strings/src/lib.rs @@ -21,6 +21,12 @@ pub use bstr::*; mod hstring; pub use hstring::*; +mod hstring_builder; +pub use hstring_builder::*; + +mod hstring_header; +use hstring_header::*; + mod bindings; mod decode; @@ -29,9 +35,6 @@ use decode::*; mod ref_count; use ref_count::*; -mod heap; -use heap::*; - mod literals; pub use literals::*; diff --git a/crates/tests/registry/Cargo.toml b/crates/tests/registry/Cargo.toml index b76595089c..f4ddc3d3d2 100644 --- a/crates/tests/registry/Cargo.toml +++ b/crates/tests/registry/Cargo.toml @@ -21,3 +21,6 @@ features = ["Win32_System_Registry"] [dependencies.windows] path = "../../libs/windows" features = ["Win32_System_Registry"] + +[dependencies.windows-strings] +path = "../../libs/strings" diff --git a/crates/tests/registry/tests/bad_string.rs b/crates/tests/registry/tests/bad_string.rs index 412d8c33a4..aeaeebafcf 100644 --- a/crates/tests/registry/tests/bad_string.rs +++ b/crates/tests/registry/tests/bad_string.rs @@ -32,5 +32,9 @@ fn bad_string() -> Result<()> { let value_as_bytes = key.get_bytes("name")?; assert_eq!(value_as_bytes, bad_string_bytes); + + let value_as_hstring = key.get_hstring("name")?; + assert_eq!(value_as_hstring.to_string_lossy(), "�ā"); + Ok(()) } diff --git a/crates/tests/registry/tests/hstring.rs b/crates/tests/registry/tests/hstring.rs new file mode 100644 index 0000000000..13816f5652 --- /dev/null +++ b/crates/tests/registry/tests/hstring.rs @@ -0,0 +1,21 @@ +use windows_registry::*; +use windows_strings::h; + +#[test] +fn hstring() -> Result<()> { + let test_key = "software\\windows-rs\\tests\\hstring"; + _ = CURRENT_USER.remove_tree(test_key); + let key = CURRENT_USER.create(test_key)?; + + key.set_hstring("hstring", h!("simple"))?; + assert_eq!(&key.get_hstring("hstring")?, h!("simple")); + + // You can embed nulls. + key.set_hstring("hstring", h!("hstring\0value\0"))?; + + // And get_hstring will only trim any trailing nulls. + let value: HSTRING = key.get_hstring("hstring")?; + assert_eq!(&value, h!("hstring\0value")); + + Ok(()) +} diff --git a/crates/tests/strings/tests/hstring.rs b/crates/tests/strings/tests/hstring.rs index 8c98566c57..13597f4b90 100644 --- a/crates/tests/strings/tests/hstring.rs +++ b/crates/tests/strings/tests/hstring.rs @@ -7,3 +7,45 @@ fn hstring() -> Result<()> { Ok(()) } + +#[test] +fn hstring_builder() -> Result<()> { + // Dropping a builder is fine. + _ = HStringBuilder::new(10)?; + + // A zero length builder is also fine. + let b = HStringBuilder::new(0)?; + let h: HSTRING = b.into(); + assert!(h.is_empty()); + + // Trimming a zero length builder is also fine. + let mut b = HStringBuilder::new(0)?; + b.trim_end(); + let h: HSTRING = b.into(); + assert!(h.is_empty()); + + // This depends on DerefMut. + const HELLO: [u16; 5] = [0x48, 0x65, 0x6C, 0x6C, 0x6F]; + let mut b = HStringBuilder::new(5)?; + b.copy_from_slice(&HELLO); + let h: HSTRING = b.into(); + assert_eq!(&h, "Hello"); + + // HSTRING can handle embedded nulls. + const HELLO00: [u16; 7] = [0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x00, 0x00]; + let mut b = HStringBuilder::new(7)?; + b.copy_from_slice(&HELLO00); + let h: HSTRING = b.into(); + assert_eq!(h.len(), 7); + assert_eq!(h.as_wide(), HELLO00); + + // But trim_end can avoid that. + let mut b = HStringBuilder::new(7)?; + b.copy_from_slice(&HELLO00); + b.trim_end(); + let h: HSTRING = b.into(); + assert_eq!(h.len(), 5); + assert_eq!(h.as_wide(), HELLO); + + Ok(()) +} diff --git a/crates/tools/bindings/src/result.txt b/crates/tools/bindings/src/result.txt index 32236b86e3..7974c109c0 100644 --- a/crates/tools/bindings/src/result.txt +++ b/crates/tools/bindings/src/result.txt @@ -2,7 +2,6 @@ --config flatten sys minimal vtbl no-bindgen-comment --filter - Windows.Win32.Foundation.E_INVALIDARG Windows.Win32.Foundation.E_UNEXPECTED Windows.Win32.Foundation.ERROR_INVALID_DATA Windows.Win32.Foundation.ERROR_NO_UNICODE_TRANSLATION