Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Libraryimport src gen audit #69619

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public AnsiStringMarshaller(string? str, Span<byte> buffer)
}

// >= for null terminator
// Use the cast to long to avoid the checked operation
if ((long)Marshal.SystemMaxDBCSCharSize * str.Length >= buffer.Length)
{
// Calculate accurate byte count when the provided stack-allocated buffer is not sufficient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public ArrayMarshaller(T[]? array, Span<byte> buffer, int sizeOfNativeElement)
_managedArray = array;

// Always allocate at least one byte when the array is zero-length.
int spaceToAllocate = Math.Max(array.Length * _sizeOfNativeElement, 1);
int bufferSize = checked(array.Length * _sizeOfNativeElement);
int spaceToAllocate = Math.Max(bufferSize, 1);
if (spaceToAllocate <= buffer.Length)
{
_span = buffer[0..spaceToAllocate];
Expand Down Expand Up @@ -107,7 +108,12 @@ public ArrayMarshaller(T[]? array, Span<byte> buffer, int sizeOfNativeElement)
/// </remarks>
public ReadOnlySpan<byte> GetNativeValuesSource(int length)
{
return _allocatedMemory == IntPtr.Zero ? default : _span = new Span<byte>((void*)_allocatedMemory, length * _sizeOfNativeElement);
if (_allocatedMemory == IntPtr.Zero)
return default;
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved

int allocatedSize = checked(length * _sizeOfNativeElement);
_span = new Span<byte>((void*)_allocatedMemory, allocatedSize);
return _span;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public PointerArrayMarshaller(T*[]? array, Span<byte> buffer, int sizeOfNativeEl
_managedArray = array;

// Always allocate at least one byte when the array is zero-length.
int spaceToAllocate = Math.Max(array.Length * _sizeOfNativeElement, 1);
int bufferSize = checked(array.Length * _sizeOfNativeElement);
int spaceToAllocate = Math.Max(bufferSize, 1);
if (spaceToAllocate <= buffer.Length)
{
_span = buffer[0..spaceToAllocate];
Expand Down Expand Up @@ -117,7 +118,8 @@ public ReadOnlySpan<byte> GetNativeValuesSource(int length)
if (_allocatedMemory == IntPtr.Zero)
return default;

_span = new Span<byte>((void*)_allocatedMemory, length * _sizeOfNativeElement);
int allocatedSize = checked(length * _sizeOfNativeElement);
_span = new Span<byte>((void*)_allocatedMemory, allocatedSize);
return _span;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public Utf8StringMarshaller(string? str, Span<byte> buffer)
const int MaxUtf8BytesPerChar = 3;

// >= for null terminator
// Use the cast to long to avoid the checked operation
if ((long)MaxUtf8BytesPerChar * str.Length >= buffer.Length)
{
// Calculate accurate byte count when the provided stack-allocated buffer is not sufficient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,24 +190,6 @@ private static MemberDeclarationSyntax PrintGeneratedSource(
.WithBody(stubCode);
}

private static TargetFramework DetermineTargetFramework(Compilation compilation, out Version version)
{
IAssemblySymbol systemAssembly = compilation.GetSpecialType(SpecialType.System_Object).ContainingAssembly;
version = systemAssembly.Identity.Version;

return systemAssembly.Identity.Name switch
{
// .NET Framework
"mscorlib" => TargetFramework.Framework,
// .NET Standard
"netstandard" => TargetFramework.Standard,
// .NET Core (when version < 5.0) or .NET
"System.Runtime" or "System.Private.CoreLib" =>
(version.Major < 5) ? TargetFramework.Core : TargetFramework.Net,
_ => TargetFramework.Unknown,
};
}

private static LibraryImportData? ProcessLibraryImportAttribute(AttributeData attrData)
{
// Found the LibraryImport, but it has an error so report the error.
Expand Down Expand Up @@ -412,11 +394,11 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateSou

ImmutableArray<AttributeSyntax> forwardedAttributes = pinvokeStub.ForwardedAttributes;

const string innerPInvokeName = "__PInvoke__";
const string innerPInvokeName = "__PInvoke";

BlockSyntax code = stubGenerator.GeneratePInvokeBody(innerPInvokeName);

LocalFunctionStatementSyntax dllImport = CreateTargetFunctionAsLocalStatement(
LocalFunctionStatementSyntax dllImport = CreateTargetDllImportAsLocalStatement(
stubGenerator,
options,
pinvokeStub.LibraryImportData,
Expand All @@ -428,10 +410,7 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateSou
dllImport = dllImport.AddAttributeLists(AttributeList(SeparatedList(forwardedAttributes)));
}

dllImport = dllImport.WithLeadingTrivia(
Comment("//"),
Comment("// Local P/Invoke"),
Comment("//"));
dllImport = dllImport.WithLeadingTrivia(Comment("// Local P/Invoke"));
code = code.AddStatements(dllImport);

return (pinvokeStub.ContainingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(PrintGeneratedSource(pinvokeStub.StubMethodSyntaxTemplate, pinvokeStub.SignatureContext, code)), pinvokeStub.Diagnostics.AddRange(diagnostics.Diagnostics));
Expand Down Expand Up @@ -472,14 +451,14 @@ private static MemberDeclarationSyntax PrintForwarderStub(ContainingSyntax userD
.AddAttributeLists(
AttributeList(
SingletonSeparatedList(
CreateDllImportAttributeFromLibraryImportAttributeData(pinvokeData))));
CreateForwarderDllImport(pinvokeData))));

MemberDeclarationSyntax toPrint = stub.ContainingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(stubMethod);

return toPrint;
}

private static LocalFunctionStatementSyntax CreateTargetFunctionAsLocalStatement(
private static LocalFunctionStatementSyntax CreateTargetDllImportAsLocalStatement(
PInvokeStubCodeGenerator stubGenerator,
LibraryImportGeneratorOptions options,
LibraryImportData libraryImportData,
Expand All @@ -491,8 +470,8 @@ private static LocalFunctionStatementSyntax CreateTargetFunctionAsLocalStatement
(ParameterListSyntax parameterList, TypeSyntax returnType, AttributeListSyntax returnTypeAttributes) = stubGenerator.GenerateTargetMethodSignatureData();
LocalFunctionStatementSyntax localDllImport = LocalFunctionStatement(returnType, stubTargetName)
.AddModifiers(
Token(SyntaxKind.ExternKeyword),
Token(SyntaxKind.StaticKeyword),
Token(SyntaxKind.ExternKeyword),
Token(SyntaxKind.UnsafeKeyword))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
.WithAttributeLists(
Expand All @@ -516,8 +495,7 @@ private static LocalFunctionStatementSyntax CreateTargetFunctionAsLocalStatement
AttributeArgument(
NameEquals(nameof(DllImportAttribute.ExactSpelling)),
null,
LiteralExpression(
libraryImportData.SetLastError ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression))
LiteralExpression(SyntaxKind.TrueLiteralExpression))
}
)))))))
.WithParameterList(parameterList);
Expand All @@ -528,7 +506,7 @@ private static LocalFunctionStatementSyntax CreateTargetFunctionAsLocalStatement
return localDllImport;
}

private static AttributeSyntax CreateDllImportAttributeFromLibraryImportAttributeData(LibraryImportData target)
private static AttributeSyntax CreateForwarderDllImport(LibraryImportData target)
{
var newAttributeArgs = new List<AttributeArgumentSyntax>
{
Expand All @@ -542,7 +520,7 @@ private static AttributeSyntax CreateDllImportAttributeFromLibraryImportAttribut
AttributeArgument(
NameEquals(nameof(DllImportAttribute.ExactSpelling)),
null,
CreateBoolExpressionSyntax(true))
LiteralExpression(SyntaxKind.TrueLiteralExpression))
};

if (target.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshalling))
Expand All @@ -552,6 +530,7 @@ private static AttributeSyntax CreateDllImportAttributeFromLibraryImportAttribut
ExpressionSyntax value = CreateEnumExpressionSyntax(CharSet.Unicode);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}

if (target.IsUserDefined.HasFlag(InteropAttributeMember.SetLastError))
{
NameEqualsSyntax name = NameEquals(nameof(DllImportAttribute.SetLastError));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,13 @@ public PInvokeStubCodeGenerator(
SupportsTargetFramework = true;
}


_context = new ManagedToNativeStubCodeContext(ReturnIdentifier, ReturnIdentifier);
_context = new ManagedToNativeStubCodeContext(environment, ReturnIdentifier, ReturnIdentifier);
_marshallers = new BoundGenerators(argTypes, CreateGenerator);

if (_marshallers.ManagedReturnMarshaller.Generator.UsesNativeIdentifier(_marshallers.ManagedReturnMarshaller.TypeInfo, _context))
{
// If we need a different native return identifier, then recreate the context with the correct identifier before we generate any code.
_context = new ManagedToNativeStubCodeContext(ReturnIdentifier, $"{ReturnIdentifier}{StubCodeContext.GeneratedNativeIdentifierSuffix}");
_context = new ManagedToNativeStubCodeContext(environment, ReturnIdentifier, $"{ReturnIdentifier}{StubCodeContext.GeneratedNativeIdentifierSuffix}");
}

bool noMarshallingNeeded = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ private static ImmutableArray<StatementSyntax> GenerateStatementsForStubContext(
if (statementsToUpdate.Count > 0)
{
// Comment separating each stage
SyntaxTriviaList newLeadingTrivia = TriviaList(
Comment($"//"),
Comment($"// {context.CurrentStage}"),
Comment($"//"));
SyntaxTriviaList newLeadingTrivia = GenerateStageTrivia(context.CurrentStage);
StatementSyntax firstStatementInStage = statementsToUpdate[0];
newLeadingTrivia = newLeadingTrivia.AddRange(firstStatementInStage.GetLeadingTrivia());
statementsToUpdate[0] = firstStatementInStage.WithLeadingTrivia(newLeadingTrivia);
Expand Down Expand Up @@ -108,5 +105,24 @@ private static StatementSyntax GenerateStatementForNativeInvoke(BoundGenerators
IdentifierName(context.GetIdentifiers(marshallers.NativeReturnMarshaller.TypeInfo).native),
invoke));
}

private static SyntaxTriviaList GenerateStageTrivia(StubCodeContext.Stage stage)
{
string comment = stage switch
{
StubCodeContext.Stage.Setup => "Perform required setup.",
StubCodeContext.Stage.Marshal => "Convert managed data to native data.",
StubCodeContext.Stage.Pin => "Pin data in preparation for calling the P/Invoke.",
StubCodeContext.Stage.Invoke => "Call the P/Invoke.",
StubCodeContext.Stage.Unmarshal => "Convert native data to managed data.",
StubCodeContext.Stage.Cleanup => "Perform required cleanup.",
StubCodeContext.Stage.KeepAlive => "Keep alive any managed objects that need to stay alive across the call.",
StubCodeContext.Stage.GuaranteedUnmarshal => "Convert native data to managed data even in the case of an exception during the non-cleanup phases.",
_ => throw new ArgumentOutOfRangeException(nameof(stage))
};
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved

// Comment separating each stage
return TriviaList(Comment($"// {stage} - {comment}"));
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;

namespace Microsoft.Interop
{
internal sealed record LinearCollectionElementMarshallingCodeContext : StubCodeContext
Expand Down Expand Up @@ -34,6 +36,9 @@ public LinearCollectionElementMarshallingCodeContext(
ParentContext = parentContext;
}

public override (TargetFramework framework, Version version) GetTargetFramework()
=> ParentContext!.GetTargetFramework();

/// <summary>
/// Get managed and native instance identifiers for the <paramref name="info"/>
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,27 @@ public sealed record ManagedToNativeStubCodeContext : StubCodeContext

public override bool AdditionalTemporaryStateLivesAcrossStages => true;

public bool SupportsTargetFramework { get; init; }

public bool StubIsBasicForwarder { get; init; }
private readonly TargetFramework _framework;
private readonly Version _frameworkVersion;

private const string InvokeReturnIdentifier = "__invokeRetVal";
private readonly string _returnIdentifier;
private readonly string _nativeReturnIdentifier;

public ManagedToNativeStubCodeContext(string returnIdentifier, string nativeReturnIdentifier)
public ManagedToNativeStubCodeContext(
StubEnvironment environment,
string returnIdentifier,
string nativeReturnIdentifier)
{
_framework = environment.TargetFramework;
_frameworkVersion = environment.TargetFrameworkVersion;
_returnIdentifier = returnIdentifier;
_nativeReturnIdentifier = nativeReturnIdentifier;
}

public override (TargetFramework framework, Version version) GetTargetFramework()
=> (_framework, _frameworkVersion);

public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
{
// If the info is in the managed return position, then we need to generate a name to use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ public static ManagedTypeInfo CreateTypeInfoForTypeSymbol(ITypeSymbol type)
{
return new DelegateTypeInfo(typeName, diagonsticFormattedName);
}
return new SimpleManagedTypeInfo(typeName, diagonsticFormattedName);
if (type.IsValueType)
{
return new ValueTypeInfo(typeName, diagonsticFormattedName, type.IsRefLikeType);
}
return new ReferenceTypeInfo(typeName, diagonsticFormattedName);
}
}

Expand Down Expand Up @@ -74,5 +78,7 @@ public sealed record SzArrayType(ManagedTypeInfo ElementTypeInfo) : ManagedTypeI

public sealed record DelegateTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName);

public sealed record SimpleManagedTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName);
public sealed record ValueTypeInfo(string FullTypeName, string DiagnosticFormattedName, bool IsByRefLike) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName);

public sealed record ReferenceTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ public CustomNativeTypeWithToFromNativeValueContext(StubCodeContext parentContex
CurrentStage = parentContext.CurrentStage;
}

public override (TargetFramework framework, Version version) GetTargetFramework()
=> ParentContext!.GetTargetFramework();

public override bool SingleFrameSpansNativeContext => ParentContext!.SingleFrameSpansNativeContext;

public override bool AdditionalTemporaryStateLivesAcrossStages => ParentContext!.AdditionalTemporaryStateLivesAcrossStages;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,17 @@ public enum Stage
GuaranteedUnmarshal
}

/// <summary>
/// The current stage being generated.
/// </summary>
public Stage CurrentStage { get; init; } = Stage.Invalid;

/// <summary>
/// Gets the currently targeted framework and version for stub code generation.
/// </summary>
/// <returns>A framework value and version.</returns>
public abstract (TargetFramework framework, Version version) GetTargetFramework();

/// <summary>
/// The stub emits code that runs in a single stack frame and the frame spans over the native context.
/// </summary>
Expand All @@ -91,6 +100,9 @@ public enum Stage
/// </summary>
public StubCodeContext? ParentContext { get; protected init; }

/// <summary>
/// Suffix for all generated native identifiers.
/// </summary>
public const string GeneratedNativeIdentifierSuffix = "_native";

/// <summary>
Expand All @@ -103,6 +115,12 @@ public virtual (string managed, string native) GetIdentifiers(TypePositionInfo i
return (info.InstanceIdentifier, $"__{info.InstanceIdentifier.TrimStart('@')}{GeneratedNativeIdentifierSuffix}");
}

/// <summary>
/// Compute identifiers that are unique for this generator
/// </summary>
/// <param name="info">TypePositionInfo the new identifier is used in service of.</param>
/// <param name="name">Name of variable.</param>
/// <returns>New identifier name for use.</returns>
public virtual string GetAdditionalIdentifier(TypePositionInfo info, string name)
{
return $"{GetIdentifiers(info).native}__{name}";
Expand Down
Loading