Skip to content

Commit

Permalink
Add the ability to have multiple wasm intrinsic modules (e.g. for Ems…
Browse files Browse the repository at this point in the history
…cripten support). STACKED on Emscripten version PR. (envoyproxy#32)

* Add support for many intrinsic modules (e.g. for Emscripten support).
  • Loading branch information
jplevyak authored Mar 6, 2019
1 parent 96e552d commit 12290b3
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 58 deletions.
5 changes: 3 additions & 2 deletions source/extensions/common/wasm/wasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ Wasm::Wasm(absl::string_view vm, absl::string_view id, absl::string_view initial
}

void Wasm::registerFunctions() {
#define _REGISTER(_fn) registerCallback(wasm_vm_.get(), #_fn, &_fn##Handler);
#define _REGISTER(_fn) registerCallback(wasm_vm_.get(), "envoy", #_fn, &_fn##Handler);
if (is_emscripten_) {
_REGISTER(getTotalMemory);
_REGISTER(_emscripten_get_heap_size);
Expand All @@ -961,7 +961,8 @@ void Wasm::registerFunctions() {
#undef _REGISTER

// Calls with the "_proxy_" prefix.
#define _REGISTER_PROXY(_fn) registerCallback(wasm_vm_.get(), "_proxy_" #_fn, &_fn##Handler);
#define _REGISTER_PROXY(_fn) \
registerCallback(wasm_vm_.get(), "envoy", "_proxy_" #_fn, &_fn##Handler);
_REGISTER_PROXY(log);

_REGISTER_PROXY(getRequestStreamInfoProtocol);
Expand Down
10 changes: 7 additions & 3 deletions source/extensions/common/wasm/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ class WasmVm : public Logger::Loggable<Logger::Id::wasm> {
virtual absl::string_view getMemory(uint32_t pointer, uint32_t size) PURE;
// Set a block of memory in the VM, returns true on success, false if the pointer/size is invalid.
virtual bool setMemory(uint32_t pointer, uint32_t size, void* data) PURE;
// Make a new intrinsic module (e.g. for Emscripten support).
virtual void makeModule(absl::string_view name) PURE;

// Get the contents of the user section with the given name or "" if it does not exist and
// optionally a presence indicator.
Expand Down Expand Up @@ -506,15 +508,17 @@ inline Context::Context(Wasm* wasm) : wasm_(wasm), id_(wasm->allocContextId()) {

// Forward declarations for VM implemenations.
template <typename R, typename... Args>
void registerCallbackWavm(WasmVm* vm, absl::string_view functionName, R (*)(Args...));
void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*)(Args...));
template <typename R, typename... Args>
void getFunctionWavm(WasmVm* vm, absl::string_view functionName,
std::function<R(Context*, Args...)>*);

template <typename R, typename... Args>
void registerCallback(WasmVm* vm, absl::string_view functionName, R (*f)(Args...)) {
void registerCallback(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*f)(Args...)) {
if (vm->vm() == WasmVmNames::get().Wavm) {
registerCallbackWavm(vm, functionName, f);
registerCallbackWavm(vm, moduleName, functionName, f);
} else {
throw WasmVmException("unsupoorted wasm vm");
}
Expand Down
1 change: 1 addition & 0 deletions source/extensions/common/wasm/wavm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ envoy_cc_library(
"wavm_with_llvm",
],
deps = [
"//external:abseil_node_hash_map",
"//include/envoy/server:wasm_interface",
"//include/envoy/thread_local:thread_local_interface",
"//source/common/common:assert_lib",
Expand Down
128 changes: 75 additions & 53 deletions source/extensions/common/wasm/wavm/wavm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "WAVM/Runtime/RuntimeData.h"
#include "WAVM/WASM/WASM.h"
#include "WAVM/WASTParse/WASTParse.h"
#include "absl/container/node_hash_map.h"
#include "absl/strings/match.h"

using namespace WAVM;
Expand Down Expand Up @@ -189,13 +190,9 @@ struct Wavm : public WasmVm {
void* allocMemory(uint32_t size, uint32_t* pointer) override;
absl::string_view getMemory(uint32_t pointer, uint32_t size) override;
bool setMemory(uint32_t pointer, uint32_t size, void* data) override;
void makeModule(absl::string_view name) override;
absl::string_view getUserSection(absl::string_view name, bool* present) override;

WAVM::Runtime::Memory* memory() { return memory_; }
WAVM::Runtime::Context* context() { return context_; }
WAVM::Runtime::ModuleInstance* moduleInstance() { return moduleInstance_; }
WAVM::Runtime::ModuleInstance* envoyModuleInstance() { return moduleInstance_; }

void GetFunctions();
void RegisterCallbacks();

Expand All @@ -207,9 +204,10 @@ struct Wavm : public WasmVm {
Emscripten::Instance* emscriptenInstance_ = nullptr;
WAVM::Runtime::GCPointer<WAVM::Runtime::Compartment> compartment_;
WAVM::Runtime::GCPointer<WAVM::Runtime::Context> context_;
Intrinsics::Module envoy_module_;
WAVM::Runtime::GCPointer<WAVM::Runtime::ModuleInstance> envoyModuleInstance_ = nullptr;
std::vector<std::unique_ptr<Intrinsics::Function>> envoy_functions_;
absl::node_hash_map<std::string, Intrinsics::Module> intrinsicModules_;
absl::node_hash_map<std::string, WAVM::Runtime::GCPointer<WAVM::Runtime::ModuleInstance>>
intrinsicModuleInstances_;
std::vector<std::unique_ptr<Intrinsics::Function>> envoyFunctions_;
};

Wavm::~Wavm() {
Expand All @@ -222,8 +220,9 @@ Wavm::~Wavm() {
delete emscriptenInstance_;
}
context_ = nullptr;
envoyModuleInstance_ = nullptr;
envoy_functions_.clear();
intrinsicModuleInstances_.clear();
intrinsicModules_.clear();
envoyFunctions_.clear();
if (compartment_) {
ASSERT(tryCollectCompartment(std::move(compartment_)));
}
Expand All @@ -234,8 +233,10 @@ std::unique_ptr<WasmVm> Wavm::clone() {
wavm->compartment_ = WAVM::Runtime::cloneCompartment(compartment_);
wavm->memory_ = WAVM::Runtime::remapToClonedCompartment(memory_, wavm->compartment_);
wavm->context_ = WAVM::Runtime::createContext(wavm->compartment_);
wavm->envoyModuleInstance_ =
WAVM::Runtime::remapToClonedCompartment(envoyModuleInstance_, wavm->compartment_);
for (auto& p : intrinsicModuleInstances_) {
wavm->intrinsicModuleInstances_.emplace(
p.first, WAVM::Runtime::remapToClonedCompartment(p.second, wavm->compartment_));
}
wavm->moduleInstance_ =
WAVM::Runtime::remapToClonedCompartment(moduleInstance_, wavm->compartment_);
return wavm;
Expand Down Expand Up @@ -264,13 +265,18 @@ bool Wavm::load(const std::string& code, bool allow_precompiled) {
} else {
module_ = WAVM::Runtime::loadPrecompiledModule(irModule_, precompiledObjectSection->data);
}
makeModule("envoy");
return true;
}

void Wavm::link(absl::string_view name, bool needs_emscripten) {
RootResolver rootResolver(compartment_);
envoyModuleInstance_ = Intrinsics::instantiateModule(compartment_, envoy_module_, "envoy");
rootResolver.moduleNameToInstanceMap().set("envoy", envoyModuleInstance_);
for (auto& p : intrinsicModules_) {
auto instance = Intrinsics::instantiateModule(compartment_, intrinsicModules_[p.first],
std::string(p.first));
intrinsicModuleInstances_.emplace(p.first, instance);
rootResolver.moduleNameToInstanceMap().set(p.first, instance);
}
if (needs_emscripten) {
emscriptenInstance_ = Emscripten::instantiate(compartment_, irModule_);
rootResolver.moduleNameToInstanceMap().set("env", emscriptenInstance_->env);
Expand All @@ -283,6 +289,10 @@ void Wavm::link(absl::string_view name, bool needs_emscripten) {
memory_ = getDefaultMemory(moduleInstance_);
}

void Wavm::makeModule(absl::string_view name) {
intrinsicModules_.emplace(std::piecewise_construct, std::make_tuple(name), std::make_tuple());
}

void Wavm::start(Context* context) {
auto f = getStartFunction(moduleInstance_);
if (f) {
Expand All @@ -301,7 +311,7 @@ void Wavm::start(Context* context) {
}

void* Wavm::allocMemory(uint32_t size, uint32_t* address) {
auto f = asFunctionNullable(getInstanceExport(moduleInstance(), "_malloc"));
auto f = asFunctionNullable(getInstanceExport(moduleInstance_, "_malloc"));
if (!f)
return nullptr;
auto values = invokeFunctionChecked(context_, f, {size});
Expand All @@ -310,18 +320,18 @@ void* Wavm::allocMemory(uint32_t size, uint32_t* address) {
ASSERT(v.type == ValueType::i32);
*address = v.u32;
return reinterpret_cast<char*>(
WAVM::Runtime::memoryArrayPtr<U8>(memory(), v.u32, static_cast<U64>(size)));
WAVM::Runtime::memoryArrayPtr<U8>(memory_, v.u32, static_cast<U64>(size)));
}

absl::string_view Wavm::getMemory(uint32_t pointer, uint32_t size) {
return {reinterpret_cast<char*>(
WAVM::Runtime::memoryArrayPtr<U8>(memory(), pointer, static_cast<U64>(size))),
WAVM::Runtime::memoryArrayPtr<U8>(memory_, pointer, static_cast<U64>(size))),
static_cast<size_t>(size)};
}

bool Wavm::setMemory(uint32_t pointer, uint32_t size, void* data) {
auto p = reinterpret_cast<char*>(
WAVM::Runtime::memoryArrayPtr<U8>(memory(), pointer, static_cast<U64>(size)));
WAVM::Runtime::memoryArrayPtr<U8>(memory_, pointer, static_cast<U64>(size)));
if (p) {
memcpy(p, data, size);
return true;
Expand Down Expand Up @@ -357,69 +367,83 @@ IR::FunctionType inferEnvoyFunctionType(R (*)(void*, Args...)) {
using namespace Wavm;

template <typename R, typename... Args>
void registerCallbackWavm(WasmVm* vm, absl::string_view functionName, R (*f)(Args...)) {
void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*f)(Args...)) {
auto wavm = static_cast<Common::Wasm::Wavm::Wavm*>(vm);
wavm->envoy_functions_.emplace_back(
new Intrinsics::Function(wavm->envoy_module_, functionName.data(), reinterpret_cast<void*>(f),
inferEnvoyFunctionType(f), IR::CallingConvention::intrinsic));
wavm->envoyFunctions_.emplace_back(new Intrinsics::Function(
wavm->intrinsicModules_[moduleName], functionName.data(), reinterpret_cast<void*>(f),
inferEnvoyFunctionType(f), IR::CallingConvention::intrinsic));
}

template void registerCallbackWavm<void, void*>(WasmVm* vm, absl::string_view functionName,
void (*f)(void*));
template void registerCallbackWavm<void, void*, U32>(WasmVm* vm, absl::string_view functionName,
template void registerCallbackWavm<void, void*>(WasmVm* vm, absl::string_view moduleName,
absl::string_view functionName, void (*f)(void*));
template void registerCallbackWavm<void, void*, U32>(WasmVm* vm, absl::string_view moduleName,
absl::string_view functionName,
void (*f)(void*, U32));
template void registerCallbackWavm<void, void*, U32, U32>(WasmVm* vm,
template void registerCallbackWavm<void, void*, U32, U32>(WasmVm* vm, absl::string_view moduleName,
absl::string_view functionName,
void (*f)(void*, U32, U32));
template void registerCallbackWavm<void, void*, U32, U32, U32>(WasmVm* vm,
absl::string_view moduleName,
absl::string_view functionName,
void (*f)(void*, U32, U32, U32));
template void
registerCallbackWavm<void, void*, U32, U32, U32, U32>(WasmVm* vm, absl::string_view functionName,
registerCallbackWavm<void, void*, U32, U32, U32, U32>(WasmVm* vm, absl::string_view moduleName,
absl::string_view functionName,
void (*f)(void*, U32, U32, U32, U32));
template void registerCallbackWavm<void, void*, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName, void (*f)(void*, U32, U32, U32, U32, U32));
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
void (*f)(void*, U32, U32, U32, U32, U32));
template void registerCallbackWavm<void, void*, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName, void (*f)(void*, U32, U32, U32, U32, U32, U32));
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
void (*f)(void*, U32, U32, U32, U32, U32, U32));
template void registerCallbackWavm<void, void*, U32, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName,
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
void (*f)(void*, U32, U32, U32, U32, U32, U32, U32));
template void registerCallbackWavm<void, void*, U32, U32, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName,
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
void (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32));
template void registerCallbackWavm<void, void*, U32, U32, U32, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName,
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
void (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32, U32));
template void registerCallbackWavm<void, void*, U32, U32, U32, U32, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName,
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
void (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32, U32, U32));

template void registerCallbackWavm<U32, void*>(WasmVm* vm, absl::string_view functionName,
U32 (*f)(void*));
template void registerCallbackWavm<U32, void*, U32>(WasmVm* vm, absl::string_view functionName,
template void registerCallbackWavm<U32, void*>(WasmVm* vm, absl::string_view moduleName,
absl::string_view functionName, U32 (*f)(void*));
template void registerCallbackWavm<U32, void*, U32>(WasmVm* vm, absl::string_view moduleName,
absl::string_view functionName,
U32 (*f)(void*, U32));
template void registerCallbackWavm<U32, void*, U32, U32>(WasmVm* vm, absl::string_view functionName,
template void registerCallbackWavm<U32, void*, U32, U32>(WasmVm* vm, absl::string_view moduleName,
absl::string_view functionName,
U32 (*f)(void*, U32, U32));
template void registerCallbackWavm<U32, void*, U32, U32, U32>(WasmVm* vm,
absl::string_view moduleName,
absl::string_view functionName,
U32 (*f)(void*, U32, U32, U32));
template void
registerCallbackWavm<U32, void*, U32, U32, U32, U32>(WasmVm* vm, absl::string_view functionName,
registerCallbackWavm<U32, void*, U32, U32, U32, U32>(WasmVm* vm, absl::string_view moduleName,
absl::string_view functionName,
U32 (*f)(void*, U32, U32, U32, U32));
template void registerCallbackWavm<U32, void*, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32, U32));
template void
registerCallbackWavm<U32, void*, U32, U32, U32, U32, U32>(WasmVm* vm, absl::string_view moduleName,
absl::string_view functionName,
U32 (*f)(void*, U32, U32, U32, U32, U32));
template void registerCallbackWavm<U32, void*, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32, U32, U32));
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
U32 (*f)(void*, U32, U32, U32, U32, U32, U32));
template void registerCallbackWavm<U32, void*, U32, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32));
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32));
template void registerCallbackWavm<U32, void*, U32, U32, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName,
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32));
template void registerCallbackWavm<U32, void*, U32, U32, U32, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName,
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32, U32));
template void registerCallbackWavm<U32, void*, U32, U32, U32, U32, U32, U32, U32, U32, U32, U32>(
WasmVm* vm, absl::string_view functionName,
WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32, U32, U32));

template <typename R, typename... Args>
Expand All @@ -436,10 +460,9 @@ template <typename R, typename... Args>
void getFunctionWavmReturn(WasmVm* vm, absl::string_view functionName,
std::function<R(Context*, Args...)>* function, uint32_t) {
auto wavm = static_cast<Common::Wasm::Wavm::Wavm*>(vm);
auto f = asFunctionNullable(getInstanceExport(wavm->moduleInstance(), std::string(functionName)));
auto f = asFunctionNullable(getInstanceExport(wavm->moduleInstance_, std::string(functionName)));
if (!f)
f = asFunctionNullable(
getInstanceExport(wavm->envoyModuleInstance(), std::string(functionName)));
f = asFunctionNullable(getInstanceExport(wavm->moduleInstance_, std::string(functionName)));
if (!f) {
*function = nullptr;
return;
Expand All @@ -449,7 +472,7 @@ void getFunctionWavmReturn(WasmVm* vm, absl::string_view functionName,
}
*function = [wavm, f](Context* context, Args... args) -> R {
UntaggedValue values[] = {args...};
CALL_WITH_CONTEXT_RETURN(invokeFunctionUnchecked(wavm->context(), f, &values[0]), context,
CALL_WITH_CONTEXT_RETURN(invokeFunctionUnchecked(wavm->context_, f, &values[0]), context,
uint32_t, i32);
};
}
Expand All @@ -460,10 +483,9 @@ template <typename R, typename... Args>
void getFunctionWavmReturn(WasmVm* vm, absl::string_view functionName,
std::function<R(Context*, Args...)>* function, Void) {
auto wavm = static_cast<Common::Wasm::Wavm::Wavm*>(vm);
auto f = asFunctionNullable(getInstanceExport(wavm->moduleInstance(), std::string(functionName)));
auto f = asFunctionNullable(getInstanceExport(wavm->moduleInstance_, std::string(functionName)));
if (!f)
f = asFunctionNullable(
getInstanceExport(wavm->envoyModuleInstance(), std::string(functionName)));
f = asFunctionNullable(getInstanceExport(wavm->moduleInstance_, std::string(functionName)));
if (!f) {
*function = nullptr;
return;
Expand All @@ -473,7 +495,7 @@ void getFunctionWavmReturn(WasmVm* vm, absl::string_view functionName,
}
*function = [wavm, f](Context* context, Args... args) -> R {
UntaggedValue values[] = {args...};
CALL_WITH_CONTEXT(invokeFunctionUnchecked(wavm->context(), f, &values[0]), context);
CALL_WITH_CONTEXT(invokeFunctionUnchecked(wavm->context_, f, &values[0]), context);
};
}

Expand Down

0 comments on commit 12290b3

Please sign in to comment.