Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SafeHandle codegen to match the approved API. #570

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions DllImportGenerator/Ancillary.Interop/MarshalEx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,12 @@ namespace System.Runtime.InteropServices
/// </summary>
public static class MarshalEx
{
/// <summary>
/// Create an instance of the given <typeparamref name="TSafeHandle"/>.
/// </summary>
/// <typeparam name="TSafeHandle">Type of the SafeHandle</typeparam>
/// <returns>New instance of <typeparamref name="TSafeHandle"/></returns>
/// <remarks>
/// The <typeparamref name="TSafeHandle"/> must be non-abstract and have a parameterless constructor.
/// </remarks>
public static TSafeHandle CreateSafeHandle<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor | DynamicallyAccessedMemberTypes.NonPublicConstructors)]TSafeHandle>()
where TSafeHandle : SafeHandle
{
if (typeof(TSafeHandle).IsAbstract || typeof(TSafeHandle).GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.CreateInstance | BindingFlags.Instance, null, Type.EmptyTypes, null) == null)
{
throw new MissingMemberException($"The safe handle type '{typeof(TSafeHandle).FullName}' must be a non-abstract type with a parameterless constructor.");
}

TSafeHandle safeHandle = (TSafeHandle)Activator.CreateInstance(typeof(TSafeHandle), nonPublic: true)!;
return safeHandle;
}

/// <summary>
/// Sets the handle of <paramref name="safeHandle"/> to the specified <paramref name="handle"/>.
/// </summary>
/// <param name="safeHandle"><see cref="SafeHandle"/> instance to update</param>
/// <param name="handle">Pre-existing handle</param>
public static void SetHandle(SafeHandle safeHandle, IntPtr handle)
public static void InitHandle(SafeHandle safeHandle, IntPtr handle)
{
typeof(SafeHandle).GetMethod("SetHandle", BindingFlags.NonPublic | BindingFlags.Instance)!.Invoke(safeHandle, new object[] { handle });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace DllImportGenerator.IntegrationTests
{
partial class NativeExportsNE
{
public class NativeExportsSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
public partial class NativeExportsSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
{
private NativeExportsSafeHandle() : base(ownsHandle: true)
{ }
Expand All @@ -18,6 +18,12 @@ protected override bool ReleaseHandle()
Assert.True(didRelease);
return didRelease;
}

public static NativeExportsSafeHandle CreateNewHandle() => AllocateHandle();


[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "alloc_handle")]
private static partial NativeExportsSafeHandle AllocateHandle();
}

[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "alloc_handle")]
Expand Down Expand Up @@ -48,6 +54,14 @@ public void ReturnValue_CreatesSafeHandle()
Assert.False(handle.IsInvalid);
}

[Fact]
public void ReturnValue_CreatesSafeHandle_DirectConstructorCall()
{
using NativeExportsNE.NativeExportsSafeHandle handle = NativeExportsNE.NativeExportsSafeHandle.CreateNewHandle();
Assert.False(handle.IsClosed);
Assert.False(handle.IsInvalid);
}

[Fact]
public void ByValue_CorrectlyUnwrapsHandle()
{
Expand Down
10 changes: 10 additions & 0 deletions DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -919,5 +919,15 @@ public IntStructWrapperNative(IntStructWrapper managed)
public IntStructWrapper ToManaged() => new IntStructWrapper { Value = value };
}
";

public static string SafeHandleWithCustomDefaultConstructorAccessibility(bool privateCtor) => BasicParametersAndModifiers("MySafeHandle") + $@"
class MySafeHandle : SafeHandle
{{
{(privateCtor ? "private" : "public")} MySafeHandle() : base(System.IntPtr.Zero, true) {{ }}

public override bool IsInvalid => handle == System.IntPtr.Zero;

protected override bool ReleaseHandle() => true;
}}";
}
}
2 changes: 2 additions & 0 deletions DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()

// SafeHandle
yield return new[] { CodeSnippets.BasicParametersAndModifiers("Microsoft.Win32.SafeHandles.SafeFileHandle") };
yield return new[] { CodeSnippets.SafeHandleWithCustomDefaultConstructorAccessibility(privateCtor: false) };
yield return new[] { CodeSnippets.SafeHandleWithCustomDefaultConstructorAccessibility(privateCtor: true) };

// PreserveSig
yield return new[] { CodeSnippets.PreserveSigFalseVoidReturn };
Expand Down
9 changes: 5 additions & 4 deletions DllImportGenerator/DllImportGenerator/DllImportStub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ public static DllImportStub Create(
// Since we're generating source for the method, we know that the current type
// has to be declared in source.
TypeDeclarationSyntax typeDecl = (TypeDeclarationSyntax)currType.DeclaringSyntaxReferences[0].GetSyntax();
// Remove current members and attributes so we don't double declare them.
// Remove current members, attributes, and base list so we don't double declare them.
typeDecl = typeDecl.WithMembers(List<MemberDeclarationSyntax>())
.WithAttributeLists(List<AttributeListSyntax>());
.WithAttributeLists(List<AttributeListSyntax>())
.WithBaseList(null);

containingTypes.Add(typeDecl);

Expand All @@ -162,7 +163,7 @@ public static DllImportStub Create(
for (int i = 0; i < method.Parameters.Length; i++)
{
var param = method.Parameters[i];
var typeInfo = TypePositionInfo.CreateForParameter(param, defaultInfo, env.Compilation, diagnostics);
var typeInfo = TypePositionInfo.CreateForParameter(param, defaultInfo, env.Compilation, diagnostics, method.ContainingType);
typeInfo = typeInfo with
{
ManagedIndex = i,
Expand All @@ -171,7 +172,7 @@ public static DllImportStub Create(
paramsTypeInfo.Add(typeInfo);
}

TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), defaultInfo, env.Compilation, diagnostics);
TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), defaultInfo, env.Compilation, diagnostics, method.ContainingType);
retTypeInfo = retTypeInfo with
{
ManagedIndex = TypePositionInfo.ReturnIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,39 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
SingletonSeparatedList(
VariableDeclarator(addRefdIdentifier)
.WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression))))));

}

var safeHandleCreationExpression = ((SafeHandleMarshallingInfo)info.MarshallingAttributeInfo).AccessibleDefaultConstructor
? (ExpressionSyntax)ObjectCreationExpression(info.ManagedType.AsTypeSyntax(), ArgumentList(), initializer: null)
: CastExpression(
info.ManagedType.AsTypeSyntax(),
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.System_Activator),
IdentifierName("CreateInstance")))
.WithArgumentList(
ArgumentList(
SeparatedList(
new []{
Argument(
TypeOfExpression(
info.ManagedType.AsTypeSyntax())),
Argument(
LiteralExpression(
SyntaxKind.TrueLiteralExpression))
.WithNameColon(
NameColon(
IdentifierName("nonPublic")))
}))));

if (info.IsManagedReturnPosition)
{
yield return ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifier),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.MarshalEx(options)),
GenericName(Identifier("CreateSafeHandle"),
TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))),
ArgumentList())));
safeHandleCreationExpression
));
}
else if (info.IsByRef && info.RefKind != RefKind.In)
{
Expand All @@ -105,13 +125,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
info.ManagedType.AsTypeSyntax(),
SingletonSeparatedList(
VariableDeclarator(newHandleObjectIdentifier)
.WithInitializer(EqualsValueClause(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.MarshalEx(options)),
GenericName(Identifier("CreateSafeHandle"),
TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))),
ArgumentList()))))));
.WithInitializer(EqualsValueClause(safeHandleCreationExpression)))));
if (info.RefKind != RefKind.Out)
{
yield return LocalDeclarationStatement(
Expand Down Expand Up @@ -168,7 +182,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.MarshalEx(options)),
IdentifierName("SetHandle")),
IdentifierName("InitHandle")),
ArgumentList(SeparatedList(
new []
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ internal sealed record GeneratedNativeMarshallingAttributeInfo(
/// <summary>
/// The type of the element is a SafeHandle-derived type with no marshalling attributes.
/// </summary>
internal sealed record SafeHandleMarshallingInfo : MarshallingInfo;
internal sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor) : MarshallingInfo;


/// <summary>
Expand Down
2 changes: 2 additions & 0 deletions DllImportGenerator/DllImportGenerator/TypeNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ static class TypeNames
public const string System_Span_Metadata = "System.Span`1";
public const string System_Span = "System.Span";

public const string System_Activator = "System.Activator";

public const string System_Runtime_InteropServices_StructLayoutAttribute = "System.Runtime.InteropServices.StructLayoutAttribute";

public const string System_Runtime_InteropServices_MarshalAsAttribute = "System.Runtime.InteropServices.MarshalAsAttribute";
Expand Down
36 changes: 24 additions & 12 deletions DllImportGenerator/DllImportGenerator/TypePositionInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ private TypePositionInfo()

public MarshallingInfo MarshallingAttributeInfo { get; init; }

public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics)
public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh. These two functions (i.e. CreateForParameter and CreateForType) are hideous... Let's make a mental note that they should not be exposed as is when we enable extending the source generator.

/cc @elinor-fung

{
var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), defaultInfo, compilation, diagnostics);
var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), defaultInfo, compilation, diagnostics, scopeSymbol);
var typeInfo = new TypePositionInfo()
{
ManagedType = paramSymbol.Type,
Expand All @@ -98,9 +98,9 @@ public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol,
return typeInfo;
}

public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable<AttributeData> attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics)
public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable<AttributeData> attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol)
{
var marshallingInfo = GetMarshallingInfo(type, attributes, defaultInfo, compilation, diagnostics);
var marshallingInfo = GetMarshallingInfo(type, attributes, defaultInfo, compilation, diagnostics, scopeSymbol);
var typeInfo = new TypePositionInfo()
{
ManagedType = type,
Expand All @@ -127,7 +127,7 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, MarshallingInfo m
return typeInfo;
}

private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<AttributeData> attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics)
private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<AttributeData> attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol)
{
// Look at attributes passed in - usage specific.
foreach (var attrData in attributes)
Expand All @@ -137,7 +137,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<
if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass))
{
// https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute
return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, diagnostics);
return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, diagnostics, scopeSymbol);
}
else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass))
{
Expand Down Expand Up @@ -167,7 +167,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<

// If the type doesn't have custom attributes that dictate marshalling,
// then consider the type itself.
if (TryCreateTypeBasedMarshallingInfo(type, defaultInfo, compilation, diagnostics, out MarshallingInfo infoMaybe))
if (TryCreateTypeBasedMarshallingInfo(type, defaultInfo, compilation, diagnostics, scopeSymbol, out MarshallingInfo infoMaybe))
{
return infoMaybe;
}
Expand All @@ -183,7 +183,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable<

return NoMarshallingInfo.Instance;

static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrData, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics)
static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrData, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol)
{
object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!;
UnmanagedType unmanagedType = unmanagedTypeObj is short
Expand Down Expand Up @@ -252,7 +252,7 @@ static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrDat
}
else if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType })
{
elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty<AttributeData>(), defaultInfo, compilation, diagnostics);
elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty<AttributeData>(), defaultInfo, compilation, diagnostics, scopeSymbol);
}

return new ArrayMarshalAsInfo(
Expand Down Expand Up @@ -307,21 +307,33 @@ static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol ty
NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null);
}

static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, out MarshallingInfo marshallingInfo)
static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol, out MarshallingInfo marshallingInfo)
{
var conversion = compilation.ClassifyCommonConversion(type, compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle)!);
if (conversion.Exists
&& conversion.IsImplicit
&& conversion.IsReference
&& !type.IsAbstract)
{
marshallingInfo = new SafeHandleMarshallingInfo();
bool hasAccessibleDefaultConstructor = false;
if (type is INamedTypeSymbol named && named.InstanceConstructors.Length > 0)
{
foreach (var ctor in named.InstanceConstructors)
{
if (ctor.Parameters.Length == 0)
{
hasAccessibleDefaultConstructor = compilation.IsSymbolAccessibleWithin(ctor, scopeSymbol);
jkoritzinsky marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
}
marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor);
return true;
}

if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType })
{
marshallingInfo = new ArrayMarshallingInfo(GetMarshallingInfo(elementType, Array.Empty<AttributeData>(), defaultInfo, compilation, diagnostics));
marshallingInfo = new ArrayMarshallingInfo(GetMarshallingInfo(elementType, Array.Empty<AttributeData>(), defaultInfo, compilation, diagnostics, scopeSymbol));
return true;
}
marshallingInfo = NoMarshallingInfo.Instance;
Expand Down