diff --git a/DllImportGenerator/Benchmarks/Strings.cs b/DllImportGenerator/Benchmarks/Strings.cs new file mode 100644 index 000000000000..bcf9bf9abae1 --- /dev/null +++ b/DllImportGenerator/Benchmarks/Strings.cs @@ -0,0 +1,335 @@ +using BenchmarkDotNet.Attributes; +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; + +namespace Benchmarks +{ + partial class NativeExportsNE + { + private class EntryPoints + { + private const string ReturnLength = "return_length"; + private const string ReverseReturn = "reverse_return"; + private const string ReverseOut = "reverse_out"; + private const string ReverseInplace = "reverse_inplace_ref"; + private const string ReverseReplace = "reverse_replace_ref"; + + private const string UShortSuffix = "_ushort"; + private const string ByteSuffix = "_byte"; + + public class Byte + { + public const string ReturnLength = EntryPoints.ReturnLength + ByteSuffix; + public const string ReverseReturn = EntryPoints.ReverseReturn + ByteSuffix; + public const string ReverseOut = EntryPoints.ReverseOut + ByteSuffix; + public const string ReverseInplace = EntryPoints.ReverseInplace + ByteSuffix; + public const string ReverseReplace = EntryPoints.ReverseReplace + ByteSuffix; + } + + public class UShort + { + public const string ReturnLength = EntryPoints.ReturnLength + UShortSuffix; + public const string ReverseReturn = EntryPoints.ReverseReturn + UShortSuffix; + public const string ReverseOut = EntryPoints.ReverseOut + UShortSuffix; + public const string ReverseInplace = EntryPoints.ReverseInplace + UShortSuffix; + public const string ReverseReplace = EntryPoints.ReverseReplace + UShortSuffix; + } + } + + public partial class Unicode + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReturnLength, CharSet = CharSet.Unicode)] + public static partial int ReturnLength(string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReverseReturn, CharSet = CharSet.Unicode)] + public static partial string Reverse_Return(string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReverseReplace, CharSet = CharSet.Unicode)] + public static partial void Reverse_Replace_Ref(ref string s); + } + + public partial class LPTStr + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReturnLength)] + public static partial int ReturnLength([MarshalAs(UnmanagedType.LPTStr)] string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReverseReturn)] + [return: MarshalAs(UnmanagedType.LPTStr)] + public static partial string Reverse_Return([MarshalAs(UnmanagedType.LPTStr)] string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReverseReplace)] + public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.LPTStr)] ref string s); + } + + public partial class LPWStr + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReturnLength)] + public static partial int ReturnLength([MarshalAs(UnmanagedType.LPWStr)] string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReverseReturn)] + [return: MarshalAs(UnmanagedType.LPWStr)] + public static partial string Reverse_Return([MarshalAs(UnmanagedType.LPWStr)] string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReverseReplace)] + public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.LPWStr)] ref string s); + } + + public partial class LPUTF8Str + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReturnLength)] + public static partial int ReturnLength([MarshalAs(UnmanagedType.LPUTF8Str)] string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReverseReturn)] + [return: MarshalAs(UnmanagedType.LPUTF8Str)] + public static partial string Reverse_Return([MarshalAs(UnmanagedType.LPUTF8Str)] string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReverseReplace)] + public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.LPUTF8Str)] ref string s); + } + + public partial class Ansi + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReturnLength, CharSet = CharSet.Ansi)] + public static partial int ReturnLength(string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReverseReturn, CharSet = CharSet.Ansi)] + public static partial string Reverse_Return(string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReverseReplace, CharSet = CharSet.Ansi)] + public static partial void Reverse_Replace_Ref(ref string s); + } + + public partial class LPStr + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReturnLength)] + public static partial int ReturnLength([MarshalAs(UnmanagedType.LPStr)] string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReverseReturn)] + [return: MarshalAs(UnmanagedType.LPStr)] + public static partial string Reverse_Return([MarshalAs(UnmanagedType.LPStr)] string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReverseReplace)] + public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.LPStr)] ref string s); + } + + public partial class Auto + { + public partial class Unix + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReturnLength, CharSet = CharSet.Auto)] + public static partial int ReturnLength(string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReverseReturn, CharSet = CharSet.Auto)] + public static partial string Reverse_Return(string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.Byte.ReverseReplace, CharSet = CharSet.Auto)] + public static partial void Reverse_Replace_Ref(ref string s); + } + + public partial class Windows + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReturnLength, CharSet = CharSet.Auto)] + public static partial int ReturnLength(string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReverseReturn, CharSet = CharSet.Auto)] + public static partial string Reverse_Return(string s); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.UShort.ReverseReplace, CharSet = CharSet.Auto)] + public static partial void Reverse_Replace_Ref(ref string s); + } + } + } + + public class Strings + { + public static IEnumerable UnicodeStrings{ get; } = new [] + { + "ABCdef 123$%^", + "🍜 !! 🍜 !!", + "🌲 木 πŸ”₯ 火 🌾 土 πŸ›‘ 金 🌊 ζ°΄", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed vitae posuere mauris, sed ultrices leo. Suspendisse potenti. Mauris enim enim, blandit tincidunt consequat in, varius sit amet neque. Morbi eget porttitor ex. Duis mattis aliquet ante quis imperdiet. Duis sit.", + string.Empty, + null, + }; + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public int StringByValue_CharSetUnicode(string str) + { + return NativeExportsNE.Unicode.ReturnLength(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public int StringByValue_LPWStr(string str) + { + return NativeExportsNE.LPWStr.ReturnLength(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public int StringByValue_LPTStr(string str) + { + return NativeExportsNE.LPTStr.ReturnLength(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public int StringByValue_LPUTF8Str(string str) + { + return NativeExportsNE.LPUTF8Str.ReturnLength(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public int StringByValue_LPStr(string str) + { + return NativeExportsNE.LPStr.ReturnLength(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public int StringByValue_CharSetAnsi(string str) + { + return NativeExportsNE.Ansi.ReturnLength(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public int StringByValue_Auto(string str) + { + if (OperatingSystem.IsWindows()) + { + return NativeExportsNE.Auto.Windows.ReturnLength(str); + } + else + { + return NativeExportsNE.Auto.Unix.ReturnLength(str); + } + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringReturn_CharSetUnicode(string str) + { + return NativeExportsNE.Unicode.Reverse_Return(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringReturn_LPWStr(string str) + { + return NativeExportsNE.LPWStr.Reverse_Return(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringReturn_LPTStr(string str) + { + return NativeExportsNE.LPTStr.Reverse_Return(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringReturn_LPUTF8Str(string str) + { + return NativeExportsNE.LPUTF8Str.Reverse_Return(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringReturn_LPStr(string str) + { + return NativeExportsNE.LPStr.Reverse_Return(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringReturn_CharSetAnsi(string str) + { + return NativeExportsNE.Ansi.Reverse_Return(str); + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringReturn_Auto(string str) + { + if (OperatingSystem.IsWindows()) + { + return NativeExportsNE.Auto.Windows.Reverse_Return(str); + } + else + { + return NativeExportsNE.Auto.Unix.Reverse_Return(str); + } + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringByRef_CharSetUnicode(string str) + { + NativeExportsNE.Unicode.Reverse_Replace_Ref(ref str); + return str; + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringByRef_LPWStr(string str) + { + NativeExportsNE.LPWStr.Reverse_Replace_Ref(ref str); + return str; + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringByRef_LPTStr(string str) + { + NativeExportsNE.LPTStr.Reverse_Replace_Ref(ref str); + return str; + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringByRef_LPUTF8Str(string str) + { + NativeExportsNE.LPUTF8Str.Reverse_Replace_Ref(ref str); + return str; + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringByRef_LPStr(string str) + { + NativeExportsNE.LPStr.Reverse_Replace_Ref(ref str); + return str; + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringByRef_CharSetAnsi(string str) + { + NativeExportsNE.Ansi.Reverse_Replace_Ref(ref str); + return str; + } + + [Benchmark] + [ArgumentsSource(nameof(UnicodeStrings))] + public string StringByRef_Auto(string str) + { + if (OperatingSystem.IsWindows()) + { + NativeExportsNE.Auto.Windows.Reverse_Replace_Ref(ref str); + } + else + { + NativeExportsNE.Auto.Unix.Reverse_Replace_Ref(ref str); + } + return str; + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/Forwarder.cs b/DllImportGenerator/DllImportGenerator/Marshalling/Forwarder.cs index 4921ae92ec64..385339bcfc2c 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/Forwarder.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/Forwarder.cs @@ -1,23 +1,100 @@ ο»Ώusing System; using System.Collections.Generic; - +using System.Runtime.InteropServices; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop { - internal class Forwarder : IMarshallingGenerator + internal class Forwarder : IMarshallingGenerator, IAttributedReturnTypeMarshallingGenerator { public TypeSyntax AsNativeType(TypePositionInfo info) { return info.ManagedType.AsTypeSyntax(); } + private bool TryRehydrateMarshalAsAttribute(TypePositionInfo info, out AttributeSyntax marshalAsAttribute) + { + marshalAsAttribute = null!; + // If the parameter has [MarshalAs] marshalling, we resurface that + // in the forwarding target since the built-in system understands it. + // ICustomMarshaller marshalling requires additional information that we throw away earlier since it's unsupported, + // so explicitly do not resurface a [MarshalAs(UnmanagdType.CustomMarshaler)] attribute. + if (info.MarshallingAttributeInfo is MarshalAsInfo { UnmanagedType: not UnmanagedType.CustomMarshaler } marshalAs) + { + marshalAsAttribute = Attribute(ParseName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute)) + .WithArgumentList(AttributeArgumentList(SingletonSeparatedList(AttributeArgument( + CastExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_UnmanagedType), + LiteralExpression(SyntaxKind.NumericLiteralExpression, + Literal((int)marshalAs.UnmanagedType))))))); + return true; + } + + if (info.MarshallingAttributeInfo is NativeContiguousCollectionMarshallingInfo collectionMarshalling + && collectionMarshalling.UseDefaultMarshalling + && collectionMarshalling.ElementCountInfo is NoCountInfo or SizeAndParamIndexInfo + && collectionMarshalling.ElementMarshallingInfo is NoMarshallingInfo or MarshalAsInfo { UnmanagedType: not UnmanagedType.CustomMarshaler } + && info.ManagedType is IArrayTypeSymbol) + { + List marshalAsArguments = new List(); + marshalAsArguments.Add( + AttributeArgument( + CastExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_UnmanagedType), + LiteralExpression(SyntaxKind.NumericLiteralExpression, + Literal((int)UnmanagedType.LPArray)))) + ); + + if (collectionMarshalling.ElementCountInfo is SizeAndParamIndexInfo countInfo) + { + if (countInfo.ConstSize != SizeAndParamIndexInfo.UnspecifiedData) + { + marshalAsArguments.Add( + AttributeArgument(NameEquals("SizeConst"), null, + LiteralExpression(SyntaxKind.NumericLiteralExpression, + Literal(countInfo.ConstSize))) + ); + } + if (countInfo.ParamIndex != SizeAndParamIndexInfo.UnspecifiedData) + { + marshalAsArguments.Add( + AttributeArgument(NameEquals("SizeParamIndex"), null, + LiteralExpression(SyntaxKind.NumericLiteralExpression, + Literal(countInfo.ParamIndex))) + ); + } + } + + if (collectionMarshalling.ElementMarshallingInfo is MarshalAsInfo elementMarshalAs) + { + marshalAsArguments.Add( + AttributeArgument(NameEquals("ArraySubType"), null, + CastExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_UnmanagedType), + LiteralExpression(SyntaxKind.NumericLiteralExpression, + Literal((int)elementMarshalAs.UnmanagedType)))) + ); + } + marshalAsAttribute = Attribute(ParseName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute)) + .WithArgumentList(AttributeArgumentList(SeparatedList(marshalAsArguments))); + return true; + } + + return false; + } + public ParameterSyntax AsParameter(TypePositionInfo info) { - return Parameter(Identifier(info.InstanceIdentifier)) + ParameterSyntax param = Parameter(Identifier(info.InstanceIdentifier)) .WithModifiers(TokenList(Token(info.RefKindSyntax))) .WithType(info.ManagedType.AsTypeSyntax()); + + if (TryRehydrateMarshalAsAttribute(info, out AttributeSyntax marshalAsAttribute)) + { + param = param.AddAttributeLists(AttributeList(SingletonSeparatedList(marshalAsAttribute))); + } + + return param; } public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) @@ -34,5 +111,14 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => false; public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) => true; + + public AttributeListSyntax? GenerateAttributesForReturnType(TypePositionInfo info) + { + if (!TryRehydrateMarshalAsAttribute(info, out AttributeSyntax marshalAsAttribute)) + { + return null; + } + return AttributeList(SingletonSeparatedList(marshalAsAttribute)); + } } } diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs index 2790aefac458..40e8a7411a93 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -72,6 +72,19 @@ internal interface IMarshallingGenerator bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context); } + /// + /// Interface for generating attributes for native return types. + /// + internal interface IAttributedReturnTypeMarshallingGenerator : IMarshallingGenerator + { + /// + /// Gets any attributes that should be applied to the return type for this . + /// + /// Object to marshal + /// Attributes for the return type for this , or null if no attributes should be added. + AttributeListSyntax? GenerateAttributesForReturnType(TypePositionInfo info); + } + /// /// Exception used to indicate marshalling isn't supported. /// diff --git a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs index a603a244a312..31d13feef318 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs @@ -421,8 +421,16 @@ public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) .WithAttributeLists( SingletonList(AttributeList( - SingletonSeparatedList( - CreateDllImportAttributeForTarget(GetTargetDllImportDataFromStubData()))))); + SingletonSeparatedList(CreateDllImportAttributeForTarget(GetTargetDllImportDataFromStubData()))))); + + if (retMarshaller.Generator is IAttributedReturnTypeMarshallingGenerator retGenerator) + { + AttributeListSyntax? returnAttribute = retGenerator.GenerateAttributesForReturnType(retMarshaller.TypeInfo); + if (returnAttribute is not null) + { + dllImport = dllImport.AddAttributeLists(returnAttribute.WithTarget(AttributeTargetSpecifier(Identifier("return")))); + } + } if (forwardedAttributes is not null) { diff --git a/DllImportGenerator/DllImportGenerator/TypeNames.cs b/DllImportGenerator/DllImportGenerator/TypeNames.cs index babe52af8ef9..678040e4e6b7 100644 --- a/DllImportGenerator/DllImportGenerator/TypeNames.cs +++ b/DllImportGenerator/DllImportGenerator/TypeNames.cs @@ -34,6 +34,8 @@ static class TypeNames public const string System_Runtime_InteropServices_MarshalAsAttribute = "System.Runtime.InteropServices.MarshalAsAttribute"; + public const string System_Runtime_InteropServices_UnmanagedType = "System.Runtime.InteropServices.UnmanagedType"; + public const string System_Runtime_InteropServices_Marshal = "System.Runtime.InteropServices.Marshal"; private const string System_Runtime_InteropServices_MarshalEx = "System.Runtime.InteropServices.MarshalEx"; diff --git a/DllImportGenerator/TestAssets/NativeExports/Handles.cs b/DllImportGenerator/TestAssets/NativeExports/Handles.cs index e17447659afa..4ab5b0b8445e 100644 --- a/DllImportGenerator/TestAssets/NativeExports/Handles.cs +++ b/DllImportGenerator/TestAssets/NativeExports/Handles.cs @@ -29,13 +29,13 @@ public static void AllocateHandleOut(nint* handle) [UnmanagedCallersOnly(EntryPoint = "release_handle")] public static byte ReleaseHandle(nint handle) { - return (byte)(ActiveHandles.Remove(handle) ? 1 : 0); + return (byte)(ReleaseHandleCore(handle) ? 1 : 0); } [UnmanagedCallersOnly(EntryPoint = "is_handle_alive")] public static byte IsHandleAlive(nint handle) { - return (byte)(ActiveHandles.Contains(handle) ? 1 : 0); + return (byte)(IsHandleAliveCore(handle) ? 1 : 0); } [UnmanagedCallersOnly(EntryPoint = "modify_handle")] @@ -47,16 +47,37 @@ public static void ModifyHandle(nint* handle, byte newHandle) } } + private static object m_lock = new object(); + private static nint AllocateHandleCore() { - if (LastHandle == int.MaxValue) + lock (m_lock) { - return InvalidHandle; + if (LastHandle == int.MaxValue) + { + return InvalidHandle; + } + + nint newHandle = ++LastHandle; + ActiveHandles.Add(newHandle); + return newHandle; } + } - nint newHandle = ++LastHandle; - ActiveHandles.Add(newHandle); - return newHandle; + private static bool IsHandleAliveCore(nint handle) + { + lock (m_lock) + { + return ActiveHandles.Contains(handle); + } + } + + private static bool ReleaseHandleCore(nint handle) + { + lock (m_lock) + { + return ActiveHandles.Remove(handle); + } } } } \ No newline at end of file diff --git a/eng/Versions.props b/eng/Versions.props index 1c79b54f2620..cd3c8ca95a31 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -25,7 +25,7 @@ 2.0.0-beta1.21118.1 1.8.0 - 0.12.1.1528 + 0.13.0 1.3.0 1.0.22