Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IComIID interface to COM structs #766

Merged
merged 1 commit into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ internal static SeparatedSyntaxList<TNode> SeparatedList<TNode>()

internal static FixedStatementSyntax FixedStatement(VariableDeclarationSyntax declaration, StatementSyntax statement) => SyntaxFactory.FixedStatement(TokenWithSpace(SyntaxKind.FixedKeyword), Token(SyntaxKind.OpenParenToken), declaration, TokenWithLineFeed(SyntaxKind.CloseParenToken), statement);

internal static ExplicitInterfaceSpecifierSyntax ExplicitInterfaceSpecifier(NameSyntax name) => SyntaxFactory.ExplicitInterfaceSpecifier(name, Token(SyntaxKind.DotToken));

internal static ThisExpressionSyntax ThisExpression() => SyntaxFactory.ThisExpression(Token(SyntaxKind.ThisKeyword));

internal static DefaultExpressionSyntax DefaultExpression(TypeSyntax type) => SyntaxFactory.DefaultExpression(Token(SyntaxKind.DefaultKeyword), Token(SyntaxKind.OpenParenToken), type, Token(SyntaxKind.CloseParenToken));
Expand Down
131 changes: 116 additions & 15 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ public class Generator : IDisposable
private static readonly AttributeSyntax SupportedOSPlatformAttribute = Attribute(IdentifierName("SupportedOSPlatform"));
private static readonly AttributeSyntax UnscopedRefAttribute = Attribute(ParseName("UnscopedRef")).WithArgumentList(null);
private static readonly IdentifierNameSyntax SliceAtNullMethodName = IdentifierName("SliceAtNull");
private static readonly IdentifierNameSyntax IComIIDGuidInterfaceName = IdentifierName("IComIID");
private static readonly IdentifierNameSyntax ComIIDGuidPropertyName = IdentifierName("Guid");

/// <summary>
/// The set of libraries that are expected to be allowed next to an application instead of being required to load from System32.
Expand Down Expand Up @@ -380,6 +382,7 @@ public class Generator : IDisposable
private readonly bool canUseUnsafeAsRef;
private readonly bool canUseUnsafeNullRef;
private readonly bool unscopedRefAttributePredefined;
private readonly bool comIIDInterfacePredefined;
private readonly bool getDelegateForFunctionPointerGenericExists;
private readonly bool generateSupportedOSPlatformAttributes;
private readonly bool generateSupportedOSPlatformAttributesOnInterfaces; // only supported on net6.0 (https://github.com/dotnet/runtime/pull/48838)
Expand Down Expand Up @@ -437,6 +440,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option
this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true;
this.canUseUnsafeNullRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("NullRef").Any() is true;
this.unscopedRefAttributePredefined = this.FindTypeSymbolIfAlreadyAvailable("System.Diagnostics.CodeAnalysis.UnscopedRefAttribute") is not null;
this.comIIDInterfacePredefined = this.FindTypeSymbolIfAlreadyAvailable($"{this.Namespace}.{IComIIDGuidInterfaceName}") is not null;
this.getDelegateForFunctionPointerGenericExists = this.compilation?.GetTypeByMetadataName(typeof(Marshal).FullName)?.GetMembers(nameof(Marshal.GetDelegateForFunctionPointer)).Any(m => m is IMethodSymbol { IsGenericMethod: true }) is true;
this.generateDefaultDllImportSearchPathsAttribute = this.compilation?.GetTypeByMetadataName(typeof(DefaultDllImportSearchPathsAttribute).FullName) is object;
if (this.FindTypeSymbolIfAlreadyAvailable("System.Runtime.Versioning.SupportedOSPlatformAttribute") is { } attribute)
Expand Down Expand Up @@ -498,6 +502,11 @@ private enum FriendlyOverloadOf
InterfaceMethod,
}

private enum Feature
{
InterfaceStaticMembers,
}

/// <summary>
/// Gets the set of macros that can be generated.
/// </summary>
Expand Down Expand Up @@ -3628,24 +3637,22 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
// private void** lpVtbl; // Vtbl* (but we avoid strong typing to enable trimming the entire vtbl struct away)
members.Add(FieldDeclaration(VariableDeclaration(PointerType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))))).AddVariables(VariableDeclarator(vtblFieldName.Identifier))).AddModifiers(TokenWithSpace(SyntaxKind.PrivateKeyword)));

BaseListSyntax baseList = BaseList(SeparatedList<BaseTypeSyntax>());

CustomAttribute? guidAttribute = this.FindGuidAttribute(typeDef.GetCustomAttributes());
Guid? guidAttributeValue = guidAttribute.HasValue ? DecodeGuidFromAttribute(guidAttribute.Value) : null;
if (guidAttribute.HasValue)
{
// internal static readonly Guid IID_Guid = new Guid(0x1234, ...);
TypeSyntax guidTypeSyntax = IdentifierName(nameof(Guid));
members.Add(FieldDeclaration(
VariableDeclaration(guidTypeSyntax)
.AddVariables(VariableDeclarator(Identifier("IID_Guid")).WithInitializer(EqualsValueClause(
GuidValue(guidAttribute.Value)))))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword))
.WithLeadingTrivia(ParseLeadingTrivia($"/// <summary>The IID guid for this interface.</summary>\n/// <value>{guidAttributeValue!.Value:B}</value>\n")));
}
var staticMembers = this.DeclareStaticCOMInterfaceMembers(guidAttribute);
members.AddRange(staticMembers.Members);
baseList = baseList.AddTypes(staticMembers.BaseTypes.ToArray());

StructDeclarationSyntax iface = StructDeclaration(ifaceName.Identifier)
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.UnsafeKeyword), TokenWithSpace(SyntaxKind.PartialKeyword))
.AddMembers(members.ToArray());

if (baseList.Types.Count > 0)
{
iface = iface.WithBaseList(baseList);
}

if (guidAttribute.HasValue)
{
iface = iface.AddAttributeLists(AttributeList().AddAttributes(GUID(DecodeGuidFromAttribute(guidAttribute.Value))));
Expand Down Expand Up @@ -3850,20 +3857,22 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
}
}

CustomAttribute? guidAttribute = this.FindGuidAttribute(typeDef.GetCustomAttributes());

InterfaceDeclarationSyntax ifaceDeclaration = InterfaceDeclaration(ifaceName.Identifier)
.WithKeyword(TokenWithSpace(SyntaxKind.InterfaceKeyword))
.AddModifiers(TokenWithSpace(this.Visibility))
.AddMembers(members.ToArray());

if (this.FindGuidFromAttribute(typeDef) is Guid guid)
if (guidAttribute.HasValue)
{
ifaceDeclaration = ifaceDeclaration.AddAttributeLists(AttributeList().AddAttributes(GUID(guid), ifaceType, ComImportAttribute));
ifaceDeclaration = ifaceDeclaration.AddAttributeLists(AttributeList().AddAttributes(GUID(DecodeGuidFromAttribute(guidAttribute.Value)), ifaceType, ComImportAttribute));
}

if (baseTypeSyntaxList.Count > 0)
{
ifaceDeclaration = ifaceDeclaration
.WithBaseList(BaseList(SeparatedList(baseTypeSyntaxList.ToArray())));
.WithBaseList(BaseList(SeparatedList(baseTypeSyntaxList)));
}

if (this.generateSupportedOSPlatformAttributesOnInterfaces && this.GetSupportedOSPlatformAttribute(typeDef.GetCustomAttributes()) is AttributeSyntax supportedOSPlatformAttribute)
Expand All @@ -3878,6 +3887,42 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
return ifaceDeclaration;
}

private (List<MemberDeclarationSyntax> Members, List<BaseTypeSyntax> BaseTypes) DeclareStaticCOMInterfaceMembers(CustomAttribute? guidAttribute)
{
List<MemberDeclarationSyntax> members = new();
List<BaseTypeSyntax> baseTypes = new();

if (guidAttribute.HasValue)
{
Guid guidAttributeValue = DecodeGuidFromAttribute(guidAttribute.Value);

// internal static readonly Guid IID_Guid = new Guid(0x1234, ...);
IdentifierNameSyntax iidGuidFieldName = IdentifierName("IID_Guid");
TypeSyntax guidTypeSyntax = IdentifierName(nameof(Guid));
members.Add(FieldDeclaration(
VariableDeclaration(guidTypeSyntax)
.AddVariables(VariableDeclarator(iidGuidFieldName.Identifier).WithInitializer(EqualsValueClause(
GuidValue(guidAttribute.Value)))))
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword))
.WithLeadingTrivia(ParseLeadingTrivia($"/// <summary>The IID guid for this interface.</summary>\n/// <value>{guidAttributeValue:B}</value>\n")));

if (this.TryDeclareCOMGuidInterfaceIfNecessary())
{
baseTypes.Add(SimpleBaseType(IComIIDGuidInterfaceName));

// static ref readonly Guid IComIID.Guid => ref IID_Guid;
PropertyDeclarationSyntax guidProperty = PropertyDeclaration(IdentifierName(nameof(Guid)).WithTrailingTrivia(Space), ComIIDGuidPropertyName.Identifier)
.WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier(IComIIDGuidInterfaceName))
.AddModifiers(TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.RefKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword))
.WithExpressionBody(ArrowExpressionClause(RefExpression(iidGuidFieldName)))
.WithSemicolonToken(Semicolon);
members.Add(guidProperty);
}
}

return (members, baseTypes);
}

private ISet<string> GetDeclarableProperties(IEnumerable<MethodDefinition> methods, bool allowNonConsecutiveAccessors)
{
Dictionary<string, (TypeSyntax Type, int Index)> goodProperties = new(StringComparer.Ordinal);
Expand Down Expand Up @@ -6134,6 +6179,62 @@ private void DeclareUnscopedRefAttributeIfNecessary()
});
}

private bool TryDeclareCOMGuidInterfaceIfNecessary()
{
// Static interface members require C# 11 and .NET 7 at minimum.
if (!this.IsFeatureAvailable(Feature.InterfaceStaticMembers))
{
return false;
}

if (this.comIIDInterfacePredefined)
{
return true;
}

this.volatileCode.GenerateSpecialType(IComIIDGuidInterfaceName.Identifier.ValueText, delegate
{
// internal static abstract ref readonly Guid Guid { get; }
PropertyDeclarationSyntax guidProperty = PropertyDeclaration(IdentifierName(nameof(Guid)).WithTrailingTrivia(Space), ComIIDGuidPropertyName.Identifier)
.AddModifiers(
TokenWithSpace(this.Visibility),
TokenWithSpace(SyntaxKind.StaticKeyword),
TokenWithSpace(SyntaxKind.AbstractKeyword),
TokenWithSpace(SyntaxKind.RefKeyword),
TokenWithSpace(SyntaxKind.ReadOnlyKeyword))
.WithAccessorList(AccessorList().AddAccessors(AccessorDeclaration(SyntaxKind.GetAccessorDeclaration).WithSemicolonToken(Semicolon)));

// internal interface IComIID { ... }
InterfaceDeclarationSyntax ifaceDecl = InterfaceDeclaration(IComIIDGuidInterfaceName.Identifier)
.AddModifiers(Token(this.Visibility))
.AddMembers(guidProperty);

this.volatileCode.AddSpecialType(IComIIDGuidInterfaceName.Identifier.ValueText, ifaceDecl);
});

return true;
}

private bool IsFeatureAvailable(Feature feature)
{
return feature switch
{
Feature.InterfaceStaticMembers => (int)this.LanguageVersion >= 1100 && this.IsTargetFrameworkAtLeastDotNetVersion(7),
_ => throw new NotImplementedException(),
};
}

private bool TryGetTargetDotNetVersion([NotNullWhen(true)] out Version? dotNetVersion)
{
dotNetVersion = this.compilation?.ReferencedAssemblyNames.FirstOrDefault(id => string.Equals(id.Name, "System.Runtime", StringComparison.OrdinalIgnoreCase))?.Version;
return dotNetVersion is not null;
}

private bool IsTargetFrameworkAtLeastDotNetVersion(int majorVersion)
{
return this.TryGetTargetDotNetVersion(out Version? actualVersion) && actualVersion.Major >= majorVersion;
}

private bool IsTypeDefStruct(TypeHandleInfo? typeHandleInfo)
{
if (typeHandleInfo is HandleTypeHandleInfo handleInfo)
Expand Down
57 changes: 50 additions & 7 deletions test/Microsoft.Windows.CsWin32.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class GeneratorTests : IDisposable, IAsyncLifetime

private readonly ITestOutputHelper logger;
private readonly Dictionary<string, CSharpCompilation> starterCompilations = new();
private readonly Dictionary<string, string[]> preprocessorSymbolsByTfm = new();
private readonly Dictionary<string, ImmutableArray<string>> preprocessorSymbolsByTfm = new();
private CSharpCompilation compilation;
private CSharpParseOptions parseOptions;
private Generator? generator;
Expand Down Expand Up @@ -108,19 +108,30 @@ public async Task InitializeAsync()
this.starterCompilations.Add("net6.0", await this.CreateCompilationAsync(MyReferenceAssemblies.Net.Net60));
this.starterCompilations.Add("net6.0-x86", await this.CreateCompilationAsync(MyReferenceAssemblies.Net.Net60, Platform.X86));
this.starterCompilations.Add("net6.0-x64", await this.CreateCompilationAsync(MyReferenceAssemblies.Net.Net60, Platform.X64));
this.starterCompilations.Add("net7.0", await this.CreateCompilationAsync(MyReferenceAssemblies.Net.Net70));

foreach (string tfm in this.starterCompilations.Keys)
{
if (tfm.StartsWith("net6"))
if (tfm.StartsWith("net6") || tfm.StartsWith("net7"))
{
AddSymbols("NET5_0_OR_GREATER", "NET6_0_OR_GREATER", "NET6_0");
}
else

if (tfm.StartsWith("net7"))
{
AddSymbols();
AddSymbols("NET7_0_OR_GREATER", "NET7_0");
}

void AddSymbols(params string[] symbols) => this.preprocessorSymbolsByTfm.Add(tfm, symbols);
// Guarantee we have at least an empty list of symbols for each TFM.
AddSymbols();

void AddSymbols(params string[] symbols)
{
if (!this.preprocessorSymbolsByTfm.TryAdd(tfm, symbols.ToImmutableArray()))
{
this.preprocessorSymbolsByTfm[tfm] = this.preprocessorSymbolsByTfm[tfm].AddRange(symbols);
}
}
}

this.compilation = this.starterCompilations["netstandard2.0"];
Expand Down Expand Up @@ -1227,6 +1238,37 @@ public void PartialStructsAllowUserContributions()
Assert.True(hasValueProperty, "Projected members not found.");
}

[Theory]
[CombinatorialData]
public void COMInterfaceIIDInterfaceOnAppropriateTFMs(
bool allowMarshaling,
[CombinatorialValues(LanguageVersion.CSharp10, LanguageVersion.CSharp11)] LanguageVersion langVersion,
[CombinatorialValues("net6.0", "net7.0")] string tfm)
{
const string structName = "IEnumBstr";
this.compilation = this.starterCompilations[tfm];
this.parseOptions = this.parseOptions.WithLanguageVersion(langVersion);
this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = allowMarshaling });
Assert.True(this.generator.TryGenerate(structName, CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();

BaseTypeDeclarationSyntax type = this.FindGeneratedType(structName).Single();
IEnumerable<BaseTypeSyntax> actual = type.BaseList?.Types ?? Enumerable.Empty<BaseTypeSyntax>();
Predicate<BaseTypeSyntax> predicate = t => t.Type.ToString().Contains("IComIID");

// Static interface members requires C# 11 and .NET 7.
// And COM *interfaces* are not allowed to have them, so assert we only generate them on structs.
if (tfm == "net7.0" && langVersion >= LanguageVersion.CSharp11 && type is StructDeclarationSyntax)
{
Assert.Contains(actual, predicate);
}
else
{
Assert.DoesNotContain(actual, predicate);
}
}

[Fact]
public void PROC_GeneratedAsStruct()
{
Expand Down Expand Up @@ -1301,7 +1343,7 @@ internal enum FILE_CREATE_FLAGS
}
}
";
this.compilation = this.compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(test, path: "test.cs"));
this.compilation = this.compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(test, this.parseOptions, "test.cs"));
this.generator = this.CreateGenerator();
Assert.True(this.generator.TryGenerate("CreateFile", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
Expand Down Expand Up @@ -1593,7 +1635,7 @@ static unsafe void Main()
}
}
";
this.compilation = this.compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(programCsSource, path: "Program.cs"));
this.compilation = this.compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(programCsSource, this.parseOptions, "Program.cs"));

this.AssertNoDiagnostics();

Expand Down Expand Up @@ -3176,6 +3218,7 @@ internal static class NetFramework
internal static class Net
{
internal static readonly ReferenceAssemblies Net60 = ReferenceAssemblies.Net.Net60.AddPackages(AdditionalModernPackages);
internal static readonly ReferenceAssemblies Net70 = ReferenceAssemblies.Net.Net70.AddPackages(AdditionalModernPackages);
}
#pragma warning restore SA1202 // Elements should be ordered by access
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
<ItemGroup>
<PackageReference Include="coverlet.msbuild" Version="3.2.0" />
<PackageReference Include="MessagePack" Version="2.2.85" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="3.10.0" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.SourceGenerators.Testing.XUnit" Version="1.1.1" />
<PackageReference Include="NuGet.Protocol" Version="6.3.1" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.4.0-4.final" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.SourceGenerators.Testing.XUnit" Version="1.1.2-beta1.22512.1" />
<PackageReference Include="NuGet.Protocol" Version="6.4.0" />
<!-- <PackageReference Include="Microsoft.Dia.Win32Metadata" Version="$(DiaMetadataVersion)" PrivateAssets="none" /> -->
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" />
<PackageReference Include="System.Collections.Immutable" Version="5.0.0" />
<PackageReference Include="System.Reflection.Metadata" Version="5.0.0" />
<PackageReference Include="System.Text.Json" Version="5.0.2" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.4.0" />
<PackageReference Include="System.Collections.Immutable" Version="6.0.0" />
<PackageReference Include="System.Reflection.Metadata" Version="6.0.1" />
<PackageReference Include="System.Text.Json" Version="6.0.7" />
<PackageReference Include="xunit" Version="2.4.2" />
<PackageReference Include="Xunit.Combinatorial" Version="1.5.25" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.5" />
Expand Down