From 19811c279a5f97bbda203530a26d9e7244faeaa4 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Fri, 1 Jul 2022 12:34:46 -0700 Subject: [PATCH] Basic support for stateless linear collection marshalling (#71473) Basic stateless linear collection marshalling for blittable elements Not handled: - caller-allocated buffer - guaranteed unmarshal - pinnable reference - non-blittable element marshalling - element scenarios on custom marshallers --- .../UserTypeMarshallingV2.md | 54 +-- .../ManagedTypeInfo.cs | 6 + .../ManualTypeMarshallingHelper.cs | 154 +++++-- .../MarshallerShape.cs | 307 +++++++++++--- ...ributedMarshallingModelGeneratorFactory.cs | 62 ++- .../ICustomTypeMarshallingStrategy.cs | 280 +++++++++++++ .../MarshallingAttributeInfo.cs | 79 +++- .../Resources/Strings.resx | 3 + .../Resources/xlf/Strings.cs.xlf | 5 + .../Resources/xlf/Strings.de.xlf | 5 + .../Resources/xlf/Strings.es.xlf | 5 + .../Resources/xlf/Strings.fr.xlf | 5 + .../Resources/xlf/Strings.it.xlf | 5 + .../Resources/xlf/Strings.ja.xlf | 5 + .../Resources/xlf/Strings.ko.xlf | 5 + .../Resources/xlf/Strings.pl.xlf | 5 + .../Resources/xlf/Strings.pt-BR.xlf | 5 + .../Resources/xlf/Strings.ru.xlf | 5 + .../Resources/xlf/Strings.tr.xlf | 5 + .../Resources/xlf/Strings.zh-Hans.xlf | 5 + .../Resources/xlf/Strings.zh-Hant.xlf | 5 + .../TypeNames.cs | 2 + .../TypeSymbolExtensions.cs | 9 + .../ElementUnmanagedTypeAttribute.cs | 17 + .../ArrayTests.Custom.cs | 142 +++++++ .../CollectionTests.V1.cs | 246 +++++++++++ .../CollectionTests.cs | 140 +------ .../CodeSnippets.cs | 386 ++++++++++++------ .../CompileFails.cs | 7 +- .../Compiles.cs | 133 ++++-- .../TestAssets/SharedTypes/NonBlittable.V1.cs | 32 +- .../TestAssets/SharedTypes/NonBlittable.cs | 89 +++- 32 files changed, 1759 insertions(+), 454 deletions(-) create mode 100644 src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ElementUnmanagedTypeAttribute.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.Custom.cs create mode 100644 src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.V1.cs diff --git a/docs/design/libraries/LibraryImportGenerator/UserTypeMarshallingV2.md b/docs/design/libraries/LibraryImportGenerator/UserTypeMarshallingV2.md index 072a97ab18ee5..ee1da9ff467a1 100644 --- a/docs/design/libraries/LibraryImportGenerator/UserTypeMarshallingV2.md +++ b/docs/design/libraries/LibraryImportGenerator/UserTypeMarshallingV2.md @@ -41,7 +41,7 @@ namespace System.Runtime.InteropServices.Marshalling; - ManagedType = managedType; - MarshallerKind = marshallerKind; - } -- +- - public Type ManagedType { get; } - public CustomTypeMarshallerKind MarshallerKind { get; } - public int BufferSize { get; set; } @@ -51,13 +51,13 @@ namespace System.Runtime.InteropServices.Marshalling; - { - } - } -- +- - public enum CustomTypeMarshallerKind - { - Value, - LinearCollection - } -- +- - [Flags] - public enum CustomTypeMarshallerFeatures - { @@ -108,8 +108,8 @@ namespace System.Runtime.InteropServices.Marshalling; + /// + public int BufferSize { get; set; } + } -+ -+ ++ ++ + /// + /// Base class attribute for custom marshaller attributes. + /// @@ -125,7 +125,7 @@ namespace System.Runtime.InteropServices.Marshalling; + /// + public sealed class GenericPlaceholder { } + } -+ ++ + /// + /// Specify marshallers used in the managed to unmanaged direction (that is, P/Invoke) + /// @@ -137,23 +137,23 @@ namespace System.Runtime.InteropServices.Marshalling; + /// + /// Managed type to marshal + public ManagedToUnmanagedMarshallersAttribute(Type managedType) { } -+ ++ + /// + /// Marshaller to use when a parameter of the managed type is passed by-value or with the in keyword. + /// + public Type? InMarshaller { get; set; } -+ ++ + /// + /// Marshaller to use when a parameter of the managed type is passed by-value or with the ref keyword. + /// + public Type? RefMarshaller { get; set; } -+ ++ + /// + /// Marshaller to use when a parameter of the managed type is passed by-value or with the out keyword. + /// + public Type? OutMarshaller { get; set; } + } -+ ++ + /// + /// Specify marshallers used in the unmanaged to managed direction (that is, Reverse P/Invoke) + /// @@ -165,23 +165,23 @@ namespace System.Runtime.InteropServices.Marshalling; + /// + /// Managed type to marshal + public UnmanagedToManagedMarshallersAttribute(Type managedType) { } -+ ++ + /// + /// Marshaller to use when a parameter of the managed type is passed by-value or with the in keyword. + /// + public Type? InMarshaller { get; set; } -+ ++ + /// + /// Marshaller to use when a parameter of the managed type is passed by-value or with the ref keyword. + /// + public Type? RefMarshaller { get; set; } -+ ++ + /// + /// Marshaller to use when a parameter of the managed type is passed by-value or with the out keyword. + /// + public Type? OutMarshaller { get; set; } + } -+ ++ + /// + /// Specify marshaller for array-element marshalling and default struct field marshalling. + /// @@ -195,7 +195,7 @@ namespace System.Runtime.InteropServices.Marshalling; + /// Marshaller type to use for marshalling . + public ElementMarshallerAttribute(Type managedType, Type elementMarshaller) { } + } -+ ++ + /// + /// Specifies that a particular generic parameter is the collection element's unmanaged type. + /// @@ -470,10 +470,10 @@ static class TMarshaller w public static class ManagedToNative { public static TNative AllocateContainerForUnmanagedElements(TCollection managed, out int numElements); // Can throw exceptions - + public static ReadOnlySpan GetManagedValuesSource(TCollection managed); // Can throw exceptions - public static Span GetUnmanagedValuesDestination(TNative nativeValue, int numElements); // Can throw exceptions + public static Span GetUnmanagedValuesDestination(TNative unmanaged, int numElements); // Can throw exceptions public static ref TOther GetPinnableReference(TManaged managed); // Optional. Can throw exceptions. Result pinnned and passed to Invoke. @@ -495,10 +495,10 @@ static class TMarshaller w public static class ManagedToNative { public static TNative AllocateContainerForUnmanagedElements(TCollection managed, Span buffer, out int numElements); // Can throw exceptions - + public static ReadOnlySpan GetManagedValuesSource(TCollection managed); // Can throw exceptions - public static Span GetUnmanagedValuesDestination(TNative nativeValue, int numElements); // Can throw exceptions + public static Span GetUnmanagedValuesDestination(TNative unmanaged, int numElements); // Can throw exceptions public static ref TOther GetPinnableReference(TManaged managed); // Optional. Can throw exceptions. Result pinnned and passed to Invoke. @@ -517,11 +517,11 @@ static class TMarshaller w { public static class NativeToManaged { - public static TCollection AllocateContainerForManagedElements(int length); // Can throw exceptions + public static TCollection AllocateContainerForManagedElements(TNative unmanaged, int length); // Can throw exceptions public static Span GetManagedValuesDestination(T[] managed) => managed; // Can throw exceptions - public static ReadOnlySpan GetUnmanagedValuesSource(TNative nativeValue, int numElements); // Can throw exceptions + public static ReadOnlySpan GetUnmanagedValuesSource(TNative unmanaged, int numElements); // Can throw exceptions public static void Free(TNative native); // Optional. Should not throw exceptions. } @@ -540,11 +540,11 @@ static class TMarshaller w { public static class NativeToManaged { - public static TCollection AllocateContainerForManagedElementsGuaranteed(int length); // Should not throw exceptions other than OutOfMemoryException. + public static TCollection AllocateContainerForManagedElementsGuaranteed(TNative unmanaged, int length); // Should not throw exceptions other than OutOfMemoryException. public static Span GetManagedValuesDestination(T[] managed) => managed; // Can throw exceptions - public static ReadOnlySpan GetUnmanagedValuesSource(TNative nativeValue, int numElements); // Can throw exceptions + public static ReadOnlySpan GetUnmanagedValuesSource(TNative unmanaged, int numElements); // Can throw exceptions public static void Free(TNative native); // Optional. Should not throw exceptions. } @@ -584,7 +584,7 @@ static class TMarshaller w public ReadOnlySpan GetManagedValuesSource(); // Can throw exceptions. - public Span GetNativeValuesDestination(); // Can throw exceptions. + public Span GetUnmanagedValuesDestination(); // Can throw exceptions. public ref TIgnored GetPinnableReference(); // Optional. Can throw exceptions. @@ -615,7 +615,7 @@ static class TMarshaller w public ReadOnlySpan GetManagedValuesSource(); // Can throw exceptions. - public Span GetNativeValuesDestination(); // Can throw exceptions. + public Span GetUnmanagedValuesDestination(); // Can throw exceptions. public ref TIgnored GetPinnableReference(); // Optional. Can throw exceptions. @@ -642,7 +642,7 @@ static class TMarshaller w public void FromUnmanaged(TNative value); // Should not throw exceptions. - public ReadOnlySpan GetNativeValuesSource(int length); // Can throw exceptions. + public ReadOnlySpan GetUnmanagedValuesSource(int length); // Can throw exceptions. public Span GetManagedValuesDestination(int length); // Can throw exceptions. @@ -667,7 +667,7 @@ static class TMarshaller w public void FromUnmanaged(TNative value); // Should not throw exceptions. - public ReadOnlySpan GetNativeValuesSource(int length); // Can throw exceptions. + public ReadOnlySpan GetUnmanagedValuesSource(int length); // Can throw exceptions. public Span GetManagedValuesDestination(int length); // Can throw exceptions. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManagedTypeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManagedTypeInfo.cs index 1eb6bd51082f1..6f2347900df3a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManagedTypeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManagedTypeInfo.cs @@ -45,6 +45,10 @@ public static ManagedTypeInfo CreateTypeInfoForTypeSymbol(ITypeSymbol type) { return new DelegateTypeInfo(typeName, diagonsticFormattedName); } + if (type.TypeKind == TypeKind.TypeParameter) + { + return new TypeParameterTypeInfo(typeName, diagonsticFormattedName); + } if (type.IsValueType) { return new ValueTypeInfo(typeName, diagonsticFormattedName, type.IsRefLikeType); @@ -80,6 +84,8 @@ public sealed record SzArrayType(ManagedTypeInfo ElementTypeInfo) : ManagedTypeI public sealed record DelegateTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName); + public sealed record TypeParameterTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName); + public sealed record ValueTypeInfo(string FullTypeName, string DiagnosticFormattedName, bool IsByRefLike) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName); public sealed record ReferenceTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs index 7a0cbac591a4a..2f45d23393e00 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ManualTypeMarshallingHelper.cs @@ -19,7 +19,9 @@ public readonly record struct CustomTypeMarshallerData( bool HasState, MarshallerShape Shape, bool IsStrictlyBlittable, - ManagedTypeInfo? BufferElementType); + ManagedTypeInfo? BufferElementType, + ManagedTypeInfo? CollectionElementType, + MarshallingInfo? CollectionElementMarshallingInfo); public readonly record struct CustomTypeMarshallers( ImmutableDictionary Scenarios) @@ -68,10 +70,10 @@ private enum MarshallingDirection Bidirectional = ManagedToUnmanaged | UnmanagedToManaged } - public static bool IsLinearCollectionEntryPoint(ITypeSymbol entryPointType) + public static bool IsLinearCollectionEntryPoint(INamedTypeSymbol entryPointType) { - // TODO: Check for linear collection marshaller - ElementUnmanagedType attribute on last generic parameter - return false; + return entryPointType.IsGenericType + && entryPointType.TypeParameters.Last().GetAttributes().Any(attr => attr.AttributeClass.ToDisplayString() == TypeNames.ElementUnmanagedTypeAttribute); } public static bool HasEntryPointMarshallerAttribute(ITypeSymbol entryPointType) @@ -79,11 +81,31 @@ public static bool HasEntryPointMarshallerAttribute(ITypeSymbol entryPointType) return entryPointType.GetAttributes().Any(attr => attr.AttributeClass.ToDisplayString() == TypeNames.CustomMarshallerAttribute); } - public static bool TryGetMarshallersFromEntryType( + public static bool TryGetValueMarshallersFromEntryType( + INamedTypeSymbol entryPointType, + ITypeSymbol managedType, + Compilation compilation, + out CustomTypeMarshallers? marshallers) + { + return TryGetMarshallersFromEntryType(entryPointType, managedType, isLinearCollectionMarshalling: false, compilation, getMarshallingInfoForElement: null, out marshallers); + } + + public static bool TryGetLinearCollectionMarshallersFromEntryType( + INamedTypeSymbol entryPointType, + ITypeSymbol managedType, + Compilation compilation, + Func getMarshallingInfo, + out CustomTypeMarshallers? marshallers) + { + return TryGetMarshallersFromEntryType(entryPointType, managedType, isLinearCollectionMarshalling: true, compilation, getMarshallingInfo, out marshallers); + } + + private static bool TryGetMarshallersFromEntryType( INamedTypeSymbol entryPointType, ITypeSymbol managedType, bool isLinearCollectionMarshalling, Compilation compilation, + Func getMarshallingInfoForElement, out CustomTypeMarshallers? marshallers) { marshallers = null; @@ -91,6 +113,9 @@ public static bool TryGetMarshallersFromEntryType( if (attrs is null || attrs.Length == 0) return false; + // We expect a callback for getting the element marshalling info when handling linear collection marshalling + Debug.Assert(!isLinearCollectionMarshalling || getMarshallingInfoForElement is not null); + Dictionary scenarios = new(); foreach (AttributeData attr in attrs) { @@ -98,11 +123,6 @@ public static bool TryGetMarshallersFromEntryType( // Verify the defined marshaller is for the managed type. ITypeSymbol? managedTypeOnAttr = attr.ConstructorArguments[0].Value as ITypeSymbol; - if (!SymbolEqualityComparer.Default.Equals(managedType, managedTypeOnAttr) - && !compilation.HasImplicitConversion(managedType, managedTypeOnAttr)) - { - continue; - } // Verify any instantiation of Generic parameters is provided by entry point. // TODO: Hard failure based on previous implementation @@ -112,7 +132,8 @@ public static bool TryGetMarshallersFromEntryType( // Verify any instantiated managed types are derived properly. // TODO: Hard failure based on previous implementation - if (!TypeSymbolsConstructedFromEqualTypes(managedType, managedTypeInst)) + if (!managedType.IsConstructedFromEqualTypes(managedTypeInst) + && !compilation.HasImplicitConversion(managedType, managedTypeInst)) return false; var marshallerScenario = (Scenario)attr.ConstructorArguments[1].Value!; @@ -121,6 +142,31 @@ public static bool TryGetMarshallersFromEntryType( if (marshallerTypeOnAttr is null) continue; + ITypeSymbol marshallerType = marshallerTypeOnAttr; + if (isLinearCollectionMarshalling && marshallerTypeOnAttr is INamedTypeSymbol namedMarshallerType) + { + // Update the marshaller type with resolved type arguments based on the entry point type + // We expect the entry point to already have its type arguments updated based on the managed type + Stack nestedTypeNames = new Stack(); + INamedTypeSymbol currentType = namedMarshallerType; + while (currentType is not null) + { + if (currentType.IsConstructedFromEqualTypes(entryPointType)) + break; + + nestedTypeNames.Push(currentType.Name); + currentType = currentType.ContainingType; + } + + currentType = entryPointType; + foreach (string name in nestedTypeNames) + { + currentType = currentType.GetTypeMembers(name).First(); + } + + marshallerType = currentType; + } + // TODO: We can probably get rid of MarshallingDirection and just use Scenario instead MarshallingDirection direction = marshallerScenario switch { @@ -149,7 +195,7 @@ or Scenario.ElementOut // TODO: Report invalid shape for scenario // Skip checking for bidirectional support for Default scenario - always take / store marshaller data - CustomTypeMarshallerData? data = GetMarshallerDataForType(marshallerTypeOnAttr, direction, managedTypeOnAttr, compilation); + CustomTypeMarshallerData? data = GetMarshallerDataForType(marshallerType, direction, managedType, isLinearCollectionMarshalling, compilation, getMarshallingInfoForElement); // TODO: Should we fire a diagnostic for duplicated scenarios or just take the last one? if (data is null @@ -170,15 +216,6 @@ or Scenario.ElementOut }; return true; - - static bool TypeSymbolsConstructedFromEqualTypes(ITypeSymbol left, ITypeSymbol right) - { - return (left, right) switch - { - (INamedTypeSymbol namedLeft, INamedTypeSymbol namedRight) => SymbolEqualityComparer.Default.Equals(namedLeft.ConstructedFrom, namedRight.ConstructedFrom), - _ => SymbolEqualityComparer.Default.Equals(left, right) - }; - } } /// @@ -212,7 +249,8 @@ static bool TypeSymbolsConstructedFromEqualTypes(ITypeSymbol left, ITypeSymbol r } } - if (innerType.ToDisplayString() != TypeNames.CustomTypeMarshallerAttributeGenericPlaceholder) + if (innerType.ToDisplayString() != TypeNames.CustomTypeMarshallerAttributeGenericPlaceholder + && innerType.ToDisplayString() != TypeNames.CustomMarshallerAttributeGenericPlaceholder) { return managedType; } @@ -292,11 +330,17 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault ({ ReturnsByRef: true } or { ReturnsByRefReadonly: true })); } - private static CustomTypeMarshallerData? GetMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, Compilation compilation) + private static CustomTypeMarshallerData? GetMarshallerDataForType( + ITypeSymbol marshallerType, + MarshallingDirection direction, + ITypeSymbol managedType, + bool isLinearCollectionMarshaller, + Compilation compilation, + Func getMarshallingInfo) { if (marshallerType is { IsStatic: true, TypeKind: TypeKind.Class }) { - return GetStatelessMarshallerDataForType(marshallerType, direction, managedType, compilation); + return GetStatelessMarshallerDataForType(marshallerType, direction, managedType, isLinearCollectionMarshaller, compilation, getMarshallingInfo); } if (marshallerType.IsValueType) { @@ -305,24 +349,31 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault return null; } - private static CustomTypeMarshallerData? GetStatelessMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, Compilation compilation) + private static CustomTypeMarshallerData? GetStatelessMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, bool isLinearCollectionMarshaller, Compilation compilation, Func? getMarshallingInfo) { - (MarshallerShape shape, Dictionary methodsByShape) = StatelessMarshallerShapeHelper.GetShapeForType(marshallerType, managedType, compilation); + (MarshallerShape shape, StatelessMarshallerShapeHelper.MarshallerMethods methods) = StatelessMarshallerShapeHelper.GetShapeForType(marshallerType, managedType, isLinearCollectionMarshaller, compilation); + ITypeSymbol? collectionElementType = null; ITypeSymbol? nativeType = null; if (direction.HasFlag(MarshallingDirection.ManagedToUnmanaged)) { if (!shape.HasFlag(MarshallerShape.CallerAllocatedBuffer) && !shape.HasFlag(MarshallerShape.ToUnmanaged)) return null; - IMethodSymbol method; - if (methodsByShape.TryGetValue(MarshallerShape.CallerAllocatedBuffer, out method)) + if (isLinearCollectionMarshaller) { - nativeType = method.ReturnType; + // Element type is the type parameter of the ReadOnlySpan returned by GetManagedValuesSource + collectionElementType = ((INamedTypeSymbol)methods.ManagedValuesSource.ReturnType).TypeArguments[0]; } - else if (methodsByShape.TryGetValue(MarshallerShape.ToUnmanaged, out method)) + + // Native type is the return type of ConvertToUnmanaged / AllocateContainerForUnmanagedElement + if (methods.ToUnmanagedWithBuffer is not null) { - nativeType = method.ReturnType; + nativeType = methods.ToUnmanagedWithBuffer.ReturnType; + } + else if (methods.ToUnmanaged is not null) + { + nativeType = methods.ToUnmanaged.ReturnType; } } @@ -331,14 +382,25 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault if (!shape.HasFlag(MarshallerShape.GuaranteedUnmarshal) && !shape.HasFlag(MarshallerShape.ToManaged)) return null; - IMethodSymbol method; - if (methodsByShape.TryGetValue(MarshallerShape.GuaranteedUnmarshal, out method)) + if (isLinearCollectionMarshaller) { - nativeType = method.Parameters[0].Type; + // Native type is the first parameter of GetUnmanagedValuesSource + nativeType = methods.UnmanagedValuesSource.Parameters[0].Type; + + // Element type is the type parameter of the Span returned by GetManagedValuesDestination + collectionElementType = ((INamedTypeSymbol)methods.ManagedValuesDestination.ReturnType).TypeArguments[0]; } - else if (methodsByShape.TryGetValue(MarshallerShape.ToManaged, out method)) + else { - nativeType = method.Parameters[0].Type; + // Native type is the first parameter of ConvertToManaged or ConvertToManagedGuaranteed + if (methods.ToManagedGuaranteed is not null) + { + nativeType = methods.ToManagedGuaranteed.Parameters[0].Type; + } + else if (methods.ToManaged is not null) + { + nativeType = methods.ToManaged.Parameters[0].Type; + } } } @@ -350,9 +412,17 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault return null; ManagedTypeInfo bufferElementType = null; - if (methodsByShape.TryGetValue(MarshallerShape.CallerAllocatedBuffer, out IMethodSymbol methodWithBuffer)) + if (methods.ToUnmanagedWithBuffer is not null) + { + bufferElementType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(((INamedTypeSymbol)methods.ToUnmanagedWithBuffer.Parameters[1].Type).TypeArguments[0]); + } + + ManagedTypeInfo? collectionElementTypeInfo = null; + MarshallingInfo? collectionElementMarshallingInfo = null; + if (collectionElementType is not null) { - bufferElementType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(((INamedTypeSymbol)methodWithBuffer.Parameters[1].Type).TypeArguments[0]); + collectionElementTypeInfo = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(collectionElementType); + collectionElementMarshallingInfo = getMarshallingInfo(collectionElementType); } return new CustomTypeMarshallerData( @@ -361,7 +431,9 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault HasState: false, shape, nativeType.IsStrictlyBlittable(), - bufferElementType); + bufferElementType, + collectionElementTypeInfo, + collectionElementMarshallingInfo); } private static CustomTypeMarshallerData? GetStatefulMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, Compilation compilation) @@ -410,7 +482,9 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault HasState: true, shape, nativeType.IsStrictlyBlittable(), - bufferElementType); + bufferElementType, + CollectionElementType: null, + CollectionElementMarshallingInfo: null); } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallerShape.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallerShape.cs index d115e7fbb9305..ec37715128bdf 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallerShape.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallerShape.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics; using System.Linq; using System.Runtime.InteropServices; using Microsoft.CodeAnalysis; @@ -92,47 +93,116 @@ public static class Stateful public static class StatelessMarshallerShapeHelper { - public static (MarshallerShape, Dictionary) GetShapeForType(ITypeSymbol marshallerType, ITypeSymbol managedType, Compilation compilation) + public record MarshallerMethods { - MarshallerShape shape = MarshallerShape.None; - var methodsByShape = new Dictionary(); + public IMethodSymbol? ToUnmanaged; + public IMethodSymbol? ToUnmanagedWithBuffer; + public IMethodSymbol? ToManaged; + public IMethodSymbol? ToManagedGuaranteed; + + // Linear collection + public IMethodSymbol? ManagedValuesSource; + public IMethodSymbol? UnmanagedValuesDestination; + public IMethodSymbol? ManagedValuesDestination; + public IMethodSymbol? UnmanagedValuesSource; + } - IMethodSymbol? method = GetConvertToUnmanagedMethod(marshallerType, managedType); - if (method is not null) - AddMethod(MarshallerShape.ToUnmanaged, method); + public static (MarshallerShape, MarshallerMethods) GetShapeForType(ITypeSymbol marshallerType, ITypeSymbol managedType, bool isLinearCollectionMarshaller, Compilation compilation) + { + MarshallerShape shape = MarshallerShape.None; + MarshallerMethods methods = new(); INamedTypeSymbol spanOfT = compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!; - method = GetConvertToUnmanagedWithCallerAllocatedBufferMethod(marshallerType, managedType, spanOfT, out _); - if (method is not null) - AddMethod(MarshallerShape.CallerAllocatedBuffer, method); - - method = GetConvertToManagedMethod(marshallerType, managedType); - if (method is not null) - AddMethod(MarshallerShape.ToManaged, method); + if (isLinearCollectionMarshaller) + { + // Managed -> Unmanaged + INamedTypeSymbol readOnlySpanOfT = compilation.GetTypeByMetadataName(TypeNames.System_ReadOnlySpan_Metadata)!; + IMethodSymbol? allocateUnmanaged = LinearCollection.AllocateContainerForUnmanagedElements(marshallerType, managedType); + IMethodSymbol? allocateUnmanagedWithBuffer = LinearCollection.AllocateContainerForUnmanagedElementsWithCallerAllocatedBuffer(marshallerType, managedType, spanOfT); + IMethodSymbol? managedSource = LinearCollection.GetManagedValuesSource(marshallerType, managedType, readOnlySpanOfT); + IMethodSymbol? unmanagedDestination = LinearCollection.GetUnmanagedValuesDestination(marshallerType, spanOfT); + if ((allocateUnmanaged is not null || allocateUnmanagedWithBuffer is not null) + && managedSource is not null + && unmanagedDestination is not null) + { + if (allocateUnmanaged is not null) + shape |= MarshallerShape.ToUnmanaged; - method = GetConvertToManagedGuaranteedMethod(marshallerType, managedType); - if (method is not null) - AddMethod(MarshallerShape.GuaranteedUnmarshal, method); + if (allocateUnmanagedWithBuffer is not null) + shape |= MarshallerShape.CallerAllocatedBuffer; - method = GetStatelessGetPinnableReference(marshallerType, managedType); - if (method is not null) - AddMethod(MarshallerShape.StatelessPinnableReference, method); + methods = methods with + { + ToUnmanaged = allocateUnmanaged, + ToUnmanagedWithBuffer = allocateUnmanagedWithBuffer, + ManagedValuesSource = managedSource, + UnmanagedValuesDestination = unmanagedDestination + }; + } - method = GetStatelessFree(marshallerType); - if (method is not null) - AddMethod(MarshallerShape.Free, method); + // Unmanaged -> Managed + IMethodSymbol? allocateManaged = LinearCollection.AllocateContainerForManagedElements(marshallerType, managedType); + IMethodSymbol? allocateManagedGuaranteed = LinearCollection.AllocateContainerForManagedElementsGuaranteed(marshallerType, managedType, spanOfT); + IMethodSymbol? managedDestination = LinearCollection.GetManagedValuesDestination(marshallerType, managedType, spanOfT); + IMethodSymbol? unmanagedSource = LinearCollection.GetUnmanagedValuesSource(marshallerType, readOnlySpanOfT); + if ((allocateManaged is not null || allocateManagedGuaranteed is not null) + && managedDestination is not null + && unmanagedSource is not null) + { + if (allocateManaged is not null) + shape |= MarshallerShape.ToManaged; - return (shape, methodsByShape); + if (allocateManagedGuaranteed is not null) + shape |= MarshallerShape.GuaranteedUnmarshal; - void AddMethod(MarshallerShape shapeToAdd, IMethodSymbol methodToAdd) + methods = methods with + { + ToManaged = allocateManaged, + ToManagedGuaranteed = allocateManagedGuaranteed, + ManagedValuesDestination = managedDestination, + UnmanagedValuesSource = unmanagedSource + }; + } + } + else { - methodsByShape.Add(shapeToAdd, methodToAdd); - shape |= shapeToAdd; + IMethodSymbol? toUnmanaged = Value.ConvertToUnmanaged(marshallerType, managedType); + if (toUnmanaged is not null) + shape |= MarshallerShape.ToUnmanaged; + + IMethodSymbol? toUnmanagedWithBuffer = Value.ConvertToUnmanagedWithCallerAllocatedBuffer(marshallerType, managedType, spanOfT); + if (toUnmanagedWithBuffer is not null) + shape |= MarshallerShape.CallerAllocatedBuffer; + + IMethodSymbol? toManaged = Value.ConvertToManaged(marshallerType, managedType); + if (toManaged is not null) + shape |= MarshallerShape.ToManaged; + + IMethodSymbol? toManagedGuaranteed = Value.ConvertToManagedGuaranteed(marshallerType, managedType); + if (toManagedGuaranteed is not null) + shape |= MarshallerShape.GuaranteedUnmarshal; + + methods = methods with + { + ToUnmanaged = toUnmanaged, + ToUnmanagedWithBuffer = toUnmanagedWithBuffer, + ToManaged = toManaged, + ToManagedGuaranteed = toManagedGuaranteed + }; } + + if (GetStatelessGetPinnableReference(marshallerType, managedType) is not null) + shape |= MarshallerShape.StatelessPinnableReference; + + if (GetStatelessFree(marshallerType) is not null) + shape |= MarshallerShape.Free; + + return (shape, methods); } private static IMethodSymbol? GetStatelessFree(ITypeSymbol type) { + // static void Free(TNative unmanaged) return type.GetMembers(ShapeMemberNames.Free) .OfType() .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: true }); @@ -140,6 +210,9 @@ void AddMethod(MarshallerShape shapeToAdd, IMethodSymbol methodToAdd) private static IMethodSymbol? GetStatelessGetPinnableReference(ITypeSymbol type, ITypeSymbol managedType) { + // static ref TOther GetPinnableReference(TManaged managed) + // or + // static ref readonly TOther GetPinnableReference(TManaged managed) return type.GetMembers(ShapeMemberNames.GetPinnableReference) .OfType() .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1 } and @@ -147,66 +220,164 @@ void AddMethod(MarshallerShape shapeToAdd, IMethodSymbol methodToAdd) && SymbolEqualityComparer.Default.Equals(m.Parameters[0].Type, managedType)); } - private static IMethodSymbol? GetConvertToUnmanagedMethod(ITypeSymbol type, ITypeSymbol managedType) + private static bool IsSpanOfUnmanagedType(ITypeSymbol typeToCheck, ITypeSymbol spanOfT) { - return type.GetMembers(ShapeMemberNames.Value.Stateless.ConvertToUnmanaged) - .OfType() - .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: false } - && SymbolEqualityComparer.Default.Equals(managedType, m.Parameters[0].Type)); + if (typeToCheck is INamedTypeSymbol namedType + && SymbolEqualityComparer.Default.Equals(spanOfT, namedType.ConstructedFrom) + && namedType.TypeArguments.Length == 1 + && namedType.TypeArguments[0].IsUnmanagedType) + { + return true; + } + + return false; } - private static IMethodSymbol? GetConvertToUnmanagedWithCallerAllocatedBufferMethod( - ITypeSymbol type, - ITypeSymbol managedType, - ITypeSymbol spanOfT, - out ITypeSymbol? spanElementType) + private static class Value { - spanElementType = null; - IEnumerable methods = type.GetMembers(ShapeMemberNames.Value.Stateless.ConvertToUnmanaged) - .OfType() - .Where(m => m is { IsStatic: true, Parameters.Length: 2, ReturnsVoid: false } - && SymbolEqualityComparer.Default.Equals(managedType, m.Parameters[0].Type)); + internal static IMethodSymbol? ConvertToUnmanaged(ITypeSymbol type, ITypeSymbol managedType) + { + // static TNative ConvertToUnmanaged(TManaged managed) + return type.GetMembers(ShapeMemberNames.Value.Stateless.ConvertToUnmanaged) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: false } + && SymbolEqualityComparer.Default.Equals(managedType, m.Parameters[0].Type)); + } - foreach (IMethodSymbol method in methods) + internal static IMethodSymbol? ConvertToUnmanagedWithCallerAllocatedBuffer( + ITypeSymbol type, + ITypeSymbol managedType, + ITypeSymbol spanOfT) { - if (IsSpanOfUnmanagedType(method.Parameters[1].Type, spanOfT, out spanElementType)) + // static TNative ConvertToUnmanaged(TManaged managed, Span buffer) + IEnumerable methods = type.GetMembers(ShapeMemberNames.Value.Stateless.ConvertToUnmanaged) + .OfType() + .Where(m => m is { IsStatic: true, Parameters.Length: 2, ReturnsVoid: false } + && SymbolEqualityComparer.Default.Equals(managedType, m.Parameters[0].Type)); + + foreach (IMethodSymbol method in methods) { - return method; + if (IsSpanOfUnmanagedType(method.Parameters[1].Type, spanOfT)) + { + return method; + } } + + return null; } - return null; + internal static IMethodSymbol? ConvertToManaged(ITypeSymbol type, ITypeSymbol managedType) + { + // static TManaged ConvertToManaged(TNative unmanaged) + return type.GetMembers(ShapeMemberNames.Value.Stateless.ConvertToManaged) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: false } + && SymbolEqualityComparer.Default.Equals(managedType, m.ReturnType)); + } - static bool IsSpanOfUnmanagedType(ITypeSymbol typeToCheck, ITypeSymbol spanOfT, out ITypeSymbol? typeArgument) + internal static IMethodSymbol? ConvertToManagedGuaranteed(ITypeSymbol type, ITypeSymbol managedType) { - typeArgument = null; - if (typeToCheck is INamedTypeSymbol namedType - && SymbolEqualityComparer.Default.Equals(spanOfT, namedType.ConstructedFrom) - && namedType.TypeArguments.Length == 1 - && namedType.TypeArguments[0].IsUnmanagedType) + // static TManaged ConvertToManagedGuaranteed(TNative unmanaged) + return type.GetMembers(ShapeMemberNames.Value.Stateless.ConvertToManagedGuaranteed) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: false } + && SymbolEqualityComparer.Default.Equals(managedType, m.ReturnType)); + } + } + + private static class LinearCollection + { + internal static IMethodSymbol? AllocateContainerForUnmanagedElements(ITypeSymbol type, ITypeSymbol managedType) + { + // static TNative AllocateContainerForUnmanagedElements(TCollection managed, out int numElements) + return type.GetMembers(ShapeMemberNames.LinearCollection.Stateless.AllocateContainerForUnmanagedElements) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 2, ReturnsVoid: false } + && managedType.IsConstructedFromEqualTypes(m.Parameters[0].Type) + && m.Parameters[1].Type.SpecialType == SpecialType.System_Int32 + && m.Parameters[1].RefKind == RefKind.Out); + } + + internal static IMethodSymbol? AllocateContainerForUnmanagedElementsWithCallerAllocatedBuffer(ITypeSymbol type, ITypeSymbol managedType, ITypeSymbol spanOfT) + { + // static TNative AllocateContainerForUnmanagedElements(TCollection managed, Span buffer, out int numElements) + IEnumerable methods = type.GetMembers(ShapeMemberNames.LinearCollection.Stateless.AllocateContainerForUnmanagedElements) + .OfType() + .Where(m => m is { IsStatic: true, Parameters.Length: 3, ReturnsVoid: false } + && managedType.IsConstructedFromEqualTypes(m.Parameters[0].Type) + && m.Parameters[2].Type.SpecialType == SpecialType.System_Int32 + && m.Parameters[2].RefKind == RefKind.Out); + + foreach (IMethodSymbol method in methods) { - typeArgument = namedType.TypeArguments[0]; - return true; + if (IsSpanOfUnmanagedType(method.Parameters[1].Type, spanOfT)) + { + return method; + } } - return false; + return null; } - } - private static IMethodSymbol? GetConvertToManagedMethod(ITypeSymbol type, ITypeSymbol managedType) - { - return type.GetMembers(ShapeMemberNames.Value.Stateless.ConvertToManaged) - .OfType() - .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: false } - && SymbolEqualityComparer.Default.Equals(managedType, m.ReturnType)); - } + internal static IMethodSymbol? GetManagedValuesSource(ITypeSymbol type, ITypeSymbol managedType, ITypeSymbol readOnlySpanOfT) + { + // static ReadOnlySpan GetManagedValuesSource(TCollection managed) + return type.GetMembers(ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesSource) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: false, ReturnType: INamedTypeSymbol returnType } + && managedType.IsConstructedFromEqualTypes(m.Parameters[0].Type) + && SymbolEqualityComparer.Default.Equals(readOnlySpanOfT, returnType.ConstructedFrom)); + } - private static IMethodSymbol? GetConvertToManagedGuaranteedMethod(ITypeSymbol type, ITypeSymbol managedType) - { - return type.GetMembers(ShapeMemberNames.Value.Stateless.ConvertToManagedGuaranteed) - .OfType() - .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: false } - && SymbolEqualityComparer.Default.Equals(managedType, m.ReturnType)); + internal static IMethodSymbol? GetUnmanagedValuesDestination(ITypeSymbol type, ITypeSymbol spanOfT) + { + // static Span GetUnmanagedValuesDestination(TNative unmanaged, int numElements) + return type.GetMembers(ShapeMemberNames.LinearCollection.Stateless.GetUnmanagedValuesDestination) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 2, ReturnsVoid: false, ReturnType: INamedTypeSymbol returnType } + && m.Parameters[1].Type.SpecialType == SpecialType.System_Int32 + && SymbolEqualityComparer.Default.Equals(spanOfT, returnType.ConstructedFrom)); + } + + internal static IMethodSymbol? AllocateContainerForManagedElements(ITypeSymbol type, ITypeSymbol managedType) + { + // static TCollection AllocateContainerForManagedElements(TNative unmanaged, int length); + return type.GetMembers(ShapeMemberNames.LinearCollection.Stateless.AllocateContainerForManagedElements) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 2, ReturnsVoid: false } + && m.Parameters[1].Type.SpecialType == SpecialType.System_Int32 + && managedType.IsConstructedFromEqualTypes(m.ReturnType)); + } + + internal static IMethodSymbol? AllocateContainerForManagedElementsGuaranteed(ITypeSymbol type, ITypeSymbol managedType, ITypeSymbol spanOfT) + { + // static TCollection AllocateContainerForManagedElementsGuaranteed(TNative unmanaged, int length); + return type.GetMembers(ShapeMemberNames.LinearCollection.Stateless.AllocateContainerForManagedElements) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 2, ReturnsVoid: false } + && m.Parameters[1].Type.SpecialType == SpecialType.System_Int32 + && managedType.IsConstructedFromEqualTypes(m.ReturnType)); + } + + internal static IMethodSymbol? GetManagedValuesDestination(ITypeSymbol type, ITypeSymbol managedType, ITypeSymbol spanOfT) + { + // static Span GetManagedValuesDestination(TCollection managed) + return type.GetMembers(ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesDestination) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 1, ReturnsVoid: false, ReturnType: INamedTypeSymbol returnType } + && managedType.IsConstructedFromEqualTypes(m.Parameters[0].Type) + && SymbolEqualityComparer.Default.Equals(spanOfT, returnType.ConstructedFrom)); + } + + internal static IMethodSymbol? GetUnmanagedValuesSource(ITypeSymbol type, ITypeSymbol readOnlySpanOfT) + { + // static ReadOnlySpan GetUnmanagedValuesSource(TNative nativeValue, int numElements) + return type.GetMembers(ShapeMemberNames.LinearCollection.Stateless.GetUnmanagedValuesSource) + .OfType() + .FirstOrDefault(m => m is { IsStatic: true, Parameters.Length: 2, ReturnsVoid: false, ReturnType: INamedTypeSymbol returnType } + && m.Parameters[1].Type.SpecialType == SpecialType.System_Int32 + && SymbolEqualityComparer.Default.Equals(readOnlySpanOfT, returnType.ConstructedFrom)); + } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 5800716f2bd68..b1bcb725d5565 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; using System.Runtime.InteropServices; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -244,6 +245,10 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo } else { + // Collections have extra configuration, so handle them separately. + if (marshalInfo is NativeLinearCollectionMarshallingInfo collectionMarshallingInfo) + return CreateNativeCollectionMarshaller(info, context, marshallerData, collectionMarshallingInfo); + marshallingStrategy = new StatelessValueMarshalling(marshallerData.MarshallerType.Syntax, marshallerData.NativeType.Syntax, marshallerData.Shape); if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax); @@ -256,6 +261,59 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo : marshallingGenerator; } + private IMarshallingGenerator CreateNativeCollectionMarshaller( + TypePositionInfo info, + StubCodeContext context, + CustomTypeMarshallerData marshallerData, + NativeLinearCollectionMarshallingInfo marshalInfo) + { + var elementInfo = new TypePositionInfo(marshallerData.CollectionElementType, marshallerData.CollectionElementMarshallingInfo) { ManagedIndex = info.ManagedIndex }; + IMarshallingGenerator elementMarshaller = _elementMarshallingGenerator.Create( + elementInfo, + new LinearCollectionElementMarshallingCodeContext(StubCodeContext.Stage.Setup, string.Empty, string.Empty, context)); + + ExpressionSyntax numElementsExpression = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)); + if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) + { + // In this case, we need a numElementsExpression supplied from metadata, so we'll calculate it here. + numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, marshalInfo.ElementCountInfo, context); + } + + // Insert the unmanaged element type into the marshaller type + TypeSyntax unmanagedElementType = elementMarshaller.AsNativeType(elementInfo).GetCompatibleGenericTypeParameterSyntax(); + TypeSyntax marshallerTypeSyntax = marshallerData.MarshallerType.Syntax; + marshallerTypeSyntax = marshallerTypeSyntax.ReplaceNodes( + marshallerTypeSyntax.DescendantNodesAndSelf().OfType().Where(t => t.IsEquivalentTo(marshalInfo.PlaceholderTypeParameter.Syntax)), + (_, _) => unmanagedElementType); + + ICustomTypeMarshallingStrategy marshallingStrategy; + bool elementIsBlittable = elementMarshaller is BlittableMarshaller; + if (elementIsBlittable) + { + marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallerTypeSyntax, marshallerData.NativeType.Syntax, marshallerData.Shape, marshallerData.CollectionElementType.Syntax, unmanagedElementType, numElementsExpression); + } + else + { + // TODO: Handle linear collection marshalling with non-blittable elements + throw new MarshallingNotSupportedException(info, context); + } + + if (marshalInfo.UseDefaultMarshalling && info.ManagedType is SzArrayType) + { + return new ArrayMarshaller( + new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: true), + elementInfo, + elementIsBlittable); + } + + IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); + + // Elements in the collection must be blittable to use the pinnable marshaller. + return marshalInfo.IsPinnableManagedType && elementIsBlittable + ? new PinnableManagedValueMarshaller(marshallingGenerator) + : marshallingGenerator; + } + private void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info, StubCodeContext context, NativeMarshallingAttributeInfo marshalInfo) { // Marshalling out or return parameter, but no out marshaller is specified @@ -323,7 +381,7 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller_V1(TypePositionIn // Collections have extra configuration, so handle them here. if (marshalInfo is NativeLinearCollectionMarshallingInfo_V1 collectionMarshallingInfo) { - return CreateNativeCollectionMarshaller(info, context, collectionMarshallingInfo, marshallingStrategy); + return CreateNativeCollectionMarshaller_V1(info, context, collectionMarshallingInfo, marshallingStrategy); } else if (marshalInfo.NativeValueType is not null) { @@ -399,7 +457,7 @@ private static void ValidateCustomNativeTypeMarshallingSupported_V1(TypePosition } } - private IMarshallingGenerator CreateNativeCollectionMarshaller( + private IMarshallingGenerator CreateNativeCollectionMarshaller_V1( TypePositionInfo info, StubCodeContext context, NativeLinearCollectionMarshallingInfo_V1 collectionInfo, diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs index 1b52e8712cca3..6fbc03c93d063 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ICustomTypeMarshallingStrategy.cs @@ -545,4 +545,284 @@ public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); } + + /// + /// Marshaller that enables support for marshalling blittable elements of a collection via a native type that implements the LinearCollection marshalling spec. + /// + internal sealed class StatelessLinearCollectionMarshalling : ICustomTypeMarshallingStrategy + { + private readonly TypeSyntax _marshallerTypeSyntax; + private readonly TypeSyntax _nativeTypeSyntax; + private readonly MarshallerShape _shape; + private readonly TypeSyntax _managedElementType; + private readonly TypeSyntax _unmanagedElementType; + private readonly ExpressionSyntax _numElementsExpression; + + public StatelessLinearCollectionMarshalling(TypeSyntax marshallerTypeSyntax, TypeSyntax nativeTypeSyntax, MarshallerShape shape, TypeSyntax managedElementType, TypeSyntax unmanagedElementType, ExpressionSyntax numElementsExpression) + { + _marshallerTypeSyntax = marshallerTypeSyntax; + _nativeTypeSyntax = nativeTypeSyntax; + _shape = shape; + _managedElementType = managedElementType; + _unmanagedElementType = unmanagedElementType; + _numElementsExpression = numElementsExpression; + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return _nativeTypeSyntax; + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + if (!_shape.HasFlag(MarshallerShape.ToUnmanaged)) + yield break; + + (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); + string numElementsIdentifier = GetNumElementsIdentifier(info, context); + + // = .AllocateContainerForUnmanagedElements(, out ); + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(nativeIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + _marshallerTypeSyntax, + IdentifierName(ShapeMemberNames.LinearCollection.Stateless.AllocateContainerForUnmanagedElements)), + ArgumentList(SeparatedList(new ArgumentSyntax[] + { + Argument(IdentifierName(managedIdentifier)), + Argument(IdentifierName(numElementsIdentifier)) + .WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword)) + }))))); + + // .GetUnmanagedValuesDestination(, ) + ExpressionSyntax destination = + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + _marshallerTypeSyntax, + IdentifierName(ShapeMemberNames.LinearCollection.Stateless.GetUnmanagedValuesDestination)), + ArgumentList(SeparatedList(new ArgumentSyntax[] + { + Argument(IdentifierName(nativeIdentifier)), + Argument(IdentifierName(numElementsIdentifier)), + }))); + + if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) + { + // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection. + // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content. + // .GetUnmanagedValuesDestination(, ).Clear(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + destination, + IdentifierName("Clear")))); + yield break; + } + + // Skip the cast if the managed and unmanaged element types are the same + if (!_unmanagedElementType.IsEquivalentTo(_managedElementType)) + { + // MemoryMarshal.Cast<, >() + destination = InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + GenericName( + Identifier("Cast")) + .WithTypeArgumentList( + TypeArgumentList( + SeparatedList( + new[] + { + _unmanagedElementType, + _managedElementType + })))), + ArgumentList(SingletonSeparatedList( + Argument(destination)))); + } + + // .GetManagedValuesSource().CopyTo(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + _marshallerTypeSyntax, + IdentifierName(ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesSource)), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(managedIdentifier))))), + IdentifierName("CopyTo"))) + .AddArgumentListArguments( + Argument(destination))); + } + + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + yield return LocalDeclarationStatement( + VariableDeclaration( + PredefinedType(Token(SyntaxKind.IntKeyword)), + SingletonSeparatedList( + VariableDeclarator(GetNumElementsIdentifier(info, context))))); + } + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + if (!_shape.HasFlag(MarshallerShape.ToManaged)) + yield break; + + (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); + string numElementsIdentifier = GetNumElementsIdentifier(info, context); + + ExpressionSyntax copySource; + ExpressionSyntax copyDestination; + if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + { + // .GetUnmanagedValuesDestination(, ) + copySource = + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + _marshallerTypeSyntax, + IdentifierName(ShapeMemberNames.LinearCollection.Stateless.GetUnmanagedValuesDestination)), + ArgumentList(SeparatedList(new ArgumentSyntax[] + { + Argument(IdentifierName(nativeIdentifier)), + Argument(IdentifierName(numElementsIdentifier)), + }))); + + // MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(.GetManagedValuesSource()), .GetManagedValuesSource().Length) + copyDestination = InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + IdentifierName("CreateSpan")), + ArgumentList( + SeparatedList(new[] + { + Argument( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + IdentifierName("GetReference")), + ArgumentList(SingletonSeparatedList( + Argument( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + _marshallerTypeSyntax, + IdentifierName(ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesSource)), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(managedIdentifier)))))))))) + .WithRefKindKeyword( + Token(SyntaxKind.RefKeyword)), + Argument( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + _marshallerTypeSyntax, + IdentifierName(ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesSource)), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(managedIdentifier))))), + IdentifierName("Length"))) + }))); + + } + else + { + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(numElementsIdentifier), + _numElementsExpression)); + + // = .AllocateContainerForManagedElements(, ); + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + _marshallerTypeSyntax, + IdentifierName(ShapeMemberNames.LinearCollection.Stateless.AllocateContainerForManagedElements)), + ArgumentList(SeparatedList(new ArgumentSyntax[] + { + Argument(IdentifierName(nativeIdentifier)), + Argument(IdentifierName(numElementsIdentifier)) + }))))); + + // .GetUnmanagedValuesSource(, ) + copySource = InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + _marshallerTypeSyntax, + IdentifierName(ShapeMemberNames.LinearCollection.Stateless.GetUnmanagedValuesSource)), + ArgumentList(SeparatedList(new ArgumentSyntax[] + { + Argument(IdentifierName(nativeIdentifier)), + Argument(IdentifierName(numElementsIdentifier)) + }))); + + // .GetManagedValuesDestination() + copyDestination = InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + _marshallerTypeSyntax, + IdentifierName(ShapeMemberNames.LinearCollection.Stateless.GetManagedValuesDestination)), + ArgumentList(SingletonSeparatedList(Argument(IdentifierName(managedIdentifier))))); + } + + // Skip the cast if the managed and unmanaged element types are the same + if (!_unmanagedElementType.IsEquivalentTo(_managedElementType)) + { + // MemoryMarshal.Cast<, >() + copySource = InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + GenericName( + Identifier("Cast"), + TypeArgumentList(SeparatedList( + new[] + { + _unmanagedElementType, + _managedElementType + })))), + ArgumentList(SingletonSeparatedList( + Argument(copySource)))); + } + + // .CopyTo(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + copySource, + IdentifierName("CopyTo"))) + .AddArgumentListArguments( + Argument(copyDestination))); + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) => Array.Empty(); + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true; + + private static string GetNumElementsIdentifier(TypePositionInfo info, StubCodeContext context) + => context.GetAdditionalIdentifier(info, "numElements"); + } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs index e4e72e0535d4d..7269e4f444caa 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs @@ -145,6 +145,20 @@ public record NativeMarshallingAttributeInfo( CustomTypeMarshallers Marshallers, bool IsPinnableManagedType) : MarshallingInfo; + /// + /// Custom type marshalling via MarshalUsingAttribute or NativeMarshallingAttribute for a linear collection + /// + public sealed record NativeLinearCollectionMarshallingInfo( + ManagedTypeInfo EntryPointType, + CustomTypeMarshallers Marshallers, + bool IsPinnableManagedType, + CountInfo ElementCountInfo, + ManagedTypeInfo PlaceholderTypeParameter, + bool UseDefaultMarshalling) : NativeMarshallingAttributeInfo( + EntryPointType, + Marshallers, + IsPinnableManagedType); + /// /// User-applied System.Runtime.InteropServices.NativeMarshallingAttribute /// @@ -583,7 +597,6 @@ private MarshallingInfo CreateNativeMarshallingInfo( ImmutableHashSet inspectedElements, ref int maxIndirectionDepthUsed) { - bool isLinearCollectionMarshalling = ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(entryPointType); if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(entryPointType)) { if (!entryPointType.IsStatic) @@ -592,15 +605,65 @@ private MarshallingInfo CreateNativeMarshallingInfo( return NoMarshallingInfo.Instance; } - if (ManualTypeMarshallingHelper.TryGetMarshallersFromEntryType(entryPointType, type, isLinearCollectionMarshalling, _compilation, out CustomTypeMarshallers? marshallers)) + ManagedTypeInfo entryPointTypeInfo = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType); + bool isPinnableManagedType = !isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null; + + bool isLinearCollectionMarshalling = ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(entryPointType); + if (isLinearCollectionMarshalling) { - bool isPinnableManagedType = !isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null; - return isLinearCollectionMarshalling - ? NoMarshallingInfo.Instance // TODO: handle linear collection marshallers - : new NativeMarshallingAttributeInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType), marshallers.Value, isPinnableManagedType); - } + // Update the entry point type with the type arguments based on the managed type + if (type is IArrayTypeSymbol arrayManagedType) + { + if (entryPointType.Arity != 2) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); + return NoMarshallingInfo.Instance; + } - return NoMarshallingInfo.Instance; + entryPointType = entryPointType.ConstructedFrom.Construct( + arrayManagedType.ElementType, + entryPointType.TypeArguments.Last()); + } + else if (type is INamedTypeSymbol namedManagedType) + { + // Entry point type for linear collection marshalling must have the arity of the managed type + 1 + // for the [ElementUnmanagedType] placeholder + if (entryPointType.Arity != namedManagedType.Arity + 1) + { + _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + entryPointType = entryPointType.ConstructedFrom.Construct( + namedManagedType.TypeArguments.Add(entryPointType.TypeArguments.Last()).ToArray()); + } + else + { + _diagnostics.ReportInvalidMarshallingAttributeInfo(attrData, nameof(SR.MarshallerEntryPointTypeMustMatchArity), entryPointType.ToDisplayString(), type.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + int maxIndirectionDepthUsedLocal = maxIndirectionDepthUsed; + Func getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, new Dictionary(), 1, ImmutableHashSet.Empty, ref maxIndirectionDepthUsedLocal); + if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(entryPointType, type, _compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? marshallers)) + { + maxIndirectionDepthUsed = maxIndirectionDepthUsedLocal; + return new NativeLinearCollectionMarshallingInfo( + entryPointTypeInfo, + marshallers.Value, + isPinnableManagedType, + parsedCountInfo, + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(entryPointType.TypeParameters.Last()), + UseDefaultMarshalling: !isMarshalUsingAttribute); + } + } + else + { + if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(entryPointType, type, _compilation, out CustomTypeMarshallers? marshallers)) + { + return new NativeMarshallingAttributeInfo(entryPointTypeInfo, marshallers.Value, isPinnableManagedType); + } + } } return CreateNativeMarshallingInfo_V1(type, entryPointType, attrData, isMarshalUsingAttribute, indirectionLevel, parsedCountInfo, useSiteAttributes, inspectedElements, ref maxIndirectionDepthUsed); diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/Strings.resx b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/Strings.resx index 54047927230a9..ab2d865565c7a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/Strings.resx +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/Strings.resx @@ -216,4 +216,7 @@ The marshaller type '{0}' for managed type '{1}' must be static. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.cs.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.cs.xlf index 25c6a40666236..29fe10cb5691e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.cs.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.cs.xlf @@ -102,6 +102,11 @@ Určený parametr musí být zařazený ze spravovaného do nespravovaného, ale zařazovací typ {0} to nepodporuje. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. Typ zařazovače {0} pro spravovaný typ {1} musí být statický. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.de.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.de.xlf index 45a1edaf2fbe6..1504b218a5cd2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.de.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.de.xlf @@ -102,6 +102,11 @@ Der angegebene Parameter muss von verwaltet zu nicht verwaltet gemarshallt werden, aber der Marshaller-Typ ‚{0}‘ unterstützt dies nicht. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. Der Marshaller-Typ ‚{0}‘ für den verwalteten Typ ‚{1}‘ muss statisch sein. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.es.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.es.xlf index 858d693b5f05e..bad966c9b768b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.es.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.es.xlf @@ -102,6 +102,11 @@ El parámetro especificado debe serializarse de administrado a no administrado, pero el tipo no administrado “{0}” no lo admite. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. El tipo de serializador "{0}" para el tipo administrado "{1}" debe ser estático. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.fr.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.fr.xlf index b0120073f669e..09c9eedd0c1d9 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.fr.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.fr.xlf @@ -102,6 +102,11 @@ Le paramètre spécifié doit être marshalé de managé à non managé, mais le type marshaleur « {0} » ne le prend pas en charge. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. Le type marshaleur « {0} » pour le type managé « {1} » doit être statique. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.it.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.it.xlf index af89ec7231a1b..ffb03f2cd11be 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.it.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.it.xlf @@ -102,6 +102,11 @@ È necessario effettuare il marshalling del parametro specificato da gestito a non gestito, ma il tipo di gestore del marshalling '{0}' non lo supporta. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. Il tipo di gestore del marshalling '{0}' per il tipo gestito '{1}' deve essere statico. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ja.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ja.xlf index f869a30bc7c02..9f5e5a7fee52c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ja.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ja.xlf @@ -102,6 +102,11 @@ 指定されたパラメーターはマネージドからアンマネージドにマーシャリングする必要がありますが、マーシャラー型 '{0}' ではそれはサポートされていません。 + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. マネージド型 '{1}' のマーシャラー型 '{0}' は静的である必要があります。 diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ko.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ko.xlf index 3898fe3a6ff47..d3eb44d2b4bb3 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ko.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ko.xlf @@ -102,6 +102,11 @@ 지정된 매개 변수를 관리형에서 비관리형으로 마샬링해야 하지만 마샬러 유형 '{0}'은(는) 지원하지 않습니다. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. 관리 유형 '{1}'에 대한 마샬러 유형 '{0}'은(는) 정적이어야 합니다. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.pl.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.pl.xlf index 86f2f1c1ed2b9..4b1649bde70fb 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.pl.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.pl.xlf @@ -102,6 +102,11 @@ Określony parametr musi być kierowany z zarządzanego do niezarządzanego, ale typ marszałka „{0}” go nie obsługuje. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. Typ marszałka „{0}” dla typu zarządzanego „{1}” musi być statyczny. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.pt-BR.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.pt-BR.xlf index 6b3833c214a3f..1560c58c14353 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.pt-BR.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.pt-BR.xlf @@ -102,6 +102,11 @@ O parâmetro especificado precisa ser marshalled de gerenciado para não gerenciado, mas o tipo de marshaller '{0}' não dá suporte a ele. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. O tipo de marshaller '{0}' do tipo gerenciado '{1}' deve ser estático. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ru.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ru.xlf index 6271065d4f4bf..d586077f4d795 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ru.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.ru.xlf @@ -102,6 +102,11 @@ Указанный параметр необходимо маршалировать из управляемого в неуправляемый, но тип маршаллера "{0}" не поддерживает это. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. Тип маршаллера "{0}" для управляемого типа "{1}" должен быть статическим. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.tr.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.tr.xlf index e647fe89cde3d..3d0d853df0308 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.tr.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.tr.xlf @@ -102,6 +102,11 @@ Belirtilen parametrenin yönetilenden yönetilmeyene doğru hazırlanması gerekiyor, ancak '{0}' hazırlayıcı türü bunu desteklemiyor. + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. '{1}' yönetilen türü için '{0}' hazırlayıcı türünün statik olması gerekir. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.zh-Hans.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.zh-Hans.xlf index 12d07bfb18ef6..06ee106c04496 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.zh-Hans.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.zh-Hans.xlf @@ -102,6 +102,11 @@ 需要将指定的参数从托管封送到非托管,但封送程序类型“{0}”不支持它。 + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. 托管类型“{1}”的封送程序类型“{0}”必须是静态的。 diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.zh-Hant.xlf b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.zh-Hant.xlf index 7b21148b921c3..3816706352722 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.zh-Hant.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Resources/xlf/Strings.zh-Hant.xlf @@ -102,6 +102,11 @@ 指定的參數必須從受控封送處理到非受控,但封送處理程式類型 '{0}' 不支援。 + + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + The marshaller entry point type '{0}' for managed type '{1}' must have an arity of one greater than the managed type. + + The marshaller type '{0}' for managed type '{1}' must be static. 受控類型 '{1}' 的封送處理程式類型 '{0}' 必須是靜態。 diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs index f9c3128f77243..5428b9715596b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs @@ -17,6 +17,8 @@ public static class TypeNames public const string CustomTypeMarshallerAttributeGenericPlaceholder = "System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute.GenericPlaceholder"; public const string CustomMarshallerAttribute = "System.Runtime.InteropServices.Marshalling.CustomMarshallerAttribute"; + public const string CustomMarshallerAttributeGenericPlaceholder = CustomMarshallerAttribute + ".GenericPlaceholder"; + public const string ElementUnmanagedTypeAttribute = "System.Runtime.InteropServices.Marshalling.ElementUnmanagedTypeAttribute"; public const string AnsiStringMarshaller = "System.Runtime.InteropServices.Marshalling.AnsiStringMarshaller"; public const string BStrStringMarshaller = "System.Runtime.InteropServices.Marshalling.BStrStringMarshaller"; diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeSymbolExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeSymbolExtensions.cs index 513b961c59088..2ff90bf93b91b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeSymbolExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeSymbolExtensions.cs @@ -176,5 +176,14 @@ or SpecialType.System_UIntPtr or SpecialType.System_Single or SpecialType.System_Double; } + + public static bool IsConstructedFromEqualTypes(this ITypeSymbol type, ITypeSymbol other) + { + return (type, other) switch + { + (INamedTypeSymbol namedType, INamedTypeSymbol namedOther) => SymbolEqualityComparer.Default.Equals(namedType.ConstructedFrom, namedOther.ConstructedFrom), + _ => SymbolEqualityComparer.Default.Equals(type, other) + }; + } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ElementUnmanagedTypeAttribute.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ElementUnmanagedTypeAttribute.cs new file mode 100644 index 0000000000000..d0f508904a1ea --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ElementUnmanagedTypeAttribute.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Runtime.InteropServices.Marshalling +{ + /// + /// Specifies that a particular generic parameter is the collection element's unmanaged type. + /// + /// + /// If this attribute is provided on a generic parameter of a marshaller, then the generator will assume + /// that it is a linear collection marshaller. + /// + [AttributeUsage(AttributeTargets.GenericParameter)] + public sealed class ElementUnmanagedTypeAttribute : Attribute + { + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.Custom.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.Custom.cs new file mode 100644 index 0000000000000..a1b8789fa161f --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.Custom.cs @@ -0,0 +1,142 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using SharedTypes; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using System.Text; + +using Xunit; + +namespace LibraryImportGenerator.IntegrationTests +{ + partial class NativeExportsNE + { + public partial class Arrays + { + // TODO: All these tests can be removed once we switch the array marshaller in runtime libraries + // to V2 of custom type marshalling shapes + public partial class Custom + { + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int Sum([MarshalUsing(typeof(CustomArrayMarshaller<,>))] int[] values, int numValues); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")] + public static partial int SumInArray([MarshalUsing(typeof(CustomArrayMarshaller<,>))] in int[] values, int numValues); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "duplicate_int_array")] + public static partial void Duplicate([MarshalUsing(typeof(CustomArrayMarshaller<,>), CountElementName = "numValues")] ref int[] values, int numValues); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "create_range_array")] + [return: MarshalUsing(typeof(CustomArrayMarshaller<,>), CountElementName = "numValues")] + public static partial int[] CreateRange(int start, int end, out int numValues); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "create_range_array_out")] + public static partial void CreateRange_Out(int start, int end, out int numValues, [MarshalUsing(typeof(CustomArrayMarshaller<,>), CountElementName = "numValues")] out int[] res); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_char_array", StringMarshalling = StringMarshalling.Utf16)] + public static partial int SumChars([MarshalUsing(typeof(CustomArrayMarshaller<,>))] char[] chars, int numElements); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_char_array", StringMarshalling = StringMarshalling.Utf16)] + public static partial void ReverseChars([MarshalUsing(typeof(CustomArrayMarshaller<,>), CountElementName = "numElements")] ref char[] chars, int numElements); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_long_bytes")] + [return: MarshalUsing(typeof(CustomArrayMarshaller<,>), ConstantElementCount = sizeof(long))] + public static partial byte[] GetLongBytes(long l); + } + } + } + + public class ArrayTests_Custom + { + private int[] GetIntArray() => new[] { 1, 5, 79, 165, 32, 3 }; + + [Fact] + public void IntArray_ByValue() + { + int[] array = GetIntArray(); + Assert.Equal(array.Sum(), NativeExportsNE.Arrays.Custom.Sum(array, array.Length)); + } + + [Fact] + public void NullIntArray_ByValue() + { + int[] array = null; + Assert.Equal(-1, NativeExportsNE.Arrays.Custom.Sum(array, 0)); + } + + [Fact] + public void ZeroLengthArray_MarshalledAsNonNull() + { + var array = new int[0]; + Assert.Equal(0, NativeExportsNE.Arrays.Custom.Sum(array, array.Length)); + } + + [Fact] + public void IntArray_In() + { + int[] array = GetIntArray(); + Assert.Equal(array.Sum(), NativeExportsNE.Arrays.Custom.SumInArray(array, array.Length)); + } + + [Fact] + public void IntArray_Ref() + { + int[] array = GetIntArray(); + var newArray = array; + NativeExportsNE.Arrays.Custom.Duplicate(ref newArray, array.Length); + Assert.Equal((IEnumerable)array, newArray); + } + + [Fact] + public void CharArray_ByValue() + { + char[] array = CharacterTests.CharacterMappings().Select(o => (char)o[0]).ToArray(); + Assert.Equal(array.Sum(c => c), NativeExportsNE.Arrays.Custom.SumChars(array, array.Length)); + } + + [Fact] + public void CharArray_Ref() + { + char[] array = CharacterTests.CharacterMappings().Select(o => (char)o[0]).ToArray(); + var newArray = array; + NativeExportsNE.Arrays.Custom.ReverseChars(ref newArray, array.Length); + Assert.Equal(array.Reverse(), newArray); + } + + [Fact] + public void IntArray_Return() + { + int start = 5; + int end = 20; + + IEnumerable expected = Enumerable.Range(start, end - start); + Assert.Equal(expected, NativeExportsNE.Arrays.Custom.CreateRange(start, end, out _)); + + int[] res; + NativeExportsNE.Arrays.Custom.CreateRange_Out(start, end, out _, out res); + Assert.Equal(expected, res); + } + + [Fact] + public void NullArray_Return() + { + Assert.Null(NativeExportsNE.Arrays.Custom.CreateRange(1, 0, out _)); + + int[] res; + NativeExportsNE.Arrays.Custom.CreateRange_Out(1, 0, out _, out res); + Assert.Null(res); + } + + [Fact] + public void ConstantSizeArray() + { + var longVal = 0x12345678ABCDEF10L; + + Assert.Equal(longVal, MemoryMarshal.Read(NativeExportsNE.Arrays.Custom.GetLongBytes(longVal))); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.V1.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.V1.cs new file mode 100644 index 0000000000000..ea039de946ead --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.V1.cs @@ -0,0 +1,246 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using SharedTypes; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using System.Text; + +using Xunit; + +namespace LibraryImportGenerator.IntegrationTests +{ + partial class NativeExportsNE + { + public partial class Collections + { + public partial class V1 + { + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int Sum([MarshalUsing(typeof(ListMarshaller_V1))] List values, int numValues); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")] + public static partial int SumInArray([MarshalUsing(typeof(ListMarshaller_V1))] in List values, int numValues); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "duplicate_int_array")] + public static partial void Duplicate([MarshalUsing(typeof(ListMarshaller_V1), CountElementName = "numValues")] ref List values, int numValues); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "create_range_array")] + [return: MarshalUsing(typeof(ListMarshaller_V1), CountElementName = "numValues")] + public static partial List CreateRange(int start, int end, out int numValues); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "create_range_array_out")] + public static partial void CreateRange_Out(int start, int end, out int numValues, [MarshalUsing(typeof(ListMarshaller_V1), CountElementName = "numValues")] out List res); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] + public static partial int SumStringLengths([MarshalUsing(typeof(ListMarshaller_V1)), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] List strArray); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] + public static partial int SumStringLengths([MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] WrappedList_V1 strArray); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_replace")] + public static partial void ReverseStrings_Ref([MarshalUsing(typeof(ListMarshaller_V1), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] ref List strArray, out int numElements); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_return")] + [return: MarshalUsing(typeof(ListMarshaller_V1), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] + public static partial List ReverseStrings_Return([MarshalUsing(typeof(ListMarshaller_V1), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] List strArray, out int numElements); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_out")] + public static partial void ReverseStrings_Out( + [MarshalUsing(typeof(ListMarshaller_V1)), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] List strArray, + out int numElements, + [MarshalUsing(typeof(ListMarshaller_V1), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] out List res); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_long_bytes")] + [return: MarshalUsing(typeof(ListMarshaller_V1), ConstantElementCount = sizeof(long))] + public static partial List GetLongBytes(long l); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "and_all_members")] + [return: MarshalAs(UnmanagedType.U1)] + public static partial bool AndAllMembers([MarshalUsing(typeof(ListMarshaller_V1))] List pArray, int length); + } + } + } + + public class CollectionTests_V1 + { + [Fact] + public void BlittableElementColllectionMarshalledToNativeAsExpected() + { + var list = new List { 1, 5, 79, 165, 32, 3 }; + Assert.Equal(list.Sum(), NativeExportsNE.Collections.V1.Sum(list, list.Count)); + } + + [Fact] + public void NullBlittableElementColllectionMarshalledToNativeAsExpected() + { + Assert.Equal(-1, NativeExportsNE.Collections.V1.Sum(null, 0)); + } + + [Fact] + public void BlittableElementColllectionInParameter() + { + var list = new List { 1, 5, 79, 165, 32, 3 }; + Assert.Equal(list.Sum(), NativeExportsNE.Collections.V1.SumInArray(list, list.Count)); + } + + [Fact] + public void BlittableElementCollectionRefParameter() + { + var list = new List { 1, 5, 79, 165, 32, 3 }; + var newList = list; + NativeExportsNE.Collections.V1.Duplicate(ref newList, list.Count); + Assert.Equal((IEnumerable)list, newList); + } + + [Fact] + public void BlittableElementCollectionReturnedFromNative() + { + int start = 5; + int end = 20; + + IEnumerable expected = Enumerable.Range(start, end - start); + Assert.Equal(expected, NativeExportsNE.Collections.V1.CreateRange(start, end, out _)); + + List res; + NativeExportsNE.Collections.V1.CreateRange_Out(start, end, out _, out res); + Assert.Equal(expected, res); + } + + [Fact] + public void NullBlittableElementCollectionReturnedFromNative() + { + Assert.Null(NativeExportsNE.Collections.V1.CreateRange(1, 0, out _)); + + List res; + NativeExportsNE.Collections.V1.CreateRange_Out(1, 0, out _, out res); + Assert.Null(res); + } + + private static List GetStringList() + { + return new() + { + "ABCdef 123$%^", + "🍜 !! 🍜 !!", + "🌲 木 🔥 火 🌾 土 🛡 金 🌊 水" , + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed vitae posuere mauris, sed ultrices leo. Suspendisse potenti. Mauris enim enim, blandit tincidunt consequat in, varius sit amet neque. Morbi eget porttitor ex. Duis mattis aliquet ante quis imperdiet. Duis sit.", + string.Empty, + null + }; + } + + [Fact] + public void ByValueCollectionWithNonBlittableElements() + { + var strings = GetStringList(); + Assert.Equal(strings.Sum(str => str?.Length ?? 0), NativeExportsNE.Collections.V1.SumStringLengths(strings)); + } + + [Fact] + public void ByValueNullCollectionWithNonBlittableElements() + { + Assert.Equal(0, NativeExportsNE.Collections.V1.SumStringLengths(null)); + } + + [Fact] + public void ByValueCollectionWithNonBlittableElements_WithDefaultMarshalling() + { + var strings = new WrappedList_V1(GetStringList()); + Assert.Equal(strings.Wrapped.Sum(str => str?.Length ?? 0), NativeExportsNE.Collections.V1.SumStringLengths(strings)); + } + + [Fact] + public void ByRefCollectionWithNonBlittableElements() + { + var strings = GetStringList(); + var expectedStrings = strings.Select(s => ReverseChars(s)).ToList(); + NativeExportsNE.Collections.V1.ReverseStrings_Ref(ref strings, out _); + + Assert.Equal((IEnumerable)expectedStrings, strings); + } + + [Fact] + public void ReturnCollectionWithNonBlittableElements() + { + var strings = GetStringList(); + var expectedStrings = strings.Select(s => ReverseChars(s)).ToList(); + Assert.Equal(expectedStrings, NativeExportsNE.Collections.V1.ReverseStrings_Return(strings, out _)); + + List res; + NativeExportsNE.Collections.V1.ReverseStrings_Out(strings, out _, out res); + Assert.Equal(expectedStrings, res); + } + + [Fact] + public void ByRefNullCollectionWithNonBlittableElements() + { + List strings = null; + NativeExportsNE.Collections.V1.ReverseStrings_Ref(ref strings, out _); + + Assert.Null(strings); + } + + [Fact] + public void ReturnNullCollectionWithNonBlittableElements() + { + List strings = null; + Assert.Null(NativeExportsNE.Collections.V1.ReverseStrings_Return(strings, out _)); + + List res; + NativeExportsNE.Collections.V1.ReverseStrings_Out(strings, out _, out res); + Assert.Null(res); + } + + [Fact] + public void ConstantSizeCollection() + { + var longVal = 0x12345678ABCDEF10L; + + Assert.Equal(longVal, MemoryMarshal.Read(CollectionsMarshal.AsSpan(NativeExportsNE.Collections.V1.GetLongBytes(longVal)))); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CollectionWithSimpleNonBlittableTypeMarshalling(bool result) + { + var boolValues = new List + { + new BoolStruct_V1 + { + b1 = true, + b2 = true, + b3 = true, + }, + new BoolStruct_V1 + { + b1 = true, + b2 = true, + b3 = true, + }, + new BoolStruct_V1 + { + b1 = true, + b2 = true, + b3 = result, + }, + }; + + Assert.Equal(result, NativeExportsNE.Collections.V1.AndAllMembers(boolValues, boolValues.Count)); + } + + private static string ReverseChars(string value) + { + if (value == null) + return null; + + var chars = value.ToCharArray(); + Array.Reverse(chars); + return new string(chars); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs index cf058810d434b..b4851346917a2 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs @@ -18,50 +18,24 @@ partial class NativeExportsNE public partial class Collections { [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] - public static partial int Sum([MarshalUsing(typeof(ListMarshaller))] List values, int numValues); - - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] - public static partial int Sum(ref int values, int numValues); + public static partial int Sum([MarshalUsing(typeof(ListMarshaller<,>))] List values, int numValues); [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")] - public static partial int SumInArray([MarshalUsing(typeof(ListMarshaller))] in List values, int numValues); + public static partial int SumInArray([MarshalUsing(typeof(ListMarshaller<,>))] in List values, int numValues); [LibraryImport(NativeExportsNE_Binary, EntryPoint = "duplicate_int_array")] - public static partial void Duplicate([MarshalUsing(typeof(ListMarshaller), CountElementName = "numValues")] ref List values, int numValues); + public static partial void Duplicate([MarshalUsing(typeof(ListMarshaller<,>), CountElementName = "numValues")] ref List values, int numValues); [LibraryImport(NativeExportsNE_Binary, EntryPoint = "create_range_array")] - [return:MarshalUsing(typeof(ListMarshaller), CountElementName = "numValues")] + [return: MarshalUsing(typeof(ListMarshaller<,>), CountElementName = "numValues")] public static partial List CreateRange(int start, int end, out int numValues); [LibraryImport(NativeExportsNE_Binary, EntryPoint = "create_range_array_out")] - public static partial void CreateRange_Out(int start, int end, out int numValues, [MarshalUsing(typeof(ListMarshaller), CountElementName = "numValues")] out List res); - - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] - public static partial int SumStringLengths([MarshalUsing(typeof(ListMarshaller)), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] List strArray); - - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] - public static partial int SumStringLengths([MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] WrappedList strArray); - - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_replace")] - public static partial void ReverseStrings_Ref([MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] ref List strArray, out int numElements); - - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_return")] - [return: MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] - public static partial List ReverseStrings_Return([MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] List strArray, out int numElements); - - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_out")] - public static partial void ReverseStrings_Out( - [MarshalUsing(typeof(ListMarshaller)), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] List strArray, - out int numElements, - [MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] out List res); + public static partial void CreateRange_Out(int start, int end, out int numValues, [MarshalUsing(typeof(ListMarshaller<,>), CountElementName = "numValues")] out List res); [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_long_bytes")] - [return:MarshalUsing(typeof(ListMarshaller), ConstantElementCount = sizeof(long))] + [return: MarshalUsing(typeof(ListMarshaller<,>), ConstantElementCount = sizeof(long))] public static partial List GetLongBytes(long l); - - [LibraryImport(NativeExportsNE_Binary, EntryPoint = "and_all_members")] - [return: MarshalAs(UnmanagedType.U1)] - public static partial bool AndAllMembers([MarshalUsing(typeof(ListMarshaller))] List pArray, int length); } } @@ -133,68 +107,6 @@ private static List GetStringList() }; } - [Fact] - public void ByValueCollectionWithNonBlittableElements() - { - var strings = GetStringList(); - Assert.Equal(strings.Sum(str => str?.Length ?? 0), NativeExportsNE.Collections.SumStringLengths(strings)); - } - - [Fact] - public void ByValueNullCollectionWithNonBlittableElements() - { - Assert.Equal(0, NativeExportsNE.Collections.SumStringLengths(null)); - } - - [Fact] - public void ByValueCollectionWithNonBlittableElements_WithDefaultMarshalling() - { - var strings = new WrappedList(GetStringList()); - Assert.Equal(strings.Wrapped.Sum(str => str?.Length ?? 0), NativeExportsNE.Collections.SumStringLengths(strings)); - } - - [Fact] - public void ByRefCollectionWithNonBlittableElements() - { - var strings = GetStringList(); - var expectedStrings = strings.Select(s => ReverseChars(s)).ToList(); - NativeExportsNE.Collections.ReverseStrings_Ref(ref strings, out _); - - Assert.Equal((IEnumerable)expectedStrings, strings); - } - - [Fact] - public void ReturnCollectionWithNonBlittableElements() - { - var strings = GetStringList(); - var expectedStrings = strings.Select(s => ReverseChars(s)).ToList(); - Assert.Equal(expectedStrings, NativeExportsNE.Collections.ReverseStrings_Return(strings, out _)); - - List res; - NativeExportsNE.Collections.ReverseStrings_Out(strings, out _, out res); - Assert.Equal(expectedStrings, res); - } - - [Fact] - public void ByRefNullCollectionWithNonBlittableElements() - { - List strings = null; - NativeExportsNE.Collections.ReverseStrings_Ref(ref strings, out _); - - Assert.Null(strings); - } - - [Fact] - public void ReturnNullCollectionWithNonBlittableElements() - { - List strings = null; - Assert.Null(NativeExportsNE.Collections.ReverseStrings_Return(strings, out _)); - - List res; - NativeExportsNE.Collections.ReverseStrings_Out(strings, out _, out res); - Assert.Null(res); - } - [Fact] public void ConstantSizeCollection() { @@ -202,45 +114,5 @@ public void ConstantSizeCollection() Assert.Equal(longVal, MemoryMarshal.Read(CollectionsMarshal.AsSpan(NativeExportsNE.Collections.GetLongBytes(longVal)))); } - - [Theory] - [InlineData(true)] - [InlineData(false)] - public void CollectionWithSimpleNonBlittableTypeMarshalling(bool result) - { - var boolValues = new List - { - new BoolStruct_V1 - { - b1 = true, - b2 = true, - b3 = true, - }, - new BoolStruct_V1 - { - b1 = true, - b2 = true, - b3 = true, - }, - new BoolStruct_V1 - { - b1 = true, - b2 = true, - b3 = result, - }, - }; - - Assert.Equal(result, NativeExportsNE.Collections.AndAllMembers(boolValues, boolValues.Count)); - } - - private static string ReverseChars(string value) - { - if (value == null) - return null; - - var chars = value.ToCharArray(); - Array.Reverse(chars); - return new string(chars); - } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs index d3e5b0c2e1fa9..eb3c4a66c00ea 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs @@ -1379,49 +1379,162 @@ struct RecursiveStruct2 int i; }"; - public static string CollectionByValue(string elementType) => BasicParameterByValue($"TestCollection<{elementType}>", DisableRuntimeMarshalling) + @" -[NativeMarshalling(typeof(Marshaller<>))] -class TestCollection {} - -[CustomTypeMarshaller(typeof(TestCollection<>), CustomTypeMarshallerKind.LinearCollection, Features = CustomTypeMarshallerFeatures.TwoStageMarshalling)] -ref struct Marshaller -{ - public Marshaller(int nativeElementSize) : this() {} - public Marshaller(TestCollection managed, int nativeElementSize) : this() {} - public System.ReadOnlySpan GetManagedValuesSource() => throw null; - public System.Span GetManagedValuesDestination(int length) => throw null; - public System.ReadOnlySpan GetNativeValuesSource(int length) => throw null; - public System.Span GetNativeValuesDestination() => throw null; - public System.IntPtr ToNativeValue() => throw null; - public void FromNativeValue(System.IntPtr value) => throw null; - public TestCollection ToManaged() => throw null; -} + public static class CustomCollectionMarshalling + { + public static string TestCollection(bool defineNativeMarshalling = true) => $@" +{(defineNativeMarshalling ? "[NativeMarshalling(typeof(Marshaller<,>))]" : string.Empty)} +class TestCollection {{}} "; - public static string CollectionByValue() => CollectionByValue(typeof(T).ToString()); - - public static string MarshalUsingCollectionCountInfoParametersAndModifiers(string collectionType) => $@" + public static string CollectionOutParameter(string collectionType, string predeclaration = "") => $@" using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; -{DisableRuntimeMarshalling} +{predeclaration} partial class Test {{ [LibraryImport(""DoesNotExist"")] - [return:MarshalUsing(ConstantElementCount=10)] - public static partial {collectionType} Method( - {collectionType} p, - in {collectionType} pIn, - int pRefSize, - [MarshalUsing(CountElementName = ""pRefSize"")] ref {collectionType} pRef, - [MarshalUsing(CountElementName = ""pOutSize"")] out {collectionType} pOut, - out int pOutSize - ); -}}"; + public static partial int Method( + [MarshalUsing(ConstantElementCount = 10)] out {collectionType} pOut); +}} +"; + public static string CollectionReturnType(string collectionType, string predeclaration = "") => $@" +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +{predeclaration} +partial class Test +{{ + [LibraryImport(""DoesNotExist"")] + [return: MarshalUsing(ConstantElementCount = 10)] + public static partial {collectionType} Method(); +}} +"; + public static class Stateless + { + public const string In = @" +[CustomMarshaller(typeof(TestCollection<>), Scenario.ManagedToUnmanagedIn, typeof(Marshaller<,>))] +static unsafe class Marshaller +{ + public static byte* AllocateContainerForUnmanagedElements(TestCollection managed, out int numElements) => throw null; + public static System.ReadOnlySpan GetManagedValuesSource(TestCollection managed) => throw null; + public static System.Span GetUnmanagedValuesDestination(byte* unmanaged, int numElements) => throw null; +} +"; + public const string InBuffer = @" +[CustomMarshaller(typeof(TestCollection<>), Scenario.ManagedToUnmanagedIn, typeof(Marshaller<,>))] +static unsafe class Marshaller +{ + public const int BufferSize = 0x100; + public static byte* AllocateContainerForUnmanagedElements(TestCollection managed, System.Span buffer, out int numElements) => throw null; + public static System.ReadOnlySpan GetManagedValuesSource(TestCollection managed) => throw null; + public static System.Span GetUnmanagedValuesDestination(byte* unmanaged, int numElements) => throw null; +} +"; + public const string Ref = @" +[CustomMarshaller(typeof(TestCollection<>), Scenario.Default, typeof(Marshaller<,>))] +static unsafe class Marshaller +{ + public static byte* AllocateContainerForUnmanagedElements(TestCollection managed, out int numElements) => throw null; + public static System.ReadOnlySpan GetManagedValuesSource(TestCollection managed) => throw null; + public static System.Span GetUnmanagedValuesDestination(byte* unmanaged, int numElements) => throw null; - public static string CustomCollectionWithMarshaller(bool enableDefaultMarshalling) + public static TestCollection AllocateContainerForManagedElements(byte* unmanaged, int length) => throw null; + public static System.Span GetManagedValuesDestination(TestCollection managed) => throw null; + public static System.ReadOnlySpan GetUnmanagedValuesSource(byte* unmanaged, int numElements) => throw null; +} +"; + public const string RefNested = @" +[CustomMarshaller(typeof(TestCollection<>), Scenario.Default, typeof(Marshaller<,>.Ref.Nested))] +static unsafe class Marshaller +{ + static class Nested + { + static class Ref { - string nativeMarshallingAttribute = enableDefaultMarshalling ? "[NativeMarshalling(typeof(Marshaller<>))]" : string.Empty; - return nativeMarshallingAttribute + @"class TestCollection {} + public static byte* AllocateContainerForUnmanagedElements(TestCollection managed, out int numElements) => throw null; + public static System.ReadOnlySpan GetManagedValuesSource(TestCollection managed) => throw null; + public static System.Span GetUnmanagedValuesDestination(byte* unmanaged, int numElements) => throw null; + + public static TestCollection AllocateContainerForManagedElements(byte* unmanaged, int length) => throw null; + public static System.Span GetManagedValuesDestination(TestCollection managed) => throw null; + public static System.ReadOnlySpan GetUnmanagedValuesSource(byte* unmanaged, int numElements) => throw null; + } + } +} +"; + public const string Out = @" +[CustomMarshaller(typeof(TestCollection<>), Scenario.ManagedToUnmanagedOut, typeof(Marshaller<,>))] +static unsafe class Marshaller +{ + public static TestCollection AllocateContainerForManagedElements(byte* unmanaged, int length) => throw null; + public static System.Span GetManagedValuesDestination(TestCollection managed) => throw null; + public static System.ReadOnlySpan GetUnmanagedValuesSource(byte* unmanaged, int numElements) => throw null; +} +"; + public static string ByValue() => ByValue(typeof(T).ToString()); + public static string ByValue(string elementType) => BasicParameterByValue($"TestCollection<{elementType}>", DisableRuntimeMarshalling) + + TestCollection() + + In; + + public static string ByValueCallerAllocatedBuffer() => ByValueCallerAllocatedBuffer(typeof(T).ToString()); + public static string ByValueCallerAllocatedBuffer(string elementType) => BasicParameterByValue($"TestCollection<{elementType}>", DisableRuntimeMarshalling) + + TestCollection() + + In; + + public static string DefaultMarshallerParametersAndModifiers() => DefaultMarshallerParametersAndModifiers(typeof(T).ToString()); + public static string DefaultMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionCountInfoParametersAndModifiers($"TestCollection<{elementType}>") + + TestCollection() + + Ref; + + public static string CustomMarshallerParametersAndModifiers() => CustomMarshallerParametersAndModifiers(typeof(T).ToString()); + public static string CustomMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionParametersAndModifiers($"TestCollection<{elementType}>", $"Marshaller<,>") + + TestCollection(defineNativeMarshalling: false) + + Ref; + + public static string CustomMarshallerReturnValueLength() => CustomMarshallerReturnValueLength(typeof(T).ToString()); + public static string CustomMarshallerReturnValueLength(string elementType) => MarshalUsingCollectionReturnValueLength($"TestCollection<{elementType}>", $"Marshaller<,>") + + TestCollection(defineNativeMarshalling: false) + + Ref; + + public static string NativeToManagedOnlyOutParameter() => NativeToManagedOnlyOutParameter(typeof(T).ToString()); + public static string NativeToManagedOnlyOutParameter(string elementType) => CollectionOutParameter($"TestCollection<{elementType}>") + + TestCollection() + + Out; + + public static string NativeToManagedOnlyReturnValue() => NativeToManagedOnlyReturnValue(typeof(T).ToString()); + public static string NativeToManagedOnlyReturnValue(string elementType) => CollectionReturnType($"TestCollection<{elementType}>") + + TestCollection() + + Out; + + public static string NestedMarshallerParametersAndModifiers() => DefaultMarshallerParametersAndModifiers(typeof(T).ToString()); + public static string NestedMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionCountInfoParametersAndModifiers($"TestCollection<{elementType}>") + + TestCollection() + + RefNested; + + public static string GenericCollectionMarshallingArityMismatch => BasicParameterByValue("TestCollection", DisableRuntimeMarshalling) + + @" +[NativeMarshalling(typeof(Marshaller<,,>))] +class TestCollection {} + +[CustomMarshaller(typeof(TestCollection<>), Scenario.Default, typeof(Marshaller<,,>))] +static unsafe class Marshaller +{ + public static byte* AllocateContainerForUnmanagedElements(TestCollection managed, out int numElements) => throw null; + public static System.ReadOnlySpan GetManagedValuesSource(TestCollection managed) => throw null; + public static System.Span GetUnmanagedValuesDestination(byte* unmanaged, int numElements) => throw null; + + public static TestCollection AllocateContainerForManagedElements(byte* unmanaged, int length) => throw null; + public static System.Span GetManagedValuesDestination(TestCollection managed) => throw null; + public static System.ReadOnlySpan GetUnmanagedValuesSource(byte* unmanaged, int numElements) => throw null; +} +"; + } + } + + public static class CustomCollectionMarshalling_V1 + { + public static string ByValue(string elementType) => BasicParameterByValue($"TestCollection<{elementType}>", DisableRuntimeMarshalling) + @" +[NativeMarshalling(typeof(Marshaller<>))] +class TestCollection {} [CustomTypeMarshaller(typeof(TestCollection<>), CustomTypeMarshallerKind.LinearCollection, Features = CustomTypeMarshallerFeatures.TwoStageMarshalling)] ref struct Marshaller @@ -1435,95 +1548,44 @@ public Marshaller(TestCollection managed, int nativeElementSize) : this() {} public System.IntPtr ToNativeValue() => throw null; public void FromNativeValue(System.IntPtr value) => throw null; public TestCollection ToManaged() => throw null; -}"; - } - - public static string MarshalUsingCollectionCountInfoParametersAndModifiers() => MarshalUsingCollectionCountInfoParametersAndModifiers(typeof(T).ToString()); - - public static string CustomCollectionDefaultMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionCountInfoParametersAndModifiers($"TestCollection<{elementType}>") + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); - - public static string CustomCollectionDefaultMarshallerParametersAndModifiers() => CustomCollectionDefaultMarshallerParametersAndModifiers(typeof(T).ToString()); - - public static string MarshalUsingCollectionParametersAndModifiers(string collectionType, string marshallerType) => $@" -using System.Runtime.InteropServices; -using System.Runtime.InteropServices.Marshalling; -{DisableRuntimeMarshalling} -partial class Test -{{ - [LibraryImport(""DoesNotExist"")] - [return:MarshalUsing(typeof({marshallerType}), ConstantElementCount=10)] - public static partial {collectionType} Method( - [MarshalUsing(typeof({marshallerType}))] {collectionType} p, - [MarshalUsing(typeof({marshallerType}))] in {collectionType} pIn, - int pRefSize, - [MarshalUsing(typeof({marshallerType}), CountElementName = ""pRefSize"")] ref {collectionType} pRef, - [MarshalUsing(typeof({marshallerType}), CountElementName = ""pOutSize"")] out {collectionType} pOut, - out int pOutSize - ); -}}"; +} +"; - public static string CustomCollectionCustomMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionParametersAndModifiers($"TestCollection<{elementType}>", $"Marshaller<{elementType}>") + CustomCollectionWithMarshaller(enableDefaultMarshalling: false); + public static string ByValue() => ByValue(typeof(T).ToString()); - public static string CustomCollectionCustomMarshallerParametersAndModifiers() => CustomCollectionCustomMarshallerParametersAndModifiers(typeof(T).ToString()); + public static string CustomCollectionWithMarshaller(bool enableDefaultMarshalling) + { + string nativeMarshallingAttribute = enableDefaultMarshalling ? "[NativeMarshalling(typeof(Marshaller<>))]" : string.Empty; + return nativeMarshallingAttribute + @"class TestCollection {} - public static string MarshalUsingCollectionReturnValueLength(string collectionType, string marshallerType) => $@" -using System.Runtime.InteropServices; -using System.Runtime.InteropServices.Marshalling; -{DisableRuntimeMarshalling} -partial class Test -{{ - [LibraryImport(""DoesNotExist"")] - public static partial int Method( - [MarshalUsing(typeof({marshallerType}), CountElementName = MarshalUsingAttribute.ReturnsCountValue)] out {collectionType} pOut - ); -}}"; + [CustomTypeMarshaller(typeof(TestCollection<>), CustomTypeMarshallerKind.LinearCollection, Features = CustomTypeMarshallerFeatures.TwoStageMarshalling)] + ref struct Marshaller + { + public Marshaller(int nativeElementSize) : this() {} + public Marshaller(TestCollection managed, int nativeElementSize) : this() {} + public System.ReadOnlySpan GetManagedValuesSource() => throw null; + public System.Span GetManagedValuesDestination(int length) => throw null; + public System.ReadOnlySpan GetNativeValuesSource(int length) => throw null; + public System.Span GetNativeValuesDestination() => throw null; + public System.IntPtr ToNativeValue() => throw null; + public void FromNativeValue(System.IntPtr value) => throw null; + public TestCollection ToManaged() => throw null; + }"; + } - public static string CustomCollectionCustomMarshallerReturnValueLength(string elementType) => MarshalUsingCollectionReturnValueLength($"TestCollection<{elementType}>", $"Marshaller<{elementType}>") + CustomCollectionWithMarshaller(enableDefaultMarshalling: false); + public static string DefaultMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionCountInfoParametersAndModifiers($"TestCollection<{elementType}>") + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); - public static string CustomCollectionCustomMarshallerReturnValueLength() => CustomCollectionCustomMarshallerReturnValueLength(typeof(T).ToString()); + public static string DefaultMarshallerParametersAndModifiers() => DefaultMarshallerParametersAndModifiers(typeof(T).ToString()); - public static string MarshalUsingArrayParameterWithSizeParam(string sizeParamType, bool isByRef) => $@" -using System.Runtime.InteropServices; -using System.Runtime.InteropServices.Marshalling; -{DisableRuntimeMarshalling} -partial class Test -{{ - [LibraryImport(""DoesNotExist"")] - public static partial void Method( - {(isByRef ? "ref" : "")} {sizeParamType} pRefSize, - [MarshalUsing(CountElementName = ""pRefSize"")] ref int[] pRef - ); -}}"; + public static string CustomMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionParametersAndModifiers($"TestCollection<{elementType}>", $"Marshaller<{elementType}>") + CustomCollectionWithMarshaller(enableDefaultMarshalling: false); - public static string MarshalUsingArrayParameterWithSizeParam(bool isByRef) => MarshalUsingArrayParameterWithSizeParam(typeof(T).ToString(), isByRef); + public static string CustomMarshallerParametersAndModifiers() => CustomMarshallerParametersAndModifiers(typeof(T).ToString()); - public static string MarshalUsingCollectionWithConstantAndElementCount => $@" -using System.Runtime.InteropServices; -using System.Runtime.InteropServices.Marshalling; -{DisableRuntimeMarshalling} -partial class Test -{{ - [LibraryImport(""DoesNotExist"")] - public static partial void Method( - int pRefSize, - [MarshalUsing(ConstantElementCount = 10, CountElementName = ""pRefSize"")] ref int[] pRef - ); -}}"; + public static string CustomMarshallerReturnValueLength(string elementType) => MarshalUsingCollectionReturnValueLength($"TestCollection<{elementType}>", $"Marshaller<{elementType}>") + CustomCollectionWithMarshaller(enableDefaultMarshalling: false); - public static string MarshalUsingCollectionWithNullElementName => $@" -using System.Runtime.InteropServices; -using System.Runtime.InteropServices.Marshalling; -{DisableRuntimeMarshalling} -partial class Test -{{ - [LibraryImport(""DoesNotExist"")] - public static partial void Method( - int pRefSize, - [MarshalUsing(CountElementName = null)] ref int[] pRef - ); -}}"; + public static string CustomMarshallerReturnValueLength() => CustomMarshallerReturnValueLength(typeof(T).ToString()); - public static string GenericCollectionMarshallingArityMismatch => BasicParameterByValue("TestCollection", DisableRuntimeMarshalling) + @" + public static string GenericCollectionMarshallingArityMismatch => BasicParameterByValue("TestCollection", DisableRuntimeMarshalling) + @" [NativeMarshalling(typeof(Marshaller<,>))] class TestCollection {} @@ -1542,7 +1604,7 @@ public Marshaller(TestCollection managed, int nativeElementSize) : this() {} public TestCollection ToManaged() => throw null; }"; - public static string GenericCollectionWithCustomElementMarshalling => $@" + public static string GenericCollectionWithCustomElementMarshalling => $@" using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; {DisableRuntimeMarshalling} @@ -1569,7 +1631,7 @@ public IntWrapper(int i){{}} " + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); - public static string GenericCollectionWithCustomElementMarshallingDuplicateElementIndirectionDepth => $@" + public static string GenericCollectionWithCustomElementMarshallingDuplicateElementIndirectionDepth => $@" using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; {DisableRuntimeMarshalling} @@ -1588,7 +1650,7 @@ public IntWrapper(int i){{}} " + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); - public static string GenericCollectionWithCustomElementMarshallingUnusedElementIndirectionDepth => $@" + public static string GenericCollectionWithCustomElementMarshallingUnusedElementIndirectionDepth => $@" using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; {DisableRuntimeMarshalling} @@ -1606,6 +1668,98 @@ public IntWrapper(int i){{}} }} " + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); + } + + public static string MarshalUsingCollectionCountInfoParametersAndModifiers(string collectionType) => $@" +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +{DisableRuntimeMarshalling} +partial class Test +{{ + [LibraryImport(""DoesNotExist"")] + [return:MarshalUsing(ConstantElementCount=10)] + public static partial {collectionType} Method( + {collectionType} p, + in {collectionType} pIn, + int pRefSize, + [MarshalUsing(CountElementName = ""pRefSize"")] ref {collectionType} pRef, + [MarshalUsing(CountElementName = ""pOutSize"")] out {collectionType} pOut, + out int pOutSize + ); +}}"; + + public static string MarshalUsingCollectionCountInfoParametersAndModifiers() => MarshalUsingCollectionCountInfoParametersAndModifiers(typeof(T).ToString()); + + public static string MarshalUsingCollectionParametersAndModifiers(string collectionType, string marshallerType) => $@" +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +{DisableRuntimeMarshalling} +partial class Test +{{ + [LibraryImport(""DoesNotExist"")] + [return:MarshalUsing(typeof({marshallerType}), ConstantElementCount=10)] + public static partial {collectionType} Method( + [MarshalUsing(typeof({marshallerType}))] {collectionType} p, + [MarshalUsing(typeof({marshallerType}))] in {collectionType} pIn, + int pRefSize, + [MarshalUsing(typeof({marshallerType}), CountElementName = ""pRefSize"")] ref {collectionType} pRef, + [MarshalUsing(typeof({marshallerType}), CountElementName = ""pOutSize"")] out {collectionType} pOut, + out int pOutSize + ); +}}"; + + public static string MarshalUsingCollectionReturnValueLength(string collectionType, string marshallerType) => $@" +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +{DisableRuntimeMarshalling} +partial class Test +{{ + [LibraryImport(""DoesNotExist"")] + public static partial int Method( + [MarshalUsing(typeof({marshallerType}), CountElementName = MarshalUsingAttribute.ReturnsCountValue)] out {collectionType} pOut + ); +}}"; + + public static string MarshalUsingArrayParameterWithSizeParam(string sizeParamType, bool isByRef) => $@" +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +{DisableRuntimeMarshalling} +partial class Test +{{ + [LibraryImport(""DoesNotExist"")] + public static partial void Method( + {(isByRef ? "ref" : "")} {sizeParamType} pRefSize, + [MarshalUsing(CountElementName = ""pRefSize"")] ref int[] pRef + ); +}}"; + + public static string MarshalUsingArrayParameterWithSizeParam(bool isByRef) => MarshalUsingArrayParameterWithSizeParam(typeof(T).ToString(), isByRef); + + public static string MarshalUsingCollectionWithConstantAndElementCount => $@" +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +{DisableRuntimeMarshalling} +partial class Test +{{ + [LibraryImport(""DoesNotExist"")] + public static partial void Method( + int pRefSize, + [MarshalUsing(ConstantElementCount = 10, CountElementName = ""pRefSize"")] ref int[] pRef + ); +}}"; + + public static string MarshalUsingCollectionWithNullElementName => $@" +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +{DisableRuntimeMarshalling} +partial class Test +{{ + [LibraryImport(""DoesNotExist"")] + public static partial void Method( + int pRefSize, + [MarshalUsing(CountElementName = null)] ref int[] pRef + ); +}}"; public static string MarshalAsAndMarshalUsingOnReturnValue => $@" using System.Runtime.InteropServices; diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs index 935980dbc7395..539a5116bcf0c 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs @@ -117,11 +117,12 @@ public static IEnumerable CodeSnippetsToCompile() yield return new object[] { CodeSnippets.MarshalUsingCollectionWithNullElementName, 2, 0 }; // Generic collection marshaller has different arity than collection. - yield return new object[] { CodeSnippets.GenericCollectionMarshallingArityMismatch, 2, 0 }; + yield return new object[] { CodeSnippets.CustomCollectionMarshalling.Stateless.GenericCollectionMarshallingArityMismatch, 2, 0 }; + yield return new object[] { CodeSnippets.CustomCollectionMarshalling_V1.GenericCollectionMarshallingArityMismatch, 2, 0 }; yield return new object[] { CodeSnippets.MarshalAsAndMarshalUsingOnReturnValue, 2, 0 }; - yield return new object[] { CodeSnippets.GenericCollectionWithCustomElementMarshallingDuplicateElementIndirectionDepth, 2, 0 }; - yield return new object[] { CodeSnippets.GenericCollectionWithCustomElementMarshallingUnusedElementIndirectionDepth, 1, 0 }; + yield return new object[] { CodeSnippets.CustomCollectionMarshalling_V1.GenericCollectionWithCustomElementMarshallingDuplicateElementIndirectionDepth, 2, 0 }; + yield return new object[] { CodeSnippets.CustomCollectionMarshalling_V1.GenericCollectionWithCustomElementMarshallingUnusedElementIndirectionDepth, 1, 0 }; yield return new object[] { CodeSnippets.RecursiveCountElementNameOnReturnValue, 2, 0 }; yield return new object[] { CodeSnippets.RecursiveCountElementNameOnParameter, 2, 0 }; yield return new object[] { CodeSnippets.MutuallyRecursiveCountElementNameOnParameter, 4, 0 }; diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs index 5536cfd4a24b5..4a971cd7f1d79 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs @@ -222,20 +222,47 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.MaybeBlittableGenericTypeParametersAndModifiers() }; yield return new[] { CodeSnippets.MaybeBlittableGenericTypeParametersAndModifiers() }; yield return new[] { CodeSnippets.MaybeBlittableGenericTypeParametersAndModifiers() }; + } + public static IEnumerable CustomCollections() + { // Custom collection marshalling - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; - yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.ByValueCallerAllocatedBuffer() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.ByValue() }; yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; @@ -248,37 +275,66 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; - yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerReturnValueLength() }; - yield return new[] { CodeSnippets.GenericCollectionWithCustomElementMarshalling }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.CustomMarshallerReturnValueLength() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.NativeToManagedOnlyOutParameter() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.NativeToManagedOnlyReturnValue() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling.Stateless.NestedMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.DefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.CustomMarshallerReturnValueLength() }; + yield return new[] { CodeSnippets.CustomCollectionMarshalling_V1.GenericCollectionWithCustomElementMarshalling }; yield return new[] { CodeSnippets.CollectionsOfCollectionsStress }; } [Theory] [MemberData(nameof(CodeSnippetsToCompile))] + [MemberData(nameof(CustomCollections))] public async Task ValidateSnippets(string source) { Compilation comp = await TestUtils.CreateCompilation(source); @@ -301,7 +357,6 @@ public static IEnumerable CodeSnippetsToCompileWithPreprocessorSymbols yield return new object[] { CodeSnippets.PreprocessorIfAfterAttributeAroundFunctionAdditionalFunctionAfter("Foo"), new string[] { "Foo" } }; yield return new object[] { CodeSnippets.PreprocessorIfAfterAttributeAroundFunctionAdditionalFunctionAfter("Foo"), Array.Empty() }; } - [Theory] [MemberData(nameof(CodeSnippetsToCompileWithPreprocessorSymbols))] public async Task ValidateSnippetsWithPreprocessorDefintions(string source, IEnumerable preprocessorSymbols) diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.V1.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.V1.cs index 7f8d5cc0f72ef..78b8d1bda0203 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.V1.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.V1.cs @@ -144,24 +144,24 @@ public IntStructWrapperNative(IntStructWrapper managed) } [CustomTypeMarshaller(typeof(List<>), CustomTypeMarshallerKind.LinearCollection, Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer, BufferSize = 0x200)] - public unsafe ref struct ListMarshaller + public unsafe ref struct ListMarshaller_V1 { private List managedList; private readonly int sizeOfNativeElement; private IntPtr allocatedMemory; - public ListMarshaller(int sizeOfNativeElement) + public ListMarshaller_V1(int sizeOfNativeElement) : this() { this.sizeOfNativeElement = sizeOfNativeElement; } - public ListMarshaller(List managed, int sizeOfNativeElement) + public ListMarshaller_V1(List managed, int sizeOfNativeElement) :this(managed, Span.Empty, sizeOfNativeElement) { } - public ListMarshaller(List managed, Span stackSpace, int sizeOfNativeElement) + public ListMarshaller_V1(List managed, Span stackSpace, int sizeOfNativeElement) { allocatedMemory = default; this.sizeOfNativeElement = sizeOfNativeElement; @@ -228,10 +228,10 @@ public void FreeNative() } } - [NativeMarshalling(typeof(WrappedListMarshaller<>))] - public struct WrappedList + [NativeMarshalling(typeof(WrappedListMarshaller_V1<>))] + public struct WrappedList_V1 { - public WrappedList(List list) + public WrappedList_V1(List list) { Wrapped = list; } @@ -241,25 +241,25 @@ public WrappedList(List list) public ref T GetPinnableReference() => ref CollectionsMarshal.AsSpan(Wrapped).GetPinnableReference(); } - [CustomTypeMarshaller(typeof(WrappedList<>), CustomTypeMarshallerKind.LinearCollection, Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer, BufferSize = 0x200)] - public unsafe ref struct WrappedListMarshaller + [CustomTypeMarshaller(typeof(WrappedList_V1<>), CustomTypeMarshallerKind.LinearCollection, Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer, BufferSize = 0x200)] + public unsafe ref struct WrappedListMarshaller_V1 { - private ListMarshaller _marshaller; + private ListMarshaller_V1 _marshaller; - public WrappedListMarshaller(int sizeOfNativeElement) + public WrappedListMarshaller_V1(int sizeOfNativeElement) : this() { - this._marshaller = new ListMarshaller(sizeOfNativeElement); + this._marshaller = new ListMarshaller_V1(sizeOfNativeElement); } - public WrappedListMarshaller(WrappedList managed, int sizeOfNativeElement) + public WrappedListMarshaller_V1(WrappedList_V1 managed, int sizeOfNativeElement) : this(managed, Span.Empty, sizeOfNativeElement) { } - public WrappedListMarshaller(WrappedList managed, Span stackSpace, int sizeOfNativeElement) + public WrappedListMarshaller_V1(WrappedList_V1 managed, Span stackSpace, int sizeOfNativeElement) { - this._marshaller = new ListMarshaller(managed.Wrapped, stackSpace, sizeOfNativeElement); + this._marshaller = new ListMarshaller_V1(managed.Wrapped, stackSpace, sizeOfNativeElement); } public ReadOnlySpan GetManagedValuesSource() => _marshaller.GetManagedValuesSource(); @@ -276,7 +276,7 @@ public WrappedListMarshaller(WrappedList managed, Span stackSpace, int public void FromNativeValue(byte* value) => _marshaller.FromNativeValue(value); - public WrappedList ToManaged() => new(_marshaller.ToManaged()); + public WrappedList_V1 ToManaged() => new(_marshaller.ToManaged()); public void FreeNative() => _marshaller.FreeNative(); } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs index f9f34a6e4dec5..b01337b9973c0 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs @@ -166,7 +166,6 @@ public static void Free(int* unmanaged) } } - [CustomMarshaller(typeof(IntWrapper), Scenario.Default, typeof(Marshaller))] public static unsafe class IntWrapperMarshallerStateful { @@ -282,4 +281,92 @@ public BoolStruct ToManaged() } } } + + [CustomMarshaller(typeof(List<>), Scenario.Default, typeof(ListMarshaller<,>))] + public unsafe static class ListMarshaller where TUnmanagedElement : unmanaged + { + public static byte* AllocateContainerForUnmanagedElements(List managed, out int numElements) + => AllocateContainerForUnmanagedElements(managed, Span.Empty, out numElements); + + public static byte* AllocateContainerForUnmanagedElements(List managed, Span buffer, out int numElements) + { + if (managed is null) + { + numElements = 0; + return null; + } + + numElements = managed.Count; + + // Always allocate at least one byte when the list is zero-length. + int spaceToAllocate = Math.Max(checked(sizeof(TUnmanagedElement) * numElements), 1); + if (spaceToAllocate <= buffer.Length) + { + return (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer)); + } + else + { + return (byte*)Marshal.AllocCoTaskMem(spaceToAllocate); + } + } + + public static ReadOnlySpan GetManagedValuesSource(List managed) + => CollectionsMarshal.AsSpan(managed); + + public static Span GetUnmanagedValuesDestination(byte* unmanaged, int numElements) + => new Span(unmanaged, numElements); + + public static List AllocateContainerForManagedElements(byte* unmanaged, int length) + { + if (unmanaged is null) + return null; + + var list = new List(length); + for (int i = 0; i < length; i++) + { + list.Add(default); + } + + return list; + } + + public static Span GetManagedValuesDestination(List managed) + => CollectionsMarshal.AsSpan(managed); + + public static ReadOnlySpan GetUnmanagedValuesSource(byte* nativeValue, int numElements) + => new ReadOnlySpan(nativeValue, numElements); + + public static void Free(byte* unmanaged) + => Marshal.FreeCoTaskMem((IntPtr)unmanaged); + } + + [CustomMarshaller(typeof(CustomMarshallerAttribute.GenericPlaceholder[]), Scenario.Default, typeof(CustomArrayMarshaller<,>))] + public unsafe static class CustomArrayMarshaller where TUnmanagedElement : unmanaged + { + public static byte* AllocateContainerForUnmanagedElements(T[]? managed, out int numElements) + { + if (managed is null) + { + numElements = 0; + return null; + } + + numElements = managed.Length; + return (byte*)Marshal.AllocCoTaskMem(checked(sizeof(TUnmanagedElement) * numElements)); + } + + public static ReadOnlySpan GetManagedValuesSource(T[] managed) => managed; + + public static Span GetUnmanagedValuesDestination(byte* unmanaged, int numElements) + => new Span(unmanaged, numElements); + + public static T[] AllocateContainerForManagedElements(byte* unmanaged, int length) => unmanaged is null ? null : new T[length]; + + public static Span GetManagedValuesDestination(T[] managed) => managed; + + public static ReadOnlySpan GetUnmanagedValuesSource(byte* unmanaged, int numElements) + => new Span(unmanaged, numElements); + + public static void Free(byte* unmanaged) => Marshal.FreeCoTaskMem((IntPtr)unmanaged); + } }