diff --git a/DllImportGenerator/Ancillary.Interop/MarshalEx.cs b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs index a5ad55f72e68..8a53af07090d 100644 --- a/DllImportGenerator/Ancillary.Interop/MarshalEx.cs +++ b/DllImportGenerator/Ancillary.Interop/MarshalEx.cs @@ -10,32 +10,12 @@ namespace System.Runtime.InteropServices /// public static class MarshalEx { - /// - /// Create an instance of the given . - /// - /// Type of the SafeHandle - /// New instance of - /// - /// The must be non-abstract and have a parameterless constructor. - /// - public static TSafeHandle CreateSafeHandle<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor | DynamicallyAccessedMemberTypes.NonPublicConstructors)]TSafeHandle>() - where TSafeHandle : SafeHandle - { - if (typeof(TSafeHandle).IsAbstract || typeof(TSafeHandle).GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.CreateInstance | BindingFlags.Instance, null, Type.EmptyTypes, null) == null) - { - throw new MissingMemberException($"The safe handle type '{typeof(TSafeHandle).FullName}' must be a non-abstract type with a parameterless constructor."); - } - - TSafeHandle safeHandle = (TSafeHandle)Activator.CreateInstance(typeof(TSafeHandle), nonPublic: true)!; - return safeHandle; - } - /// /// Sets the handle of to the specified . /// /// instance to update /// Pre-existing handle - public static void SetHandle(SafeHandle safeHandle, IntPtr handle) + public static void InitHandle(SafeHandle safeHandle, IntPtr handle) { typeof(SafeHandle).GetMethod("SetHandle", BindingFlags.NonPublic | BindingFlags.Instance)!.Invoke(safeHandle, new object[] { handle }); } diff --git a/DllImportGenerator/DllImportGenerator.IntegrationTests/SafeHandleTests.cs b/DllImportGenerator/DllImportGenerator.IntegrationTests/SafeHandleTests.cs index 1a6b47a63bed..fda3852b3d88 100644 --- a/DllImportGenerator/DllImportGenerator.IntegrationTests/SafeHandleTests.cs +++ b/DllImportGenerator/DllImportGenerator.IntegrationTests/SafeHandleTests.cs @@ -7,7 +7,7 @@ namespace DllImportGenerator.IntegrationTests { partial class NativeExportsNE { - public class NativeExportsSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + public partial class NativeExportsSafeHandle : SafeHandleZeroOrMinusOneIsInvalid { private NativeExportsSafeHandle() : base(ownsHandle: true) { } @@ -18,6 +18,12 @@ protected override bool ReleaseHandle() Assert.True(didRelease); return didRelease; } + + public static NativeExportsSafeHandle CreateNewHandle() => AllocateHandle(); + + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "alloc_handle")] + private static partial NativeExportsSafeHandle AllocateHandle(); } [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "alloc_handle")] @@ -48,6 +54,14 @@ public void ReturnValue_CreatesSafeHandle() Assert.False(handle.IsInvalid); } + [Fact] + public void ReturnValue_CreatesSafeHandle_DirectConstructorCall() + { + using NativeExportsNE.NativeExportsSafeHandle handle = NativeExportsNE.NativeExportsSafeHandle.CreateNewHandle(); + Assert.False(handle.IsClosed); + Assert.False(handle.IsInvalid); + } + [Fact] public void ByValue_CorrectlyUnwrapsHandle() { diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs index 324b66ece135..f27cd8467185 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs @@ -919,5 +919,15 @@ public IntStructWrapperNative(IntStructWrapper managed) public IntStructWrapper ToManaged() => new IntStructWrapper { Value = value }; } "; + + public static string SafeHandleWithCustomDefaultConstructorAccessibility(bool privateCtor) => BasicParametersAndModifiers("MySafeHandle") + $@" +class MySafeHandle : SafeHandle +{{ + {(privateCtor ? "private" : "public")} MySafeHandle() : base(System.IntPtr.Zero, true) {{ }} + + public override bool IsInvalid => handle == System.IntPtr.Zero; + + protected override bool ReleaseHandle() => true; +}}"; } } diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs index 1db8e4b0ead8..004d2a01627f 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs @@ -124,6 +124,8 @@ public static IEnumerable CodeSnippetsToCompile() // SafeHandle yield return new[] { CodeSnippets.BasicParametersAndModifiers("Microsoft.Win32.SafeHandles.SafeFileHandle") }; + yield return new[] { CodeSnippets.SafeHandleWithCustomDefaultConstructorAccessibility(privateCtor: false) }; + yield return new[] { CodeSnippets.SafeHandleWithCustomDefaultConstructorAccessibility(privateCtor: true) }; // PreserveSig yield return new[] { CodeSnippets.PreserveSigFalseVoidReturn }; diff --git a/DllImportGenerator/DllImportGenerator/DllImportStub.cs b/DllImportGenerator/DllImportGenerator/DllImportStub.cs index 9fe488f7dae2..85ac9bef9b57 100644 --- a/DllImportGenerator/DllImportGenerator/DllImportStub.cs +++ b/DllImportGenerator/DllImportGenerator/DllImportStub.cs @@ -133,9 +133,10 @@ public static DllImportStub Create( // Since we're generating source for the method, we know that the current type // has to be declared in source. TypeDeclarationSyntax typeDecl = (TypeDeclarationSyntax)currType.DeclaringSyntaxReferences[0].GetSyntax(); - // Remove current members and attributes so we don't double declare them. + // Remove current members, attributes, and base list so we don't double declare them. typeDecl = typeDecl.WithMembers(List()) - .WithAttributeLists(List()); + .WithAttributeLists(List()) + .WithBaseList(null); containingTypes.Add(typeDecl); @@ -162,7 +163,7 @@ public static DllImportStub Create( for (int i = 0; i < method.Parameters.Length; i++) { var param = method.Parameters[i]; - var typeInfo = TypePositionInfo.CreateForParameter(param, defaultInfo, env.Compilation, diagnostics); + var typeInfo = TypePositionInfo.CreateForParameter(param, defaultInfo, env.Compilation, diagnostics, method.ContainingType); typeInfo = typeInfo with { ManagedIndex = i, @@ -171,7 +172,7 @@ public static DllImportStub Create( paramsTypeInfo.Add(typeInfo); } - TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), defaultInfo, env.Compilation, diagnostics); + TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), defaultInfo, env.Compilation, diagnostics, method.ContainingType); retTypeInfo = retTypeInfo with { ManagedIndex = TypePositionInfo.ReturnIndex, diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs index 5ff9b53c66cd..ba80a029e843 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -81,19 +81,39 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont SingletonSeparatedList( VariableDeclarator(addRefdIdentifier) .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); - } + + var safeHandleCreationExpression = ((SafeHandleMarshallingInfo)info.MarshallingAttributeInfo).AccessibleDefaultConstructor + ? (ExpressionSyntax)ObjectCreationExpression(info.ManagedType.AsTypeSyntax(), ArgumentList(), initializer: null) + : CastExpression( + info.ManagedType.AsTypeSyntax(), + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Activator), + IdentifierName("CreateInstance"))) + .WithArgumentList( + ArgumentList( + SeparatedList( + new []{ + Argument( + TypeOfExpression( + info.ManagedType.AsTypeSyntax())), + Argument( + LiteralExpression( + SyntaxKind.TrueLiteralExpression)) + .WithNameColon( + NameColon( + IdentifierName("nonPublic"))) + })))); + if (info.IsManagedReturnPosition) { yield return ExpressionStatement( AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(managedIdentifier), - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.MarshalEx(options)), - GenericName(Identifier("CreateSafeHandle"), - TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))), - ArgumentList()))); + safeHandleCreationExpression + )); } else if (info.IsByRef && info.RefKind != RefKind.In) { @@ -105,13 +125,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont info.ManagedType.AsTypeSyntax(), SingletonSeparatedList( VariableDeclarator(newHandleObjectIdentifier) - .WithInitializer(EqualsValueClause( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.MarshalEx(options)), - GenericName(Identifier("CreateSafeHandle"), - TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))), - ArgumentList())))))); + .WithInitializer(EqualsValueClause(safeHandleCreationExpression))))); if (info.RefKind != RefKind.Out) { yield return LocalDeclarationStatement( @@ -168,7 +182,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseTypeName(TypeNames.MarshalEx(options)), - IdentifierName("SetHandle")), + IdentifierName("InitHandle")), ArgumentList(SeparatedList( new [] { diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index 891eed70a4ac..a9c036948d56 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -103,7 +103,7 @@ internal sealed record GeneratedNativeMarshallingAttributeInfo( /// /// The type of the element is a SafeHandle-derived type with no marshalling attributes. /// - internal sealed record SafeHandleMarshallingInfo : MarshallingInfo; + internal sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor) : MarshallingInfo; /// diff --git a/DllImportGenerator/DllImportGenerator/TypeNames.cs b/DllImportGenerator/DllImportGenerator/TypeNames.cs index de5670fba75e..457a4db0386a 100644 --- a/DllImportGenerator/DllImportGenerator/TypeNames.cs +++ b/DllImportGenerator/DllImportGenerator/TypeNames.cs @@ -22,6 +22,8 @@ static class TypeNames public const string System_Span_Metadata = "System.Span`1"; public const string System_Span = "System.Span"; + public const string System_Activator = "System.Activator"; + public const string System_Runtime_InteropServices_StructLayoutAttribute = "System.Runtime.InteropServices.StructLayoutAttribute"; public const string System_Runtime_InteropServices_MarshalAsAttribute = "System.Runtime.InteropServices.MarshalAsAttribute"; diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 10c707d9c68f..9cb216590cc0 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -82,9 +82,9 @@ private TypePositionInfo() public MarshallingInfo MarshallingAttributeInfo { get; init; } - public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics) + public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol) { - var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), defaultInfo, compilation, diagnostics); + var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), defaultInfo, compilation, diagnostics, scopeSymbol); var typeInfo = new TypePositionInfo() { ManagedType = paramSymbol.Type, @@ -98,9 +98,9 @@ public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, return typeInfo; } - public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics) + public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol) { - var marshallingInfo = GetMarshallingInfo(type, attributes, defaultInfo, compilation, diagnostics); + var marshallingInfo = GetMarshallingInfo(type, attributes, defaultInfo, compilation, diagnostics, scopeSymbol); var typeInfo = new TypePositionInfo() { ManagedType = type, @@ -127,7 +127,7 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, MarshallingInfo m return typeInfo; } - private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics) + private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol) { // Look at attributes passed in - usage specific. foreach (var attrData in attributes) @@ -137,7 +137,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass)) { // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute - return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, diagnostics); + return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, diagnostics, scopeSymbol); } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass)) { @@ -167,7 +167,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< // If the type doesn't have custom attributes that dictate marshalling, // then consider the type itself. - if (TryCreateTypeBasedMarshallingInfo(type, defaultInfo, compilation, diagnostics, out MarshallingInfo infoMaybe)) + if (TryCreateTypeBasedMarshallingInfo(type, defaultInfo, compilation, diagnostics, scopeSymbol, out MarshallingInfo infoMaybe)) { return infoMaybe; } @@ -183,7 +183,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< return NoMarshallingInfo.Instance; - static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrData, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics) + static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrData, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol) { object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!; UnmanagedType unmanagedType = unmanagedTypeObj is short @@ -252,7 +252,7 @@ static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrDat } else if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) { - elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics); + elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol); } return new ArrayMarshalAsInfo( @@ -307,7 +307,7 @@ static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol ty NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null); } - static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, out MarshallingInfo marshallingInfo) + static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol, out MarshallingInfo marshallingInfo) { var conversion = compilation.ClassifyCommonConversion(type, compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle)!); if (conversion.Exists @@ -315,13 +315,25 @@ static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshalli && conversion.IsReference && !type.IsAbstract) { - marshallingInfo = new SafeHandleMarshallingInfo(); + bool hasAccessibleDefaultConstructor = false; + if (type is INamedTypeSymbol named && named.InstanceConstructors.Length > 0) + { + foreach (var ctor in named.InstanceConstructors) + { + if (ctor.Parameters.Length == 0) + { + hasAccessibleDefaultConstructor = compilation.IsSymbolAccessibleWithin(ctor, scopeSymbol); + break; + } + } + } + marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor); return true; } if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) { - marshallingInfo = new ArrayMarshallingInfo(GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics)); + marshallingInfo = new ArrayMarshallingInfo(GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol)); return true; } marshallingInfo = NoMarshallingInfo.Instance;