diff --git a/src/webnn_native/xnnpack/BackendXNN.cpp b/src/webnn_native/xnnpack/BackendXNN.cpp index 9a235a02d..9998e0ebe 100644 --- a/src/webnn_native/xnnpack/BackendXNN.cpp +++ b/src/webnn_native/xnnpack/BackendXNN.cpp @@ -18,13 +18,39 @@ #include "webnn_native/Instance.h" #include "webnn_native/xnnpack/ContextXNN.h" +#include + namespace webnn_native::xnnpack { Backend::Backend(InstanceBase* instance) : BackendConnection(instance, wnn::BackendType::XNNPACK) { } + Backend::~Backend() { + xnn_status status = xnn_deinitialize(); + if (status != xnn_status_success) { + dawn::ErrorLog() << "xnn_deinitialize failed: " << status; + return; + } + if (mThreadpool != NULL) { + pthreadpool_destroy(mThreadpool); + } + } + MaybeError Backend::Initialize() { + xnn_status status = xnn_initialize(NULL); + if (status != xnn_status_success) { + dawn::ErrorLog() << "xnn_initialize failed: " << status; + return DAWN_INTERNAL_ERROR("Failed to intialize XNNPACK."); + } + // Create a thread pool with as half of the logical processors in the system. + mThreadpool = pthreadpool_create(std::thread::hardware_concurrency() / 2); + if (mThreadpool == NULL) { + dawn::ErrorLog() << "pthreadpool_create failed"; + return DAWN_INTERNAL_ERROR("Failed to create thread pool."); + } + dawn::InfoLog() << "backend XNNPACK backend thread numbers: " + << pthreadpool_get_threads_count(mThreadpool); return {}; } @@ -33,12 +59,7 @@ namespace webnn_native::xnnpack { dawn::ErrorLog() << "XNNPACK backend only supports CPU device."; return nullptr; } - Ref context = AcquireRef(new Context(options)); - xnn_status status = reinterpret_cast(context.Get())->Init(); - if (status != xnn_status_success) { - dawn::ErrorLog() << "Failed to init XNNPACK:" << status; - return nullptr; - } + Ref context = AcquireRef(new Context(mThreadpool)); return context.Detach(); } diff --git a/src/webnn_native/xnnpack/BackendXNN.h b/src/webnn_native/xnnpack/BackendXNN.h index 1a06c9cb2..7f0279599 100644 --- a/src/webnn_native/xnnpack/BackendXNN.h +++ b/src/webnn_native/xnnpack/BackendXNN.h @@ -19,6 +19,8 @@ #include "webnn_native/Context.h" #include "webnn_native/Error.h" +#include + #include namespace webnn_native::xnnpack { @@ -26,9 +28,13 @@ namespace webnn_native::xnnpack { class Backend : public BackendConnection { public: Backend(InstanceBase* instance); + virtual ~Backend() override; MaybeError Initialize(); ContextBase* CreateContext(ContextOptions const* options = nullptr) override; + + private: + pthreadpool_t mThreadpool; }; } // namespace webnn_native::xnnpack diff --git a/src/webnn_native/xnnpack/ContextXNN.cpp b/src/webnn_native/xnnpack/ContextXNN.cpp index dcb508a0b..e0843c6e4 100644 --- a/src/webnn_native/xnnpack/ContextXNN.cpp +++ b/src/webnn_native/xnnpack/ContextXNN.cpp @@ -14,43 +14,13 @@ #include "webnn_native/xnnpack/ContextXNN.h" -#include - #include "common/Log.h" #include "common/RefCounted.h" #include "webnn_native/xnnpack/GraphXNN.h" namespace webnn_native::xnnpack { - Context::Context(ContextOptions const* options) { - } - - Context::~Context() { - xnn_status status = xnn_deinitialize(); - if (status != xnn_status_success) { - dawn::ErrorLog() << "xnn_deinitialize failed: " << status; - return; - } - if (mThreadpool != NULL) { - pthreadpool_destroy(mThreadpool); - } - } - - xnn_status Context::Init() { - xnn_status status = xnn_initialize(NULL); - if (status != xnn_status_success) { - dawn::ErrorLog() << "xnn_initialize failed: " << status; - return status; - } - // Create a thread pool with as half of the logical processors in the system. - mThreadpool = pthreadpool_create(std::thread::hardware_concurrency() / 2); - if (mThreadpool == NULL) { - dawn::ErrorLog() << "pthreadpool_create failed"; - return xnn_status_out_of_memory; - } - dawn::InfoLog() << "XNNPACK backend thread numbers: " - << pthreadpool_get_threads_count(mThreadpool); - return xnn_status_success; + Context::Context(pthreadpool_t threadpool) : mThreadpool(threadpool) { } pthreadpool_t Context::GetThreadpool() { diff --git a/src/webnn_native/xnnpack/ContextXNN.h b/src/webnn_native/xnnpack/ContextXNN.h index 82b27b96a..d7a18a188 100644 --- a/src/webnn_native/xnnpack/ContextXNN.h +++ b/src/webnn_native/xnnpack/ContextXNN.h @@ -23,10 +23,8 @@ namespace webnn_native::xnnpack { class Context : public ContextBase { public: - explicit Context(ContextOptions const* options); - ~Context() override; - - xnn_status Init(); + explicit Context(pthreadpool_t threadpool); + ~Context() override = default; pthreadpool_t GetThreadpool();