diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index 3dd483d4..49e744c5 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -53,6 +53,7 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi }; MethodSignature originalSignature = methodDefinition.DecodeSignature(SignatureHandleProvider.Instance, null); + CustomAttributeHandleCollection? returnTypeAttributes = null; var parameters = externMethodDeclaration.ParameterList.Parameters.Select(StripAttributes).ToList(); var lengthParamUsedBy = new Dictionary(); var parametersToRemove = new List(); @@ -67,17 +68,23 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi foreach (ParameterHandle paramHandle in methodDefinition.GetParameters()) { Parameter param = this.Reader.GetParameter(paramHandle); + if (param.SequenceNumber == 0) + { + returnTypeAttributes = param.GetCustomAttributes(); + } + if (param.SequenceNumber == 0 || param.SequenceNumber - 1 >= parameters.Count) { continue; } bool isOptional = (param.Attributes & ParameterAttributes.Optional) == ParameterAttributes.Optional; - bool isReserved = this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), "ReservedAttribute") is not null; + CustomAttributeHandleCollection paramAttributes = param.GetCustomAttributes(); + bool isReserved = this.FindInteropDecorativeAttribute(paramAttributes, "ReservedAttribute") is not null; isOptional |= isReserved; // Per metadata decision made at https://github.com/microsoft/win32metadata/issues/1421#issuecomment-1372608090 bool isIn = (param.Attributes & ParameterAttributes.In) == ParameterAttributes.In; - bool isConst = this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), "ConstAttribute") is not null; - bool isComOutPtr = this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), "ComOutPtrAttribute") is not null; + bool isConst = this.FindInteropDecorativeAttribute(paramAttributes, "ConstAttribute") is not null; + bool isComOutPtr = this.FindInteropDecorativeAttribute(paramAttributes, "ComOutPtrAttribute") is not null; bool isOut = isComOutPtr || (param.Attributes & ParameterAttributes.Out) == ParameterAttributes.Out; // TODO: @@ -105,7 +112,7 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi bool hasOut = externParam.Modifiers.Any(SyntaxKind.OutKeyword); arguments[param.SequenceNumber - 1] = arguments[param.SequenceNumber - 1].WithRefKindKeyword(TokenWithSpace(hasOut ? SyntaxKind.OutKeyword : SyntaxKind.RefKeyword)); } - else if (isOut && !isIn && !isReleaseMethod && parameterTypeInfo is PointerTypeHandleInfo { ElementType: HandleTypeHandleInfo pointedElementInfo } && this.TryGetHandleReleaseMethod(pointedElementInfo.Handle, out string? outReleaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, outReleaseMethod)) + else if (isOut && !isIn && !isReleaseMethod && parameterTypeInfo is PointerTypeHandleInfo { ElementType: HandleTypeHandleInfo pointedElementInfo } && this.TryGetHandleReleaseMethod(pointedElementInfo.Handle, paramAttributes, out string? outReleaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, outReleaseMethod)) { if (this.RequestSafeHandle(outReleaseMethod) is TypeSyntax safeHandleType) { @@ -134,7 +141,7 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))))); } } - else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod) + else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, paramAttributes, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod) && !(this.TryGetTypeDefFieldType(parameterHandleTypeInfo, out TypeHandleInfo? fieldType) && !this.IsSafeHandleCompatibleTypeDefFieldType(fieldType))) { IdentifierNameSyntax typeDefHandleName = IdentifierName(externParam.Identifier.ValueText + "Local"); @@ -231,14 +238,14 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi bool isNullTerminated = false; // TODO short? sizeParamIndex = null; int? sizeConst = null; - if (this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), NativeArrayInfoAttribute) is CustomAttribute att) + if (this.FindInteropDecorativeAttribute(paramAttributes, NativeArrayInfoAttribute) is CustomAttribute att) { isArray = true; NativeArrayInfo nativeArrayInfo = DecodeNativeArrayInfoAttribute(att); sizeParamIndex = nativeArrayInfo.CountParamIndex; sizeConst = nativeArrayInfo.CountConst; } - else if (externParam.Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ByteKeyword } } && this.FindInteropDecorativeAttribute(param.GetCustomAttributes(), MemorySizeAttribute) is CustomAttribute att2) + else if (externParam.Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ByteKeyword } } && this.FindInteropDecorativeAttribute(paramAttributes, MemorySizeAttribute) is CustomAttribute att2) { // A very special case as documented in https://github.com/microsoft/win32metadata/issues/1555 // where MemorySizeAttribute is applied to byte* parameters to indicate the size of the buffer. @@ -490,7 +497,7 @@ private IEnumerable DeclareFriendlyOverloads(MethodDefi } TypeSyntax? returnSafeHandleType = originalSignature.ReturnType is HandleTypeHandleInfo returnTypeHandleInfo - && this.TryGetHandleReleaseMethod(returnTypeHandleInfo.Handle, out string? returnReleaseMethod) + && this.TryGetHandleReleaseMethod(returnTypeHandleInfo.Handle, returnTypeAttributes, out string? returnReleaseMethod) ? this.RequestSafeHandle(returnReleaseMethod) : null; SyntaxToken friendlyMethodName = externMethodDeclaration.Identifier; diff --git a/src/Microsoft.Windows.CsWin32/Generator.Handle.cs b/src/Microsoft.Windows.CsWin32/Generator.Handle.cs index 9dcd5c67..b47676b7 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Handle.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Handle.cs @@ -226,7 +226,7 @@ public partial class Generator .AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute)) .WithLeadingTrivia(ParseLeadingTrivia($@" /// -/// Represents a Win32 handle that can be closed with . +/// Represents a Win32 handle that can be closed with . /// ")); @@ -239,7 +239,7 @@ public partial class Generator } } - internal bool TryGetHandleReleaseMethod(EntityHandle handleStructDefHandle, [NotNullWhen(true)] out string? releaseMethod) + internal bool TryGetHandleReleaseMethod(EntityHandle handleStructDefHandle, CustomAttributeHandleCollection? handleReferenceAttributes, [NotNullWhen(true)] out string? releaseMethod) { if (handleStructDefHandle.IsNil) { @@ -251,20 +251,31 @@ internal bool TryGetHandleReleaseMethod(EntityHandle handleStructDefHandle, [Not { if (this.TryGetTypeDefHandle((TypeReferenceHandle)handleStructDefHandle, out TypeDefinitionHandle typeDefHandle)) { - return this.TryGetHandleReleaseMethod(typeDefHandle, out releaseMethod); + return this.TryGetHandleReleaseMethod(typeDefHandle, handleReferenceAttributes, out releaseMethod); } } else if (handleStructDefHandle.Kind == HandleKind.TypeDefinition) { - return this.TryGetHandleReleaseMethod((TypeDefinitionHandle)handleStructDefHandle, out releaseMethod); + return this.TryGetHandleReleaseMethod((TypeDefinitionHandle)handleStructDefHandle, handleReferenceAttributes, out releaseMethod); } releaseMethod = null; return false; } - internal bool TryGetHandleReleaseMethod(TypeDefinitionHandle handleStructDefHandle, [NotNullWhen(true)] out string? releaseMethod) + internal bool TryGetHandleReleaseMethod(TypeDefinitionHandle handleStructDefHandle, CustomAttributeHandleCollection? handleReferenceAttributes, [NotNullWhen(true)] out string? releaseMethod) { + // Prefer direct attributes on the type reference over the default release method for the struct type. + if (this.FindAttribute(handleReferenceAttributes, InteropDecorationNamespace, RAIIFreeAttribute) is CustomAttribute raii) + { + CustomAttributeValue args = raii.DecodeValue(CustomAttributeTypeProvider.Instance); + if (args.FixedArguments[0].Value is string localRelease) + { + releaseMethod = localRelease; + return true; + } + } + return this.MetadataIndex.HandleTypeReleaseMethod.TryGetValue(handleStructDefHandle, out releaseMethod); } diff --git a/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs b/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs index 1e44695d..c547926f 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs @@ -27,6 +27,19 @@ public void SHGetFileInfo() Assert.All(this.FindGeneratedMethod(name), m => Assert.Equal(5, m.ParameterList.Parameters.Count)); } + [Fact] + public void SpecializedRAIIFree() + { + const string Method = "CreateActCtx"; + this.generator = this.CreateGenerator(); + Assert.True(this.generator.TryGenerate(Method, CancellationToken.None)); + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics(); + + MethodDeclarationSyntax method = Assert.Single(this.FindGeneratedMethod(Method), m => !IsOrContainsExternMethod(m)); + Assert.Equal("ReleaseActCtxSafeHandle", Assert.IsType(method.ReturnType).Identifier.ValueText); + } + private void Generate(string name) { this.compilation = this.compilation.WithOptions(this.compilation.Options.WithPlatform(Platform.X64));