Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Libraryimport src gen audit #69619

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public AnsiStringMarshaller(string? str, Span<byte> buffer)
}

// >= for null terminator
// Use the cast to long to avoid the checked operation
if ((long)Marshal.SystemMaxDBCSCharSize * str.Length >= buffer.Length)
{
// Calculate accurate byte count when the provided stack-allocated buffer is not sufficient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public ArrayMarshaller(T[]? array, Span<byte> buffer, int sizeOfNativeElement)
_managedArray = array;

// Always allocate at least one byte when the array is zero-length.
int spaceToAllocate = Math.Max(array.Length * _sizeOfNativeElement, 1);
int bufferSize = checked(array.Length * _sizeOfNativeElement);
int spaceToAllocate = Math.Max(bufferSize, 1);
if (spaceToAllocate <= buffer.Length)
{
_span = buffer[0..spaceToAllocate];
Expand Down Expand Up @@ -107,7 +108,12 @@ public ArrayMarshaller(T[]? array, Span<byte> buffer, int sizeOfNativeElement)
/// </remarks>
public ReadOnlySpan<byte> GetNativeValuesSource(int length)
{
return _allocatedMemory == IntPtr.Zero ? default : _span = new Span<byte>((void*)_allocatedMemory, length * _sizeOfNativeElement);
if (_allocatedMemory == IntPtr.Zero)
return default;
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved

int allocatedSize = checked(length * _sizeOfNativeElement);
_span = new Span<byte>((void*)_allocatedMemory, allocatedSize);
return _span;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public PointerArrayMarshaller(T*[]? array, Span<byte> buffer, int sizeOfNativeEl
_managedArray = array;

// Always allocate at least one byte when the array is zero-length.
int spaceToAllocate = Math.Max(array.Length * _sizeOfNativeElement, 1);
int bufferSize = checked(array.Length * _sizeOfNativeElement);
int spaceToAllocate = Math.Max(bufferSize, 1);
if (spaceToAllocate <= buffer.Length)
{
_span = buffer[0..spaceToAllocate];
Expand Down Expand Up @@ -117,7 +118,8 @@ public ReadOnlySpan<byte> GetNativeValuesSource(int length)
if (_allocatedMemory == IntPtr.Zero)
return default;

_span = new Span<byte>((void*)_allocatedMemory, length * _sizeOfNativeElement);
int allocatedSize = checked(length * _sizeOfNativeElement);
_span = new Span<byte>((void*)_allocatedMemory, allocatedSize);
return _span;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ public Utf8StringMarshaller(string? str, Span<byte> buffer)
}

const int MaxUtf8BytesPerChar = 3;
int maxBytesNeeded = checked(MaxUtf8BytesPerChar * str.Length);
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved

// >= for null terminator
if ((long)MaxUtf8BytesPerChar * str.Length >= buffer.Length)
if (maxBytesNeeded >= buffer.Length)
{
// Calculate accurate byte count when the provided stack-allocated buffer is not sufficient
int exactByteCount = checked(Encoding.UTF8.GetByteCount(str) + 1); // + 1 for null terminator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,11 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateSou

ImmutableArray<AttributeSyntax> forwardedAttributes = pinvokeStub.ForwardedAttributes;

const string innerPInvokeName = "__PInvoke__";
const string innerPInvokeName = "__PInvoke";

BlockSyntax code = stubGenerator.GeneratePInvokeBody(innerPInvokeName);

LocalFunctionStatementSyntax dllImport = CreateTargetFunctionAsLocalStatement(
LocalFunctionStatementSyntax dllImport = CreateTargetDllImportAsLocalStatement(
stubGenerator,
options,
pinvokeStub.LibraryImportData,
Expand All @@ -428,10 +428,7 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateSou
dllImport = dllImport.AddAttributeLists(AttributeList(SeparatedList(forwardedAttributes)));
}

dllImport = dllImport.WithLeadingTrivia(
Comment("//"),
Comment("// Local P/Invoke"),
Comment("//"));
dllImport = dllImport.WithLeadingTrivia(Comment("// Local P/Invoke"));
code = code.AddStatements(dllImport);

return (pinvokeStub.ContainingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(PrintGeneratedSource(pinvokeStub.StubMethodSyntaxTemplate, pinvokeStub.SignatureContext, code)), pinvokeStub.Diagnostics.AddRange(diagnostics.Diagnostics));
Expand Down Expand Up @@ -472,14 +469,14 @@ private static MemberDeclarationSyntax PrintForwarderStub(ContainingSyntax userD
.AddAttributeLists(
AttributeList(
SingletonSeparatedList(
CreateDllImportAttributeFromLibraryImportAttributeData(pinvokeData))));
CreateForwarderDllImport(pinvokeData))));

MemberDeclarationSyntax toPrint = stub.ContainingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(stubMethod);

return toPrint;
}

private static LocalFunctionStatementSyntax CreateTargetFunctionAsLocalStatement(
private static LocalFunctionStatementSyntax CreateTargetDllImportAsLocalStatement(
PInvokeStubCodeGenerator stubGenerator,
LibraryImportGeneratorOptions options,
LibraryImportData libraryImportData,
Expand All @@ -491,8 +488,8 @@ private static LocalFunctionStatementSyntax CreateTargetFunctionAsLocalStatement
(ParameterListSyntax parameterList, TypeSyntax returnType, AttributeListSyntax returnTypeAttributes) = stubGenerator.GenerateTargetMethodSignatureData();
LocalFunctionStatementSyntax localDllImport = LocalFunctionStatement(returnType, stubTargetName)
.AddModifiers(
Token(SyntaxKind.ExternKeyword),
Token(SyntaxKind.StaticKeyword),
Token(SyntaxKind.ExternKeyword),
Token(SyntaxKind.UnsafeKeyword))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
.WithAttributeLists(
Expand All @@ -516,8 +513,7 @@ private static LocalFunctionStatementSyntax CreateTargetFunctionAsLocalStatement
AttributeArgument(
NameEquals(nameof(DllImportAttribute.ExactSpelling)),
null,
LiteralExpression(
libraryImportData.SetLastError ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression))
LiteralExpression(SyntaxKind.TrueLiteralExpression))
}
)))))))
.WithParameterList(parameterList);
Expand All @@ -528,7 +524,7 @@ private static LocalFunctionStatementSyntax CreateTargetFunctionAsLocalStatement
return localDllImport;
}

private static AttributeSyntax CreateDllImportAttributeFromLibraryImportAttributeData(LibraryImportData target)
private static AttributeSyntax CreateForwarderDllImport(LibraryImportData target)
{
var newAttributeArgs = new List<AttributeArgumentSyntax>
{
Expand All @@ -542,7 +538,7 @@ private static AttributeSyntax CreateDllImportAttributeFromLibraryImportAttribut
AttributeArgument(
NameEquals(nameof(DllImportAttribute.ExactSpelling)),
null,
CreateBoolExpressionSyntax(true))
LiteralExpression(SyntaxKind.TrueLiteralExpression))
};

if (target.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshalling))
Expand All @@ -552,6 +548,7 @@ private static AttributeSyntax CreateDllImportAttributeFromLibraryImportAttribut
ExpressionSyntax value = CreateEnumExpressionSyntax(CharSet.Unicode);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}

if (target.IsUserDefined.HasFlag(InteropAttributeMember.SetLastError))
{
NameEqualsSyntax name = NameEquals(nameof(DllImportAttribute.SetLastError));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ private static ImmutableArray<StatementSyntax> GenerateStatementsForStubContext(
if (statementsToUpdate.Count > 0)
{
// Comment separating each stage
SyntaxTriviaList newLeadingTrivia = TriviaList(
Comment($"//"),
Comment($"// {context.CurrentStage}"),
Comment($"//"));
SyntaxTriviaList newLeadingTrivia = GenerateStageTrivia(context.CurrentStage);
StatementSyntax firstStatementInStage = statementsToUpdate[0];
newLeadingTrivia = newLeadingTrivia.AddRange(firstStatementInStage.GetLeadingTrivia());
statementsToUpdate[0] = firstStatementInStage.WithLeadingTrivia(newLeadingTrivia);
Expand Down Expand Up @@ -108,5 +105,24 @@ private static StatementSyntax GenerateStatementForNativeInvoke(BoundGenerators
IdentifierName(context.GetIdentifiers(marshallers.NativeReturnMarshaller.TypeInfo).native),
invoke));
}

private static SyntaxTriviaList GenerateStageTrivia(StubCodeContext.Stage stage)
{
string comment = stage switch
{
StubCodeContext.Stage.Setup => "Perform required setup.",
StubCodeContext.Stage.Marshal => "Convert managed data to native data.",
StubCodeContext.Stage.Pin => "Pin data in preparation for calling the P/Invoke.",
StubCodeContext.Stage.Invoke => "Call the P/Invoke.",
StubCodeContext.Stage.Unmarshal => "Convert native data to managed data.",
StubCodeContext.Stage.Cleanup => "Perform required cleanup.",
StubCodeContext.Stage.KeepAlive => "Keep alive any managed objects that need to stay alive across the call.",
StubCodeContext.Stage.GuaranteedUnmarshal => "Convert native data to managed data even in the case of an exception during the non-cleanup phases.",
_ => throw new ArgumentOutOfRangeException(nameof(stage))
};
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved

// Comment separating each stage
return TriviaList(Comment($"// {stage} - {comment}"));
}
}
}