diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index 85b6a254e..8f64f72b2 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -767,11 +767,43 @@ private void MarshalRefClass(Class @class) { if (Context.Parameter.IsIndirect) { - Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))"); - Context.Before.WriteLineIndent( - $@"throw new global::System.ArgumentNullException(""{ - Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");"); - Context.Return.Write(paramInstance); + Method cctor = @class.HasNonTrivialCopyConstructor ? @class.Methods.First(c => c.IsCopyConstructor) : null; + if (cctor != null && cctor.IsGenerated) + { + Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))"); + Context.Before.WriteLineIndent( + $@"throw new global::System.ArgumentNullException(""{ + Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");"); + + var nativeClass = typePrinter.PrintNative(@class); + + var cctorName = CSharpSources.GetFunctionNativeIdentifier(Context.Context, cctor); + + var defaultValue = ""; + var TypePrinter = new CSharpTypePrinter(Context.Context); + var ExpressionPrinter = new CSharpExpressionPrinter(TypePrinter); + if (cctor.Parameters.Count > 1) + defaultValue = $", {ExpressionPrinter.VisitParameter(cctor.Parameters.Last())}"; + + Context.Before.WriteLine($"byte* __{Context.Parameter.Name}Memory = stackalloc byte[sizeof({nativeClass})];"); + Context.Before.WriteLine($"__IntPtr __{Context.Parameter.Name}Ptr = (__IntPtr)__{Context.Parameter.Name}Memory;"); + Context.Before.WriteLine($"{nativeClass}.{cctorName}(__{Context.Parameter.Name}Ptr, {Context.Parameter.Name}.__Instance{defaultValue});"); + Context.Return.Write($"__{Context.Parameter.Name}Ptr"); + + if (Context.Context.ParserOptions.IsItaniumLikeAbi && @class.HasNonTrivialDestructor) + { + Method dtor = @class.Destructors.FirstOrDefault(); + if (dtor != null) + { + // todo: virtual destructors? + Context.Cleanup.WriteLine($"{nativeClass}.dtor(__{Context.Parameter.Name}Ptr);"); + } + } + } + else + { + Context.Return.Write(paramInstance); + } } else { diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 8cedbf44d..c992bd5dc 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -3443,6 +3443,12 @@ public static string GetFunctionIdentifier(Function function) public string GetFunctionNativeIdentifier(Function function, bool isForDelegate = false) + { + return GetFunctionNativeIdentifier(Context, function, isForDelegate); + } + + public static string GetFunctionNativeIdentifier(BindingContext context, Function function, + bool isForDelegate = false) { var identifier = new StringBuilder(); @@ -3473,12 +3479,12 @@ public string GetFunctionNativeIdentifier(Function function, identifier.Append(Helpers.GetSuffixFor(specialization)); var internalParams = function.GatherInternalParams( - Context.ParserOptions.IsItaniumLikeAbi); + context.ParserOptions.IsItaniumLikeAbi); var overloads = function.Namespace.GetOverloads(function) .Where(f => (!f.Ignore || (f.OriginalFunction != null && !f.OriginalFunction.Ignore)) && (isForDelegate || internalParams.SequenceEqual( - f.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi), + f.GatherInternalParams(context.ParserOptions.IsItaniumLikeAbi), new MarshallingParamComparer()))).ToList(); var index = -1; if (overloads.Count > 1) diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index aa7a0532c..185c1d518 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -1984,4 +1984,19 @@ public void TestCallByValueCppToCSharpPointer() Assert.That(RuleOfThreeTester.CopyConstructorCalls, Is.EqualTo(0)); Assert.That(RuleOfThreeTester.CopyAssignmentCalls, Is.EqualTo(0)); } + + [Test] + public void TestCallByValueCopyConstructor() + { + using (var s = new CallByValueCopyConstructor()) + { + s.A = 500; + CSharp.CSharp.CallByValueCopyConstructorFunction(s); + Assert.That(s.A, Is.EqualTo(500)); + } + + Assert.That(CallByValueCopyConstructor.ConstructorCalls, Is.EqualTo(1)); + Assert.That(CallByValueCopyConstructor.CopyConstructorCalls, Is.EqualTo(1)); + Assert.That(CallByValueCopyConstructor.DestructorCalls, Is.EqualTo(2)); + } } diff --git a/tests/CSharp/CSharp.cpp b/tests/CSharp/CSharp.cpp index 8cb552aef..33ff64b04 100644 --- a/tests/CSharp/CSharp.cpp +++ b/tests/CSharp/CSharp.cpp @@ -1750,3 +1750,29 @@ void CallCallByValueInterfacePointer(CallByValueInterface* interface) RuleOfThreeTester value; interface->CallByPointer(&value); } + +int CallByValueCopyConstructor::constructorCalls = 0; +int CallByValueCopyConstructor::destructorCalls = 0; +int CallByValueCopyConstructor::copyConstructorCalls = 0; + +CallByValueCopyConstructor::CallByValueCopyConstructor() +{ + a = 0; + constructorCalls++; +} + +CallByValueCopyConstructor::CallByValueCopyConstructor(const CallByValueCopyConstructor& other) +{ + a = other.a; + copyConstructorCalls++; +} + +CallByValueCopyConstructor::~CallByValueCopyConstructor() +{ + destructorCalls++; +} + +void CallByValueCopyConstructorFunction(CallByValueCopyConstructor s) +{ + s.a = 99999; +} diff --git a/tests/CSharp/CSharp.h b/tests/CSharp/CSharp.h index bc9593886..6bd021b68 100644 --- a/tests/CSharp/CSharp.h +++ b/tests/CSharp/CSharp.h @@ -1582,3 +1582,16 @@ struct DLL_API CallByValueInterface { void DLL_API CallCallByValueInterfaceValue(CallByValueInterface*); void DLL_API CallCallByValueInterfaceReference(CallByValueInterface*); void DLL_API CallCallByValueInterfacePointer(CallByValueInterface*); + +struct DLL_API CallByValueCopyConstructor { + int a; + static int constructorCalls; + static int destructorCalls; + static int copyConstructorCalls; + + CallByValueCopyConstructor(); + ~CallByValueCopyConstructor(); + CallByValueCopyConstructor(const CallByValueCopyConstructor& other); +}; + +DLL_API void CallByValueCopyConstructorFunction(CallByValueCopyConstructor s);