diff --git a/crates/libs/core/src/inspectable.rs b/crates/libs/core/src/inspectable.rs index a632a216ea..f6d1e94389 100644 --- a/crates/libs/core/src/inspectable.rs +++ b/crates/libs/core/src/inspectable.rs @@ -17,6 +17,15 @@ impl IInspectable { Ok(std::mem::transmute(abi)) } } + + /// Gets the trust level of the current object. + pub fn GetTrustLevel(&self) -> Result { + unsafe { + let mut value = 0; + (self.vtable().GetTrustLevel)(std::mem::transmute_copy(self), &mut value).ok()?; + Ok(value) + } + } } #[doc(hidden)] @@ -60,14 +69,16 @@ impl IInspectable_Vtbl { *value = std::mem::transmute(h); HRESULT(0) } - unsafe extern "system" fn GetTrustLevel(_: *mut std::ffi::c_void, value: *mut i32) -> HRESULT { - // Note: even if we end up implementing this in future, it still doesn't need a this pointer - // since the data to be returned is type- not instance-specific so can be shared for all - // interfaces. - *value = 0; - HRESULT(0) + unsafe extern "system" fn GetTrustLevel(this: *mut std::ffi::c_void, value: *mut i32) -> HRESULT { + let this = (this as *mut *mut std::ffi::c_void).offset(OFFSET) as *mut T; + (*this).GetTrustLevel(value) + } + Self { + base: IUnknown_Vtbl::new::(), + GetIids, + GetRuntimeClassName: GetRuntimeClassName::, + GetTrustLevel: GetTrustLevel::, } - Self { base: IUnknown_Vtbl::new::(), GetIids, GetRuntimeClassName: GetRuntimeClassName::, GetTrustLevel } } } diff --git a/crates/libs/core/src/unknown.rs b/crates/libs/core/src/unknown.rs index 3d3e6d92e6..a6f62e00e0 100644 --- a/crates/libs/core/src/unknown.rs +++ b/crates/libs/core/src/unknown.rs @@ -72,8 +72,10 @@ pub trait IUnknownImpl { /// This function is safe to call as long as the interface pointer is non-null and valid for writes /// of an interface pointer. unsafe fn QueryInterface(&self, iid: *const GUID, interface: *mut *mut std::ffi::c_void) -> HRESULT; + /// Increments the reference count of the interface fn AddRef(&self) -> u32; + /// Decrements the reference count causing the interface's memory to be freed when the count is 0 /// /// # Safety @@ -81,6 +83,9 @@ pub trait IUnknownImpl { /// This function should only be called when the interfacer pointer is no longer used as calling `Release` /// on a non-aliased interface pointer and then using that interface pointer may result in use after free. unsafe fn Release(&self) -> u32; + + /// Gets the trust level of the current object. + unsafe fn GetTrustLevel(&self, value: *mut i32) -> HRESULT; } #[cfg(feature = "implement")] diff --git a/crates/libs/implement/src/lib.rs b/crates/libs/implement/src/lib.rs index 23e23134fe..297e202ad9 100644 --- a/crates/libs/implement/src/lib.rs +++ b/crates/libs/implement/src/lib.rs @@ -74,6 +74,8 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: } }); + let trust_level = proc_macro2::Literal::usize_unsuffixed(attributes.trust_level); + let conversions = attributes.implement.iter().enumerate().map(|(enumerate, implement)| { let interface_ident = implement.to_ident(); let offset = proc_macro2::Literal::usize_unsuffixed(enumerate); @@ -162,6 +164,13 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro: } remaining } + unsafe fn GetTrustLevel(&self, value: *mut i32) -> ::windows::core::HRESULT { + if value.is_null() { + return ::windows::core::HRESULT(-2147467261); // E_POINTER + } + *value = #trust_level; + ::windows::core::HRESULT(0) + } } impl #generics #original_ident::#generics where #constraints { /// Try casting as the provided interface @@ -225,6 +234,7 @@ impl ImplementType { #[derive(Default)] struct ImplementAttributes { pub implement: Vec, + pub trust_level: usize, } impl syn::parse::Parse for ImplementAttributes { @@ -269,6 +279,7 @@ impl ImplementAttributes { self.walk_implement(tree, namespace)?; } } + UseTree2::TrustLevel(input) => self.trust_level = *input, } Ok(()) @@ -279,6 +290,7 @@ enum UseTree2 { Path(UsePath2), Name(UseName2), Group(UseGroup2), + TrustLevel(usize), } impl UseTree2 { @@ -308,6 +320,7 @@ impl UseTree2 { Ok(ImplementType { type_name, generics }) } UseTree2::Group(input) => Err(syn::parse::Error::new(input.brace_token.span.join(), "Syntax not supported")), + _ => unimplemented!(), } } } @@ -336,6 +349,18 @@ impl syn::parse::Parse for UseTree2 { if input.peek(syn::Token![::]) { input.parse::()?; Ok(UseTree2::Path(UsePath2 { ident, tree: Box::new(input.parse()?) })) + } else if input.peek(syn::Token![=]) { + if ident != "TrustLevel" { + return Err(syn::parse::Error::new(ident.span(), "Unrecognized key-value pair")); + } + input.parse::()?; + let span = input.span(); + let value = input.call(syn::Ident::parse_any)?; + match value.to_string().as_str() { + "Partial" => Ok(UseTree2::TrustLevel(1)), + "Full" => Ok(UseTree2::TrustLevel(2)), + _ => Err(syn::parse::Error::new(span, "`TrustLevel` must be `Partial` or `Full`")), + } } else { let generics = if input.peek(syn::Token![<]) { input.parse::()?; diff --git a/crates/tests/implement/tests/trust_level.rs b/crates/tests/implement/tests/trust_level.rs new file mode 100644 index 0000000000..aae60ddd7b --- /dev/null +++ b/crates/tests/implement/tests/trust_level.rs @@ -0,0 +1,54 @@ +#![allow(non_snake_case)] + +use windows::core::*; +use windows::Foundation::*; + +#[implement(IStringable)] +struct BaseTrust; + +impl IStringable_Impl for BaseTrust { + fn ToString(&self) -> Result { + Ok("BaseTrust".into()) + } +} + +#[implement(IClosable, TrustLevel = Partial, IStringable)] +struct PartialTrust; + +impl IStringable_Impl for PartialTrust { + fn ToString(&self) -> Result { + Ok("PartialTrust".into()) + } +} + +impl IClosable_Impl for PartialTrust { + fn Close(&self) -> Result<()> { + Ok(()) + } +} + +#[implement(IStringable, TrustLevel = Full)] +struct FullTrust; + +impl IStringable_Impl for FullTrust { + fn ToString(&self) -> Result { + Ok("FullTrust".into()) + } +} + +#[test] +fn test() -> Result<()> { + let base: IStringable = BaseTrust.into(); + assert_eq!(base.ToString()?, "BaseTrust"); + assert_eq!(base.cast::()?.GetTrustLevel()?, 0); + + let partial: IStringable = PartialTrust.into(); + assert_eq!(partial.ToString()?, "PartialTrust"); + assert_eq!(partial.cast::()?.GetTrustLevel()?, 1); + + let full: IStringable = FullTrust.into(); + assert_eq!(full.ToString()?, "FullTrust"); + assert_eq!(full.cast::()?.GetTrustLevel()?, 2); + + Ok(()) +}