Skip to content

Commit

Permalink
[wasm] more cases when looking up unmanaged delegates (dotnet#107113)
Browse files Browse the repository at this point in the history
Make the association between the wasm_native_to_interp_ftndescs generation and the lookup from unmanaged more robust so that we don't see problems like dotnet#107212 where the same slot was being reused for multiple methods with different signatures. To do this we change the Key(s) we use and fix the string escaping it relies on, and attempt to lookup by token first.

Next , we rewrite the C code generation to make it easier to read and modify and mitigate some potentially negative memory side effects of that we introduce a gratuitous custom text writer that understands the idea of concatenated strings and use that where possible when building the output.

Next, we change the import code generation to use binary rather than linear search for both the module and symbol. And finally, we update the ICall table generation to use the extensions.

part of dotnet#104391 and dotnet#107212
  • Loading branch information
lewing authored and jtschuster committed Sep 17, 2024
1 parent 06e771e commit 1be0042
Show file tree
Hide file tree
Showing 11 changed files with 549 additions and 227 deletions.
62 changes: 47 additions & 15 deletions src/mono/browser/runtime/pinvoke.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,71 @@ mono_wasm_pinvoke_vararg_stub (void)
/* This is just a stub used to mark vararg pinvokes */
}

int
table_compare_name (const void *t1, const void *t2)
{
return strcmp (((PinvokeTable*)t1)->name, ((PinvokeTable*)t2)->name);
}

void*
wasm_dl_lookup_pinvoke_table (const char *name)
{
for (int i = 0; i < sizeof (pinvoke_tables) / sizeof (void*); ++i) {
if (!strcmp (name, pinvoke_names [i]))
return pinvoke_tables [i];
}
return NULL;
PinvokeImport needle = { name, NULL };
return bsearch (&needle, pinvoke_tables, (sizeof (pinvoke_tables) / sizeof (PinvokeTable)), sizeof (PinvokeTable), table_compare_name);
}

int
wasm_dl_is_pinvoke_table (void *handle)
{
for (int i = 0; i < sizeof (pinvoke_tables) / sizeof (void*); ++i) {
if (pinvoke_tables [i] == handle) {
for (int i = 0; i < sizeof (pinvoke_tables) / sizeof (PinvokeTable); ++i) {
if (&pinvoke_tables[i] == handle) {
return 1;
}
}
return 0;
}

static int
export_compare_key (const void *k1, const void *k2)
{
return strcmp (((UnmanagedExport*)k1)->key, ((UnmanagedExport*)k2)->key);
}

static int
export_compare_key_and_token (const void *k1, const void *k2)
{
UnmanagedExport *e1 = (UnmanagedExport*)k1;
UnmanagedExport *e2 = (UnmanagedExport*)k2;

// first compare by key
int compare = strcmp (e1->key, e2->key);
if (compare)
return compare;

// then by token
return (int)(e1->token - e2->token);
}

void*
wasm_dl_get_native_to_interp (const char *key, void *extra_arg)
wasm_dl_get_native_to_interp (uint32_t token, const char *key, void *extra_arg)
{
#ifdef GEN_PINVOKE
for (int i = 0; i < sizeof (wasm_native_to_interp_map) / sizeof (void*); ++i) {
if (!strcmp (wasm_native_to_interp_map [i], key)) {
void *addr = wasm_native_to_interp_funcs [i];
wasm_native_to_interp_ftndescs [i] = *(InterpFtnDesc*)extra_arg;
return addr;
}
UnmanagedExport needle = { key, token, NULL };
int count = (sizeof (wasm_native_to_interp_table) / sizeof (UnmanagedExport));

// comparison must match the one used in the PInvokeTableGenerator to ensure the same order
UnmanagedExport *result = bsearch (&needle, wasm_native_to_interp_table, count, sizeof (UnmanagedExport), export_compare_key_and_token);
if (!result) {
// assembly may have been trimmed / modified, try to find by key only
result = bsearch (&needle, wasm_native_to_interp_table, count, sizeof (UnmanagedExport), export_compare_key);
}
return NULL;

if (!result)
return NULL;

void *addr = result->func;
wasm_native_to_interp_ftndescs [result - wasm_native_to_interp_table] = *(InterpFtnDesc*)extra_arg;
return addr;
#else
return NULL;
#endif
Expand Down
16 changes: 14 additions & 2 deletions src/mono/browser/runtime/pinvoke.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ typedef struct {
void *func;
} PinvokeImport;

typedef struct {
const char *name;
PinvokeImport *imports;
int count;
} PinvokeTable;

typedef struct {
const char *key;
uint32_t token;
void *func;
} UnmanagedExport;

typedef struct {
void *func;
void *arg;
Expand All @@ -20,7 +32,7 @@ int
wasm_dl_is_pinvoke_table (void *handle);

void*
wasm_dl_get_native_to_interp (const char *key, void *extra_arg);
wasm_dl_get_native_to_interp (uint32_t token, const char *key, void *extra_arg);

void
mono_wasm_pinvoke_vararg_stub (void);
Expand All @@ -45,6 +57,6 @@ double
mono_wasm_interp_method_args_get_darg (MonoInterpMethodArguments *margs, int i);

void*
mono_wasm_interp_method_args_get_retval (MonoInterpMethodArguments *margs);
mono_wasm_interp_method_args_get_retval (MonoInterpMethodArguments *margs);

#endif
95 changes: 72 additions & 23 deletions src/mono/browser/runtime/runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,29 +199,43 @@ init_icall_table (void)
static void*
get_native_to_interp (MonoMethod *method, void *extra_arg)
{
void *addr;

void *addr = NULL;
MONO_ENTER_GC_UNSAFE;
MonoClass *klass = mono_method_get_class (method);
MonoImage *image = mono_class_get_image (klass);
MonoAssembly *assembly = mono_image_get_assembly (image);
MonoAssemblyName *aname = mono_assembly_get_name (assembly);
const char *name = mono_assembly_name_get_name (aname);
const char *namespace = mono_class_get_namespace (klass);
const char *class_name = mono_class_get_name (klass);
const char *method_name = mono_method_get_name (method);
char key [128];
MonoMethodSignature *sig = mono_method_signature (method);
uint32_t param_count = mono_signature_get_param_count (sig);
uint32_t token = mono_method_get_token (method);

char buf [128];
char *key = buf;
int len;
if (name != NULL) {
// the key must match the one used in PInvokeTableGenerator
len = snprintf (key, sizeof(buf), "%s#%d:%s:%s:%s", method_name, param_count, name, namespace, class_name);

if (len >= sizeof (buf)) {
// The key is too long, try again with a larger buffer
key = g_new (char, len + 1);
snprintf (key, len + 1, "%s#%d:%s:%s:%s", method_name, param_count, name, namespace, class_name);
}

assert (strlen (name) < 100);
snprintf (key, sizeof(key), "%s_%s_%s", name, class_name, method_name);
char *fixedName = mono_fixup_symbol_name ("", key, "");
addr = wasm_dl_get_native_to_interp (fixedName, extra_arg);
free (fixedName);
addr = wasm_dl_get_native_to_interp (token, key, extra_arg);

if (key != buf)
free (key);
}
MONO_EXIT_GC_UNSAFE;
return addr;
}

static void *sysglobal_native_handle;
static void *sysglobal_native_handle = (void *)0xDeadBeef;

static void*
wasm_dl_load (const char *name, int flags, char **err, void *user_data)
Expand All @@ -248,24 +262,33 @@ wasm_dl_load (const char *name, int flags, char **err, void *user_data)
return NULL;
}

int
import_compare_name (const void *k1, const void *k2)
{
const PinvokeImport *e1 = (const PinvokeImport*)k1;
const PinvokeImport *e2 = (const PinvokeImport*)k2;

return strcmp (e1->name, e2->name);
}

static void*
wasm_dl_symbol (void *handle, const char *name, char **err, void *user_data)
{
if (handle == sysglobal_native_handle)
assert (0);
assert (handle != sysglobal_native_handle);

#if WASM_SUPPORTS_DLOPEN
if (!wasm_dl_is_pinvoke_tables (handle)) {
return dlsym (handle, name);
}
#endif

PinvokeImport *table = (PinvokeImport*)handle;
for (int i = 0; table [i].name; ++i) {
if (!strcmp (table [i].name, name))
return table [i].func;
}
return NULL;
PinvokeTable* index = (PinvokeTable*)handle;
PinvokeImport key = { name, NULL };
PinvokeImport* result = (PinvokeImport *)bsearch(&key, index->imports, index->count, sizeof(PinvokeImport), import_compare_name);
if (!result) {
// *err = g_strdup_printf ("Symbol not found: %s", name);
return NULL;
}
return result->func;
}

MonoDomain *
Expand Down Expand Up @@ -363,24 +386,50 @@ mono_wasm_assembly_find_method (MonoClass *klass, const char *name, int argument
return result;
}

MonoMethod*
mono_wasm_get_method_matching (MonoImage *image, uint32_t token, MonoClass *klass, const char* name, int param_count)
{
MonoMethod *result = NULL;
MONO_ENTER_GC_UNSAFE;
MonoMethod *method = mono_get_method (image, token, klass);
MonoMethodSignature *sig = mono_method_signature (method);
// Lookp by token but verify the name and param count in case assembly was trimmed
if (mono_signature_get_param_count (sig) == param_count) {
const char *method_name = mono_method_get_name (method);
if (!strcmp (method_name, name)) {
result = method;
}
}
// If the token lookup failed, try to find the method by name and param count
if (!result) {
result = mono_class_get_method_from_name (klass, name, param_count);
}
MONO_EXIT_GC_UNSAFE;
return result;
}

/*
* mono_wasm_marshal_get_managed_wrapper:
* Creates a wrapper for a function pointer to a method marked with
* UnamangedCallersOnlyAttribute.
* This wrapper ensures that the interpreter initializes the pointers.
*/
void
mono_wasm_marshal_get_managed_wrapper (const char* assemblyName, const char* namespaceName, const char* typeName, const char* methodName, int num_params)
mono_wasm_marshal_get_managed_wrapper (const char* assemblyName, const char* namespaceName, const char* typeName, const char* methodName, uint32_t token, int param_count)
{
MonoError error;
mono_error_init (&error);
MONO_ENTER_GC_UNSAFE;
MonoAssembly* assembly = mono_wasm_assembly_load (assemblyName);
assert (assembly);
MonoClass* class = mono_wasm_assembly_find_class (assembly, namespaceName, typeName);
assert (class);
MonoMethod* method = mono_wasm_assembly_find_method (class, methodName, num_params);
MonoImage *image = mono_assembly_get_image (assembly);
assert (image);
MonoClass* klass = mono_class_from_name (image, namespaceName, typeName);
assert (klass);
MonoMethod *method = mono_wasm_get_method_matching (image, token, klass, methodName, param_count);
assert (method);
MonoMethod *managedWrapper = mono_marshal_get_managed_wrapper (method, NULL, 0, &error);
assert (managedWrapper);
mono_compile_method (managedWrapper);
}
MONO_EXIT_GC_UNSAFE;
}
2 changes: 1 addition & 1 deletion src/mono/browser/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ MonoDomain *mono_wasm_load_runtime_common (int debug_level, MonoLogCallback log_
MonoAssembly *mono_wasm_assembly_load (const char *name);
MonoClass *mono_wasm_assembly_find_class (MonoAssembly *assembly, const char *namespace, const char *name);
MonoMethod *mono_wasm_assembly_find_method (MonoClass *klass, const char *name, int arguments);
void mono_wasm_marshal_get_managed_wrapper (const char* assemblyName, const char* namespaceName, const char* typeName, const char* methodName, int num_params);
void mono_wasm_marshal_get_managed_wrapper (const char* assemblyName, const char* namespaceName, const char* typeName, const char* methodName, uint32_t token, int param_count);
int initialize_runtime ();

#endif
66 changes: 66 additions & 0 deletions src/mono/wasm/Wasm.Build.Tests/PInvokeTableGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,71 @@ file class Foo
Assert.Contains("Main running", output);
}

[Theory]
[BuildAndRun(host: RunHost.Chrome)]
public void UnmanagedCallersOnly_Namespaced(BuildArgs buildArgs, RunHost host, string id)
{
string code =
"""
using System;
using System.Runtime.InteropServices;

public class Test
{
public unsafe static int Main()
{
((delegate* unmanaged<void>)&A.Conflict.C)();
((delegate* unmanaged<void>)&B.Conflict.C)();
((delegate* unmanaged<void>)&A.Conflict.C\u733f)();
((delegate* unmanaged<void>)&B.Conflict.C\u733f)();
return 42;
}
}

namespace A {
public class Conflict {
[UnmanagedCallersOnly(EntryPoint = "A_Conflict_C")]
public static void C() {
Console.WriteLine("A.Conflict.C");
}

[UnmanagedCallersOnly(EntryPoint = "A_Conflict_C\u733f")]
public static void C\u733f() {
Console.WriteLine("A.Conflict.C\U0001F412");
}
}
}

namespace B {
public class Conflict {
[UnmanagedCallersOnly(EntryPoint = "B_Conflict_C")]
public static void C() {
Console.WriteLine("B.Conflict.C");
}

[UnmanagedCallersOnly(EntryPoint = "B_Conflict_C\u733f")]
public static void C\u733f() {
Console.WriteLine("B.Conflict.C\U0001F412");
}
}
}
""";

(buildArgs, string output) = BuildForVariadicFunctionTests(
code,
buildArgs with { ProjectName = $"cb_namespace_{buildArgs.Config}" },
id
);

Assert.DoesNotMatch(".*(warning|error).*>[A-Z0-9]+__Foo", output);

output = RunAndTestWasmApp(buildArgs, buildDir: _projectDir, expectedExitCode: 42, host: host, id: id);
Assert.Contains("A.Conflict.C", output);
Assert.Contains("B.Conflict.C", output);
Assert.Contains("A.Conflict.C\U0001F412", output);
Assert.Contains("B.Conflict.C\U0001F412", output);
}

[Theory]
[BuildAndRun(host: RunHost.None)]
public void IcallWithOverloadedParametersAndEnum(BuildArgs buildArgs, string id)
Expand Down Expand Up @@ -951,6 +1016,7 @@ public void UCOWithSpecialCharacters(BuildArgs buildArgs, RunHost host, string i
DotnetWasmFromRuntimePack: false));

var runOutput = RunAndTestWasmApp(buildArgs, buildDir: _projectDir, expectedExitCode: 42, host: host, id: id);
Assert.DoesNotContain("Conflict.A.Managed8\u4F60Func(123) -> 123", runOutput);
Assert.Contains("ManagedFunc returned 42", runOutput);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,24 @@ public unsafe partial class Test
public unsafe static int Main(string[] args)
{
((IntPtr)(delegate* unmanaged<int,int>)&Interop.Managed8\u4F60Func).ToString();

Console.WriteLine($"main: {args.Length}");
Interop.UnmanagedFunc();

return 42;
}
}

namespace Conflict.A {
file class Interop {
[UnmanagedCallersOnly(EntryPoint = "ConflictManagedFunc")]
public static int Managed8\u4F60Func(int number)
{
Console.WriteLine($"Conflict.A.Managed8\u4F60Func({number}) -> {number}");
return number;
}
}
}

file partial class Interop
{
[UnmanagedCallersOnly(EntryPoint = "ManagedFunc")]
Expand Down
2 changes: 1 addition & 1 deletion src/mono/wasm/testassets/native-libs/local.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ void UnmanagedFunc()
printf("UnmanagedFunc calling ManagedFunc\n");
ret = ManagedFunc(123);
printf("ManagedFunc returned %d\n", ret);
}
}
Loading

0 comments on commit 1be0042

Please sign in to comment.