From cf12baa48790589c09359d2123de87efe80bfd15 Mon Sep 17 00:00:00 2001 From: Arlie Davis Date: Sat, 25 May 2024 20:21:41 -0700 Subject: [PATCH] Fix bug in COM interface chain support The definition of a COM interface may inherit from another interface. These are known as "interface chains". The `#[implement]` macro allows designers to specify only the minimal set of interface chains that are needed for a given COM object implementation. The `#[implement]` macro (and the `#[interface]` macro) work together to pull in the implementations of all interfaces along the chain. Unfortunately there is a bug in the implementation of `QueryInterface` for interface chains. The current `QueryInterface` implementation will only check the IIDs of the interfaces at the root of the chian, i.e. the "most-derived" interface. `QueryInterface` will not search the IIDs of interfaces that are in the inheritance chain. This bug is demonstrated (detected) by the new unit tests in `crates/tests/implement_core/src/com_chain.rs`. This PR fixes the bug by generating an `fn matches()` method that checks the current IID and then checks the parent interface (if any) by calling its `match()` method. This fixes the unit test. --- crates/libs/interface/src/lib.rs | 17 ++++++- crates/tests/implement_core/src/com_chain.rs | 50 ++++++++++++++++++++ crates/tests/implement_core/src/lib.rs | 1 + 3 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 crates/tests/implement_core/src/com_chain.rs diff --git a/crates/libs/interface/src/lib.rs b/crates/libs/interface/src/lib.rs index 090581fd37..908e38a137 100644 --- a/crates/libs/interface/src/lib.rs +++ b/crates/libs/interface/src/lib.rs @@ -211,6 +211,17 @@ impl Interface { let parent_vtable_generics = if self.parent_is_iunknown() { quote!(Identity, OFFSET) } else { quote!(Identity, Impl, OFFSET) }; let parent_vtable = self.parent_vtable(); + // or_parent_matches will be `|| parent::matches(iid)` if this interface inherits from another + // interface (except for IUnknown) or will be empty if this is not applicable. This is what allows + // QueryInterface to work correctly for all interfaces in an inheritance chain, e.g. + // IFoo3 derives from IFoo2 derives from IFoo. + // + // We avoid matching IUnknown because object identity depends on the uniqueness of the IUnknown pointer. + let or_parent_matches = match parent_vtable.as_ref() { + Some(parent) if !self.parent_is_iunknown() => quote! (|| <#parent>::matches(iid)), + _ => quote!(), + }; + let functions = self .methods .iter() @@ -287,8 +298,10 @@ impl Interface { Self { base__: #parent_vtable::new::<#parent_vtable_generics>(), #(#entries),* } } - pub fn matches(iid: &windows_core::GUID) -> bool { - iid == &<#name as ::windows_core::Interface>::IID + #[inline(always)] + pub fn matches(iid: &::windows_core::GUID) -> bool { + *iid == <#name as ::windows_core::Interface>::IID + #or_parent_matches } } } diff --git a/crates/tests/implement_core/src/com_chain.rs b/crates/tests/implement_core/src/com_chain.rs new file mode 100644 index 0000000000..5124b390e4 --- /dev/null +++ b/crates/tests/implement_core/src/com_chain.rs @@ -0,0 +1,50 @@ +use windows_core::*; + +#[interface("cccccccc-0000-0000-0000-000000000001")] +unsafe trait IFoo: IUnknown {} + +#[interface("cccccccc-0000-0000-0000-000000000002")] +unsafe trait IFoo2: IFoo {} + +#[interface("cccccccc-0000-0000-0000-000000000003")] +unsafe trait IFoo3: IFoo2 {} + +// ObjectA implements a single interface chain, which consists of 3 different +// interfaces: IFoo3, IFoo2, and IFoo. You do not need to explicitly list all +// of the interfaces in the interface chain. Listing all of the interfaces is +// less efficient because it generates redundant interface chains (pointer +// fields in the the generated ObjectA_Impl type), which will never be used. +#[implement(IFoo3)] +struct ObjectWithChains {} + +impl IFoo_Impl for ObjectWithChains {} +impl IFoo2_Impl for ObjectWithChains {} +impl IFoo3_Impl for ObjectWithChains {} + +#[test] +fn interface_chain_query() { + let object = ComObject::new(ObjectWithChains {}); + let unknown: IUnknown = object.to_interface(); + let _foo: IFoo = unknown.cast().expect("QueryInterface for IFoo"); + let _foo2: IFoo2 = unknown.cast().expect("QueryInterface for IFoo2"); + let _foo3: IFoo3 = unknown.cast().expect("QueryInterface for IFoo3"); +} + +// ObjectRedundantChains implements the same interfaces as ObjectWithChains, +// but it defines more than one interface chain. This is unnecessary because it +// is redundant, but we are verifying that this works. +#[implement(IFoo3, IFoo2, IFoo)] +struct ObjectRedundantChains {} + +impl IFoo_Impl for ObjectRedundantChains {} +impl IFoo2_Impl for ObjectRedundantChains {} +impl IFoo3_Impl for ObjectRedundantChains {} + +#[test] +fn redundant_interface_chains() { + let object = ComObject::new(ObjectRedundantChains {}); + let unknown: IUnknown = object.to_interface(); + let _foo: IFoo = unknown.cast().expect("QueryInterface for IFoo"); + let _foo2: IFoo2 = unknown.cast().expect("QueryInterface for IFoo2"); + let _foo3: IFoo3 = unknown.cast().expect("QueryInterface for IFoo3"); +} diff --git a/crates/tests/implement_core/src/lib.rs b/crates/tests/implement_core/src/lib.rs index e083fca649..aa8f3bec53 100644 --- a/crates/tests/implement_core/src/lib.rs +++ b/crates/tests/implement_core/src/lib.rs @@ -3,4 +3,5 @@ #![cfg(test)] +mod com_chain; mod com_object;