Skip to content

Commit

Permalink
Use proper SafeHandle types for specific APIs
Browse files Browse the repository at this point in the history
Some APIs return a common handle type but require the handle to be released with a less-common method. The metadata now includes this detail, and with this change, CsWin32 honors that.

Fixes microsoft/win32metadata#1581
  • Loading branch information
AArnott committed May 24, 2023
1 parent cf70149 commit 3f59bd0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
23 changes: 15 additions & 8 deletions src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
};

MethodSignature<TypeHandleInfo> originalSignature = methodDefinition.DecodeSignature(SignatureHandleProvider.Instance, null);
CustomAttributeHandleCollection? returnTypeAttributes = null;
var parameters = externMethodDeclaration.ParameterList.Parameters.Select(StripAttributes).ToList();
var lengthParamUsedBy = new Dictionary<int, int>();
var parametersToRemove = new List<int>();
Expand All @@ -67,17 +68,23 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
foreach (ParameterHandle paramHandle in methodDefinition.GetParameters())
{
Parameter param = this.Reader.GetParameter(paramHandle);
if (param.SequenceNumber == 0)
{
returnTypeAttributes = param.GetCustomAttributes();
}

if (param.SequenceNumber == 0 || param.SequenceNumber - 1 >= parameters.Count)
{
continue;
}

bool isOptional = (param.Attributes & ParameterAttributes.Optional) == ParameterAttributes.Optional;
bool isReserved = this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), "ReservedAttribute") is not null;
CustomAttributeHandleCollection paramAttributes = param.GetCustomAttributes();
bool isReserved = this.FindInteropDecorativeAttribute(paramAttributes, "ReservedAttribute") is not null;
isOptional |= isReserved; // Per metadata decision made at https://github.com/microsoft/win32metadata/issues/1421#issuecomment-1372608090
bool isIn = (param.Attributes & ParameterAttributes.In) == ParameterAttributes.In;
bool isConst = this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), "ConstAttribute") is not null;
bool isComOutPtr = this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), "ComOutPtrAttribute") is not null;
bool isConst = this.FindInteropDecorativeAttribute(paramAttributes, "ConstAttribute") is not null;
bool isComOutPtr = this.FindInteropDecorativeAttribute(paramAttributes, "ComOutPtrAttribute") is not null;
bool isOut = isComOutPtr || (param.Attributes & ParameterAttributes.Out) == ParameterAttributes.Out;

// TODO:
Expand Down Expand Up @@ -105,7 +112,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
bool hasOut = externParam.Modifiers.Any(SyntaxKind.OutKeyword);
arguments[param.SequenceNumber - 1] = arguments[param.SequenceNumber - 1].WithRefKindKeyword(TokenWithSpace(hasOut ? SyntaxKind.OutKeyword : SyntaxKind.RefKeyword));
}
else if (isOut && !isIn && !isReleaseMethod && parameterTypeInfo is PointerTypeHandleInfo { ElementType: HandleTypeHandleInfo pointedElementInfo } && this.TryGetHandleReleaseMethod(pointedElementInfo.Handle, out string? outReleaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, outReleaseMethod))
else if (isOut && !isIn && !isReleaseMethod && parameterTypeInfo is PointerTypeHandleInfo { ElementType: HandleTypeHandleInfo pointedElementInfo } && this.TryGetHandleReleaseMethod(pointedElementInfo.Handle, paramAttributes, out string? outReleaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, outReleaseMethod))
{
if (this.RequestSafeHandle(outReleaseMethod) is TypeSyntax safeHandleType)
{
Expand Down Expand Up @@ -134,7 +141,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle")))))));
}
}
else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod)
else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, paramAttributes, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod)
&& !(this.TryGetTypeDefFieldType(parameterHandleTypeInfo, out TypeHandleInfo? fieldType) && !this.IsSafeHandleCompatibleTypeDefFieldType(fieldType)))
{
IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local");
Expand Down Expand Up @@ -231,14 +238,14 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
bool isNullTerminated = false; // TODO
short? sizeParamIndex = null;
int? sizeConst = null;
if (this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), NativeArrayInfoAttribute) is CustomAttribute att)
if (this.FindInteropDecorativeAttribute(paramAttributes, NativeArrayInfoAttribute) is CustomAttribute att)
{
isArray = true;
NativeArrayInfo nativeArrayInfo = DecodeNativeArrayInfoAttribute(att);
sizeParamIndex = nativeArrayInfo.CountParamIndex;
sizeConst = nativeArrayInfo.CountConst;
}
else if (externParam.Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ByteKeyword } } && this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), MemorySizeAttribute) is CustomAttribute att2)
else if (externParam.Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ByteKeyword } } && this.FindInteropDecorativeAttribute(paramAttributes, MemorySizeAttribute) is CustomAttribute att2)
{
// A very special case as documented in https://github.com/microsoft/win32metadata/issues/1555
// where MemorySizeAttribute is applied to byte* parameters to indicate the size of the buffer.
Expand Down Expand Up @@ -490,7 +497,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
}

TypeSyntax? returnSafeHandleType = originalSignature.ReturnType is HandleTypeHandleInfo returnTypeHandleInfo
&& this.TryGetHandleReleaseMethod(returnTypeHandleInfo.Handle, out string? returnReleaseMethod)
&& this.TryGetHandleReleaseMethod(returnTypeHandleInfo.Handle, returnTypeAttributes, out string? returnReleaseMethod)
? this.RequestSafeHandle(returnReleaseMethod) : null;
SyntaxToken friendlyMethodName = externMethodDeclaration.Identifier;

Expand Down
21 changes: 16 additions & 5 deletions src/Microsoft.Windows.CsWin32/Generator.Handle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ public partial class Generator
.AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute))
.WithLeadingTrivia(ParseLeadingTrivia($@"
/// <summary>
/// Represents a Win32 handle that can be closed with <see cref=""{this.options.ClassName}.{renamedReleaseMethod ?? releaseMethod}""/>.
/// Represents a Win32 handle that can be closed with <see cref=""{this.options.ClassName}.{renamedReleaseMethod ?? releaseMethod}({releaseMethodParameterType.Type})""/>.
/// </summary>
"));

Expand All @@ -239,7 +239,7 @@ public partial class Generator
}
}

internal bool TryGetHandleReleaseMethod(EntityHandle handleStructDefHandle, [NotNullWhen(true)] out string? releaseMethod)
internal bool TryGetHandleReleaseMethod(EntityHandle handleStructDefHandle, CustomAttributeHandleCollection? handleReferenceAttributes, [NotNullWhen(true)] out string? releaseMethod)
{
if (handleStructDefHandle.IsNil)
{
Expand All @@ -251,20 +251,31 @@ internal bool TryGetHandleReleaseMethod(EntityHandle handleStructDefHandle, [Not
{
if (this.TryGetTypeDefHandle((TypeReferenceHandle)handleStructDefHandle, out TypeDefinitionHandle typeDefHandle))
{
return this.TryGetHandleReleaseMethod(typeDefHandle, out releaseMethod);
return this.TryGetHandleReleaseMethod(typeDefHandle, handleReferenceAttributes, out releaseMethod);
}
}
else if (handleStructDefHandle.Kind == HandleKind.TypeDefinition)
{
return this.TryGetHandleReleaseMethod((TypeDefinitionHandle)handleStructDefHandle, out releaseMethod);
return this.TryGetHandleReleaseMethod((TypeDefinitionHandle)handleStructDefHandle, handleReferenceAttributes, out releaseMethod);
}

releaseMethod = null;
return false;
}

internal bool TryGetHandleReleaseMethod(TypeDefinitionHandle handleStructDefHandle, [NotNullWhen(true)] out string? releaseMethod)
internal bool TryGetHandleReleaseMethod(TypeDefinitionHandle handleStructDefHandle, CustomAttributeHandleCollection? handleReferenceAttributes, [NotNullWhen(true)] out string? releaseMethod)
{
// Prefer direct attributes on the type reference over the default release method for the struct type.
if (this.FindAttribute(handleReferenceAttributes, InteropDecorationNamespace, RAIIFreeAttribute) is CustomAttribute raii)
{
CustomAttributeValue<TypeSyntax> args = raii.DecodeValue(CustomAttributeTypeProvider.Instance);
if (args.FixedArguments[0].Value is string localRelease)
{
releaseMethod = localRelease;
return true;
}
}

return this.MetadataIndex.HandleTypeReleaseMethod.TryGetValue(handleStructDefHandle, out releaseMethod);
}

Expand Down
13 changes: 13 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ public void SHGetFileInfo()
Assert.All(this.FindGeneratedMethod(name), m => Assert.Equal(5, m.ParameterList.Parameters.Count));
}

[Fact]
public void SpecializedRAIIFree()
{
const string Method = "CreateActCtx";
this.generator = this.CreateGenerator();
Assert.True(this.generator.TryGenerate(Method, CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();

MethodDeclarationSyntax method = Assert.Single(this.FindGeneratedMethod(Method), m => !IsOrContainsExternMethod(m));
Assert.Equal("ReleaseActCtxSafeHandle", Assert.IsType<IdentifierNameSyntax>(method.ReturnType).Identifier.ValueText);
}

private void Generate(string name)
{
this.compilation = this.compilation.WithOptions(this.compilation.Options.WithPlatform(Platform.X64));
Expand Down

0 comments on commit 3f59bd0

Please sign in to comment.