Skip to content

Commit

Permalink
Stop requiring bidirectional support for MarshalMode.Default (#71977)
Browse files Browse the repository at this point in the history
  • Loading branch information
elinor-fung authored Jul 12, 2022
1 parent 5ab4175 commit a3a106f
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ public static class MarshalUsingProperties
public const string ConstantElementCount = nameof(ConstantElementCount);
}

[Flags]
private enum MarshallingDirection
{
ManagedToUnmanaged = 0x1,
UnmanagedToManaged = 0x2,
Bidirectional = ManagedToUnmanaged | UnmanagedToManaged
}

public static bool IsLinearCollectionEntryPoint(INamedTypeSymbol entryPointType)
{
return entryPointType.IsGenericType
Expand Down Expand Up @@ -114,7 +106,8 @@ private static bool TryGetMarshallersFromEntryType(
return false;

// We expect a callback for getting the element marshalling info when handling linear collection marshalling
Debug.Assert(!isLinearCollectionMarshalling || getMarshallingInfoForElement is not null);
if (isLinearCollectionMarshalling && getMarshallingInfoForElement is null)
return false;

Dictionary<MarshalMode, CustomTypeMarshallerData> modes = new();

Expand All @@ -127,7 +120,12 @@ private static bool TryGetMarshallersFromEntryType(
// We don't report a diagnostic here since Roslyn will report a diagnostic anyway.
continue;
}
Debug.Assert(attr.ConstructorArguments.Length == 3);

if (attr.ConstructorArguments.Length != 3)
{
Debug.WriteLine($"{attr} has {attr.ConstructorArguments.Length} constructor arguments - expected 3");
continue;
}

// Verify the defined marshaller is for the managed type.
ITypeSymbol? managedTypeOnAttr = attr.ConstructorArguments[0].Value as ITypeSymbol;
Expand Down Expand Up @@ -175,33 +173,9 @@ private static bool TryGetMarshallersFromEntryType(
marshallerType = currentType;
}

// TODO: We can probably get rid of MarshallingDirection and just use MarshalMode instead
MarshallingDirection direction = marshalMode switch
{
MarshalMode.Default
=> MarshallingDirection.Bidirectional,

MarshalMode.ManagedToUnmanagedIn
or MarshalMode.UnmanagedToManagedOut
or MarshalMode.ElementIn
=> MarshallingDirection.ManagedToUnmanaged,

MarshalMode.ManagedToUnmanagedOut
or MarshalMode.UnmanagedToManagedIn
or MarshalMode.ElementOut
=> MarshallingDirection.UnmanagedToManaged,

MarshalMode.ManagedToUnmanagedRef
or MarshalMode.UnmanagedToManagedRef
or MarshalMode.ElementRef
=> MarshallingDirection.Bidirectional,

_ => throw new UnreachableException()
};

// TODO: Report invalid shape for mode
// Skip checking for bidirectional support for Default mode - always take / store marshaller data
CustomTypeMarshallerData? data = GetMarshallerDataForType(marshallerType, direction, managedType, isLinearCollectionMarshalling, compilation, getMarshallingInfoForElement);
CustomTypeMarshallerData? data = GetMarshallerDataForType(marshallerType, marshalMode, managedType, isLinearCollectionMarshalling, compilation, getMarshallingInfoForElement);

// TODO: Should we fire a diagnostic for duplicated modes or just take the last one?
if (data is null
Expand Down Expand Up @@ -338,32 +312,50 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault

private static CustomTypeMarshallerData? GetMarshallerDataForType(
ITypeSymbol marshallerType,
MarshallingDirection direction,
MarshalMode mode,
ITypeSymbol managedType,
bool isLinearCollectionMarshaller,
Compilation compilation,
Func<ITypeSymbol, MarshallingInfo> getMarshallingInfo)
{
if (marshallerType is { IsStatic: true, TypeKind: TypeKind.Class })
{
return GetStatelessMarshallerDataForType(marshallerType, direction, managedType, isLinearCollectionMarshaller, compilation, getMarshallingInfo);
return GetStatelessMarshallerDataForType(marshallerType, mode, managedType, isLinearCollectionMarshaller, compilation, getMarshallingInfo);
}
if (marshallerType.IsValueType)
{
return GetStatefulMarshallerDataForType(marshallerType, direction, managedType, isLinearCollectionMarshaller, compilation, getMarshallingInfo);
return GetStatefulMarshallerDataForType(marshallerType, mode, managedType, isLinearCollectionMarshaller, compilation, getMarshallingInfo);
}
return null;
}

private static CustomTypeMarshallerData? GetStatelessMarshallerDataForType(ITypeSymbol marshallerType, MarshallingDirection direction, ITypeSymbol managedType, bool isLinearCollectionMarshaller, Compilation compilation, Func<ITypeSymbol, MarshallingInfo>? getMarshallingInfo)
private static bool ModeUsesManagedToUnmanagedShape(MarshalMode mode)
=> mode is MarshalMode.Default
or MarshalMode.ManagedToUnmanagedIn
or MarshalMode.UnmanagedToManagedOut
or MarshalMode.ElementIn
or MarshalMode.ManagedToUnmanagedRef
or MarshalMode.UnmanagedToManagedRef
or MarshalMode.ElementRef;

private static bool ModeUsesUnmanagedToManagedShape(MarshalMode mode)
=> mode is MarshalMode.Default
or MarshalMode.ManagedToUnmanagedOut
or MarshalMode.UnmanagedToManagedIn
or MarshalMode.ElementOut
or MarshalMode.ManagedToUnmanagedRef
or MarshalMode.UnmanagedToManagedRef
or MarshalMode.ElementRef;

private static CustomTypeMarshallerData? GetStatelessMarshallerDataForType(ITypeSymbol marshallerType, MarshalMode mode, ITypeSymbol managedType, bool isLinearCollectionMarshaller, Compilation compilation, Func<ITypeSymbol, MarshallingInfo>? getMarshallingInfo)
{
(MarshallerShape shape, StatelessMarshallerShapeHelper.MarshallerMethods methods) = StatelessMarshallerShapeHelper.GetShapeForType(marshallerType, managedType, isLinearCollectionMarshaller, compilation);

ITypeSymbol? collectionElementType = null;
ITypeSymbol? nativeType = null;
if (direction.HasFlag(MarshallingDirection.ManagedToUnmanaged))
if (ModeUsesManagedToUnmanagedShape(mode))
{
if (!shape.HasFlag(MarshallerShape.CallerAllocatedBuffer) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
if (mode != MarshalMode.Default && !shape.HasFlag(MarshallerShape.CallerAllocatedBuffer) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
return null;

if (isLinearCollectionMarshaller)
Expand All @@ -383,9 +375,10 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault
}
}

if (direction.HasFlag(MarshallingDirection.UnmanagedToManaged))
if (ModeUsesUnmanagedToManagedShape(mode))
{
if (!shape.HasFlag(MarshallerShape.GuaranteedUnmarshal) && !shape.HasFlag(MarshallerShape.ToManaged))
// Unmanaged to managed requires ToManaged either with or without guaranteed unmarshal
if (mode != MarshalMode.Default && !shape.HasFlag(MarshallerShape.GuaranteedUnmarshal) && !shape.HasFlag(MarshallerShape.ToManaged))
return null;

if (isLinearCollectionMarshaller)
Expand All @@ -411,7 +404,7 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault
}

// Bidirectional requires ToUnmanaged without the caller-allocated buffer
if (direction.HasFlag(MarshallingDirection.Bidirectional) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
if (mode != MarshalMode.Default && ModeUsesManagedToUnmanagedShape(mode) && ModeUsesUnmanagedToManagedShape(mode) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
return null;

if (nativeType is null)
Expand Down Expand Up @@ -444,7 +437,7 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault

private static CustomTypeMarshallerData? GetStatefulMarshallerDataForType(
ITypeSymbol marshallerType,
MarshallingDirection direction,
MarshalMode mode,
ITypeSymbol managedType,
bool isLinearCollectionMarshaller,
Compilation compilation,
Expand All @@ -454,9 +447,10 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault

ITypeSymbol? nativeType = null;
ITypeSymbol? collectionElementType = null;
if (direction.HasFlag(MarshallingDirection.ManagedToUnmanaged))
if (ModeUsesManagedToUnmanagedShape(mode))
{
if (!shape.HasFlag(MarshallerShape.CallerAllocatedBuffer) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
// Managed to unmanaged requires ToUnmanaged either with or without the caller-allocated buffer
if (mode != MarshalMode.Default && !shape.HasFlag(MarshallerShape.CallerAllocatedBuffer) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
return null;

if (methods.ToUnmanaged is not null)
Expand All @@ -471,25 +465,26 @@ public static (AttributeData? attribute, INamedTypeSymbol? entryType) GetDefault
}
}

if (nativeType is null && direction.HasFlag(MarshallingDirection.UnmanagedToManaged))
if (ModeUsesUnmanagedToManagedShape(mode))
{
if (!shape.HasFlag(MarshallerShape.GuaranteedUnmarshal) && !shape.HasFlag(MarshallerShape.ToManaged))
// Unmanaged to managed requires ToManaged either with or without guaranteed unmarshal
if (mode != MarshalMode.Default && !shape.HasFlag(MarshallerShape.GuaranteedUnmarshal) && !shape.HasFlag(MarshallerShape.ToManaged))
return null;

if (methods.FromUnmanaged is not null)
if (methods.FromUnmanaged is not null && nativeType is null)
{
nativeType = methods.FromUnmanaged.Parameters[0].Type;
}

if (isLinearCollectionMarshaller)
if (isLinearCollectionMarshaller && collectionElementType is null)
{
// Element type is the type parameter of the Span returned by GetManagedValuesDestination
collectionElementType = ((INamedTypeSymbol)methods.ManagedValuesDestination.ReturnType).TypeArguments[0];
}
}

// Bidirectional requires ToUnmanaged without the caller-allocated buffer
if (direction.HasFlag(MarshallingDirection.Bidirectional) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
if (mode != MarshalMode.Default && ModeUsesManagedToUnmanagedShape(mode) && ModeUsesUnmanagedToManagedShape(mode) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
return null;

if (nativeType is null)
Expand Down
Loading

0 comments on commit a3a106f

Please sign in to comment.