Skip to content

Commit

Permalink
Use automatic coarse grained memory management in wasm binding. (#5528)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Nov 9, 2024
1 parent 7c41446 commit 5ca37c3
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 136 deletions.
42 changes: 11 additions & 31 deletions source/slang-wasm/slang-wasm-bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,36 @@ EMSCRIPTEN_BINDINGS(slang)
{
constant("SLANG_OK", SLANG_OK);

function(
"createGlobalSession",
&slang::wgsl::createGlobalSession,
return_value_policy::take_ownership());

function("getLastError", &slang::wgsl::getLastError);

function(
"getCompileTargets",
&slang::wgsl::getCompileTargets,
return_value_policy::take_ownership());
function("getCompileTargets", &slang::wgsl::getCompileTargets);

class_<slang::wgsl::GlobalSession>("GlobalSession")
.function(
"createSession",
&slang::wgsl::GlobalSession::createSession,
return_value_policy::take_ownership());
allow_raw_pointers());

function("createGlobalSession", &slang::wgsl::createGlobalSession, allow_raw_pointers());

class_<slang::wgsl::Session>("Session")
.function(
"loadModuleFromSource",
&slang::wgsl::Session::loadModuleFromSource,
return_value_policy::take_ownership())
allow_raw_pointers())
.function(
"createCompositeComponentType",
&slang::wgsl::Session::createCompositeComponentType,
return_value_policy::take_ownership());
allow_raw_pointers());

class_<slang::wgsl::ComponentType>("ComponentType")
.function("link", &slang::wgsl::ComponentType::link, return_value_policy::take_ownership())
.function("link", &slang::wgsl::ComponentType::link, allow_raw_pointers())
.function("getEntryPointCode", &slang::wgsl::ComponentType::getEntryPointCode)
.function("getEntryPointCodeBlob", &slang::wgsl::ComponentType::getEntryPointCodeBlob)
.function("getTargetCodeBlob", &slang::wgsl::ComponentType::getTargetCodeBlob)
.function("getTargetCode", &slang::wgsl::ComponentType::getTargetCode)
.function("getLayout", &slang::wgsl::ComponentType::getLayout, allow_raw_pointers())
.function(
"loadStrings",
&slang::wgsl::ComponentType::loadStrings,
return_value_policy::take_ownership());
.function("loadStrings", &slang::wgsl::ComponentType::loadStrings, allow_raw_pointers());

class_<slang::wgsl::TypeLayoutReflection>("TypeLayoutReflection")
.function(
Expand Down Expand Up @@ -85,15 +76,15 @@ EMSCRIPTEN_BINDINGS(slang)
.function(
"findEntryPointByName",
&slang::wgsl::Module::findEntryPointByName,
return_value_policy::take_ownership())
allow_raw_pointers())
.function(
"findAndCheckEntryPoint",
&slang::wgsl::Module::findAndCheckEntryPoint,
return_value_policy::take_ownership())
allow_raw_pointers())
.function(
"getDefinedEntryPoint",
&slang::wgsl::Module::getDefinedEntryPoint,
return_value_policy::take_ownership())
allow_raw_pointers())
.function("getDefinedEntryPointCount", &slang::wgsl::Module::getDefinedEntryPointCount);

value_object<slang::wgsl::Error>("Error")
Expand All @@ -104,14 +95,6 @@ EMSCRIPTEN_BINDINGS(slang)
class_<slang::wgsl::EntryPoint, base<slang::wgsl::ComponentType>>("EntryPoint")
.function("getName", &slang::wgsl::EntryPoint::getName, allow_raw_pointers());

class_<slang::wgsl::CompileTargets>("CompileTargets")
.function(
"findCompileTarget",
&slang::wgsl::CompileTargets::findCompileTarget,
return_value_policy::take_ownership());

register_vector<slang::wgsl::ComponentType*>("ComponentTypeList");

register_vector<std::string>("StringList");
register_optional<std::vector<std::string>>();

Expand Down Expand Up @@ -251,7 +234,4 @@ EMSCRIPTEN_BINDINGS(slang)
"createLanguageServer",
&slang::wgsl::lsp::createLanguageServer,
return_value_policy::take_ownership());

class_<slang::wgsl::HashedString>("HashedString")
.function("getString", &slang::wgsl::HashedString::getString);
};
151 changes: 85 additions & 66 deletions source/slang-wasm/slang-wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace wgsl
{

Error g_error;
CompileTargets g_compileTargets;

Error getLastError()
{
Expand All @@ -25,9 +24,30 @@ Error getLastError()
return currentError;
}

CompileTargets* getCompileTargets()
emscripten::val getCompileTargets()
{
return &g_compileTargets;
struct TargetPair
{
const char* name;
SlangCompileTarget target;
};
static const TargetPair targets[] = {
{"GLSL", SLANG_GLSL},
{"HLSL", SLANG_HLSL},
{"WGSL", SLANG_WGSL},
{"SPIRV", SLANG_SPIRV},
{"METAL", SLANG_METAL},
};

std::vector<emscripten::val> result;
for (auto target : targets)
{
auto entry = emscripten::val::object();
entry.set("name", target.name);
entry.set("value", (int)target.target);
result.push_back(entry);
}
return emscripten::val::array(result);
}

GlobalSession* createGlobalSession()
Expand All @@ -46,35 +66,9 @@ GlobalSession* createGlobalSession()
return new GlobalSession(globalSession);
}

CompileTargets::CompileTargets()
{
#define MAKE_PAIR(x) {#x, SLANG_##x}

m_compileTargetMap = {
MAKE_PAIR(GLSL),
MAKE_PAIR(HLSL),
MAKE_PAIR(WGSL),
MAKE_PAIR(SPIRV),
MAKE_PAIR(METAL),
};
}

int CompileTargets::findCompileTarget(const std::string& name)
{
auto res = m_compileTargetMap.find(name);
if (res != m_compileTargetMap.end())
{
return res->second;
}
else
{
return SLANG_TARGET_UNKNOWN;
}
}

Session* GlobalSession::createSession(int compileTarget)
{
ISession* session = nullptr;
Slang::ComPtr<ISession> session;
{
SessionDesc sessionDesc = {};
sessionDesc.structureSize = sizeof(sessionDesc);
Expand All @@ -83,7 +77,7 @@ Session* GlobalSession::createSession(int compileTarget)
target.format = (SlangCompileTarget)compileTarget;
sessionDesc.targets = &target;
sessionDesc.targetCount = targetCount;
SlangResult result = m_interface->createSession(sessionDesc, &session);
SlangResult result = m_interface->createSession(sessionDesc, session.writeRef());
if (result != SLANG_OK)
{
g_error.type = std::string("USER");
Expand All @@ -95,12 +89,19 @@ Session* GlobalSession::createSession(int compileTarget)
return new Session(session);
}

Module* Session::loadModuleFromSource(
Session::~Session()
{
m_componentTypes = {};
auto refCount = static_cast<Slang::Linkage*>(m_interface.get())->debugGetReferenceCount();
m_interface = nullptr;
}

emscripten::val Session::loadModuleFromSource(
const std::string& slangCode,
const std::string& name,
const std::string& path)
{
Slang::ComPtr<IModule> module;
IModule* module = nullptr;
{
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
Slang::ComPtr<ISlangBlob> slangCodeBlob =
Expand All @@ -116,14 +117,13 @@ Module* Session::loadModuleFromSource(
g_error.message = std::string(
(char*)diagnosticsBlob->getBufferPointer(),
(char*)diagnosticsBlob->getBufferPointer() + diagnosticsBlob->getBufferSize());
return nullptr;
return emscripten::val::null();
}
}

return new Module(module);
return emscripten::val(Module(module, this));
}

EntryPoint* Module::findEntryPointByName(const std::string& name)
emscripten::val Module::findEntryPointByName(const std::string& name)
{
Slang::ComPtr<IEntryPoint> entryPoint;
{
Expand All @@ -133,15 +133,15 @@ EntryPoint* Module::findEntryPointByName(const std::string& name)
{
g_error.type = std::string("USER");
g_error.result = result;
return nullptr;
return emscripten::val::null();
}
}

return new EntryPoint(entryPoint);
m_session->addComponentType(entryPoint.get());
return emscripten::val(EntryPoint(entryPoint.get(), m_session));
}


EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage)
emscripten::val Module::findAndCheckEntryPoint(const std::string& name, int stage)
{
Slang::ComPtr<IEntryPoint> entryPoint;
{
Expand All @@ -161,22 +161,22 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage)
char* diagnostics = (char*)diagnosticsBlob->getBufferPointer();
g_error.message = std::string(diagnostics);
}
return nullptr;
return emscripten::val::null();
}
}

return new EntryPoint(entryPoint);
m_session->addComponentType(entryPoint.get());
return emscripten::val(EntryPoint(entryPoint.get(), m_session));
}

int Module::getDefinedEntryPointCount()
{
return moduleInterface()->getDefinedEntryPointCount();
}

EntryPoint* Module::getDefinedEntryPoint(int index)
emscripten::val Module::getDefinedEntryPoint(int index)
{
if (moduleInterface()->getDefinedEntryPointCount() <= index)
return nullptr;
return emscripten::val::null();

Slang::ComPtr<IEntryPoint> entryPoint;
{
Expand All @@ -192,21 +192,37 @@ EntryPoint* Module::getDefinedEntryPoint(int index)
char* diagnostics = (char*)diagnosticsBlob->getBufferPointer();
g_error.message = std::string(diagnostics);
}
return nullptr;
return emscripten::val::null();
}
}

return new EntryPoint(entryPoint);
m_session->addComponentType(entryPoint.get());
return emscripten::val(EntryPoint(entryPoint.get(), m_session));
}


ComponentType* Session::createCompositeComponentType(const std::vector<ComponentType*>& components)
emscripten::val Session::createCompositeComponentType(emscripten::val components)
{
if (!components.isArray())
{
g_error.type = std::string("Slang WASM Bind");
g_error.message = std::string("createCompositeComponentType: Components must be an array");
return emscripten::val::null();
}
std::vector<emscripten::val> componentsArray =
emscripten::vecFromJSArray<emscripten::val>(components);

Slang::ComPtr<IComponentType> composite;
{
std::vector<IComponentType*> nativeComponents(components.size());
for (size_t i = 0U; i < components.size(); i++)
nativeComponents[i] = components[i]->interface();
std::vector<IComponentType*> nativeComponents;
for (size_t i = 0U; i < componentsArray.size(); i++)
{
auto componentVal = componentsArray[i];
if (componentVal.instanceof (emscripten::val::module_property("ComponentType")))
{
auto componentType = componentVal.as<ComponentType>();
nativeComponents.push_back(componentType.interface());
}
}
SlangResult result = m_interface->createCompositeComponentType(
nativeComponents.data(),
(SlangInt)nativeComponents.size(),
Expand All @@ -215,14 +231,14 @@ ComponentType* Session::createCompositeComponentType(const std::vector<Component
{
g_error.type = std::string("USER");
g_error.result = result;
return nullptr;
return emscripten::val::null();
}
}

return new ComponentType(composite);
addComponentType(composite.get());
return emscripten::val(ComponentType(composite, this));
}

ComponentType* ComponentType::link()
emscripten::val ComponentType::link()
{
Slang::ComPtr<IComponentType> linkedProgram;
{
Expand All @@ -235,11 +251,11 @@ ComponentType* ComponentType::link()
g_error.message = std::string(
(char*)diagnosticBlob->getBufferPointer(),
(char*)diagnosticBlob->getBufferPointer() + diagnosticBlob->getBufferSize());
return nullptr;
return emscripten::val::null();
}
}

return new ComponentType(linkedProgram);
m_session->addComponentType(linkedProgram.get());
return emscripten::val(ComponentType(linkedProgram, m_session));
}

std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetIndex)
Expand Down Expand Up @@ -344,26 +360,26 @@ emscripten::val ComponentType::getTargetCodeBlob(int targetIndex)
return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(), ptr));
}

HashedString* ComponentType::loadStrings()
emscripten::val ComponentType::loadStrings()
{
slang::ProgramLayout* slangReflection = interface()->getLayout();
if (!slangReflection)
{
g_error.type = std::string("USER");
g_error.message = std::string("Failed to get reflection data");
return nullptr;
return emscripten::val::null();
}

SlangUInt hashedStringCount = slangReflection->getHashedStringCount();
if (hashedStringCount == 0)
{
g_error.type = std::string("USER");
g_error.message = std::string("Warn: No reflection data found");
return nullptr;
return emscripten::val::null();
}

size_t stringSize = 0;
HashedString* hashedStrings = new HashedString();
std::vector<emscripten::val> result;
for (SlangUInt ii = 0; ii < hashedStringCount; ++ii)
{
// For each string we can fetch its bytes from the Slang
Expand All @@ -381,9 +397,12 @@ HashedString* ComponentType::loadStrings()
//
int hash = spComputeStringHash(stringData, stringSize);

hashedStrings->insertString(hash, std::string(stringData));
emscripten::val entry = emscripten::val::object();
entry.set("hash", hash);
entry.set("string", std::string(stringData));
result.push_back(entry);
}
return hashedStrings;
return emscripten::val::array(result);
}

ProgramLayout* ComponentType::getLayout(unsigned int targetIndex)
Expand Down
Loading

0 comments on commit 5ca37c3

Please sign in to comment.