Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mono][wasm] Force interpreter to initialize the pointers #100288

Merged
merged 13 commits into from
Mar 28, 2024
23 changes: 23 additions & 0 deletions src/mono/browser/runtime/runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ int monoeg_g_setenv(const char *variable, const char *value, int overwrite);
int32_t mini_parse_debug_option (const char *option);
char *mono_method_get_full_name (MonoMethod *method);
void mono_trace_init (void);
MonoMethod *mono_marshal_get_managed_wrapper (MonoMethod *method, MonoClass *delegate_klass, MonoGCHandle target_handle, MonoError *error);

/* Not part of public headers */
#define MONO_ICALL_TABLE_CALLBACKS_VERSION 3
Expand Down Expand Up @@ -356,3 +357,25 @@ mono_wasm_assembly_find_method (MonoClass *klass, const char *name, int argument
MONO_EXIT_GC_UNSAFE;
return result;
}

/*
* mono_wasm_marshal_get_managed_wrapper:
* Creates a wrapper for a function pointer to a method marked with
* UnamangedCallersOnlyAttribute.
* This wrapper ensures that the interpreter initializes the pointers.
*/
void
mono_wasm_marshal_get_managed_wrapper (const char* assemblyName, const char* typeName, const char* methodName, int num_params)
{
MonoError error;
mono_error_init (&error);
MonoAssembly* assembly = mono_wasm_assembly_load (assemblyName);
assert (assembly);
MonoClass* class = mono_wasm_assembly_find_class (assembly, "", typeName);
assert (class);
MonoMethod* method = mono_wasm_assembly_find_method (class, methodName, num_params);
assert (method);
MonoMethod *managedWrapper = mono_marshal_get_managed_wrapper (method, NULL, 0, &error);
assert (managedWrapper);
mono_compile_method (managedWrapper);
}
2 changes: 2 additions & 0 deletions src/mono/browser/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ extern int mono_wasm_enable_gc;
MonoDomain *mono_wasm_load_runtime_common (int debug_level, MonoLogCallback log_callback, const char *interp_opts);
MonoAssembly *mono_wasm_assembly_load (const char *name);
MonoClass *mono_wasm_assembly_find_class (MonoAssembly *assembly, const char *namespace, const char *name);
MonoMethod *mono_wasm_assembly_find_method (MonoClass *klass, const char *name, int arguments);
void mono_wasm_marshal_get_managed_wrapper (const char* assemblyName, const char* typeName, const char* methodName, int num_params);

#endif
5 changes: 0 additions & 5 deletions src/mono/sample/wasi/native/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ public static int MyExport(int number)
public unsafe static int Main(string[] args)
{
Console.WriteLine($"main: {args.Length}");
// workaround to force the interpreter to initialize wasm_native_to_interp_ftndesc for MyExport
if (args.Length > 10000) {
((IntPtr)(delegate* unmanaged<int,int>)&MyExport).ToString();
}

MyImport();
return 0;
}
Expand Down
76 changes: 76 additions & 0 deletions src/mono/wasi/Wasi.Build.Tests/PInvokeTableGeneratorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO;
using Xunit;
using Xunit.Abstractions;
using Wasm.Build.Tests;

#nullable enable

namespace Wasi.Build.Tests;

public class PInvokeTableGeneratorTests : BuildTestBase
{
public PInvokeTableGeneratorTests(ITestOutputHelper output, SharedBuildPerTestClassFixture buildContext)
: base(output, buildContext)
{
}

[Fact]
public void InteropSupportForUnmanagedEntryPointWithoutDelegate()
{
string config = "Release";
string id = $"{config}_{GetRandomId()}";
string projectFile = CreateWasmTemplateProject(id, "wasiconsole");
string code =
"""
using System;
using System.Runtime.InteropServices;
public unsafe class Test
{
[UnmanagedCallersOnly(EntryPoint = "ManagedFunc")]
public static int MyExport(int number)
{
// called from MyImport aka UnmanagedFunc
Console.WriteLine($"MyExport({number}) -> 42");
return 42;
}

[DllImport("*", EntryPoint = "UnmanagedFunc")]
public static extern void MyImport(); // calls ManagedFunc aka MyExport

public unsafe static int Main(string[] args)
{
Console.WriteLine($"main: {args.Length}");
MyImport();
return 0;
}
}
""";

string projectName = Path.GetFileNameWithoutExtension(projectFile);

var buildArgs = new BuildArgs(projectName, config, AOT: true, ProjectFileContents: id, ExtraBuildArgs: null);
buildArgs = ExpandBuildArgs(buildArgs);
AddItemsPropertiesToProject("<NativeFileReference Include=\"local.c\" />");
AddItemsPropertiesToProject(projectFile, "<WasmSingleFileBundle>true</WasmSingleFileBundle>");

BuildProject(buildArgs,
id: id,
new BuildProjectOptions(
InitProject: () =>
{
File.WriteAllText(Path.Combine(_projectDir!, "Program.cs"), code);
mkhamoyan marked this conversation as resolved.
Show resolved Hide resolved
},
DotnetWasmFromRuntimePack: false,
Publish: true,
TargetFramework: BuildTestBase.DefaultTargetFramework));

CommandResult res = new RunCommand(s_buildEnv, _testOutput)
.WithWorkingDirectory(_projectDir!)
.ExecuteWithCapturedOutput($"run --no-silent --no-build -c {config}")
.EnsureSuccessful();
Assert.Contains("MyExport(123) -> 42", res.Output);
}
}
11 changes: 11 additions & 0 deletions src/mono/wasm/templates/templates/wasi-console/local.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include <stdio.h>

int ManagedFunc(int number);

void UnmanagedFunc()
{
int ret = 0;
printf("UnmanagedFunc calling ManagedFunc\n");
ret = ManagedFunc(123);
printf("ManagedFunc returned %d\n", ret);
}
19 changes: 18 additions & 1 deletion src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ private void EmitNativeToInterp(StreamWriter w, List<PInvokeCallback> callbacks)
// Only blittable parameter/return types are supposed.
int cb_index = 0;

w.Write(@"#include <mono/utils/details/mono-error-types.h>
#include <mono/metadata/assembly.h>
#include <mono/utils/mono-error.h>
#include <mono/metadata/object.h>
#include <mono/utils/details/mono-logger-types.h>
#include ""runtime.h""
");

// Arguments to interp entry functions in the runtime
w.WriteLine($"InterpFtnDesc wasm_native_to_interp_ftndescs[{callbacks.Count}] = {{}};");

Expand Down Expand Up @@ -371,7 +379,16 @@ private void EmitNativeToInterp(StreamWriter w, List<PInvokeCallback> callbacks)
if (!is_void)
sb.Append($" {MapType(method.ReturnType)} res;\n");

//sb.Append($" printf(\"{entry_name} called\\n\");\n");
// In case when null force interpreter to initialize the pointers
sb.Append($" if (!(WasmInterpEntrySig_{cb_index})wasm_native_to_interp_ftndescs [{cb_index}].func) {{\n");
var assemblyFullName = cb.Method.DeclaringType == null ? "" : cb.Method.DeclaringType.Assembly.FullName;
var assemblyName = assemblyFullName != null && assemblyFullName.Split(',').Length > 0 ? assemblyFullName.Split(',')[0].Trim() : "";
var typeName = cb.Method.DeclaringType == null || cb.Method.DeclaringType.FullName == null ? "" : cb.Method.DeclaringType.FullName;
var methodName = cb.Method.Name;
int numParams = method.GetParameters().Length;
sb.Append($" mono_wasm_marshal_get_managed_wrapper (\"{assemblyName}\", \"{typeName}\", \"{methodName}\", {numParams});\n");
sb.Append($" }}\n");

sb.Append($" ((WasmInterpEntrySig_{cb_index})wasm_native_to_interp_ftndescs [{cb_index}].func) (");
if (!is_void)
{
Expand Down
Loading