Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not use an array for a pointer field that doesn't actually refer to an array #907

Merged
merged 1 commit into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/Microsoft.Windows.CsWin32/Generator.Com.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
{
TypeDefinition typeDef = this.Reader.GetTypeDefinition(typeDefHandle);
string originalIfaceName = this.Reader.GetString(typeDef.Name);
IdentifierNameSyntax ifaceName = IdentifierName(this.GetMangledIdentifier(originalIfaceName, context.AllowMarshaling, isManagedType: true));
bool isManagedType = this.IsManagedType(typeDefHandle);
IdentifierNameSyntax ifaceName = IdentifierName(this.GetMangledIdentifier(originalIfaceName, context.AllowMarshaling, isManagedType));
IdentifierNameSyntax vtblFieldName = IdentifierName("lpVtbl");
var members = new List<MemberDeclarationSyntax>();
var vtblMembers = new List<MemberDeclarationSyntax>();
TypeSyntaxSettings typeSettings = this.comSignatureTypeSettings;
TypeSyntaxSettings typeSettings = context.Filter(this.comSignatureTypeSettings);
IdentifierNameSyntax pThisLocal = IdentifierName("pThis");
ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" && !this.IsNonCOMInterface(typeDef) ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null;
List<QualifiedMethodDefinitionHandle> ccwMethodsToSkip = new();
Expand Down Expand Up @@ -132,8 +133,9 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
ISet<string> declaredProperties = this.GetDeclarableProperties(
allMethods.Select(qh => qh.Reader.GetMethodDefinition(qh.MethodHandle)),
originalIfaceName,
allowNonConsecutiveAccessors: true);
ISet<string>? ifaceDeclaredProperties = ccwThisParameter is not null ? this.GetDeclarableProperties(allMethods.Select(qh => qh.Reader.GetMethodDefinition(qh.MethodHandle)), originalIfaceName, allowNonConsecutiveAccessors: false) : null;
allowNonConsecutiveAccessors: true,
context);
ISet<string>? ifaceDeclaredProperties = ccwThisParameter is not null ? this.GetDeclarableProperties(allMethods.Select(qh => qh.Reader.GetMethodDefinition(qh.MethodHandle)), originalIfaceName, allowNonConsecutiveAccessors: false, context) : null;

foreach (QualifiedMethodDefinitionHandle methodDefHandle in allMethods)
{
Expand Down Expand Up @@ -183,7 +185,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type

// We can declare this method as a property accessor if it represents a property.
// We must also confirm that the property type is the same in both cases, because sometimes they aren't (e.g. IUIAutomationProxyFactoryEntry.ClassName).
if (this.TryGetPropertyAccessorInfo(methodDefinition.Method, originalIfaceName, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType) &&
if (this.TryGetPropertyAccessorInfo(methodDefinition.Method, originalIfaceName, context, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType) &&
declaredProperties.Contains(propertyName.Identifier.ValueText))
{
StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression(
Expand Down Expand Up @@ -343,7 +345,7 @@ StatementSyntax InvokeVtblAndThrow() => ExpressionStatement(InvocationExpression

if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
{
if (this.TryGetPropertyAccessorInfo(methodDefinition.Method, originalIfaceName, out propertyName, out accessorKind, out propertyType) &&
if (this.TryGetPropertyAccessorInfo(methodDefinition.Method, originalIfaceName, context, out propertyName, out accessorKind, out propertyType) &&
ifaceDeclaredProperties!.Contains(propertyName.Identifier.ValueText))
{
switch (accessorKind)
Expand Down Expand Up @@ -601,7 +603,7 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)

var members = new List<MemberDeclarationSyntax>();
var friendlyOverloads = new List<MethodDeclarationSyntax>();
ISet<string> declaredProperties = this.GetDeclarableProperties(allMethods.Select(this.Reader.GetMethodDefinition), actualIfaceName, allowNonConsecutiveAccessors: false);
ISet<string> declaredProperties = this.GetDeclarableProperties(allMethods.Select(this.Reader.GetMethodDefinition), actualIfaceName, allowNonConsecutiveAccessors: false, context);

foreach (MethodDefinitionHandle methodDefHandle in allMethods)
{
Expand All @@ -617,7 +619,7 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
// Even if it could be represented as a property accessor, we cannot do so if a property by the same name was already declared in anything other than the previous row.
// Adding an accessor to a property later than the very next row would screw up the virtual method table ordering.
// We must also confirm that the property type is the same in both cases, because sometimes they aren't (e.g. IUIAutomationProxyFactoryEntry.ClassName).
if (this.TryGetPropertyAccessorInfo(methodDefinition, actualIfaceName, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType) && declaredProperties.Contains(propertyName.Identifier.ValueText))
if (this.TryGetPropertyAccessorInfo(methodDefinition, actualIfaceName, context, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType) && declaredProperties.Contains(propertyName.Identifier.ValueText))
{
AccessorDeclarationSyntax accessor = AccessorDeclaration(accessorKind.Value).WithSemicolonToken(Semicolon);

Expand Down Expand Up @@ -842,15 +844,15 @@ private bool UsePreserveSigForComMethod(MethodDefinition methodDefinition, Metho
|| this.options.ComInterop.PreserveSigMethods.Contains(ifaceName.ToString());
}

private ISet<string> GetDeclarableProperties(IEnumerable<MethodDefinition> methods, string ifaceName, bool allowNonConsecutiveAccessors)
private ISet<string> GetDeclarableProperties(IEnumerable<MethodDefinition> methods, string ifaceName, bool allowNonConsecutiveAccessors, Context context)
{
Dictionary<string, (TypeSyntax Type, int Index)> goodProperties = new(StringComparer.Ordinal);
HashSet<string> badProperties = new(StringComparer.Ordinal);
int rowIndex = -1;
foreach (MethodDefinition methodDefinition in methods)
{
rowIndex++;
if (this.TryGetPropertyAccessorInfo(methodDefinition, ifaceName, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType))
if (this.TryGetPropertyAccessorInfo(methodDefinition, ifaceName, context, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType))
{
if (badProperties.Contains(propertyName.Identifier.ValueText))
{
Expand Down Expand Up @@ -886,11 +888,12 @@ void ReportBadProperty()
return goodProperties.Count == 0 ? ImmutableHashSet<string>.Empty : new HashSet<string>(goodProperties.Keys, StringComparer.Ordinal);
}

private bool TryGetPropertyAccessorInfo(MethodDefinition methodDefinition, string ifaceName, [NotNullWhen(true)] out IdentifierNameSyntax? propertyName, [NotNullWhen(true)] out SyntaxKind? accessorKind, [NotNullWhen(true)] out TypeSyntax? propertyType)
private bool TryGetPropertyAccessorInfo(MethodDefinition methodDefinition, string ifaceName, Context context, [NotNullWhen(true)] out IdentifierNameSyntax? propertyName, [NotNullWhen(true)] out SyntaxKind? accessorKind, [NotNullWhen(true)] out TypeSyntax? propertyType)
{
propertyName = null;
accessorKind = null;
propertyType = null;
TypeSyntaxSettings syntaxSettings = context.Filter(this.comSignatureTypeSettings);

if ((methodDefinition.Attributes & MethodAttributes.SpecialName) != MethodAttributes.SpecialName)
{
Expand Down Expand Up @@ -934,7 +937,8 @@ private bool TryGetPropertyAccessorInfo(MethodDefinition methodDefinition, strin
}

Parameter propertyTypeParameter = this.Reader.GetParameter(parameters.Skip(1).Single());
propertyType = signature.ParameterTypes[0].ToTypeSyntax(this.comSignatureTypeSettings, propertyTypeParameter.GetCustomAttributes(), propertyTypeParameter.Attributes).Type;
TypeHandleInfo propertyTypeInfo = signature.ParameterTypes[0];
propertyType = propertyTypeInfo.ToTypeSyntax(syntaxSettings, propertyTypeParameter.GetCustomAttributes(), propertyTypeParameter.Attributes).Type;

if (isGetter)
{
Expand All @@ -946,7 +950,7 @@ private bool TryGetPropertyAccessorInfo(MethodDefinition methodDefinition, strin
return false;
}

if (propertyType is PointerTypeSyntax propertyTypePointer)
if (propertyType is PointerTypeSyntax propertyTypePointer && (syntaxSettings.AllowMarshaling || !this.IsManagedType(propertyTypeInfo)))
{
propertyType = propertyTypePointer.ElementType;
}
Expand Down
25 changes: 23 additions & 2 deletions src/Microsoft.Windows.CsWin32/Generator.Struct.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,28 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
CustomAttributeHandleCollection fieldAttributes = fieldDef.GetCustomAttributes();
TypeHandleInfo fieldTypeInfo = fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null);
hasUtf16CharField |= fieldTypeInfo is PrimitiveTypeHandleInfo { PrimitiveTypeCode: PrimitiveTypeCode.Char };
TypeSyntaxAndMarshaling fieldTypeSyntax = fieldTypeInfo.ToTypeSyntax(typeSettings, fieldAttributes);
TypeSyntaxSettings thisFieldTypeSettings = typeSettings;

// Do not qualify names of a type nested inside *this* struct, since this struct may or may not have a mangled name.
if (thisFieldTypeSettings.QualifyNames && fieldTypeInfo is HandleTypeHandleInfo fieldHandleTypeInfo && this.IsNestedType(fieldHandleTypeInfo.Handle))
{
if (fieldHandleTypeInfo.Handle.Kind == HandleKind.TypeReference)
{
if (this.TryGetTypeDefHandle((TypeReferenceHandle)fieldHandleTypeInfo.Handle, out QualifiedTypeDefinitionHandle fieldTypeDefHandle) && fieldTypeDefHandle.Generator == this)
{
foreach (TypeDefinitionHandle nestedTypeHandle in typeDef.GetNestedTypes())
{
if (fieldTypeDefHandle.DefinitionHandle == nestedTypeHandle)
{
thisFieldTypeSettings = thisFieldTypeSettings with { QualifyNames = false };
break;
}
}
}
}
}

TypeSyntaxAndMarshaling fieldTypeSyntax = fieldTypeInfo.ToTypeSyntax(thisFieldTypeSettings, fieldAttributes);
(TypeSyntax FieldType, SyntaxList<MemberDeclarationSyntax> AdditionalMembers, AttributeSyntax? MarshalAsAttribute) fieldInfo = this.ReinterpretFieldType(fieldDef, fieldTypeSyntax.Type, fieldAttributes, context);
additionalMembers = additionalMembers.AddRange(fieldInfo.AdditionalMembers);

Expand Down Expand Up @@ -147,7 +168,7 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle

// If the field is a pointer to a COM interface (and we're using bona fide interfaces),
// then we must type it as an array.
if (context.AllowMarshaling && fieldTypeHandleInfo is PointerTypeHandleInfo ptr3 && this.IsManagedType(ptr3.ElementType))
if (context.AllowMarshaling && fieldTypeHandleInfo is PointerTypeHandleInfo ptr3 && this.IsInterface(ptr3.ElementType))
{
return (ArrayType(ptr3.ElementType.ToTypeSyntax(typeSettings, null).Type).AddRankSpecifiers(ArrayRankSpecifier()), default(SyntaxList<MemberDeclarationSyntax>), marshalAs);
}
Expand Down
15 changes: 12 additions & 3 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void AddSymbolIf(bool condition, string symbol)
this.externReleaseSignatureTypeSettings = this.externSignatureTypeSettings with { PreferNativeInt = false, PreferMarshaledTypes = false };
this.comSignatureTypeSettings = this.generalTypeSettings with { QualifyNames = true, PreferInOutRef = options.AllowMarshaling };
this.extensionMethodSignatureTypeSettings = this.generalTypeSettings with { QualifyNames = true };
this.functionPointerTypeSettings = this.generalTypeSettings with { QualifyNames = true };
this.functionPointerTypeSettings = this.generalTypeSettings with { QualifyNames = true, AllowMarshaling = false };
this.errorMessageTypeSettings = this.generalTypeSettings with { QualifyNames = true, Generator = null }; // Avoid risk of infinite recursion from errors in ToTypeSyntax

this.methodsAndConstantsClassName = IdentifierName(options.ClassName);
Expand Down Expand Up @@ -1159,9 +1159,17 @@ private bool TryGetRenamedMethod(string methodName, [NotNullWhen(true)] out stri
StructDeclarationSyntax structDeclaration = this.DeclareStruct(typeDefHandle, context);

// Proactively generate all nested types as well.
// If the outer struct is using ExplicitLayout, generate the nested types as unmanaged structs since that's what will be needed.
Context nestedContext = context;
bool explicitLayout = (typeDef.Attributes & TypeAttributes.ExplicitLayout) == TypeAttributes.ExplicitLayout;
if (context.AllowMarshaling && explicitLayout)
{
nestedContext = nestedContext with { AllowMarshaling = false };
}

foreach (TypeDefinitionHandle nestedHandle in typeDef.GetNestedTypes())
{
if (this.RequestInteropTypeHelper(nestedHandle, context) is { } nestedType)
if (this.RequestInteropTypeHelper(nestedHandle, nestedContext) is { } nestedType)
{
structDeclaration = structDeclaration.AddMembers(nestedType);
}
Expand Down Expand Up @@ -1201,7 +1209,7 @@ private bool TryGetRenamedMethod(string methodName, [NotNullWhen(true)] out stri
}
catch (Exception ex)
{
throw new GenerationFailedException("Failed to generate " + this.Reader.GetString(typeDef.Name), ex);
throw new GenerationFailedException($"Failed to generate {this.Reader.GetString(typeDef.Name)}{(context.AllowMarshaling ? string.Empty : " (unmanaged)")}", ex);
}
}

Expand Down Expand Up @@ -1365,6 +1373,7 @@ private IEnumerable<NamespaceMetadata> GetNamespacesToSearch(string? @namespace)
}
}

[DebuggerDisplay($"AllowMarshaling: {{{nameof(AllowMarshaling)}}}")]
internal record struct Context
{
/// <summary>
Expand Down
13 changes: 6 additions & 7 deletions src/Microsoft.Windows.CsWin32/PointerTypeHandleInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ internal record PointerTypeHandleInfo(TypeHandleInfo ElementType) : TypeHandleIn

internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs, CustomAttributeHandleCollection? customAttributes, ParameterAttributes parameterAttributes)
{
// We can't marshal a pointer exposed as a field, unless it's a pointer to an array.
if (inputs.AllowMarshaling && inputs.IsField && (customAttributes is null || inputs.Generator?.FindNativeArrayInfoAttribute(customAttributes.Value) is null))
{
inputs = inputs with { AllowMarshaling = false };
}

TypeSyntaxAndMarshaling elementTypeDetails = this.ElementType.ToTypeSyntax(inputs with { PreferInOutRef = false }, customAttributes);
if (elementTypeDetails.MarshalAsAttribute is object || inputs.Generator?.IsManagedType(this.ElementType) is true || (inputs.PreferInOutRef && this.ElementType is PrimitiveTypeHandleInfo { PrimitiveTypeCode: not PrimitiveTypeCode.Void }))
{
Expand Down Expand Up @@ -68,13 +74,6 @@ internal override TypeSyntaxAndMarshaling ToTypeSyntax(TypeSyntaxSettings inputs
return new TypeSyntaxAndMarshaling(PredefinedType(Token(SyntaxKind.ObjectKeyword)), new MarshalAsAttribute(UnmanagedType.IUnknown), null);
}

// Since we'll be using pointers, we have to ensure the element type does not require any marshaling.
if (inputs.AllowMarshaling)
{
// Evidently all tests pass without actually doing this, so we'll leave it out for now.
////elementTypeDetails = this.ElementType.ToTypeSyntax(inputs with { AllowMarshaling = false }, customAttributes);
}

return new TypeSyntaxAndMarshaling(PointerType(elementTypeDetails.Type));
}

Expand Down
8 changes: 8 additions & 0 deletions test/GenerationSandbox.Tests/GeneratedForm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Windows.Win32.Foundation;
using Windows.Win32.Networking.ActiveDirectory;
using Windows.Win32.System.Diagnostics.Debug;
using Windows.Win32.System.Threading;

#pragma warning disable CA1812 // dead code

Expand Down Expand Up @@ -65,4 +66,11 @@ private static void ZZStringUsed()
Windows.Win32.UI.Shell.SHFILEOPSTRUCTW s = default;
PCZZWSTR from = s.pFrom;
}

private static void PROCESS_BASIC_INFORMATION_PebBaseAddressIsPointer()
{
PROCESS_BASIC_INFORMATION info = default;
PEB_unmanaged* p = null;
info.PebBaseAddress = p;
}
}
1 change: 1 addition & 0 deletions test/GenerationSandbox.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ MainAVIHeader
MAX_PATH
NTSTATUS
PathParseIconLocation
PROCESS_BASIC_INFORMATION
PZZSTR
PZZWSTR
RECT
Expand Down
2 changes: 2 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ public void InterestingComInterfaces(
"IPicture", // An interface with properties that cannot be represented as properties.
"ID2D1DeviceContext2", // CreateLookupTable3D takes fixed length arrays as parameters
"IVPBaseConfig", // GetConnectInfo has a CountParamIndex that points to an [In, Out] parameter.
"IXAudio2SourceVoice", // Requires switch to unmanaged IXAudio2Voice struct which verifies type names retain the _unmanaged suffix everywhere required.
"MSP_EVENT_INFO", // Generates ITStream_unmanaged and ITTerminal_unmanaged
"IWMDMDevice2")] // The GetSpecifyPropertyPages method has an NativeArrayInfo.CountParamIndex pointing at an [Out] parameter.
string api,
bool allowMarshaling)
Expand Down
31 changes: 31 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/StructTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,35 @@ public void FARPROC_GeneratedAsStruct(string tfm)
BaseTypeDeclarationSyntax type = Assert.Single(this.FindGeneratedType("FARPROC"));
Assert.IsType<StructDeclarationSyntax>(type);
}

[Fact]
public void PointerFieldIsDeclaredAsPointer()
{
this.generator = this.CreateGenerator();
Assert.True(this.generator.TryGenerate("PROCESS_BASIC_INFORMATION", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();

var type = (StructDeclarationSyntax)Assert.Single(this.FindGeneratedType("PROCESS_BASIC_INFORMATION"));
FieldDeclarationSyntax field = Assert.Single(type.Members.OfType<FieldDeclarationSyntax>(), m => m.Declaration.Variables.Any(v => v.Identifier.ValueText == "PebBaseAddress"));
Assert.IsType<PointerTypeSyntax>(field.Declaration.Type);
}

[Theory]
[CombinatorialData]
public void InterestingStructs(
[CombinatorialValues(
"WSD_EVENT")] // has a pointer field to a managed struct
string name,
bool allowMarshaling)
{
var options = DefaultTestGeneratorOptions with
{
AllowMarshaling = allowMarshaling,
};
this.generator = this.CreateGenerator(options);
Assert.True(this.generator.TryGenerate(name, CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
}
}