From 5d2b7db154b996353bfb84062e31940edeea9c2a Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Wed, 11 Mar 2020 09:39:14 -0700 Subject: [PATCH 01/12] Use registered ComWrappers in Marshal APIs for object <-> IUnknown --- .../Runtime/InteropServices/ComWrappers.cs | 67 +++++++++-- .../InteropServices/Marshal.CoreCLR.cs | 74 +++++++++++- src/coreclr/src/vm/ecalllist.h | 4 +- src/coreclr/src/vm/interoplibinterface.cpp | 110 +++++++++--------- src/coreclr/src/vm/marshalnative.cpp | 43 ++----- src/coreclr/src/vm/marshalnative.h | 4 +- .../src/Interop/COM/ComWrappers/Program.cs | 55 ++++++++- .../Marshal/GetComInterfaceForObjectTests.cs | 4 +- .../GetUniqueObjectForIUnknownTests.cs | 2 +- 9 files changed, 257 insertions(+), 106 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index f5d0e06b016a3..283d6dca228b8 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -118,11 +118,29 @@ private struct ComInterfaceInstance /// Flags used to configure the generated interface. /// The generated COM interface that can be passed outside the .NET runtime. public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags) + { + IntPtr ptr = GetOrCreateComInterfaceForObjectInternal(this, instance, flags); + if (ptr == IntPtr.Zero) + throw new ArgumentException(); + + return ptr; + } + + /// + /// Create a COM representation of the supplied object that can be passed to a non-managed environment. + /// + /// The implemenentation to use when creating the COM representation. + /// The managed object to expose outside the .NET runtime. + /// Flags used to configure the generated interface. + /// The generated COM interface that can be passed outside the .NET runtime or IntPtr.Zero if it could not be created. + /// + /// If is null, the global instance (if registered) will be used. + /// + internal static IntPtr GetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags) { if (instance == null) throw new ArgumentNullException(nameof(instance)); - ComWrappers impl = this; return GetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags); } @@ -146,7 +164,16 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa // Call to execute the abstract instance function internal static unsafe void* CallComputeVtables(ComWrappers? comWrappersImpl, object obj, CreateComInterfaceFlags flags, out int count) - => (comWrappersImpl ?? s_globalInstance!).ComputeVtables(obj, flags, out count); + { + ComWrappers? impl = comWrappersImpl ?? s_globalInstance; + if (impl is null) + { + count = -1; + return null; + } + + return impl.ComputeVtables(obj, flags, out count); + } /// /// Get the currently registered managed object or creates a new managed object and registers it. @@ -156,7 +183,11 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa /// Returns a managed object associated with the supplied external COM object. public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags) { - return GetOrCreateObjectForComInstanceInternal(externalComObject, flags, null); + object? obj = GetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, null); + if (obj == null) + throw new ArgumentNullException(); + + return obj; } /// @@ -172,7 +203,13 @@ public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateOb // Call to execute the abstract instance function internal static object? CallCreateObject(ComWrappers? comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags) - => (comWrappersImpl ?? s_globalInstance!).CreateObject(externalComObject, flags); + { + ComWrappers? impl = comWrappersImpl ?? s_globalInstance; + if (impl == null) + return null; + + return impl.CreateObject(externalComObject, flags); + } /// /// Get the currently registered managed object or uses the supplied managed object and registers it. @@ -189,20 +226,34 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create if (wrapper == null) throw new ArgumentNullException(nameof(externalComObject)); - return GetOrCreateObjectForComInstanceInternal(externalComObject, flags, wrapper); + object? obj = GetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, wrapper); + if (obj == null) + throw new ArgumentNullException(); + + return obj; } - private object GetOrCreateObjectForComInstanceInternal(IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe) + /// + /// Get the currently registered managed object or creates a new managed object and registers it. + /// + /// The implemenentation to use when creating the managed object. + /// Object to import for usage into the .NET runtime. + /// Flags used to describe the external object. + /// The to be used as the wrapper for the external object. + /// Returns a managed object associated with the supplied external COM object or null if it could not be created. + /// + /// If is null, the global instance (if registered) will be used. + /// + internal static object? GetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe) { if (externalComObject == IntPtr.Zero) throw new ArgumentNullException(nameof(externalComObject)); - ComWrappers impl = this; object? wrapperMaybeLocal = wrapperMaybe; object? retValue = null; GetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue)); - return retValue!; + return retValue; } [DllImport(RuntimeHelpers.QCall)] diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs index 07b7ce866c3d6..64511d7185eb8 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs @@ -326,6 +326,16 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// public static IntPtr /* IUnknown* */ GetIUnknownForObject(object o) { + if (o is null) + { + throw new ArgumentNullException(nameof(o)); + } + + // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) + IntPtr ptrMaybe = ComWrappers.GetOrCreateComInterfaceForObjectInternal(impl: null, o, CreateComInterfaceFlags.None); + if (ptrMaybe != IntPtr.Zero) + return ptrMaybe; + return GetIUnknownForObjectNative(o, false); } @@ -344,6 +354,11 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// public static IntPtr /* IDispatch */ GetIDispatchForObject(object o) { + if (o is null) + { + throw new ArgumentNullException(nameof(o)); + } + return GetIDispatchForObjectNative(o, false); } @@ -356,6 +371,16 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// public static IntPtr /* IUnknown* */ GetComInterfaceForObject(object o, Type T) { + if (o is null) + { + throw new ArgumentNullException(nameof(o)); + } + + if (T is null) + { + throw new ArgumentNullException(nameof(T)); + } + return GetComInterfaceForObjectNative(o, T, false, true); } @@ -368,15 +393,58 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// public static IntPtr /* IUnknown* */ GetComInterfaceForObject(object o, Type T, CustomQueryInterfaceMode mode) { + if (o is null) + { + throw new ArgumentNullException(nameof(o)); + } + + if (T is null) + { + throw new ArgumentNullException(nameof(T)); + } + bool bEnableCustomizedQueryInterface = ((mode == CustomQueryInterfaceMode.Allow) ? true : false); return GetComInterfaceForObjectNative(o, T, false, bEnableCustomizedQueryInterface); } [MethodImpl(MethodImplOptions.InternalCall)] - private static extern IntPtr /* IUnknown* */ GetComInterfaceForObjectNative(object o, Type t, bool onlyInContext, bool fEnalbeCustomizedQueryInterface); + private static extern IntPtr /* IUnknown* */ GetComInterfaceForObjectNative(object o, Type t, bool onlyInContext, bool fEnableCustomizedQueryInterface); + + /// + /// Return the managed object representing the IUnknown* + /// + public static object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk) + { + if (pUnk == IntPtr.Zero) + { + throw new ArgumentNullException(nameof(pUnk)); + } + + // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) + object? objMaybe = ComWrappers.GetOrCreateObjectForComInstanceInternal(impl: null, pUnk, CreateObjectFlags.None, wrapperMaybe: null); + if (objMaybe != null) + return objMaybe; + + return GetObjectForIUnknownNative(pUnk); + } [MethodImpl(MethodImplOptions.InternalCall)] - public static extern object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk); + private static extern object GetObjectForIUnknownNative(IntPtr /* IUnknown* */ pUnk); + + public static object GetUniqueObjectForIUnknown(IntPtr unknown) + { + if (unknown == IntPtr.Zero) + { + throw new ArgumentNullException(nameof(unknown)); + } + + // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) + object? objMaybe = ComWrappers.GetOrCreateObjectForComInstanceInternal(impl: null, unknown, CreateObjectFlags.UniqueInstance, wrapperMaybe: null); + if (objMaybe != null) + return objMaybe; + + return GetUniqueObjectForIUnknownNative(unknown); + } /// /// Return a unique Object given an IUnknown. This ensures that you receive a fresh @@ -385,7 +453,7 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// ReleaseComObject on a RCW and not worry about other active uses ofsaid RCW. /// [MethodImpl(MethodImplOptions.InternalCall)] - public static extern object GetUniqueObjectForIUnknown(IntPtr unknown); + private static extern object GetUniqueObjectForIUnknownNative(IntPtr unknown); /// /// Return an Object for IUnknown, using the Type T. diff --git a/src/coreclr/src/vm/ecalllist.h b/src/coreclr/src/vm/ecalllist.h index 3cbfa0e5ec40e..4025d2162fd11 100644 --- a/src/coreclr/src/vm/ecalllist.h +++ b/src/coreclr/src/vm/ecalllist.h @@ -805,8 +805,8 @@ FCFuncStart(gInteropMarshalFuncs) FCFuncElement("GetHRForException", MarshalNative::GetHRForException) FCFuncElement("GetRawIUnknownForComObjectNoAddRef", MarshalNative::GetRawIUnknownForComObjectNoAddRef) FCFuncElement("IsComObject", MarshalNative::IsComObject) - FCFuncElement("GetObjectForIUnknown", MarshalNative::GetObjectForIUnknown) - FCFuncElement("GetUniqueObjectForIUnknown", MarshalNative::GetUniqueObjectForIUnknown) + FCFuncElement("GetObjectForIUnknownNative", MarshalNative::GetObjectForIUnknownNative) + FCFuncElement("GetUniqueObjectForIUnknownNative", MarshalNative::GetUniqueObjectForIUnknownNative) FCFuncElement("AddRef", MarshalNative::AddRef) FCFuncElement("GetNativeVariantForObject", MarshalNative::GetNativeVariantForObject) FCFuncElement("GetObjectForNativeVariant", MarshalNative::GetObjectForNativeVariant) diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index 70e5f14f2c7df..34e603d2c765c 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -488,7 +488,6 @@ namespace THROWS; MODE_COOPERATIVE; PRECONDITION(instance != NULL); - POSTCONDITION(CheckPointer(RETVAL)); } CONTRACT_END; @@ -525,7 +524,8 @@ namespace void* vtables = CallComputeVTables(&gc.implRef, &gc.instRef, flags, &vtableCount); // Re-query the associated InteropSyncBlockInfo for an existing managed object wrapper. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw)) + if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw) + && ((vtables != nullptr && vtableCount > 0) || (vtableCount == 0))) { OBJECTHANDLE instHandle = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType); @@ -567,10 +567,8 @@ namespace wrapperRaw = newWrapper.Extract(); STRESS_LOG1(LF_INTEROP, LL_INFO100, "Created MOW: 0x%p\n", wrapperRaw); } - else + else if (wrapperRaw != NULL) { - _ASSERTE(wrapperRaw != NULL); - // It is possible the supplied wrapper is no longer valid. If so, reactivate the // wrapper using the protected OBJECTREF. IUnknown* wrapper = static_cast(wrapperRaw); @@ -602,7 +600,6 @@ namespace THROWS; MODE_COOPERATIVE; PRECONDITION(identity != NULL); - POSTCONDITION(RETVAL != NULL); } CONTRACT_END; @@ -655,63 +652,64 @@ namespace if (gc.objRef == NULL) { gc.objRef = CallGetObject(&gc.implRef, identity, flags); - if (gc.objRef == NULL) - COMPlusThrow(kArgumentNullException); } - // Construct the new context with the object details. - DWORD flags = (resultHolder.Result.FromTrackerRuntime - ? ExternalObjectContext::Flags_ReferenceTracker - : ExternalObjectContext::Flags_None) | - (uniqueInstance - ? ExternalObjectContext::Flags_None - : ExternalObjectContext::Flags_InCache); - ExternalObjectContext::Construct( - resultHolder.GetContext(), - identity, - GetCurrentCtxCookie(), - gc.objRef->GetSyncBlockIndex(), - flags); - - if (uniqueInstance) + if (gc.objRef != NULL) { - extObjCxt = resultHolder.GetContext(); - } - else - { - // Attempt to insert the new context into the cache. - ExtObjCxtCache::LockHolder lock(cache); - extObjCxt = cache->FindOrAdd(identity, resultHolder.GetContext()); - } - - // If the returned context matches the new context it means the - // new context was inserted or a unique instance was requested. - if (extObjCxt == resultHolder.GetContext()) - { - // Update the object's SyncBlock with a handle to the context for runtime cleanup. - SyncBlock* syncBlock = gc.objRef->GetSyncBlock(); - InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfo(); - _ASSERTE(syncBlock->IsPrecious()); - - // Since the caller has the option of providing a wrapper, it is - // possible the supplied wrapper already has an associated external - // object and an object can only be associated with one external object. - if (!interopInfo->TrySetExternalComObjectContext((void**)extObjCxt)) + // Construct the new context with the object details. + DWORD flags = (resultHolder.Result.FromTrackerRuntime + ? ExternalObjectContext::Flags_ReferenceTracker + : ExternalObjectContext::Flags_None) | + (uniqueInstance + ? ExternalObjectContext::Flags_None + : ExternalObjectContext::Flags_InCache); + ExternalObjectContext::Construct( + resultHolder.GetContext(), + identity, + GetCurrentCtxCookie(), + gc.objRef->GetSyncBlockIndex(), + flags); + + if (uniqueInstance) { - // Failed to set the context; one must already exist. - // Remove from the cache above as well. + extObjCxt = resultHolder.GetContext(); + } + else + { + // Attempt to insert the new context into the cache. ExtObjCxtCache::LockHolder lock(cache); - cache->Remove(resultHolder.GetContext()); + extObjCxt = cache->FindOrAdd(identity, resultHolder.GetContext()); + } + + // If the returned context matches the new context it means the + // new context was inserted or a unique instance was requested. + if (extObjCxt == resultHolder.GetContext()) + { + // Update the object's SyncBlock with a handle to the context for runtime cleanup. + SyncBlock* syncBlock = gc.objRef->GetSyncBlock(); + InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfo(); + _ASSERTE(syncBlock->IsPrecious()); + + // Since the caller has the option of providing a wrapper, it is + // possible the supplied wrapper already has an associated external + // object and an object can only be associated with one external object. + if (!interopInfo->TrySetExternalComObjectContext((void**)extObjCxt)) + { + // Failed to set the context; one must already exist. + // Remove from the cache above as well. + ExtObjCxtCache::LockHolder lock(cache); + cache->Remove(resultHolder.GetContext()); - COMPlusThrow(kNotSupportedException); + COMPlusThrow(kNotSupportedException); + } + + // Detach from the holder to avoid cleanup. + (void)resultHolder.DetachContext(); + STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created EOC (Unique Instance: %d): 0x%p\n", (int)uniqueInstance, extObjCxt); } - // Detach from the holder to avoid cleanup. - (void)resultHolder.DetachContext(); - STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created EOC (Unique Instance: %d): 0x%p\n", (int)uniqueInstance, extObjCxt); + _ASSERTE(extObjCxt->IsActive()); } - - _ASSERTE(extObjCxt->IsActive()); } GCPROTECT_END(); @@ -957,6 +955,9 @@ namespace InteropLibImports externalObjectFlags, gc.wrapperMaybeRef); + if (gc.objRef == NULL) + COMPlusThrow(kArgumentNullException); + // Get wrapper for managed object *trackerTarget = GetOrCreateComInterfaceForObjectInternal( gc.implRef, @@ -1078,7 +1079,6 @@ void* QCALLTYPE ComWrappersNative::GetOrCreateComInterfaceForObject( END_QCALL; - _ASSERTE(wrapper != NULL); return wrapper; } diff --git a/src/coreclr/src/vm/marshalnative.cpp b/src/coreclr/src/vm/marshalnative.cpp index adc20310648cb..d0d3b27c68980 100644 --- a/src/coreclr/src/vm/marshalnative.cpp +++ b/src/coreclr/src/vm/marshalnative.cpp @@ -372,7 +372,7 @@ FCIMPL1(UINT32, MarshalNative::OffsetOfHelper, ReflectFieldObject *pFieldUNSAFE) FieldDesc *pField = refField->GetField(); TypeHandle th = TypeHandle(pField->GetApproxEnclosingMethodTable()); - + if (th.IsBlittable()) { return pField->GetOffset(); @@ -388,7 +388,7 @@ FCIMPL1(UINT32, MarshalNative::OffsetOfHelper, ReflectFieldObject *pFieldUNSAFE) { // It isn't marshalable so throw an ArgumentException. StackSString strTypeName; - TypeString::AppendType(strTypeName, th); + TypeString::AppendType(strTypeName, th); COMPlusThrow(kArgumentException, IDS_CANNOT_MARSHAL, strTypeName.GetUnicode(), NULL, NULL); } EEClassNativeLayoutInfo const* pNativeLayoutInfo = th.GetMethodTable()->GetNativeLayoutInfo(); @@ -401,7 +401,7 @@ FCIMPL1(UINT32, MarshalNative::OffsetOfHelper, ReflectFieldObject *pFieldUNSAFE) #endif while (numReferenceFields--) { - if (pNFD->GetFieldDesc() == pField) + if (pNFD->GetFieldDesc() == pField) { externalOffset = pNFD->GetExternalOffset(); INDEBUG(foundField = true); @@ -802,11 +802,7 @@ FCIMPL2(IUnknown*, MarshalNative::GetIUnknownForObjectNative, Object* orefUNSAFE OBJECTREF oref = (OBJECTREF) orefUNSAFE; HELPER_METHOD_FRAME_BEGIN_RET_1(oref); - HRESULT hr = S_OK; - - if(!oref) - COMPlusThrowArgumentNull(W("o")); - + _ASSERTE(oref != NULL); // Ensure COM is started up. EnsureComStarted(); @@ -868,11 +864,7 @@ FCIMPL2(IDispatch*, MarshalNative::GetIDispatchForObjectNative, Object* orefUNSA OBJECTREF oref = (OBJECTREF) orefUNSAFE; HELPER_METHOD_FRAME_BEGIN_RET_1(oref); - HRESULT hr = S_OK; - - if(!oref) - COMPlusThrowArgumentNull(W("o")); - + _ASSERTE(oref != NULL); // Ensure COM is started up. EnsureComStarted(); @@ -899,13 +891,8 @@ FCIMPL4(IUnknown*, MarshalNative::GetComInterfaceForObjectNative, Object* orefUN REFLECTCLASSBASEREF refClass = (REFLECTCLASSBASEREF) refClassUNSAFE; HELPER_METHOD_FRAME_BEGIN_RET_2(oref, refClass); - HRESULT hr = S_OK; - - if(!oref) - COMPlusThrowArgumentNull(W("o")); - if(!refClass) - COMPlusThrowArgumentNull(W("t")); - + _ASSERTE(oref != NULL); + _ASSERTE(refClass != NULL); // Ensure COM is started up. EnsureComStarted(); @@ -946,23 +933,18 @@ FCIMPLEND //==================================================================== // return an Object for IUnknown //==================================================================== -FCIMPL1(Object*, MarshalNative::GetObjectForIUnknown, IUnknown* pUnk) +FCIMPL1(Object*, MarshalNative::GetObjectForIUnknownNative, IUnknown* pUnk) { CONTRACTL { FCALL_CHECK; - PRECONDITION(CheckPointer(pUnk, NULL_OK)); + PRECONDITION(CheckPointer(pUnk)); } CONTRACTL_END; OBJECTREF oref = NULL; HELPER_METHOD_FRAME_BEGIN_RET_1(oref); - HRESULT hr = S_OK; - - if(!pUnk) - COMPlusThrowArgumentNull(W("pUnk")); - // Ensure COM is started up. EnsureComStarted(); @@ -974,12 +956,12 @@ FCIMPL1(Object*, MarshalNative::GetObjectForIUnknown, IUnknown* pUnk) FCIMPLEND -FCIMPL1(Object*, MarshalNative::GetUniqueObjectForIUnknown, IUnknown* pUnk) +FCIMPL1(Object*, MarshalNative::GetUniqueObjectForIUnknownNative, IUnknown* pUnk) { CONTRACTL { FCALL_CHECK; - PRECONDITION(CheckPointer(pUnk, NULL_OK)); + PRECONDITION(CheckPointer(pUnk)); } CONTRACTL_END; @@ -988,9 +970,6 @@ FCIMPL1(Object*, MarshalNative::GetUniqueObjectForIUnknown, IUnknown* pUnk) HRESULT hr = S_OK; - if(!pUnk) - COMPlusThrowArgumentNull(W("pUnk")); - // Ensure COM is started up. EnsureComStarted(); diff --git a/src/coreclr/src/vm/marshalnative.h b/src/coreclr/src/vm/marshalnative.h index 7ff167c586fc3..3325a0d05544e 100644 --- a/src/coreclr/src/vm/marshalnative.h +++ b/src/coreclr/src/vm/marshalnative.h @@ -106,12 +106,12 @@ class MarshalNative //==================================================================== // return an Object for IUnknown //==================================================================== - static FCDECL1(Object*, GetObjectForIUnknown, IUnknown* pUnk); + static FCDECL1(Object*, GetObjectForIUnknownNative, IUnknown* pUnk); //==================================================================== // return a unique cacheless Object for IUnknown //==================================================================== - static FCDECL1(Object*, GetUniqueObjectForIUnknown, IUnknown* pUnk); + static FCDECL1(Object*, GetUniqueObjectForIUnknownNative, IUnknown* pUnk); //==================================================================== // return a unique cacheless Object for IUnknown diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs index e85f87a4b8269..f06951b1f6fe5 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs @@ -155,8 +155,20 @@ class TestComWrappers : ComWrappers { public static readonly TestComWrappers Global = new TestComWrappers(); + public bool ReturnInvalid { get; set; } + + public object LastComputeVtablesObject { get; private set; } + protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) { + LastComputeVtablesObject = obj; + + if (ReturnInvalid) + { + count = -1; + return null; + } + Assert.IsTrue(obj is Test); IntPtr fpQueryInteface = default; @@ -187,10 +199,13 @@ class TestComWrappers : ComWrappers protected override object? CreateObject(IntPtr externalComObject, CreateObjectFlags flag) { + if (ReturnInvalid) + return null; + var iid = typeof(ITrackerObject).GUID; IntPtr iTestComObject; int hr = Marshal.QueryInterface(externalComObject, ref iid, out iTestComObject); - Assert.AreEqual(hr, 0); + Assert.AreEqual(0, hr); return new ITrackerObjectWrapper(iTestComObject); } @@ -479,6 +494,9 @@ static void ValidateGlobalInstanceScenarios() wrappers2.RegisterAsGlobalInstance(); }, "Should not be able to reset for global ComWrappers"); + ValidateMarshalAPIs(wrappers1, true); + ValidateMarshalAPIs(wrappers1, false); + Console.WriteLine($"Validate NotifyEndOfReferenceTrackingOnThread()..."); int hr; @@ -489,6 +507,41 @@ static void ValidateGlobalInstanceScenarios() Assert.AreEqual(TestComWrappers.ReleaseObjectsCallAck, hr); } + private static void ValidateMarshalAPIs(TestComWrappers registeredWrapper, bool validateUseRegistered) + { + registeredWrapper.ReturnInvalid = !validateUseRegistered; + + string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; + Console.WriteLine($"Validate Marshal.GetIUnknownForObject: {scenario}..."); + + var testObj = new Test(); + IntPtr comWrapper1 = Marshal.GetIUnknownForObject(testObj); + Assert.AreNotEqual(IntPtr.Zero, comWrapper1); + Assert.AreEqual(testObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + + IntPtr comWrapper2 = Marshal.GetIUnknownForObject(testObj); + Assert.AreEqual(comWrapper1, comWrapper2); + + Marshal.Release(comWrapper1); + Marshal.Release(comWrapper2); + + Console.WriteLine($"Validate Marshal.GetObjectForIUnknown: {scenario}..."); + + IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject(); + object objWrapper1 = Marshal.GetObjectForIUnknown(trackerObjRaw); + Assert.AreEqual(validateUseRegistered, objWrapper1 is ITrackerObjectWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(ITrackerObjectWrapper)} instance"); + object objWrapper2 = Marshal.GetObjectForIUnknown(trackerObjRaw); + Assert.AreEqual(objWrapper1, objWrapper2); + + Console.WriteLine($"Validate Marshal.GetUniqueObjectForIUnknown: {scenario}..."); + + object objWrapper3 = Marshal.GetUniqueObjectForIUnknown(trackerObjRaw); + Assert.AreEqual(validateUseRegistered, objWrapper3 is ITrackerObjectWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(ITrackerObjectWrapper)} instance"); + Assert.AreNotEqual(objWrapper1, objWrapper3); + + Marshal.Release(trackerObjRaw); + } + static int Main(string[] doNotUse) { try diff --git a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetComInterfaceForObjectTests.cs b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetComInterfaceForObjectTests.cs index f72fa3af23a00..26e1876686853 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetComInterfaceForObjectTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetComInterfaceForObjectTests.cs @@ -141,8 +141,8 @@ public void GetComInterfaceForObject_NullObject_ThrowsArgumentNullException() [PlatformSpecific(TestPlatforms.Windows)] public void GetComInterfaceForObject_NullType_ThrowsArgumentNullException() { - AssertExtensions.Throws("t", () => Marshal.GetComInterfaceForObject(new object(), null)); - AssertExtensions.Throws("t", () => Marshal.GetComInterfaceForObject(new object(), null, CustomQueryInterfaceMode.Allow)); + AssertExtensions.Throws("T", () => Marshal.GetComInterfaceForObject(new object(), null)); + AssertExtensions.Throws("T", () => Marshal.GetComInterfaceForObject(new object(), null, CustomQueryInterfaceMode.Allow)); } public static IEnumerable GetComInterfaceForObject_InvalidType_TestData() diff --git a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetUniqueObjectForIUnknownTests.cs b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetUniqueObjectForIUnknownTests.cs index 1da3b1529af9b..88d36ee5b5b35 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetUniqueObjectForIUnknownTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetUniqueObjectForIUnknownTests.cs @@ -63,7 +63,7 @@ public void GetUniqueObjectForIUnknown_Unix_ThrowsPlatformNotSupportedException( [PlatformSpecific(TestPlatforms.Windows)] public void GetUniqueObjectForIUnknown_NullPointer_ThrowsArgumentNullException() { - AssertExtensions.Throws("pUnk", () => Marshal.GetUniqueObjectForIUnknown(IntPtr.Zero)); + AssertExtensions.Throws("unknown", () => Marshal.GetUniqueObjectForIUnknown(IntPtr.Zero)); } private static void NonGenericMethod(int i) { } From 0a3ddb48aeb87ad89bcb6cf4c406f0855583ace4 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Wed, 11 Mar 2020 15:51:59 -0700 Subject: [PATCH 02/12] PR feedback --- .../Runtime/InteropServices/ComWrappers.cs | 3 +- .../InteropServices/Marshal.CoreCLR.cs | 6 +-- src/coreclr/src/vm/interoplibinterface.cpp | 41 ++++++++++--------- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 283d6dca228b8..cd4780d970d0c 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -158,7 +158,8 @@ internal static IntPtr GetOrCreateComInterfaceForObjectInternal(ComWrappers? imp /// All memory returned from this function must either be unmanaged memory, pinned managed memory, or have been /// allocated with the API. /// - /// If the interface entries cannot be created and null is returned, the call to will throw a . + /// If the interface entries cannot be created and a negative or null and a non-zero are returned, + /// the call to will throw a . /// protected unsafe abstract ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count); diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs index 64511d7185eb8..54aa45f6bbee3 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs @@ -332,7 +332,7 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) } // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - IntPtr ptrMaybe = ComWrappers.GetOrCreateComInterfaceForObjectInternal(impl: null, o, CreateComInterfaceFlags.None); + IntPtr ptrMaybe = ComWrappers.GetOrCreateComInterfaceForObjectInternal(impl: null, o, CreateComInterfaceFlags.TrackerSupport); if (ptrMaybe != IntPtr.Zero) return ptrMaybe; @@ -421,7 +421,7 @@ public static object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk) } // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - object? objMaybe = ComWrappers.GetOrCreateObjectForComInstanceInternal(impl: null, pUnk, CreateObjectFlags.None, wrapperMaybe: null); + object? objMaybe = ComWrappers.GetOrCreateObjectForComInstanceInternal(impl: null, pUnk, CreateObjectFlags.TrackerObject, wrapperMaybe: null); if (objMaybe != null) return objMaybe; @@ -439,7 +439,7 @@ public static object GetUniqueObjectForIUnknown(IntPtr unknown) } // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - object? objMaybe = ComWrappers.GetOrCreateObjectForComInstanceInternal(impl: null, unknown, CreateObjectFlags.UniqueInstance, wrapperMaybe: null); + object? objMaybe = ComWrappers.GetOrCreateObjectForComInstanceInternal(impl: null, unknown, CreateObjectFlags.TrackerObject | CreateObjectFlags.UniqueInstance, wrapperMaybe: null); if (objMaybe != null) return objMaybe; diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index 34e603d2c765c..1fba4f62a87fd 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -494,7 +494,7 @@ namespace HRESULT hr; SafeComHolder newWrapper; - void* wrapperRaw = NULL; + void* wrapperRawMaybe = NULL; struct { @@ -513,7 +513,7 @@ namespace _ASSERTE(syncBlock->IsPrecious()); // Query the associated InteropSyncBlockInfo for an existing managed object wrapper. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw)) + if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe)) { // Compute VTables for the new existing COM object using the supplied COM Wrappers implementation. // @@ -524,7 +524,7 @@ namespace void* vtables = CallComputeVTables(&gc.implRef, &gc.instRef, flags, &vtableCount); // Re-query the associated InteropSyncBlockInfo for an existing managed object wrapper. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw) + if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe) && ((vtables != nullptr && vtableCount > 0) || (vtableCount == 0))) { OBJECTHANDLE instHandle = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType); @@ -551,7 +551,7 @@ namespace // If the managed object wrapper couldn't be set, then // it should be possible to get the current one. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw)) + if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe)) { UNREACHABLE(); } @@ -564,18 +564,18 @@ namespace { // A new managed object wrapper was created, remove the object from the holder. // No AddRef() here since the wrapper should be created with a reference. - wrapperRaw = newWrapper.Extract(); - STRESS_LOG1(LF_INTEROP, LL_INFO100, "Created MOW: 0x%p\n", wrapperRaw); + wrapperRawMaybe = newWrapper.Extract(); + STRESS_LOG1(LF_INTEROP, LL_INFO100, "Created MOW: 0x%p\n", wrapperRawMaybe); } - else if (wrapperRaw != NULL) + else if (wrapperRawMaybe != NULL) { // It is possible the supplied wrapper is no longer valid. If so, reactivate the // wrapper using the protected OBJECTREF. - IUnknown* wrapper = static_cast(wrapperRaw); + IUnknown* wrapper = static_cast(wrapperRawMaybe); hr = InteropLib::Com::IsActiveWrapper(wrapper); if (hr == S_FALSE) { - STRESS_LOG1(LF_INTEROP, LL_INFO100, "Reactivating MOW: 0x%p\n", wrapperRaw); + STRESS_LOG1(LF_INTEROP, LL_INFO100, "Reactivating MOW: 0x%p\n", wrapperRawMaybe); OBJECTHANDLE h = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType); hr = InteropLib::Com::ReactivateWrapper(wrapper, static_cast(h)); } @@ -586,7 +586,7 @@ namespace GCPROTECT_END(); - RETURN wrapperRaw; + RETURN wrapperRawMaybe; } OBJECTREF GetOrCreateObjectForComInstanceInternal( @@ -610,7 +610,7 @@ namespace { OBJECTREF implRef; OBJECTREF wrapperMaybeRef; - OBJECTREF objRef; + OBJECTREF objRefMaybe; } gc; ::ZeroMemory(&gc, sizeof(gc)); GCPROTECT_BEGIN(gc); @@ -631,7 +631,7 @@ namespace if (extObjCxt != NULL) { - gc.objRef = extObjCxt->GetObjectRef(); + gc.objRefMaybe = extObjCxt->GetObjectRef(); } else { @@ -646,15 +646,18 @@ namespace COMPlusThrowHR(hr); // The user could have supplied a wrapper so assign that now. - gc.objRef = gc.wrapperMaybeRef; + gc.objRefMaybe = gc.wrapperMaybeRef; // If the wrapper hasn't been set yet, call the implementation to create one. - if (gc.objRef == NULL) + if (gc.objRefMaybe == NULL) { - gc.objRef = CallGetObject(&gc.implRef, identity, flags); + gc.objRefMaybe = CallGetObject(&gc.implRef, identity, flags); } - if (gc.objRef != NULL) + // The object may be null if the specified ComWrapper implementation returns null + // or there is no registered global instance. It is the caller's responsibility + // to handle this case and error if necessary. + if (gc.objRefMaybe != NULL) { // Construct the new context with the object details. DWORD flags = (resultHolder.Result.FromTrackerRuntime @@ -667,7 +670,7 @@ namespace resultHolder.GetContext(), identity, GetCurrentCtxCookie(), - gc.objRef->GetSyncBlockIndex(), + gc.objRefMaybe->GetSyncBlockIndex(), flags); if (uniqueInstance) @@ -686,7 +689,7 @@ namespace if (extObjCxt == resultHolder.GetContext()) { // Update the object's SyncBlock with a handle to the context for runtime cleanup. - SyncBlock* syncBlock = gc.objRef->GetSyncBlock(); + SyncBlock* syncBlock = gc.objRefMaybe->GetSyncBlock(); InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfo(); _ASSERTE(syncBlock->IsPrecious()); @@ -714,7 +717,7 @@ namespace GCPROTECT_END(); - RETURN gc.objRef; + RETURN gc.objRefMaybe; } } From dad304e6c50544f9bd23f6b1173781f82559c4fe Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Thu, 12 Mar 2020 14:15:36 -0700 Subject: [PATCH 03/12] Switch to Try semantics --- .../Runtime/InteropServices/ComWrappers.cs | 38 +++++------ .../InteropServices/Marshal.CoreCLR.cs | 16 ++--- src/coreclr/src/vm/ecalllist.h | 4 +- src/coreclr/src/vm/interoplibinterface.cpp | 63 ++++++++++++------- src/coreclr/src/vm/interoplibinterface.h | 7 ++- 5 files changed, 74 insertions(+), 54 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index cd4780d970d0c..731437a123009 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -119,8 +119,8 @@ private struct ComInterfaceInstance /// The generated COM interface that can be passed outside the .NET runtime. public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags) { - IntPtr ptr = GetOrCreateComInterfaceForObjectInternal(this, instance, flags); - if (ptr == IntPtr.Zero) + IntPtr ptr; + if (!TryGetOrCreateComInterfaceForObjectInternal(this, instance, flags, out ptr)) throw new ArgumentException(); return ptr; @@ -132,20 +132,21 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa /// The implemenentation to use when creating the COM representation. /// The managed object to expose outside the .NET runtime. /// Flags used to configure the generated interface. - /// The generated COM interface that can be passed outside the .NET runtime or IntPtr.Zero if it could not be created. + /// The generated COM interface that can be passed outside the .NET runtime or IntPtr.Zero if it could not be created. + /// Returns true if a CEM representation could be created, false otherwise /// /// If is null, the global instance (if registered) will be used. /// - internal static IntPtr GetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags) + internal static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue) { if (instance == null) throw new ArgumentNullException(nameof(instance)); - return GetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags); + return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags, out retValue); } [DllImport(RuntimeHelpers.QCall)] - private static extern IntPtr GetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags); + private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue); /// /// Compute the desired Vtable for respecting the values of . @@ -184,11 +185,11 @@ internal static IntPtr GetOrCreateComInterfaceForObjectInternal(ComWrappers? imp /// Returns a managed object associated with the supplied external COM object. public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags) { - object? obj = GetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, null); - if (obj == null) + object? obj; + if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, null, out obj)) throw new ArgumentNullException(); - return obj; + return obj!; } /// @@ -227,11 +228,11 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create if (wrapper == null) throw new ArgumentNullException(nameof(externalComObject)); - object? obj = GetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, wrapper); - if (obj == null) + object? obj; + if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, wrapper, out obj)) throw new ArgumentNullException(); - return obj; + return obj!; } /// @@ -241,24 +242,23 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create /// Object to import for usage into the .NET runtime. /// Flags used to describe the external object. /// The to be used as the wrapper for the external object. - /// Returns a managed object associated with the supplied external COM object or null if it could not be created. + /// The managed object associated with the supplied external COM object or null if it could not be created. + /// Returns true if a managed object could be retrieved/created, false otherwise /// /// If is null, the global instance (if registered) will be used. /// - internal static object? GetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe) + internal static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue) { if (externalComObject == IntPtr.Zero) throw new ArgumentNullException(nameof(externalComObject)); object? wrapperMaybeLocal = wrapperMaybe; - object? retValue = null; - GetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue)); - - return retValue; + retValue = null; + return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue)); } [DllImport(RuntimeHelpers.QCall)] - private static extern void GetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue); + private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue); /// /// Called when a request is made for a collection of objects to be released outside of normal object or COM interface lifetime. diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs index 54aa45f6bbee3..11b0fdf1241cb 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs @@ -332,8 +332,8 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) } // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - IntPtr ptrMaybe = ComWrappers.GetOrCreateComInterfaceForObjectInternal(impl: null, o, CreateComInterfaceFlags.TrackerSupport); - if (ptrMaybe != IntPtr.Zero) + IntPtr ptrMaybe; + if (ComWrappers.TryGetOrCreateComInterfaceForObjectInternal(impl: null, o, CreateComInterfaceFlags.TrackerSupport, out ptrMaybe)) return ptrMaybe; return GetIUnknownForObjectNative(o, false); @@ -421,9 +421,9 @@ public static object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk) } // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - object? objMaybe = ComWrappers.GetOrCreateObjectForComInstanceInternal(impl: null, pUnk, CreateObjectFlags.TrackerObject, wrapperMaybe: null); - if (objMaybe != null) - return objMaybe; + object? objMaybe; + if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, pUnk, CreateObjectFlags.TrackerObject, wrapperMaybe: null, out objMaybe)) + return objMaybe!; return GetObjectForIUnknownNative(pUnk); } @@ -439,9 +439,9 @@ public static object GetUniqueObjectForIUnknown(IntPtr unknown) } // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - object? objMaybe = ComWrappers.GetOrCreateObjectForComInstanceInternal(impl: null, unknown, CreateObjectFlags.TrackerObject | CreateObjectFlags.UniqueInstance, wrapperMaybe: null); - if (objMaybe != null) - return objMaybe; + object? objMaybe; + if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, unknown, CreateObjectFlags.TrackerObject | CreateObjectFlags.UniqueInstance, wrapperMaybe: null, out objMaybe)) + return objMaybe!; return GetUniqueObjectForIUnknownNative(unknown); } diff --git a/src/coreclr/src/vm/ecalllist.h b/src/coreclr/src/vm/ecalllist.h index 4025d2162fd11..03b2d7c3eee32 100644 --- a/src/coreclr/src/vm/ecalllist.h +++ b/src/coreclr/src/vm/ecalllist.h @@ -997,8 +997,8 @@ FCFuncEnd() #ifdef FEATURE_COMWRAPPERS FCFuncStart(gComWrappersFuncs) QCFuncElement("GetIUnknownImplInternal", ComWrappersNative::GetIUnknownImpl) - QCFuncElement("GetOrCreateComInterfaceForObjectInternal", ComWrappersNative::GetOrCreateComInterfaceForObject) - QCFuncElement("GetOrCreateObjectForComInstanceInternal", ComWrappersNative::GetOrCreateObjectForComInstance) + QCFuncElement("TryGetOrCreateComInterfaceForObjectInternal", ComWrappersNative::TryGetOrCreateComInterfaceForObject) + QCFuncElement("TryGetOrCreateObjectForComInstanceInternal", ComWrappersNative::TryGetOrCreateObjectForComInstance) FCFuncEnd() #endif // FEATURE_COMWRAPPERS diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index 1fba4f62a87fd..ea2035b9bd930 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -478,16 +478,18 @@ namespace CALL_MANAGED_METHOD_NORET(args); } - void* GetOrCreateComInterfaceForObjectInternal( + bool TryGetOrCreateComInterfaceForObjectInternal( _In_opt_ OBJECTREF impl, _In_ OBJECTREF instance, - _In_ CreateComInterfaceFlags flags) + _In_ CreateComInterfaceFlags flags, + _Outptr_ void** wrapperRaw) { - CONTRACT(void*) + CONTRACT(bool) { THROWS; MODE_COOPERATIVE; PRECONDITION(instance != NULL); + PRECONDITION(wrapperRaw != NULL); } CONTRACT_END; @@ -586,20 +588,23 @@ namespace GCPROTECT_END(); - RETURN wrapperRawMaybe; + *wrapperRaw = wrapperRawMaybe; + RETURN (wrapperRawMaybe != NULL); } - OBJECTREF GetOrCreateObjectForComInstanceInternal( + bool TryGetOrCreateObjectForComInstanceInternal( _In_opt_ OBJECTREF impl, _In_ IUnknown* identity, _In_ CreateObjectFlags flags, - _In_opt_ OBJECTREF wrapperMaybe) + _In_opt_ OBJECTREF wrapperMaybe, + _Out_ OBJECTREF* objRef) { - CONTRACT(OBJECTREF) + CONTRACT(bool) { THROWS; MODE_COOPERATIVE; PRECONDITION(identity != NULL); + PRECONDITION(objRef != NULL); } CONTRACT_END; @@ -717,7 +722,8 @@ namespace GCPROTECT_END(); - RETURN gc.objRefMaybe; + *objRef = gc.objRefMaybe; + RETURN (gc.objRefMaybe != NULL); } } @@ -952,20 +958,25 @@ namespace InteropLibImports gc.wrapperMaybeRef = NULL; // No supplied wrapper here. // Get wrapper for external object - gc.objRef = GetOrCreateObjectForComInstanceInternal( + bool success = TryGetOrCreateObjectForComInstanceInternal( gc.implRef, externalComObject, externalObjectFlags, - gc.wrapperMaybeRef); + gc.wrapperMaybeRef, + &gc.objRef); - if (gc.objRef == NULL) + if (!success) COMPlusThrow(kArgumentNullException); // Get wrapper for managed object - *trackerTarget = GetOrCreateComInterfaceForObjectInternal( + success = TryGetOrCreateComInterfaceForObjectInternal( gc.implRef, gc.objRef, - trackerTargetFlags); + trackerTargetFlags, + trackerTarget); + + if (!success) + COMPlusThrow(kArgumentException); STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created Target for External: 0x%p => 0x%p\n", OBJECTREFToObject(gc.objRef), *trackerTarget); GCPROTECT_END(); @@ -1059,14 +1070,15 @@ namespace InteropLibImports #ifdef FEATURE_COMWRAPPERS -void* QCALLTYPE ComWrappersNative::GetOrCreateComInterfaceForObject( +BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject( _In_ QCall::ObjectHandleOnStack comWrappersImpl, _In_ QCall::ObjectHandleOnStack instance, - _In_ INT32 flags) + _In_ INT32 flags, + _Outptr_ void** wrapper) { QCALL_CONTRACT; - void* wrapper = NULL; + bool success; BEGIN_QCALL; @@ -1074,18 +1086,19 @@ void* QCALLTYPE ComWrappersNative::GetOrCreateComInterfaceForObject( // are being manipulated. { GCX_COOP(); - wrapper = GetOrCreateComInterfaceForObjectInternal( + success = TryGetOrCreateComInterfaceForObjectInternal( ObjectToOBJECTREF(*comWrappersImpl.m_ppObject), ObjectToOBJECTREF(*instance.m_ppObject), - (CreateComInterfaceFlags)flags); + (CreateComInterfaceFlags)flags, + wrapper); } END_QCALL; - return wrapper; + return success; } -void QCALLTYPE ComWrappersNative::GetOrCreateObjectForComInstance( +BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( _In_ QCall::ObjectHandleOnStack comWrappersImpl, _In_ void* ext, _In_ INT32 flags, @@ -1096,6 +1109,8 @@ void QCALLTYPE ComWrappersNative::GetOrCreateObjectForComInstance( _ASSERTE(ext != NULL); + bool success; + BEGIN_QCALL; HRESULT hr; @@ -1110,17 +1125,21 @@ void QCALLTYPE ComWrappersNative::GetOrCreateObjectForComInstance( // are being manipulated. { GCX_COOP(); - OBJECTREF newObj = GetOrCreateObjectForComInstanceInternal( + OBJECTREF newObj; + success = TryGetOrCreateObjectForComInstanceInternal( ObjectToOBJECTREF(*comWrappersImpl.m_ppObject), identity, (CreateObjectFlags)flags, - ObjectToOBJECTREF(*wrapperMaybe.m_ppObject)); + ObjectToOBJECTREF(*wrapperMaybe.m_ppObject), + &newObj); // Set the return value retValue.Set(newObj); } END_QCALL; + + return success; } void QCALLTYPE ComWrappersNative::GetIUnknownImpl( diff --git a/src/coreclr/src/vm/interoplibinterface.h b/src/coreclr/src/vm/interoplibinterface.h index 2ed54cfc90706..5ea4b55c28f50 100644 --- a/src/coreclr/src/vm/interoplibinterface.h +++ b/src/coreclr/src/vm/interoplibinterface.h @@ -17,12 +17,13 @@ class ComWrappersNative _Out_ void** fpAddRef, _Out_ void** fpRelease); - static void* QCALLTYPE GetOrCreateComInterfaceForObject( + static BOOL QCALLTYPE TryGetOrCreateComInterfaceForObject( _In_ QCall::ObjectHandleOnStack comWrappersImpl, _In_ QCall::ObjectHandleOnStack instance, - _In_ INT32 flags); + _In_ INT32 flags, + _Outptr_ void** wrapperRaw); - static void QCALLTYPE GetOrCreateObjectForComInstance( + static BOOL QCALLTYPE TryGetOrCreateObjectForComInstance( _In_ QCall::ObjectHandleOnStack comWrappersImpl, _In_ void* externalComObject, _In_ INT32 flags, From 234c733bde7e74c684d5b7423333007ab52c41ab Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Thu, 12 Mar 2020 15:05:58 -0700 Subject: [PATCH 04/12] Check if global instance is registered --- .../Runtime/InteropServices/ComWrappers.cs | 8 ++++ .../InteropServices/Marshal.CoreCLR.cs | 39 +++++++++++++------ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 731437a123009..6ac8a57a2001b 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -289,6 +289,14 @@ public void RegisterAsGlobalInstance() } } + /// + /// Get whether or not a global instance has been registered. + /// + internal static bool IsGlobalInstanceRegistered() + { + return s_globalInstance != null; + } + /// /// Get the runtime provided IUnknown implementation. /// diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs index 11b0fdf1241cb..b702d8bff8ed1 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs @@ -331,10 +331,15 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) throw new ArgumentNullException(nameof(o)); } - // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - IntPtr ptrMaybe; - if (ComWrappers.TryGetOrCreateComInterfaceForObjectInternal(impl: null, o, CreateComInterfaceFlags.TrackerSupport, out ptrMaybe)) - return ptrMaybe; + if (ComWrappers.IsGlobalInstanceRegistered()) + { + // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) + IntPtr ptrMaybe; + if (ComWrappers.TryGetOrCreateComInterfaceForObjectInternal(impl: null, o, CreateComInterfaceFlags.TrackerSupport, out ptrMaybe)) + { + return ptrMaybe; + } + } return GetIUnknownForObjectNative(o, false); } @@ -420,10 +425,15 @@ public static object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk) throw new ArgumentNullException(nameof(pUnk)); } - // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - object? objMaybe; - if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, pUnk, CreateObjectFlags.TrackerObject, wrapperMaybe: null, out objMaybe)) - return objMaybe!; + if (ComWrappers.IsGlobalInstanceRegistered()) + { + // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) + object? objMaybe; + if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, pUnk, CreateObjectFlags.TrackerObject, wrapperMaybe: null, out objMaybe)) + { + return objMaybe!; + } + } return GetObjectForIUnknownNative(pUnk); } @@ -438,10 +448,15 @@ public static object GetUniqueObjectForIUnknown(IntPtr unknown) throw new ArgumentNullException(nameof(unknown)); } - // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - object? objMaybe; - if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, unknown, CreateObjectFlags.TrackerObject | CreateObjectFlags.UniqueInstance, wrapperMaybe: null, out objMaybe)) - return objMaybe!; + if (ComWrappers.IsGlobalInstanceRegistered()) + { + // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) + object? objMaybe; + if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, unknown, CreateObjectFlags.TrackerObject | CreateObjectFlags.UniqueInstance, wrapperMaybe: null, out objMaybe)) + { + return objMaybe!; + } + } return GetUniqueObjectForIUnknownNative(unknown); } From e090745848830f031a9a42496874f58003431cf8 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Thu, 12 Mar 2020 19:07:12 -0700 Subject: [PATCH 05/12] PR feedback --- .../src/System/Runtime/InteropServices/ComWrappers.cs | 2 +- src/coreclr/src/vm/interoplibinterface.cpp | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 6ac8a57a2001b..908523957b196 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -133,7 +133,7 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa /// The managed object to expose outside the .NET runtime. /// Flags used to configure the generated interface. /// The generated COM interface that can be passed outside the .NET runtime or IntPtr.Zero if it could not be created. - /// Returns true if a CEM representation could be created, false otherwise + /// Returns true if a COM representation could be created, false otherwise /// /// If is null, the global instance (if registered) will be used. /// diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index ea2035b9bd930..00474e30253f8 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -1095,7 +1095,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject( END_QCALL; - return success; + return (success ? TRUE : FALSE); } BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( @@ -1134,12 +1134,13 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( &newObj); // Set the return value - retValue.Set(newObj); + if (success) + retValue.Set(newObj); } END_QCALL; - return success; + return (success ? TRUE : FALSE); } void QCALLTYPE ComWrappersNative::GetIUnknownImpl( From 17e6db51bc40b0be56669628a38242de52b2aa1b Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Tue, 17 Mar 2020 13:41:53 -0700 Subject: [PATCH 06/12] Use registered ComWrappers in InterfaceMarshaaler for object <-> IUnknown --- .../src/System/StubHelpers.cs | 54 ++++++++++++++++++- src/coreclr/src/vm/ecalllist.h | 4 +- .../ReferenceTrackerRuntime.cpp | 11 ++++ .../src/Interop/COM/ComWrappers/Program.cs | 30 +++++++++++ 4 files changed, 95 insertions(+), 4 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs index 8c94325722e13..62a7ff1907de8 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs @@ -686,11 +686,61 @@ internal static long ConvertToManaged(double nativeDate) #if FEATURE_COMINTEROP internal static class InterfaceMarshaler { + // See interopconverter.h + [Flags] + private enum ItfMarshalFlags : int + { + ITF_MARSHAL_INSP_ITF = 0x01, + ITF_MARSHAL_SUPPRESS_ADDREF = 0x02, + ITF_MARSHAL_CLASS_IS_HINT = 0x04, + ITF_MARSHAL_DISP_ITF = 0x08, + ITF_MARSHAL_USE_BASIC_ITF = 0x10, + ITF_MARSHAL_WINRT_SCENARIO = 0x20, + }; + + private static bool IsForIUnknown(int flags) + { + ItfMarshalFlags interfaceFlags = (ItfMarshalFlags)flags; + return (interfaceFlags & ItfMarshalFlags.ITF_MARSHAL_USE_BASIC_ITF) != 0 + && (interfaceFlags & ItfMarshalFlags.ITF_MARSHAL_INSP_ITF) == 0 + && (interfaceFlags & ItfMarshalFlags.ITF_MARSHAL_DISP_ITF) == 0; + } + + internal static IntPtr ConvertToNative(object objSrc, IntPtr itfMT, IntPtr classMT, int flags) + { + if (ComWrappers.IsGlobalInstanceRegistered() && IsForIUnknown(flags)) + { + // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) + IntPtr ptrMaybe; + if (ComWrappers.TryGetOrCreateComInterfaceForObjectInternal(impl: null, objSrc, CreateComInterfaceFlags.TrackerSupport, out ptrMaybe)) + { + return ptrMaybe; + } + } + + return ConvertToNativeInternal(objSrc, itfMT, classMT, flags); + } + [MethodImpl(MethodImplOptions.InternalCall)] - internal static extern IntPtr ConvertToNative(object objSrc, IntPtr itfMT, IntPtr classMT, int flags); + internal static extern IntPtr ConvertToNativeInternal(object objSrc, IntPtr itfMT, IntPtr classMT, int flags); + + internal static object ConvertToManaged(IntPtr pUnk, IntPtr itfMT, IntPtr classMT, int flags) + { + if (ComWrappers.IsGlobalInstanceRegistered() && IsForIUnknown(flags)) + { + // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) + object? objMaybe; + if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, Marshal.ReadIntPtr(pUnk), CreateObjectFlags.TrackerObject, wrapperMaybe: null, out objMaybe)) + { + return objMaybe!; + } + } + + return ConvertToManagedInternal(pUnk, itfMT, classMT, flags); + } [MethodImpl(MethodImplOptions.InternalCall)] - internal static extern object ConvertToManaged(IntPtr pUnk, IntPtr itfMT, IntPtr classMT, int flags); + internal static extern object ConvertToManagedInternal(IntPtr pUnk, IntPtr itfMT, IntPtr classMT, int flags); [DllImport(RuntimeHelpers.QCall)] internal static extern void ClearNative(IntPtr pUnk); diff --git a/src/coreclr/src/vm/ecalllist.h b/src/coreclr/src/vm/ecalllist.h index 03b2d7c3eee32..57b12dc9d79d7 100644 --- a/src/coreclr/src/vm/ecalllist.h +++ b/src/coreclr/src/vm/ecalllist.h @@ -958,8 +958,8 @@ FCFuncStart(gObjectMarshalerFuncs) FCFuncEnd() FCFuncStart(gInterfaceMarshalerFuncs) - FCFuncElement("ConvertToNative", StubHelpers::InterfaceMarshaler__ConvertToNative) - FCFuncElement("ConvertToManaged", StubHelpers::InterfaceMarshaler__ConvertToManaged) + FCFuncElement("ConvertToNativeInternal", StubHelpers::InterfaceMarshaler__ConvertToNative) + FCFuncElement("ConvertToManagedInternal", StubHelpers::InterfaceMarshaler__ConvertToManaged) QCFuncElement("ClearNative", StubHelpers::InterfaceMarshaler__ClearNative) FCFuncElement("ConvertToManagedWithoutUnboxing", StubHelpers::InterfaceMarshaler__ConvertToManagedWithoutUnboxing) FCFuncEnd() diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp index f0a0f30b277e7..6afa734c06a34 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp @@ -337,3 +337,14 @@ extern "C" DLL_EXPORT int STDMETHODCALLTYPE Trigger_NotifyEndOfReferenceTracking { return TrackerRuntimeManager.NotifyEndOfReferenceTrackingOnThread(); } + +extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObject(IUnknown *obj, int i) +{ + if (obj == nullptr) + return E_POINTER; + + HRESULT hr; + ComSmartPtr testObj; + RETURN_IF_FAILED(obj->QueryInterface(&testObj)) + return testObj->SetValue(i); +} diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs index f06951b1f6fe5..ef496f52df60c 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs @@ -83,6 +83,16 @@ struct MockReferenceTrackerRuntime extern public static int Trigger_NotifyEndOfReferenceTrackingOnThread(); } + struct MarshalInterface + { + [DllImport(nameof(MockReferenceTrackerRuntime))] + [return: MarshalAs(UnmanagedType.IUnknown)] + extern public static object? CreateTrackerObject(); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static void UpdateTestObject([MarshalAs(UnmanagedType.IUnknown)] object testObj, int i); + } + [Guid("42951130-245C-485E-B60B-4ED4254256F8")] public interface ITrackerObject { @@ -497,6 +507,9 @@ static void ValidateGlobalInstanceScenarios() ValidateMarshalAPIs(wrappers1, true); ValidateMarshalAPIs(wrappers1, false); + ValidateInterfaceMarshaler(wrappers1, true); + ValidateInterfaceMarshaler(wrappers1, false); + Console.WriteLine($"Validate NotifyEndOfReferenceTrackingOnThread()..."); int hr; @@ -542,6 +555,23 @@ private static void ValidateMarshalAPIs(TestComWrappers registeredWrapper, bool Marshal.Release(trackerObjRaw); } + private static void ValidateInterfaceMarshaler(TestComWrappers registeredWrapper, bool validateUseRegistered) + { + registeredWrapper.ReturnInvalid = !validateUseRegistered; + string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; + + Console.WriteLine($"Validate ConvertToNative: {scenario}..."); + var testObj = new Test(); + int value = 10; + MarshalInterface.UpdateTestObject(testObj, value); + Assert.AreEqual(validateUseRegistered, value == testObj.GetValue()); + Assert.AreEqual(testObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + + Console.WriteLine($"Validate ConvertToManaged: {scenario}..."); + object obj = MarshalInterface.CreateTrackerObject(); + Assert.AreEqual(validateUseRegistered, obj is ITrackerObjectWrapper, $"Should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(ITrackerObjectWrapper)} instance"); + } + static int Main(string[] doNotUse) { try From ad962e7e1f41f111218f00eb9ea88b730966abe7 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Tue, 17 Mar 2020 15:41:54 -0700 Subject: [PATCH 07/12] PR feedback --- .../src/System.Private.CoreLib/src/System/StubHelpers.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs index 62a7ff1907de8..7defd5ff6a6a9 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs @@ -688,7 +688,7 @@ internal static class InterfaceMarshaler { // See interopconverter.h [Flags] - private enum ItfMarshalFlags : int + private enum ItfMarshalFlags { ITF_MARSHAL_INSP_ITF = 0x01, ITF_MARSHAL_SUPPRESS_ADDREF = 0x02, From 33280fe997cd4c1f06083038428f0d2c91d6c956 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Thu, 19 Mar 2020 18:38:17 -0700 Subject: [PATCH 08/12] Use registered ComWrappers for all object <-> COM interface --- .../Runtime/InteropServices/ComWrappers.cs | 15 +- .../InteropServices/Marshal.CoreCLR.cs | 28 ---- .../src/System/StubHelpers.cs | 54 +------- src/coreclr/src/interop/comwrappers.hpp | 3 +- src/coreclr/src/interop/inc/interoplib.h | 6 + src/coreclr/src/interop/interoplib.cpp | 37 +++++ src/coreclr/src/vm/ecalllist.h | 5 +- src/coreclr/src/vm/interopconverter.cpp | 128 +++++++++++++++--- src/coreclr/src/vm/interoplibinterface.cpp | 108 ++++++++++++++- src/coreclr/src/vm/interoplibinterface.h | 20 +++ src/coreclr/src/vm/runtimecallablewrapper.cpp | 3 + 11 files changed, 293 insertions(+), 114 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 908523957b196..07df6081b24b8 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -137,7 +137,7 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa /// /// If is null, the global instance (if registered) will be used. /// - internal static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue) + private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue) { if (instance == null) throw new ArgumentNullException(nameof(instance)); @@ -247,7 +247,7 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create /// /// If is null, the global instance (if registered) will be used. /// - internal static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue) + private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue) { if (externalComObject == IntPtr.Zero) throw new ArgumentNullException(nameof(externalComObject)); @@ -287,16 +287,13 @@ public void RegisterAsGlobalInstance() { throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance); } - } - /// - /// Get whether or not a global instance has been registered. - /// - internal static bool IsGlobalInstanceRegistered() - { - return s_globalInstance != null; + SetGlobalInstanceRegistered(); } + [DllImport(RuntimeHelpers.QCall)] + private static extern void SetGlobalInstanceRegistered(); + /// /// Get the runtime provided IUnknown implementation. /// diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs index b702d8bff8ed1..dfef4090c4590 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs @@ -331,16 +331,6 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) throw new ArgumentNullException(nameof(o)); } - if (ComWrappers.IsGlobalInstanceRegistered()) - { - // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - IntPtr ptrMaybe; - if (ComWrappers.TryGetOrCreateComInterfaceForObjectInternal(impl: null, o, CreateComInterfaceFlags.TrackerSupport, out ptrMaybe)) - { - return ptrMaybe; - } - } - return GetIUnknownForObjectNative(o, false); } @@ -425,15 +415,6 @@ public static object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk) throw new ArgumentNullException(nameof(pUnk)); } - if (ComWrappers.IsGlobalInstanceRegistered()) - { - // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - object? objMaybe; - if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, pUnk, CreateObjectFlags.TrackerObject, wrapperMaybe: null, out objMaybe)) - { - return objMaybe!; - } - } return GetObjectForIUnknownNative(pUnk); } @@ -448,15 +429,6 @@ public static object GetUniqueObjectForIUnknown(IntPtr unknown) throw new ArgumentNullException(nameof(unknown)); } - if (ComWrappers.IsGlobalInstanceRegistered()) - { - // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - object? objMaybe; - if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, unknown, CreateObjectFlags.TrackerObject | CreateObjectFlags.UniqueInstance, wrapperMaybe: null, out objMaybe)) - { - return objMaybe!; - } - } return GetUniqueObjectForIUnknownNative(unknown); } diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs index 7defd5ff6a6a9..691113f665be0 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs @@ -686,61 +686,11 @@ internal static long ConvertToManaged(double nativeDate) #if FEATURE_COMINTEROP internal static class InterfaceMarshaler { - // See interopconverter.h - [Flags] - private enum ItfMarshalFlags - { - ITF_MARSHAL_INSP_ITF = 0x01, - ITF_MARSHAL_SUPPRESS_ADDREF = 0x02, - ITF_MARSHAL_CLASS_IS_HINT = 0x04, - ITF_MARSHAL_DISP_ITF = 0x08, - ITF_MARSHAL_USE_BASIC_ITF = 0x10, - ITF_MARSHAL_WINRT_SCENARIO = 0x20, - }; - - private static bool IsForIUnknown(int flags) - { - ItfMarshalFlags interfaceFlags = (ItfMarshalFlags)flags; - return (interfaceFlags & ItfMarshalFlags.ITF_MARSHAL_USE_BASIC_ITF) != 0 - && (interfaceFlags & ItfMarshalFlags.ITF_MARSHAL_INSP_ITF) == 0 - && (interfaceFlags & ItfMarshalFlags.ITF_MARSHAL_DISP_ITF) == 0; - } - - internal static IntPtr ConvertToNative(object objSrc, IntPtr itfMT, IntPtr classMT, int flags) - { - if (ComWrappers.IsGlobalInstanceRegistered() && IsForIUnknown(flags)) - { - // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - IntPtr ptrMaybe; - if (ComWrappers.TryGetOrCreateComInterfaceForObjectInternal(impl: null, objSrc, CreateComInterfaceFlags.TrackerSupport, out ptrMaybe)) - { - return ptrMaybe; - } - } - - return ConvertToNativeInternal(objSrc, itfMT, classMT, flags); - } - [MethodImpl(MethodImplOptions.InternalCall)] - internal static extern IntPtr ConvertToNativeInternal(object objSrc, IntPtr itfMT, IntPtr classMT, int flags); - - internal static object ConvertToManaged(IntPtr pUnk, IntPtr itfMT, IntPtr classMT, int flags) - { - if (ComWrappers.IsGlobalInstanceRegistered() && IsForIUnknown(flags)) - { - // Passing null as the ComWrapper implementation will use the globally registered wrappper (if available) - object? objMaybe; - if (ComWrappers.TryGetOrCreateObjectForComInstanceInternal(impl: null, Marshal.ReadIntPtr(pUnk), CreateObjectFlags.TrackerObject, wrapperMaybe: null, out objMaybe)) - { - return objMaybe!; - } - } - - return ConvertToManagedInternal(pUnk, itfMT, classMT, flags); - } + internal static extern IntPtr ConvertToNative(object objSrc, IntPtr itfMT, IntPtr classMT, int flags); [MethodImpl(MethodImplOptions.InternalCall)] - internal static extern object ConvertToManagedInternal(IntPtr pUnk, IntPtr itfMT, IntPtr classMT, int flags); + internal static extern object ConvertToManaged(IntPtr ppUnk, IntPtr itfMT, IntPtr classMT, int flags); [DllImport(RuntimeHelpers.QCall)] internal static extern void ClearNative(IntPtr pUnk); diff --git a/src/coreclr/src/interop/comwrappers.hpp b/src/coreclr/src/interop/comwrappers.hpp index d2337746e39b8..cf37f30b53dce 100644 --- a/src/coreclr/src/interop/comwrappers.hpp +++ b/src/coreclr/src/interop/comwrappers.hpp @@ -17,8 +17,9 @@ enum class CreateComInterfaceFlagsEx : int32_t // Highest bit is reserved for internal usage IsPegged = 1 << 31, + IsComActivated = 2 << 31, - InternalMask = IsPegged, + InternalMask = IsPegged | IsComActivated, }; DEFINE_ENUM_FLAG_OPERATORS(CreateComInterfaceFlagsEx); diff --git a/src/coreclr/src/interop/inc/interoplib.h b/src/coreclr/src/interop/inc/interoplib.h index df59b8b90a584..2572efc5ece95 100644 --- a/src/coreclr/src/interop/inc/interoplib.h +++ b/src/coreclr/src/interop/inc/interoplib.h @@ -45,6 +45,12 @@ namespace InteropLib // Reactivate the supplied wrapper. HRESULT ReactivateWrapper(_In_ IUnknown* wrapper, _In_ InteropLib::OBJECTHANDLE handle) noexcept; + // Get the object for the supplied wrapper + HRESULT GetObjectForWrapper(_In_ IUnknown *wrapper, _Out_ OBJECTHANDLE *object) noexcept; + + HRESULT MarkComActivated(_In_ IUnknown *wrapper) noexcept; + HRESULT IsComActivated(_In_ IUnknown *wrapper) noexcept; + struct ExternalWrapperResult { // The returned context memory is guaranteed to be initialized to zero. diff --git a/src/coreclr/src/interop/interoplib.cpp b/src/coreclr/src/interop/interoplib.cpp index 8283024ad5d4a..669b6ad2fb1fe 100644 --- a/src/coreclr/src/interop/interoplib.cpp +++ b/src/coreclr/src/interop/interoplib.cpp @@ -89,6 +89,43 @@ namespace InteropLib return S_OK; } + HRESULT GetObjectForWrapper(_In_ IUnknown *wrapper, _Out_ OBJECTHANDLE *object) noexcept + { + if (object == nullptr) + return E_POINTER; + + *object = nullptr; + + HRESULT hr = IsActiveWrapper(wrapper); + if (hr != S_OK) + return hr; + + ManagedObjectWrapper *mow = ManagedObjectWrapper::MapFromIUnknown(wrapper); + _ASSERTE(mow != nullptr); + + *object = mow->Target; + return S_OK; + } + + HRESULT MarkComActivated(_In_ IUnknown *wrapperMaybe) noexcept + { + ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe); + if (wrapper == nullptr) + return E_INVALIDARG; + + wrapper->SetFlag(CreateComInterfaceFlagsEx::IsComActivated); + return S_OK; + } + + HRESULT IsComActivated(_In_ IUnknown *wrapperMaybe) noexcept + { + ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe); + if (wrapper == nullptr) + return E_INVALIDARG; + + return wrapper->IsSet(CreateComInterfaceFlagsEx::IsComActivated) ? S_OK : S_FALSE; + } + HRESULT CreateWrapperForExternal( _In_ IUnknown* external, _In_ enum CreateObjectFlags flags, diff --git a/src/coreclr/src/vm/ecalllist.h b/src/coreclr/src/vm/ecalllist.h index 57b12dc9d79d7..691bad54bf537 100644 --- a/src/coreclr/src/vm/ecalllist.h +++ b/src/coreclr/src/vm/ecalllist.h @@ -958,8 +958,8 @@ FCFuncStart(gObjectMarshalerFuncs) FCFuncEnd() FCFuncStart(gInterfaceMarshalerFuncs) - FCFuncElement("ConvertToNativeInternal", StubHelpers::InterfaceMarshaler__ConvertToNative) - FCFuncElement("ConvertToManagedInternal", StubHelpers::InterfaceMarshaler__ConvertToManaged) + FCFuncElement("ConvertToNative", StubHelpers::InterfaceMarshaler__ConvertToNative) + FCFuncElement("ConvertToManaged", StubHelpers::InterfaceMarshaler__ConvertToManaged) QCFuncElement("ClearNative", StubHelpers::InterfaceMarshaler__ClearNative) FCFuncElement("ConvertToManagedWithoutUnboxing", StubHelpers::InterfaceMarshaler__ConvertToManagedWithoutUnboxing) FCFuncEnd() @@ -999,6 +999,7 @@ FCFuncStart(gComWrappersFuncs) QCFuncElement("GetIUnknownImplInternal", ComWrappersNative::GetIUnknownImpl) QCFuncElement("TryGetOrCreateComInterfaceForObjectInternal", ComWrappersNative::TryGetOrCreateComInterfaceForObject) QCFuncElement("TryGetOrCreateObjectForComInstanceInternal", ComWrappersNative::TryGetOrCreateObjectForComInstance) + QCFuncElement("SetGlobalInstanceRegistered", GlobalComWrappers::SetGlobalInstanceRegistered) FCFuncEnd() #endif // FEATURE_COMWRAPPERS diff --git a/src/coreclr/src/vm/interopconverter.cpp b/src/coreclr/src/vm/interopconverter.cpp index 58ac452295bd9..8bec4fabf09ab 100644 --- a/src/coreclr/src/vm/interopconverter.cpp +++ b/src/coreclr/src/vm/interopconverter.cpp @@ -17,9 +17,72 @@ #include "runtimecallablewrapper.h" #include "cominterfacemarshaler.h" #include "binder.h" +#include // CreateComInterfaceFlags, CreateObjectFlags +#include #include "winrttypenameconverter.h" #include "typestring.h" +namespace +{ + bool TryGetComIPFromObjectRefUsingComWrappers( + _In_ OBJECTREF instance, + _Outptr_ IUnknown** wrapperRaw) + { +#ifdef FEATURE_COMWRAPPERS + InteropLib::Com::CreateComInterfaceFlags flags = InteropLib::Com::CreateComInterfaceFlags::CreateComInterfaceFlags_TrackerSupport; + return GlobalComWrappers::TryGetOrCreateComInterfaceForObject(instance, flags, (void**)wrapperRaw); +#else + return false; +#endif // FEATURE_COMWRAPPERS + } + + bool TryGetObjectRefFromComIPUsingComWrappers( + _In_ IUnknown* pUnknown, + _In_ DWORD dwFlags, + _Out_ OBJECTREF *pObjOut) + { +#ifdef FEATURE_COMWRAPPERS + int flags = InteropLib::Com::CreateObjectFlags::CreateObjectFlags_TrackerObject; + if ((dwFlags & ObjFromComIP::UNIQUE_OBJECT) != 0) + flags |= InteropLib::Com::CreateObjectFlags::CreateObjectFlags_UniqueInstance; + + return GlobalComWrappers::TryGetOrCreateObjectForComInstance(pUnknown, flags, pObjOut); +#else + return false; +#endif // FEATURE_COMWRAPPERS + } + + void EnsureObjectRefIsValidForSpecifiedClass( + _In_ OBJECTREF obj, + _In_ DWORD dwFlags, + _In_ MethodTable *pMTClass) + { + _ASSERTE(obj != NULL); + _ASSERTE(pMTClass != NULL); + + if ((dwFlags & ObjFromComIP::CLASS_IS_HINT) != 0) + return; + + // make sure we can cast to the specified class + FAULT_NOT_FATAL(); + + // Bad format exception thrown for backward compatibility + THROW_BAD_FORMAT_MAYBE(pMTClass->IsArray() == FALSE, BFA_UNEXPECTED_ARRAY_TYPE, pMTClass); + + if (CanCastComObject(obj, pMTClass)) + return; + + StackSString ssObjClsName; + StackSString ssDestClsName; + + obj->GetMethodTable()->_GetFullyQualifiedNameForClass(ssObjClsName); + pMTClass->_GetFullyQualifiedNameForClass(ssDestClsName); + + COMPlusThrow(kInvalidCastException, IDS_EE_CANNOTCAST, + ssObjClsName.GetUnicode(), ssDestClsName.GetUnicode()); + } +} + //-------------------------------------------------------------------------------- // IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, MethodTable *pMT, ...) // Convert ObjectRef to a COM IP, based on MethodTable* pMT. @@ -45,6 +108,11 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, MethodTable *pMT, BOOL bSecuri if (*poref == NULL) RETURN NULL; + if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) + { + pUnk.SuppressRelease(); + RETURN pUnk; + } SyncBlock* pBlock = (*poref)->GetSyncBlock(); @@ -111,6 +179,28 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType if (*poref == NULL) RETURN NULL; + if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) + { + hr = S_OK; + void *pvObj; + if (ReqIpType & ComIpType_Dispatch) + { + hr = pUnk->QueryInterface(IID_IDispatch, &pvObj); + } + else if (ReqIpType & ComIpType_Inspectable) + { + hr = pUnk->QueryInterface(IID_IInspectable, &pvObj); + } + + if (FAILED(hr)) + COMPlusThrowHR(hr); + + if (pFetchedIpType != NULL) + *pFetchedIpType = ReqIpType; + + RETURN pUnk; + } + MethodTable *pMT = (*poref)->GetMethodTable(); SyncBlock* pBlock = (*poref)->GetSyncBlock(); @@ -376,6 +466,16 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, REFIID iid, bool throwIfNoComI if (*poref == NULL) RETURN NULL; + if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) + { + void *pvObj; + hr = pUnk->QueryInterface(iid, &pvObj); + if (FAILED(hr)) + COMPlusThrowHR(hr); + + RETURN pUnk; + } + MethodTable *pMT = (*poref)->GetMethodTable(); SyncBlock* pBlock = (*poref)->GetSyncBlock(); @@ -434,9 +534,18 @@ void GetObjectRefFromComIP(OBJECTREF* pObjOut, IUnknown **ppUnk, MethodTable *pM _ASSERTE(g_fComStarted && "COM has not been started up, make sure EnsureComStarted is called before any COM objects are used!"); IUnknown *pUnk = *ppUnk; + *pObjOut = NULL; + + if (TryGetObjectRefFromComIPUsingComWrappers(pUnk, dwFlags, pObjOut)) + { + if (pMTClass != NULL) + EnsureObjectRefIsValidForSpecifiedClass(*pObjOut, dwFlags, pMTClass); + + return; + } + Thread * pThread = GetThread(); - *pObjOut = NULL; IUnknown* pOuter = pUnk; SafeComHolder pAutoOuterUnk = NULL; @@ -537,22 +646,7 @@ void GetObjectRefFromComIP(OBJECTREF* pObjOut, IUnknown **ppUnk, MethodTable *pM // make sure we can cast to the specified class if (pMTClass != NULL) { - FAULT_NOT_FATAL(); - - // Bad format exception thrown for backward compatibility - THROW_BAD_FORMAT_MAYBE(pMTClass->IsArray() == FALSE, BFA_UNEXPECTED_ARRAY_TYPE, pMTClass); - - if (!CanCastComObject(*pObjOut, pMTClass)) - { - StackSString ssObjClsName; - StackSString ssDestClsName; - - (*pObjOut)->GetMethodTable()->_GetFullyQualifiedNameForClass(ssObjClsName); - pMTClass->_GetFullyQualifiedNameForClass(ssDestClsName); - - COMPlusThrow(kInvalidCastException, IDS_EE_CANNOTCAST, - ssObjClsName.GetUnicode(), ssDestClsName.GetUnicode()); - } + EnsureObjectRefIsValidForSpecifiedClass(*pObjOut, dwFlags, pMTClass); } else if (dwFlags & ObjFromComIP::REQUIRE_IINSPECTABLE) { diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index 00474e30253f8..985a1b3415353 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -596,6 +596,7 @@ namespace _In_opt_ OBJECTREF impl, _In_ IUnknown* identity, _In_ CreateObjectFlags flags, + _In_ bool unwrap, _In_opt_ OBJECTREF wrapperMaybe, _Out_ OBJECTREF* objRef) { @@ -624,6 +625,7 @@ namespace gc.wrapperMaybeRef = wrapperMaybe; ExtObjCxtCache* cache = ExtObjCxtCache::GetInstance(); + InteropLib::OBJECTHANDLE handle = NULL; // Check if the user requested a unique instance. bool uniqueInstance = !!(flags & CreateObjectFlags::CreateObjectFlags_UniqueInstance); @@ -632,21 +634,41 @@ namespace // Query the external object cache ExtObjCxtCache::LockHolder lock(cache); extObjCxt = cache->Find(identity); + + if (extObjCxt == NULL && unwrap) + { + InteropLib::OBJECTHANDLE handleLocal; + if (InteropLib::Com::GetObjectForWrapper(identity, &handleLocal) == S_OK + && InteropLib::Com::IsComActivated(identity) == S_FALSE) + { + handle = handleLocal; + } + } } if (extObjCxt != NULL) { gc.objRefMaybe = extObjCxt->GetObjectRef(); } + else if (handle != NULL) + { + ::OBJECTHANDLE objectHandle = static_cast<::OBJECTHANDLE>(handle); + gc.objRefMaybe = ObjectFromHandle(objectHandle); + } else { // Create context instance for the possibly new external object. ExternalWrapperResultHolder resultHolder; - hr = InteropLib::Com::CreateWrapperForExternal( - identity, - flags, - sizeof(ExternalObjectContext), - &resultHolder); + + { + GCX_PREEMP(); + hr = InteropLib::Com::CreateWrapperForExternal( + identity, + flags, + sizeof(ExternalObjectContext), + &resultHolder); + } + if (FAILED(hr)) COMPlusThrowHR(hr); @@ -725,6 +747,8 @@ namespace *objRef = gc.objRefMaybe; RETURN (gc.objRefMaybe != NULL); } + + bool g_IsGlobalComWrappersRegistered; } namespace InteropLibImports @@ -962,6 +986,7 @@ namespace InteropLibImports gc.implRef, externalComObject, externalObjectFlags, + false /*unwrap*/, gc.wrapperMaybeRef, &gc.objRef); @@ -1130,6 +1155,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( ObjectToOBJECTREF(*comWrappersImpl.m_ppObject), identity, (CreateObjectFlags)flags, + false /*unwrap*/, ObjectToOBJECTREF(*wrapperMaybe.m_ppObject), &newObj); @@ -1221,6 +1247,78 @@ void ComWrappersNative::MarkExternalComObjectContextCollected(_In_ void* context } } +void ComWrappersNative::MarkWrapperAsComActivated(_In_ IUnknown* wrapper) +{ + CONTRACTL + { + NOTHROW; + MODE_ANY; + PRECONDITION(wrapper != NULL); + } + CONTRACTL_END; + + InteropLib::Com::MarkComActivated(wrapper); +} + +void QCALLTYPE GlobalComWrappers::SetGlobalInstanceRegistered() +{ + QCALL_CONTRACT; + + _ASSERTE(!g_IsGlobalComWrappersRegistered); + + BEGIN_QCALL; + + g_IsGlobalComWrappersRegistered = true;; + + END_QCALL; +} + +bool GlobalComWrappers::TryGetOrCreateComInterfaceForObject( + _In_ OBJECTREF instance, + _In_ INT32 flags, + _Outptr_ void** wrapperRaw) +{ + if (!g_IsGlobalComWrappersRegistered) + return false; + + // Switch to Cooperative mode since object references + // are being manipulated. + { + GCX_COOP(); + + // Passing NULL as the ComWrappers implementation indicates using the globally registered instance + return TryGetOrCreateComInterfaceForObjectInternal( + NULL, + instance, + (CreateComInterfaceFlags)flags, + wrapperRaw); + } +} + +bool GlobalComWrappers::TryGetOrCreateObjectForComInstance( + _In_ IUnknown* externalComObject, + _In_ INT32 flags, + _Out_ OBJECTREF* objRef) +{ + if (!g_IsGlobalComWrappersRegistered) + return false; + + // Switch to Cooperative mode since object references + // are being manipulated. + { + GCX_COOP(); + + // Passing NULL as the ComWrappers implementation indicates using the globally registered instance + return TryGetOrCreateObjectForComInstanceInternal( + NULL /*comWrappersImpl*/, + externalComObject, + (CreateObjectFlags)flags, + true /*unwrap*/, + NULL /*wrapperMaybe*/, + objRef); + } +} + #endif // FEATURE_COMWRAPPERS void Interop::OnGCStarted(_In_ int nCondemnedGeneration) diff --git a/src/coreclr/src/vm/interoplibinterface.h b/src/coreclr/src/vm/interoplibinterface.h index 5ea4b55c28f50..03f2e585a932c 100644 --- a/src/coreclr/src/vm/interoplibinterface.h +++ b/src/coreclr/src/vm/interoplibinterface.h @@ -34,6 +34,26 @@ class ComWrappersNative static void DestroyManagedObjectComWrapper(_In_ void* wrapper); static void DestroyExternalComObjectContext(_In_ void* context); static void MarkExternalComObjectContextCollected(_In_ void* context); + +public: // COM activation + static void MarkWrapperAsComActivated(_In_ IUnknown* wrapper); +}; + +class GlobalComWrappers +{ +public: + static void QCALLTYPE SetGlobalInstanceRegistered(); + +public: // Functions operating on a registered global instance + static bool TryGetOrCreateComInterfaceForObject( + _In_ OBJECTREF instance, + _In_ INT32 flags, + _Outptr_ void** wrapperRaw); + + static bool TryGetOrCreateObjectForComInstance( + _In_ IUnknown* externalComObject, + _In_ INT32 flags, + _Out_ OBJECTREF* objRef); }; #endif // FEATURE_COMWRAPPERS diff --git a/src/coreclr/src/vm/runtimecallablewrapper.cpp b/src/coreclr/src/vm/runtimecallablewrapper.cpp index f61ed2b61b417..61475bf9ba0d5 100644 --- a/src/coreclr/src/vm/runtimecallablewrapper.cpp +++ b/src/coreclr/src/vm/runtimecallablewrapper.cpp @@ -52,6 +52,7 @@ SLIST_HEADER RCW::s_RCWStandbyList; #endif // FEATURE_COMINTEROP_APARTMENT_SUPPORT #ifdef FEATURE_COMINTEROP_UNMANAGED_ACTIVATION +#include "interoplibinterface.h" #ifndef CROSSGEN_COMPILE @@ -247,6 +248,8 @@ IUnknown *ComClassFactory::CreateInstanceFromClassFactory(IClassFactory *pClassF if (ccw != NULL) ccw->MarkComActivated(); + ComWrappersNative::MarkWrapperAsComActivated(pUnk); + pUnk.SuppressRelease(); RETURN pUnk; } From 6ff4d2b6dd73e91158ff43e060e37e79b4a39cc8 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Mon, 23 Mar 2020 17:27:43 -0700 Subject: [PATCH 09/12] Add tests --- .../{ => API}/ComWrappersTests.csproj | 3 +- .../COM/ComWrappers/{ => API}/Program.cs | 264 +---------- .../src/Interop/COM/ComWrappers/Common.cs | 146 +++++++ .../ComWrappers/GlobalInstance/App.manifest | 26 ++ .../GlobalInstance/CoreShim.X.manifest | 16 + .../GlobalInstance/GlobalInstanceTests.csproj | 37 ++ .../COM/ComWrappers/GlobalInstance/Program.cs | 410 ++++++++++++++++++ .../ReferenceTrackerRuntime.cpp | 32 +- .../COM/NETServer/ConsumeNETServerTesting.cs | 3 + 9 files changed, 674 insertions(+), 263 deletions(-) rename src/coreclr/tests/src/Interop/COM/ComWrappers/{ => API}/ComWrappersTests.csproj (82%) rename src/coreclr/tests/src/Interop/COM/ComWrappers/{ => API}/Program.cs (56%) create mode 100644 src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs create mode 100644 src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest create mode 100644 src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest create mode 100644 src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj create mode 100644 src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/ComWrappersTests.csproj b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/ComWrappersTests.csproj similarity index 82% rename from src/coreclr/tests/src/Interop/COM/ComWrappers/ComWrappersTests.csproj rename to src/coreclr/tests/src/Interop/COM/ComWrappers/API/ComWrappersTests.csproj index e82960ae1e101..83acfa1f6fd5e 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/ComWrappersTests.csproj +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/ComWrappersTests.csproj @@ -9,8 +9,9 @@ + - + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs similarity index 56% rename from src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs rename to src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs index ef496f52df60c..2c30bb6c2a61d 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs @@ -7,178 +7,18 @@ namespace ComWrappersTests using System; using System.Collections; using System.Collections.Generic; - using System.IO; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; + using ComWrappersTests.Common; using TestLibrary; class Program { - // - // Managed object with native wrapper definition. - // - [Guid("447BB9ED-DA48-4ABC-8963-5BB5C3E0AA09")] - interface ITest - { - void SetValue(int i); - } - - class Test : ITest - { - public static int InstanceCount = 0; - - private int value = -1; - public Test() { InstanceCount++; } - ~Test() { InstanceCount--; } - - public void SetValue(int i) => this.value = i; - public int GetValue() => this.value; - } - - public struct IUnknownVtbl - { - public IntPtr QueryInterface; - public IntPtr AddRef; - public IntPtr Release; - } - - public struct ITestVtbl - { - public IUnknownVtbl IUnknownImpl; - public IntPtr SetValue; - - public delegate int _SetValue(IntPtr thisPtr, int i); - public static _SetValue pSetValue = new _SetValue(SetValueInternal); - - public static int SetValueInternal(IntPtr dispatchPtr, int i) - { - unsafe - { - try - { - ComWrappers.ComInterfaceDispatch.GetInstance((ComWrappers.ComInterfaceDispatch*)dispatchPtr).SetValue(i); - } - catch (Exception e) - { - return e.HResult; - } - } - return 0; // S_OK; - } - } - - // - // Native interface defintion with managed wrapper for tracker object - // - struct MockReferenceTrackerRuntime - { - [DllImport(nameof(MockReferenceTrackerRuntime))] - extern public static IntPtr CreateTrackerObject(); - - [DllImport(nameof(MockReferenceTrackerRuntime))] - extern public static void ReleaseAllTrackerObjects(); - - [DllImport(nameof(MockReferenceTrackerRuntime))] - extern public static int Trigger_NotifyEndOfReferenceTrackingOnThread(); - } - - struct MarshalInterface - { - [DllImport(nameof(MockReferenceTrackerRuntime))] - [return: MarshalAs(UnmanagedType.IUnknown)] - extern public static object? CreateTrackerObject(); - - [DllImport(nameof(MockReferenceTrackerRuntime))] - extern public static void UpdateTestObject([MarshalAs(UnmanagedType.IUnknown)] object testObj, int i); - } - - [Guid("42951130-245C-485E-B60B-4ED4254256F8")] - public interface ITrackerObject - { - int AddObjectRef(IntPtr obj); - void DropObjectRef(int id); - }; - - public struct VtblPtr - { - public IntPtr Vtbl; - } - - public class ITrackerObjectWrapper : ITrackerObject - { - private struct ITrackerObjectWrapperVtbl - { - public IntPtr QueryInterface; - public _AddRef AddRef; - public _Release Release; - public _AddObjectRef AddObjectRef; - public _DropObjectRef DropObjectRef; - } - - private delegate int _AddRef(IntPtr This); - private delegate int _Release(IntPtr This); - private delegate int _AddObjectRef(IntPtr This, IntPtr obj, out int id); - private delegate int _DropObjectRef(IntPtr This, int id); - - private readonly IntPtr instance; - private readonly ITrackerObjectWrapperVtbl vtable; - - public ITrackerObjectWrapper(IntPtr instance) - { - var inst = Marshal.PtrToStructure(instance); - this.vtable = Marshal.PtrToStructure(inst.Vtbl); - this.instance = instance; - } - - ~ITrackerObjectWrapper() - { - if (this.instance != IntPtr.Zero) - { - this.vtable.Release(this.instance); - } - } - - public int AddObjectRef(IntPtr obj) - { - int id; - int hr = this.vtable.AddObjectRef(this.instance, obj, out id); - if (hr != 0) - { - throw new COMException($"{nameof(AddObjectRef)}", hr); - } - - return id; - } - - public void DropObjectRef(int id) - { - int hr = this.vtable.DropObjectRef(this.instance, id); - if (hr != 0) - { - throw new COMException($"{nameof(DropObjectRef)}", hr); - } - } - } - class TestComWrappers : ComWrappers { - public static readonly TestComWrappers Global = new TestComWrappers(); - - public bool ReturnInvalid { get; set; } - - public object LastComputeVtablesObject { get; private set; } - protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) { - LastComputeVtablesObject = obj; - - if (ReturnInvalid) - { - count = -1; - return null; - } - Assert.IsTrue(obj is Test); IntPtr fpQueryInteface = default; @@ -209,15 +49,12 @@ class TestComWrappers : ComWrappers protected override object? CreateObject(IntPtr externalComObject, CreateObjectFlags flag) { - if (ReturnInvalid) - return null; - var iid = typeof(ITrackerObject).GUID; - IntPtr iTestComObject; - int hr = Marshal.QueryInterface(externalComObject, ref iid, out iTestComObject); + IntPtr iTrackerComObject; + int hr = Marshal.QueryInterface(externalComObject, ref iid, out iTrackerComObject); Assert.AreEqual(0, hr); - return new ITrackerObjectWrapper(iTestComObject); + return new ITrackerObjectWrapper(iTrackerComObject); } public const int ReleaseObjectsCallAck = unchecked((int)-1); @@ -483,95 +320,6 @@ static void ValidateRuntimeTrackerScenario() GC.Collect(); } - static void ValidateGlobalInstanceScenarios() - { - Console.WriteLine($"Running {nameof(ValidateGlobalInstanceScenarios)}..."); - Console.WriteLine($"Validate RegisterAsGlobalInstance()..."); - - var wrappers1 = TestComWrappers.Global; - wrappers1.RegisterAsGlobalInstance(); - - Assert.Throws( - () => - { - wrappers1.RegisterAsGlobalInstance(); - }, "Should not be able to re-register for global ComWrappers"); - - var wrappers2 = new TestComWrappers(); - Assert.Throws( - () => - { - wrappers2.RegisterAsGlobalInstance(); - }, "Should not be able to reset for global ComWrappers"); - - ValidateMarshalAPIs(wrappers1, true); - ValidateMarshalAPIs(wrappers1, false); - - ValidateInterfaceMarshaler(wrappers1, true); - ValidateInterfaceMarshaler(wrappers1, false); - - Console.WriteLine($"Validate NotifyEndOfReferenceTrackingOnThread()..."); - - int hr; - var cw = TestComWrappers.Global; - - // Trigger the thread lifetime end API and verify the callback occurs. - hr = MockReferenceTrackerRuntime.Trigger_NotifyEndOfReferenceTrackingOnThread(); - Assert.AreEqual(TestComWrappers.ReleaseObjectsCallAck, hr); - } - - private static void ValidateMarshalAPIs(TestComWrappers registeredWrapper, bool validateUseRegistered) - { - registeredWrapper.ReturnInvalid = !validateUseRegistered; - - string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; - Console.WriteLine($"Validate Marshal.GetIUnknownForObject: {scenario}..."); - - var testObj = new Test(); - IntPtr comWrapper1 = Marshal.GetIUnknownForObject(testObj); - Assert.AreNotEqual(IntPtr.Zero, comWrapper1); - Assert.AreEqual(testObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); - - IntPtr comWrapper2 = Marshal.GetIUnknownForObject(testObj); - Assert.AreEqual(comWrapper1, comWrapper2); - - Marshal.Release(comWrapper1); - Marshal.Release(comWrapper2); - - Console.WriteLine($"Validate Marshal.GetObjectForIUnknown: {scenario}..."); - - IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject(); - object objWrapper1 = Marshal.GetObjectForIUnknown(trackerObjRaw); - Assert.AreEqual(validateUseRegistered, objWrapper1 is ITrackerObjectWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(ITrackerObjectWrapper)} instance"); - object objWrapper2 = Marshal.GetObjectForIUnknown(trackerObjRaw); - Assert.AreEqual(objWrapper1, objWrapper2); - - Console.WriteLine($"Validate Marshal.GetUniqueObjectForIUnknown: {scenario}..."); - - object objWrapper3 = Marshal.GetUniqueObjectForIUnknown(trackerObjRaw); - Assert.AreEqual(validateUseRegistered, objWrapper3 is ITrackerObjectWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(ITrackerObjectWrapper)} instance"); - Assert.AreNotEqual(objWrapper1, objWrapper3); - - Marshal.Release(trackerObjRaw); - } - - private static void ValidateInterfaceMarshaler(TestComWrappers registeredWrapper, bool validateUseRegistered) - { - registeredWrapper.ReturnInvalid = !validateUseRegistered; - string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; - - Console.WriteLine($"Validate ConvertToNative: {scenario}..."); - var testObj = new Test(); - int value = 10; - MarshalInterface.UpdateTestObject(testObj, value); - Assert.AreEqual(validateUseRegistered, value == testObj.GetValue()); - Assert.AreEqual(testObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); - - Console.WriteLine($"Validate ConvertToManaged: {scenario}..."); - object obj = MarshalInterface.CreateTrackerObject(); - Assert.AreEqual(validateUseRegistered, obj is ITrackerObjectWrapper, $"Should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(ITrackerObjectWrapper)} instance"); - } - static int Main(string[] doNotUse) { try @@ -582,10 +330,6 @@ static int Main(string[] doNotUse) ValidateIUnknownImpls(); ValidateBadComWrapperImpl(); ValidateRuntimeTrackerScenario(); - - // Perform all global impacting test scenarios last to - // avoid polluting non-global tests. - ValidateGlobalInstanceScenarios(); } catch (Exception e) { diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs new file mode 100644 index 0000000000000..06c052b0a237a --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs @@ -0,0 +1,146 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace ComWrappersTests.Common +{ + using System; + using System.Runtime.InteropServices; + + // + // Managed object with native wrapper definition. + // + [Guid("447BB9ED-DA48-4ABC-8963-5BB5C3E0AA09")] + interface ITest + { + void SetValue(int i); + } + + class Test : ITest + { + public static int InstanceCount = 0; + + private int value = -1; + public Test() { InstanceCount++; } + ~Test() { InstanceCount--; } + + public void SetValue(int i) => this.value = i; + public int GetValue() => this.value; + } + + public struct IUnknownVtbl + { + public IntPtr QueryInterface; + public IntPtr AddRef; + public IntPtr Release; + } + + public struct ITestVtbl + { + public IUnknownVtbl IUnknownImpl; + public IntPtr SetValue; + + public delegate int _SetValue(IntPtr thisPtr, int i); + public static _SetValue pSetValue = new _SetValue(SetValueInternal); + + public static int SetValueInternal(IntPtr dispatchPtr, int i) + { + unsafe + { + try + { + ComWrappers.ComInterfaceDispatch.GetInstance((ComWrappers.ComInterfaceDispatch*)dispatchPtr).SetValue(i); + } + catch (Exception e) + { + return e.HResult; + } + } + return 0; // S_OK; + } + } + + // + // Native interface defintion with managed wrapper for tracker object + // + struct MockReferenceTrackerRuntime + { + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static IntPtr CreateTrackerObject(); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static void ReleaseAllTrackerObjects(); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int Trigger_NotifyEndOfReferenceTrackingOnThread(); + } + + [Guid("42951130-245C-485E-B60B-4ED4254256F8")] + public interface ITrackerObject + { + int AddObjectRef(IntPtr obj); + void DropObjectRef(int id); + }; + + public struct VtblPtr + { + public IntPtr Vtbl; + } + + public class ITrackerObjectWrapper : ITrackerObject + { + private struct ITrackerObjectWrapperVtbl + { + public IntPtr QueryInterface; + public _AddRef AddRef; + public _Release Release; + public _AddObjectRef AddObjectRef; + public _DropObjectRef DropObjectRef; + } + + private delegate int _AddRef(IntPtr This); + private delegate int _Release(IntPtr This); + private delegate int _AddObjectRef(IntPtr This, IntPtr obj, out int id); + private delegate int _DropObjectRef(IntPtr This, int id); + + private readonly IntPtr instance; + private readonly ITrackerObjectWrapperVtbl vtable; + + public ITrackerObjectWrapper(IntPtr instance) + { + var inst = Marshal.PtrToStructure(instance); + this.vtable = Marshal.PtrToStructure(inst.Vtbl); + this.instance = instance; + } + + ~ITrackerObjectWrapper() + { + if (this.instance != IntPtr.Zero) + { + this.vtable.Release(this.instance); + } + } + + public int AddObjectRef(IntPtr obj) + { + int id; + int hr = this.vtable.AddObjectRef(this.instance, obj, out id); + if (hr != 0) + { + throw new COMException($"{nameof(AddObjectRef)}", hr); + } + + return id; + } + + public void DropObjectRef(int id) + { + int hr = this.vtable.DropObjectRef(this.instance, id); + if (hr != 0) + { + throw new COMException($"{nameof(DropObjectRef)}", hr); + } + } + } +} + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest new file mode 100644 index 0000000000000..bb7ec83fae8b8 --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest new file mode 100644 index 0000000000000..abb39fbb21c7d --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest @@ -0,0 +1,16 @@ + + + + + + + + + + + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj new file mode 100644 index 0000000000000..1e5e14f330e2e --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj @@ -0,0 +1,37 @@ + + + Exe + App.manifest + true + true + + BuildOnly + + true + true + + true + true + + + + + + + + + + + + + false + Content + Always + + + + + PreserveNewest + + + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs new file mode 100644 index 0000000000000..f4489ccad211b --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs @@ -0,0 +1,410 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace ComWrappersTests.GlobalInstance +{ + using System; + using System.Collections; + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + + using ComWrappersTests.Common; + using TestLibrary; + + class Program + { + struct MarshalInterface + { + [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint=nameof(MockReferenceTrackerRuntime.CreateTrackerObject))] + [return: MarshalAs(UnmanagedType.IUnknown)] + extern public static object CreateTrackerObjectAsIUnknown(); + + [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint=nameof(MockReferenceTrackerRuntime.CreateTrackerObject))] + [return: MarshalAs(UnmanagedType.Interface)] + extern public static FakeWrapper CreateTrackerObjectAsInterface(); + + [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint = nameof(MockReferenceTrackerRuntime.CreateTrackerObject))] + [return: MarshalAs(UnmanagedType.Interface)] + extern public static Test CreateTrackerObjectWrongType(); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int UpdateTestObjectAsIUnknown( + [MarshalAs(UnmanagedType.IUnknown)] object testObj, + int i, + [MarshalAs(UnmanagedType.IUnknown)] out object ret); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int UpdateTestObjectAsIDispatch( + [MarshalAs(UnmanagedType.IDispatch)] object testObj, + int i, + [MarshalAs(UnmanagedType.IDispatch)] out object ret); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int UpdateTestObjectAsIInspectable( + [MarshalAs(UnmanagedType.IInspectable)] object testObj, + int i, + [MarshalAs(UnmanagedType.IInspectable)] out object ret); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int UpdateTestObjectAsInterface( + [MarshalAs(UnmanagedType.Interface)] Test testObj, + int i, + [Out, MarshalAs(UnmanagedType.Interface)] out Test ret); + } + + private const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046"; + private const string IID_IINSPECTABLE = "AF86E2E0-B12D-4c6a-9C5A-D7AA65101E90"; + class TestEx : Test + { + public readonly Guid[] Interfaces; + public TestEx(params string[] iids) + { + Interfaces = new Guid[iids.Length]; + for (int i = 0; i < iids.Length; i++) + Interfaces[i] = Guid.Parse(iids[i]); + } + } + + class FakeWrapper + { + private delegate int _AddRef(IntPtr This); + private delegate int _Release(IntPtr This); + private struct IUnknownWrapperVtbl + { + public IntPtr QueryInterface; + public _AddRef AddRef; + public _Release Release; + } + + private readonly IntPtr wrappedInstance; + + private readonly IUnknownWrapperVtbl vtable; + + public FakeWrapper(IntPtr instance) + { + this.wrappedInstance = instance; + var inst = Marshal.PtrToStructure(instance); + this.vtable = Marshal.PtrToStructure(inst.Vtbl); + } + + ~FakeWrapper() + { + if (this.wrappedInstance != IntPtr.Zero) + { + this.vtable.Release(this.wrappedInstance); + } + } + } + + class GlobalComWrappers : ComWrappers + { + public static GlobalComWrappers Instance = new GlobalComWrappers(); + + public bool ReturnInvalid { get; set; } + + public object LastComputeVtablesObject { get; private set; } + + protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + LastComputeVtablesObject = obj; + + if (ReturnInvalid || !(obj is Test)) + { + count = -1; + return null; + } + + IntPtr fpQueryInteface = default; + IntPtr fpAddRef = default; + IntPtr fpRelease = default; + ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease); + + var vtbl = new ITestVtbl() + { + IUnknownImpl = new IUnknownVtbl() + { + QueryInterface = fpQueryInteface, + AddRef = fpAddRef, + Release = fpRelease + }, + SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue) + }; + var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl)); + Marshal.StructureToPtr(vtbl, vtblRaw, false); + + int countLocal = obj is TestEx ? ((TestEx)obj).Interfaces.Length + 1 : 1; + var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ComInterfaceEntry) * countLocal); + entryRaw[0].IID = typeof(ITest).GUID; + entryRaw[0].Vtable = vtblRaw; + + if (obj is TestEx) + { + var testEx = (TestEx)obj; + + for (int i = 1; i < testEx.Interfaces.Length + 1; i++) + { + // Including interfaces to allow QI, but not actually returning a valid vtable, since it is not needed for the tests here. + entryRaw[i].IID = testEx.Interfaces[i-1]; + entryRaw[i].Vtable = IntPtr.Zero; + } + } + + count = countLocal; + return entryRaw; + } + + protected override object? CreateObject(IntPtr externalComObject, CreateObjectFlags flag) + { + if (ReturnInvalid) + return null; + + Guid[] iids = { + typeof(ITrackerObject).GUID, + typeof(ITest).GUID, + typeof(Server.Contract.IDispatchTesting).GUID, + typeof(Server.Contract.IConsumeNETServer).GUID + }; + + for (var i = 0; i < iids.Length; i++) + { + var iid = iids[i]; + IntPtr comObject; + int hr = Marshal.QueryInterface(externalComObject, ref iid, out comObject); + if (hr == 0) + return new FakeWrapper(comObject); + } + + return null; + } + + public const int ReleaseObjectsCallAck = unchecked((int)-1); + + protected override void ReleaseObjects(IEnumerable objects) + { + throw new Exception() { HResult = ReleaseObjectsCallAck }; + } + } + + private static void ValidateRegisterAsGlobalInstance() + { + Console.WriteLine($"Running {nameof(ValidateRegisterAsGlobalInstance)}..."); + + var wrappers1 = GlobalComWrappers.Instance; + wrappers1.RegisterAsGlobalInstance(); + Assert.Throws( + () => + { + wrappers1.RegisterAsGlobalInstance(); + }, "Should not be able to re-register for global ComWrappers"); + + var wrappers2 = new GlobalComWrappers(); + Assert.Throws( + () => + { + wrappers2.RegisterAsGlobalInstance(); + }, "Should not be able to reset for global ComWrappers"); + } + + private static void ValidateMarshalAPIs(bool validateUseRegistered) + { + string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; + Console.WriteLine($"Running {nameof(ValidateMarshalAPIs)}: {scenario}..."); + + GlobalComWrappers registeredWrapper = GlobalComWrappers.Instance; + registeredWrapper.ReturnInvalid = !validateUseRegistered; + + Console.WriteLine($" -- Validate Marshal.GetIUnknownForObject..."); + + var testObj = new Test(); + IntPtr comWrapper1 = Marshal.GetIUnknownForObject(testObj); + Assert.AreNotEqual(IntPtr.Zero, comWrapper1); + Assert.AreEqual(testObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + + IntPtr comWrapper2 = Marshal.GetIUnknownForObject(testObj); + Assert.AreEqual(comWrapper1, comWrapper2); + + Marshal.Release(comWrapper1); + Marshal.Release(comWrapper2); + + Console.WriteLine($" -- Validate Marshal.GetIDispatchForObject..."); + + Assert.Throws(() => Marshal.GetIDispatchForObject(testObj)); + + if (validateUseRegistered) + { + var dispatchObj = new TestEx(IID_IDISPATCH); + IntPtr dispatchWrapper = Marshal.GetIDispatchForObject(dispatchObj); + Assert.AreNotEqual(IntPtr.Zero, dispatchWrapper); + Assert.AreEqual(dispatchObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + } + + Console.WriteLine($" -- Validate Marshal.GetObjectForIUnknown..."); + + IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject(); + object objWrapper1 = Marshal.GetObjectForIUnknown(trackerObjRaw); + Assert.AreEqual(validateUseRegistered, objWrapper1 is FakeWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + object objWrapper2 = Marshal.GetObjectForIUnknown(trackerObjRaw); + Assert.AreEqual(objWrapper1, objWrapper2); + + Console.WriteLine($" -- Validate Marshal.GetUniqueObjectForIUnknown..."); + + object objWrapper3 = Marshal.GetUniqueObjectForIUnknown(trackerObjRaw); + Assert.AreEqual(validateUseRegistered, objWrapper3 is FakeWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + + Assert.AreNotEqual(objWrapper1, objWrapper3); + + Marshal.Release(trackerObjRaw); + } + + private static void ValidatePInvokes(bool validateUseRegistered) + { + string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; + Console.WriteLine($"Running {nameof(ValidatePInvokes)}: {scenario}..."); + + GlobalComWrappers.Instance.ReturnInvalid = !validateUseRegistered; + + Console.WriteLine($" -- Validate MarshalAs IUnknown..."); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIUnknown, validateUseRegistered); + object obj = MarshalInterface.CreateTrackerObjectAsIUnknown(); + Assert.AreEqual(validateUseRegistered, obj is FakeWrapper, $"Should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + + if (validateUseRegistered) + { + Console.WriteLine($" -- Validate MarshalAs IDispatch..."); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIDispatch, validateUseRegistered, new TestEx(IID_IDISPATCH)); + + Console.WriteLine($" -- Validate MarshalAs IInspectable..."); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIInspectable, validateUseRegistered, new TestEx(IID_IINSPECTABLE)); + } + + Console.WriteLine($" -- Validate MarshalAs Interface..."); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsInterface, validateUseRegistered); + + if (validateUseRegistered) + { + Assert.Throws(() => MarshalInterface.CreateTrackerObjectWrongType()); + + FakeWrapper wrapper = MarshalInterface.CreateTrackerObjectAsInterface(); + Assert.IsNotNull(obj, $"Should have returned {nameof(FakeWrapper)} instance"); + } + } + + private delegate int UpdateTestObject(T testObj, int i, out T ret) where T : class; + private static void ValidateInterfaceMarshaler(UpdateTestObject func, bool validateUseRegistered, Test testObj = null) where T : class + { + const int E_NOINTERFACE = unchecked((int)0x80004002); + int value = 10; + + if (testObj == null) + testObj = new Test(); + + T retObj; + int hr = func(testObj as T, value, out retObj); + Assert.AreEqual(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + if (validateUseRegistered) + { + Assert.IsTrue(retObj is Test); + Assert.AreEqual(value, testObj.GetValue()); + Assert.AreEqual(testObj, retObj); + } + else + { + Assert.AreEqual(E_NOINTERFACE, hr); + } + } + + private static void ValidateComActivation(bool validateUseRegistered) + { + string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; + Console.WriteLine($"Running {nameof(ValidateComActivation)}: {scenario}..."); + GlobalComWrappers.Instance.ReturnInvalid = !validateUseRegistered; + + Console.WriteLine($" -- Validate native server..."); + ValidateNativeServerActivation(); + + Console.WriteLine($" -- Validate managed server..."); + ValidateManagedServerActivation(); + } + + private static void ValidateNativeServerActivation() + { + bool returnValid = !GlobalComWrappers.Instance.ReturnInvalid; + + Type t= Type.GetTypeFromCLSID(Guid.Parse(Server.Contract.Guids.DispatchTesting)); + var server = Activator.CreateInstance(t); + Assert.AreEqual(returnValid, server is FakeWrapper, $"Should{(returnValid ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + + IntPtr ptr = Marshal.GetIUnknownForObject(server); + var obj = Marshal.GetObjectForIUnknown(ptr); + Assert.AreEqual(server, obj); + } + + private static void ValidateManagedServerActivation() + { + bool returnValid = !GlobalComWrappers.Instance.ReturnInvalid; + + // Initialize CoreShim and hostpolicymock + HostPolicyMock.Initialize(Environment.CurrentDirectory, null); + Environment.SetEnvironmentVariable("CORESHIM_COMACT_ASSEMBLYNAME", "NETServer"); + Environment.SetEnvironmentVariable("CORESHIM_COMACT_TYPENAME", "ConsumeNETServerTesting"); + + using (HostPolicyMock.Mock_corehost_resolve_component_dependencies(0, string.Empty, string.Empty, string.Empty)) + { + Type t = Type.GetTypeFromCLSID(Guid.Parse(Server.Contract.Guids.ConsumeNETServerTesting)); + var server = Activator.CreateInstance(t); + Assert.AreEqual(returnValid, server is FakeWrapper, $"Should{(returnValid ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + object serverUnwrapped = GlobalComWrappers.Instance.LastComputeVtablesObject; + Assert.AreEqual("ConsumeNETServerTesting", serverUnwrapped.GetType().Name); + + IntPtr ptr = Marshal.GetIUnknownForObject(server); + var obj = Marshal.GetObjectForIUnknown(ptr); + Assert.AreEqual(server, obj); + Assert.AreEqual(returnValid, obj is FakeWrapper, $"Should{(returnValid ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + serverUnwrapped.GetType().GetMethod("NotEqualByRCW").Invoke(serverUnwrapped, new object[] { obj }); + } + } + + private static void ValidateNotifyEndOfReferenceTrackingOnThread() + { + Console.WriteLine($"Running {nameof(ValidateNotifyEndOfReferenceTrackingOnThread)}..."); + + // Make global instance return invalid object so that the Exception thrown by + // GlobalComWrappers.ReleaseObjects is marshalled using the built-in system. + GlobalComWrappers.Instance.ReturnInvalid = true; + + // Trigger the thread lifetime end API and verify the callback occurs. + int hr = MockReferenceTrackerRuntime.Trigger_NotifyEndOfReferenceTrackingOnThread(); + Assert.AreEqual(GlobalComWrappers.ReleaseObjectsCallAck, hr); + } + + static int Main(string[] doNotUse) + { + try + { + // The first test registereds a global ComWrappers instance + // Subsequents tests assume the global instance has already been registered. + ValidateRegisterAsGlobalInstance(); + + ValidateMarshalAPIs(validateUseRegistered: true); + ValidateMarshalAPIs(validateUseRegistered: false); + + ValidatePInvokes(validateUseRegistered: true); + ValidatePInvokes(validateUseRegistered: false); + + ValidateComActivation(validateUseRegistered: true); + ValidateComActivation(validateUseRegistered: false); + + ValidateNotifyEndOfReferenceTrackingOnThread(); + } + catch (Exception e) + { + Console.WriteLine($"Test Failure: {e}"); + return 101; + } + + return 100; + } + } +} + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp index 6afa734c06a34..867efa01866e4 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace API { @@ -338,7 +339,7 @@ extern "C" DLL_EXPORT int STDMETHODCALLTYPE Trigger_NotifyEndOfReferenceTracking return TrackerRuntimeManager.NotifyEndOfReferenceTrackingOnThread(); } -extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObject(IUnknown *obj, int i) +extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsIUnknown(IUnknown *obj, int i, IUnknown **out) { if (obj == nullptr) return E_POINTER; @@ -346,5 +347,32 @@ extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObject(IUnknown *obj, int HRESULT hr; ComSmartPtr testObj; RETURN_IF_FAILED(obj->QueryInterface(&testObj)) - return testObj->SetValue(i); + RETURN_IF_FAILED(testObj->SetValue(i)); + + *out = testObj.Detach(); + return S_OK; +} + +extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsIDispatch(IDispatch *obj, int i, IDispatch **out) +{ + if (obj == nullptr) + return E_POINTER; + + return UpdateTestObjectAsIUnknown(obj, i, (IUnknown**)out); +} + +extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsIInspectable(IInspectable * obj, int i, IInspectable **out) +{ + if (obj == nullptr) + return E_POINTER; + + return UpdateTestObjectAsIUnknown(obj, i, (IUnknown **)out); +} + +extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsInterface(ITest *obj, int i, ITest **out) +{ + if (obj == nullptr) + return E_POINTER; + + return UpdateTestObjectAsIUnknown(obj, i, (IUnknown**)out); } diff --git a/src/coreclr/tests/src/Interop/COM/NETServer/ConsumeNETServerTesting.cs b/src/coreclr/tests/src/Interop/COM/NETServer/ConsumeNETServerTesting.cs index d8dba73982d9c..ec9f92ff3e69a 100644 --- a/src/coreclr/tests/src/Interop/COM/NETServer/ConsumeNETServerTesting.cs +++ b/src/coreclr/tests/src/Interop/COM/NETServer/ConsumeNETServerTesting.cs @@ -15,6 +15,9 @@ public class ConsumeNETServerTesting : Server.Contract.IConsumeNETServer public ConsumeNETServerTesting() { _ccw = Marshal.GetIUnknownForObject(this); + + // At this point, the CCW has not been marked as COM-activated, + // so the returned RCW will be unwrapped. _rcwUnwrapped = Marshal.GetObjectForIUnknown(_ccw); } From 9581b5797b02000e1396d753a03c1431cde8cf85 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Mon, 23 Mar 2020 19:58:41 -0700 Subject: [PATCH 10/12] Fix flag for marking as COM-activated --- src/coreclr/src/interop/comwrappers.hpp | 2 +- .../COM/ComWrappers/GlobalInstance/Program.cs | 99 +++++++++++++------ 2 files changed, 69 insertions(+), 32 deletions(-) diff --git a/src/coreclr/src/interop/comwrappers.hpp b/src/coreclr/src/interop/comwrappers.hpp index cf37f30b53dce..e9838646b0189 100644 --- a/src/coreclr/src/interop/comwrappers.hpp +++ b/src/coreclr/src/interop/comwrappers.hpp @@ -16,8 +16,8 @@ enum class CreateComInterfaceFlagsEx : int32_t TrackerSupport = InteropLib::Com::CreateComInterfaceFlags_TrackerSupport, // Highest bit is reserved for internal usage + IsComActivated = 1 << 30, IsPegged = 1 << 31, - IsComActivated = 2 << 31, InternalMask = IsPegged | IsComActivated, }; diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs index f4489ccad211b..8d9552b03c243 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs @@ -53,6 +53,8 @@ extern public static int UpdateTestObjectAsInterface( [Out, MarshalAs(UnmanagedType.Interface)] out Test ret); } + private const string ManagedServerTypeName = "ConsumeNETServerTesting"; + private const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046"; private const string IID_IINSPECTABLE = "AF86E2E0-B12D-4c6a-9C5A-D7AA65101E90"; class TestEx : Test @@ -109,49 +111,43 @@ class GlobalComWrappers : ComWrappers { LastComputeVtablesObject = obj; - if (ReturnInvalid || !(obj is Test)) + if (ReturnInvalid) { count = -1; return null; } - IntPtr fpQueryInteface = default; - IntPtr fpAddRef = default; - IntPtr fpRelease = default; - ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease); - - var vtbl = new ITestVtbl() + if (obj is Test) { - IUnknownImpl = new IUnknownVtbl() + return ComputeVtablesForTestObject((Test)obj, out count); + } + else if (string.Equals(ManagedServerTypeName, obj.GetType().Name)) + { + IntPtr fpQueryInteface = default; + IntPtr fpAddRef = default; + IntPtr fpRelease = default; + ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease); + + var vtbl = new IUnknownVtbl() { QueryInterface = fpQueryInteface, AddRef = fpAddRef, Release = fpRelease - }, - SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue) - }; - var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl)); - Marshal.StructureToPtr(vtbl, vtblRaw, false); - - int countLocal = obj is TestEx ? ((TestEx)obj).Interfaces.Length + 1 : 1; - var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ComInterfaceEntry) * countLocal); - entryRaw[0].IID = typeof(ITest).GUID; - entryRaw[0].Vtable = vtblRaw; + }; + var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVtbl), sizeof(IUnknownVtbl)); + Marshal.StructureToPtr(vtbl, vtblRaw, false); - if (obj is TestEx) - { - var testEx = (TestEx)obj; + // Including interfaces to allow QI, but not actually returning a valid vtable, since it is not needed for the tests here. + var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVtbl), sizeof(ComInterfaceEntry)); + entryRaw[0].IID = typeof(Server.Contract.IConsumeNETServer).GUID; + entryRaw[0].Vtable = vtblRaw; - for (int i = 1; i < testEx.Interfaces.Length + 1; i++) - { - // Including interfaces to allow QI, but not actually returning a valid vtable, since it is not needed for the tests here. - entryRaw[i].IID = testEx.Interfaces[i-1]; - entryRaw[i].Vtable = IntPtr.Zero; - } + count = 1; + return entryRaw; } - count = countLocal; - return entryRaw; + count = -1; + return null; } protected override object? CreateObject(IntPtr externalComObject, CreateObjectFlags flag) @@ -184,6 +180,47 @@ protected override void ReleaseObjects(IEnumerable objects) { throw new Exception() { HResult = ReleaseObjectsCallAck }; } + + private unsafe ComInterfaceEntry* ComputeVtablesForTestObject(Test obj, out int count) + { + IntPtr fpQueryInteface = default; + IntPtr fpAddRef = default; + IntPtr fpRelease = default; + ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease); + + var vtbl = new ITestVtbl() + { + IUnknownImpl = new IUnknownVtbl() + { + QueryInterface = fpQueryInteface, + AddRef = fpAddRef, + Release = fpRelease + }, + SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue) + }; + var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl)); + Marshal.StructureToPtr(vtbl, vtblRaw, false); + + int countLocal = obj is TestEx ? ((TestEx)obj).Interfaces.Length + 1 : 1; + var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ComInterfaceEntry) * countLocal); + entryRaw[0].IID = typeof(ITest).GUID; + entryRaw[0].Vtable = vtblRaw; + + if (obj is TestEx) + { + var testEx = (TestEx)obj; + + for (int i = 1; i < testEx.Interfaces.Length + 1; i++) + { + // Including interfaces to allow QI, but not actually returning a valid vtable, since it is not needed for the tests here. + entryRaw[i].IID = testEx.Interfaces[i - 1]; + entryRaw[i].Vtable = IntPtr.Zero; + } + } + + count = countLocal; + return entryRaw; + } } private static void ValidateRegisterAsGlobalInstance() @@ -347,7 +384,7 @@ private static void ValidateManagedServerActivation() // Initialize CoreShim and hostpolicymock HostPolicyMock.Initialize(Environment.CurrentDirectory, null); Environment.SetEnvironmentVariable("CORESHIM_COMACT_ASSEMBLYNAME", "NETServer"); - Environment.SetEnvironmentVariable("CORESHIM_COMACT_TYPENAME", "ConsumeNETServerTesting"); + Environment.SetEnvironmentVariable("CORESHIM_COMACT_TYPENAME", ManagedServerTypeName); using (HostPolicyMock.Mock_corehost_resolve_component_dependencies(0, string.Empty, string.Empty, string.Empty)) { @@ -355,7 +392,7 @@ private static void ValidateManagedServerActivation() var server = Activator.CreateInstance(t); Assert.AreEqual(returnValid, server is FakeWrapper, $"Should{(returnValid ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); object serverUnwrapped = GlobalComWrappers.Instance.LastComputeVtablesObject; - Assert.AreEqual("ConsumeNETServerTesting", serverUnwrapped.GetType().Name); + Assert.AreEqual(ManagedServerTypeName, serverUnwrapped.GetType().Name); IntPtr ptr = Marshal.GetIUnknownForObject(server); var obj = Marshal.GetObjectForIUnknown(ptr); From 10ef23677dfb5313943f05f61238078aecc6fbfc Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Mon, 23 Mar 2020 22:04:45 -0700 Subject: [PATCH 11/12] PR fixes Fix tests --- .../Runtime/InteropServices/ComWrappers.cs | 4 +- .../InteropServices/Marshal.CoreCLR.cs | 2 - src/coreclr/src/interop/inc/interoplib.h | 6 +-- src/coreclr/src/interop/interoplib.cpp | 6 +-- src/coreclr/src/vm/interopconverter.cpp | 18 +++++---- src/coreclr/src/vm/interoplibinterface.cpp | 40 ++++++++++++++----- src/coreclr/src/vm/interoplibinterface.h | 4 +- .../COM/ComWrappers/GlobalInstance/Program.cs | 20 ++++++---- 8 files changed, 63 insertions(+), 37 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 07df6081b24b8..8bc2678e9622e 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -129,7 +129,7 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa /// /// Create a COM representation of the supplied object that can be passed to a non-managed environment. /// - /// The implemenentation to use when creating the COM representation. + /// The implementation to use when creating the COM representation. /// The managed object to expose outside the .NET runtime. /// Flags used to configure the generated interface. /// The generated COM interface that can be passed outside the .NET runtime or IntPtr.Zero if it could not be created. @@ -238,7 +238,7 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create /// /// Get the currently registered managed object or creates a new managed object and registers it. /// - /// The implemenentation to use when creating the managed object. + /// The implementation to use when creating the managed object. /// Object to import for usage into the .NET runtime. /// Flags used to describe the external object. /// The to be used as the wrapper for the external object. diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs index dfef4090c4590..e9b8b3df601b5 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs @@ -415,7 +415,6 @@ public static object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk) throw new ArgumentNullException(nameof(pUnk)); } - return GetObjectForIUnknownNative(pUnk); } @@ -429,7 +428,6 @@ public static object GetUniqueObjectForIUnknown(IntPtr unknown) throw new ArgumentNullException(nameof(unknown)); } - return GetUniqueObjectForIUnknownNative(unknown); } diff --git a/src/coreclr/src/interop/inc/interoplib.h b/src/coreclr/src/interop/inc/interoplib.h index 2572efc5ece95..fa0e0c3e25918 100644 --- a/src/coreclr/src/interop/inc/interoplib.h +++ b/src/coreclr/src/interop/inc/interoplib.h @@ -46,10 +46,10 @@ namespace InteropLib HRESULT ReactivateWrapper(_In_ IUnknown* wrapper, _In_ InteropLib::OBJECTHANDLE handle) noexcept; // Get the object for the supplied wrapper - HRESULT GetObjectForWrapper(_In_ IUnknown *wrapper, _Out_ OBJECTHANDLE *object) noexcept; + HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept; - HRESULT MarkComActivated(_In_ IUnknown *wrapper) noexcept; - HRESULT IsComActivated(_In_ IUnknown *wrapper) noexcept; + HRESULT MarkComActivated(_In_ IUnknown* wrapper) noexcept; + HRESULT IsComActivated(_In_ IUnknown* wrapper) noexcept; struct ExternalWrapperResult { diff --git a/src/coreclr/src/interop/interoplib.cpp b/src/coreclr/src/interop/interoplib.cpp index 669b6ad2fb1fe..fb7af5aa7ec65 100644 --- a/src/coreclr/src/interop/interoplib.cpp +++ b/src/coreclr/src/interop/interoplib.cpp @@ -89,7 +89,7 @@ namespace InteropLib return S_OK; } - HRESULT GetObjectForWrapper(_In_ IUnknown *wrapper, _Out_ OBJECTHANDLE *object) noexcept + HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept { if (object == nullptr) return E_POINTER; @@ -107,7 +107,7 @@ namespace InteropLib return S_OK; } - HRESULT MarkComActivated(_In_ IUnknown *wrapperMaybe) noexcept + HRESULT MarkComActivated(_In_ IUnknown* wrapperMaybe) noexcept { ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe); if (wrapper == nullptr) @@ -117,7 +117,7 @@ namespace InteropLib return S_OK; } - HRESULT IsComActivated(_In_ IUnknown *wrapperMaybe) noexcept + HRESULT IsComActivated(_In_ IUnknown* wrapperMaybe) noexcept { ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe); if (wrapper == nullptr) diff --git a/src/coreclr/src/vm/interopconverter.cpp b/src/coreclr/src/vm/interopconverter.cpp index 8bec4fabf09ab..045f6b1a05c09 100644 --- a/src/coreclr/src/vm/interopconverter.cpp +++ b/src/coreclr/src/vm/interopconverter.cpp @@ -53,11 +53,11 @@ namespace } void EnsureObjectRefIsValidForSpecifiedClass( - _In_ OBJECTREF obj, + _In_ OBJECTREF *obj, _In_ DWORD dwFlags, _In_ MethodTable *pMTClass) { - _ASSERTE(obj != NULL); + _ASSERTE(*obj != NULL); _ASSERTE(pMTClass != NULL); if ((dwFlags & ObjFromComIP::CLASS_IS_HINT) != 0) @@ -69,13 +69,13 @@ namespace // Bad format exception thrown for backward compatibility THROW_BAD_FORMAT_MAYBE(pMTClass->IsArray() == FALSE, BFA_UNEXPECTED_ARRAY_TYPE, pMTClass); - if (CanCastComObject(obj, pMTClass)) + if (CanCastComObject(*obj, pMTClass)) return; StackSString ssObjClsName; StackSString ssDestClsName; - obj->GetMethodTable()->_GetFullyQualifiedNameForClass(ssObjClsName); + (*obj)->GetMethodTable()->_GetFullyQualifiedNameForClass(ssObjClsName); pMTClass->_GetFullyQualifiedNameForClass(ssDestClsName); COMPlusThrow(kInvalidCastException, IDS_EE_CANNOTCAST, @@ -182,13 +182,15 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) { hr = S_OK; - void *pvObj; + + SafeComHolder pvObj; if (ReqIpType & ComIpType_Dispatch) { hr = pUnk->QueryInterface(IID_IDispatch, &pvObj); } else if (ReqIpType & ComIpType_Inspectable) { + SafeComHolder pvObj; hr = pUnk->QueryInterface(IID_IInspectable, &pvObj); } @@ -468,7 +470,7 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, REFIID iid, bool throwIfNoComI if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) { - void *pvObj; + SafeComHolder pvObj; hr = pUnk->QueryInterface(iid, &pvObj); if (FAILED(hr)) COMPlusThrowHR(hr); @@ -539,7 +541,7 @@ void GetObjectRefFromComIP(OBJECTREF* pObjOut, IUnknown **ppUnk, MethodTable *pM if (TryGetObjectRefFromComIPUsingComWrappers(pUnk, dwFlags, pObjOut)) { if (pMTClass != NULL) - EnsureObjectRefIsValidForSpecifiedClass(*pObjOut, dwFlags, pMTClass); + EnsureObjectRefIsValidForSpecifiedClass(pObjOut, dwFlags, pMTClass); return; } @@ -646,7 +648,7 @@ void GetObjectRefFromComIP(OBJECTREF* pObjOut, IUnknown **ppUnk, MethodTable *pM // make sure we can cast to the specified class if (pMTClass != NULL) { - EnsureObjectRefIsValidForSpecifiedClass(*pObjOut, dwFlags, pMTClass); + EnsureObjectRefIsValidForSpecifiedClass(pObjOut, dwFlags, pMTClass); } else if (dwFlags & ObjFromComIP::REQUIRE_IINSPECTABLE) { diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index 985a1b3415353..669003601c24c 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -397,9 +397,12 @@ namespace } }; - // Global instance + // Global instance of the external object cache Volatile ExtObjCxtCache::g_Instance; + // Indicator for if a ComWrappers implementation is globally registered + bool g_IsGlobalComWrappersRegistered; + // Defined handle types for the specific object uses. const HandleType InstanceHandleType{ HNDTYPE_STRONG }; @@ -592,6 +595,11 @@ namespace RETURN (wrapperRawMaybe != NULL); } + // The unwrap parameter indicates whether or not COM instances that are actually CCWs should + // be unwrapped to the original managed object. + // For implicit usage of ComWrappers (i.e. automatically called by the runtime when there is a global instance), + // CCWs should be unwrapped to allow for round-tripping object -> COM instance -> object. + // For explicit usage of ComWrappers (i.e. directly called via a ComWrappers APIs), CCWs should not be unwrapped. bool TryGetOrCreateObjectForComInstanceInternal( _In_opt_ OBJECTREF impl, _In_ IUnknown* identity, @@ -635,8 +643,11 @@ namespace ExtObjCxtCache::LockHolder lock(cache); extObjCxt = cache->Find(identity); + // If is no object found in the cache, check if the object COM instance is actually the CCW + // representing a managed object. if (extObjCxt == NULL && unwrap) { + // If the COM instance is a CCW that is not COM-activated, use the object of that wrapper object. InteropLib::OBJECTHANDLE handleLocal; if (InteropLib::Com::GetObjectForWrapper(identity, &handleLocal) == S_OK && InteropLib::Com::IsComActivated(identity) == S_FALSE) @@ -652,6 +663,8 @@ namespace } else if (handle != NULL) { + // We have an object handle from the COM instance which is a CCW. Use that object. + // This allows for the round-trip from object -> COM instance -> object. ::OBJECTHANDLE objectHandle = static_cast<::OBJECTHANDLE>(handle); gc.objRefMaybe = ObjectFromHandle(objectHandle); } @@ -747,8 +760,6 @@ namespace *objRef = gc.objRefMaybe; RETURN (gc.objRefMaybe != NULL); } - - bool g_IsGlobalComWrappersRegistered; } namespace InteropLibImports @@ -980,13 +991,14 @@ namespace InteropLibImports gc.implRef = NULL; // Use the globally registered implementation. gc.wrapperMaybeRef = NULL; // No supplied wrapper here. + bool unwrapIfManagedObjectWrapper = false; // Don't unwrap CCWs // Get wrapper for external object bool success = TryGetOrCreateObjectForComInstanceInternal( gc.implRef, externalComObject, externalObjectFlags, - false /*unwrap*/, + unwrapIfManagedObjectWrapper, gc.wrapperMaybeRef, &gc.objRef); @@ -1150,12 +1162,14 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( // are being manipulated. { GCX_COOP(); + + bool unwrapIfManagedObjectWrapper = false; // Don't unwrap CCWs OBJECTREF newObj; success = TryGetOrCreateObjectForComInstanceInternal( ObjectToOBJECTREF(*comWrappersImpl.m_ppObject), identity, (CreateObjectFlags)flags, - false /*unwrap*/, + unwrapIfManagedObjectWrapper, ObjectToOBJECTREF(*wrapperMaybe.m_ppObject), &newObj); @@ -1247,17 +1261,19 @@ void ComWrappersNative::MarkExternalComObjectContextCollected(_In_ void* context } } -void ComWrappersNative::MarkWrapperAsComActivated(_In_ IUnknown* wrapper) +void ComWrappersNative::MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe) { CONTRACTL { NOTHROW; MODE_ANY; - PRECONDITION(wrapper != NULL); + PRECONDITION(wrapperMaybe != NULL); } CONTRACTL_END; - InteropLib::Com::MarkComActivated(wrapper); + // The IUnknown may or may not represent a wrapper, so E_INVALIDARG is okay here. + HRESULT hr = InteropLib::Com::MarkComActivated(wrapperMaybe); + _ASSERTE(SUCCEEDED(hr) || hr == E_INVALIDARG); } void QCALLTYPE GlobalComWrappers::SetGlobalInstanceRegistered() @@ -1268,7 +1284,7 @@ void QCALLTYPE GlobalComWrappers::SetGlobalInstanceRegistered() BEGIN_QCALL; - g_IsGlobalComWrappersRegistered = true;; + g_IsGlobalComWrappersRegistered = true; END_QCALL; } @@ -1308,12 +1324,16 @@ bool GlobalComWrappers::TryGetOrCreateObjectForComInstance( { GCX_COOP(); + // For implicit usage of ComWrappers (i.e. automatically called by the runtime when there is a global instance), + // unwrap CCWs to allow for round-tripping object -> COM instance -> object. + bool unwrapIfManagedObjectWrapper = true; + // Passing NULL as the ComWrappers implementation indicates using the globally registered instance return TryGetOrCreateObjectForComInstanceInternal( NULL /*comWrappersImpl*/, externalComObject, (CreateObjectFlags)flags, - true /*unwrap*/, + unwrapIfManagedObjectWrapper, NULL /*wrapperMaybe*/, objRef); } diff --git a/src/coreclr/src/vm/interoplibinterface.h b/src/coreclr/src/vm/interoplibinterface.h index 03f2e585a932c..b3d9f0f47cc66 100644 --- a/src/coreclr/src/vm/interoplibinterface.h +++ b/src/coreclr/src/vm/interoplibinterface.h @@ -36,12 +36,14 @@ class ComWrappersNative static void MarkExternalComObjectContextCollected(_In_ void* context); public: // COM activation - static void MarkWrapperAsComActivated(_In_ IUnknown* wrapper); + static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe); }; class GlobalComWrappers { public: + // Native QCall for the ComWrappers managed type to indicate a global instance is registered + // This should be set if the private static member representing the global instance on ComWrappers is non-null. static void QCALLTYPE SetGlobalInstanceRegistered(); public: // Functions operating on a registered global instance diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs index 8d9552b03c243..3bd14141570cd 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs @@ -188,14 +188,16 @@ protected override void ReleaseObjects(IEnumerable objects) IntPtr fpRelease = default; ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease); + var iUnknownVtbl = new IUnknownVtbl() + { + QueryInterface = fpQueryInteface, + AddRef = fpAddRef, + Release = fpRelease + }; + var vtbl = new ITestVtbl() { - IUnknownImpl = new IUnknownVtbl() - { - QueryInterface = fpQueryInteface, - AddRef = fpAddRef, - Release = fpRelease - }, + IUnknownImpl = iUnknownVtbl, SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue) }; var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl)); @@ -208,13 +210,15 @@ protected override void ReleaseObjects(IEnumerable objects) if (obj is TestEx) { - var testEx = (TestEx)obj; + var iUnknownVtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVtbl), sizeof(IUnknownVtbl)); + Marshal.StructureToPtr(iUnknownVtbl, iUnknownVtblRaw, false); + var testEx = (TestEx)obj; for (int i = 1; i < testEx.Interfaces.Length + 1; i++) { // Including interfaces to allow QI, but not actually returning a valid vtable, since it is not needed for the tests here. entryRaw[i].IID = testEx.Interfaces[i - 1]; - entryRaw[i].Vtable = IntPtr.Zero; + entryRaw[i].Vtable = iUnknownVtblRaw; } } From 574af1982571eb7e29934a8337b701a76341d05d Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Tue, 24 Mar 2020 10:32:12 -0700 Subject: [PATCH 12/12] PR feedback --- .../Runtime/InteropServices/ComWrappers.cs | 1 + src/coreclr/src/vm/interopconverter.cpp | 10 ++-------- src/coreclr/src/vm/interoplibinterface.cpp | 19 ++++++++++--------- src/coreclr/src/vm/interoplibinterface.h | 3 +-- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 8bc2678e9622e..59f3d47ac3585 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -292,6 +292,7 @@ public void RegisterAsGlobalInstance() } [DllImport(RuntimeHelpers.QCall)] + [SuppressGCTransition] private static extern void SetGlobalInstanceRegistered(); /// diff --git a/src/coreclr/src/vm/interopconverter.cpp b/src/coreclr/src/vm/interopconverter.cpp index 045f6b1a05c09..44470d7d297b4 100644 --- a/src/coreclr/src/vm/interopconverter.cpp +++ b/src/coreclr/src/vm/interopconverter.cpp @@ -17,7 +17,6 @@ #include "runtimecallablewrapper.h" #include "cominterfacemarshaler.h" #include "binder.h" -#include // CreateComInterfaceFlags, CreateObjectFlags #include #include "winrttypenameconverter.h" #include "typestring.h" @@ -29,8 +28,7 @@ namespace _Outptr_ IUnknown** wrapperRaw) { #ifdef FEATURE_COMWRAPPERS - InteropLib::Com::CreateComInterfaceFlags flags = InteropLib::Com::CreateComInterfaceFlags::CreateComInterfaceFlags_TrackerSupport; - return GlobalComWrappers::TryGetOrCreateComInterfaceForObject(instance, flags, (void**)wrapperRaw); + return GlobalComWrappers::TryGetOrCreateComInterfaceForObject(instance, (void**)wrapperRaw); #else return false; #endif // FEATURE_COMWRAPPERS @@ -42,11 +40,7 @@ namespace _Out_ OBJECTREF *pObjOut) { #ifdef FEATURE_COMWRAPPERS - int flags = InteropLib::Com::CreateObjectFlags::CreateObjectFlags_TrackerObject; - if ((dwFlags & ObjFromComIP::UNIQUE_OBJECT) != 0) - flags |= InteropLib::Com::CreateObjectFlags::CreateObjectFlags_UniqueInstance; - - return GlobalComWrappers::TryGetOrCreateObjectForComInstance(pUnknown, flags, pObjOut); + return GlobalComWrappers::TryGetOrCreateObjectForComInstance(pUnknown, dwFlags, pObjOut); #else return false; #endif // FEATURE_COMWRAPPERS diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index 669003601c24c..c71b0d955fd53 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -1278,20 +1278,15 @@ void ComWrappersNative::MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe) void QCALLTYPE GlobalComWrappers::SetGlobalInstanceRegistered() { - QCALL_CONTRACT; + // QCALL contracts are not used here because the managed declaration + // uses the SuppressGCTransition attribute _ASSERTE(!g_IsGlobalComWrappersRegistered); - - BEGIN_QCALL; - g_IsGlobalComWrappersRegistered = true; - - END_QCALL; } bool GlobalComWrappers::TryGetOrCreateComInterfaceForObject( _In_ OBJECTREF instance, - _In_ INT32 flags, _Outptr_ void** wrapperRaw) { if (!g_IsGlobalComWrappersRegistered) @@ -1302,18 +1297,20 @@ bool GlobalComWrappers::TryGetOrCreateComInterfaceForObject( { GCX_COOP(); + CreateComInterfaceFlags flags = CreateComInterfaceFlags::CreateComInterfaceFlags_TrackerSupport; + // Passing NULL as the ComWrappers implementation indicates using the globally registered instance return TryGetOrCreateComInterfaceForObjectInternal( NULL, instance, - (CreateComInterfaceFlags)flags, + flags, wrapperRaw); } } bool GlobalComWrappers::TryGetOrCreateObjectForComInstance( _In_ IUnknown* externalComObject, - _In_ INT32 flags, + _In_ INT32 objFromComIPFlags, _Out_ OBJECTREF* objRef) { if (!g_IsGlobalComWrappersRegistered) @@ -1324,6 +1321,10 @@ bool GlobalComWrappers::TryGetOrCreateObjectForComInstance( { GCX_COOP(); + int flags = CreateObjectFlags::CreateObjectFlags_TrackerObject; + if ((objFromComIPFlags & ObjFromComIP::UNIQUE_OBJECT) != 0) + flags |= CreateObjectFlags::CreateObjectFlags_UniqueInstance; + // For implicit usage of ComWrappers (i.e. automatically called by the runtime when there is a global instance), // unwrap CCWs to allow for round-tripping object -> COM instance -> object. bool unwrapIfManagedObjectWrapper = true; diff --git a/src/coreclr/src/vm/interoplibinterface.h b/src/coreclr/src/vm/interoplibinterface.h index b3d9f0f47cc66..2985412db4b29 100644 --- a/src/coreclr/src/vm/interoplibinterface.h +++ b/src/coreclr/src/vm/interoplibinterface.h @@ -49,12 +49,11 @@ class GlobalComWrappers public: // Functions operating on a registered global instance static bool TryGetOrCreateComInterfaceForObject( _In_ OBJECTREF instance, - _In_ INT32 flags, _Outptr_ void** wrapperRaw); static bool TryGetOrCreateObjectForComInstance( _In_ IUnknown* externalComObject, - _In_ INT32 flags, + _In_ INT32 objFromComIPFlags, _Out_ OBJECTREF* objRef); };