Skip to content

Commit

Permalink
Merge pull request #241 from huningxin/fix_xnnpack_pthreadpool
Browse files Browse the repository at this point in the history
[XNNPACK] create pthreadpool per backend
  • Loading branch information
fujunwei authored Apr 27, 2022
2 parents b54c872 + 54d2b2b commit 93db15e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 41 deletions.
33 changes: 27 additions & 6 deletions src/webnn_native/xnnpack/BackendXNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,39 @@
#include "webnn_native/Instance.h"
#include "webnn_native/xnnpack/ContextXNN.h"

#include <thread>

namespace webnn_native::xnnpack {

Backend::Backend(InstanceBase* instance)
: BackendConnection(instance, wnn::BackendType::XNNPACK) {
}

Backend::~Backend() {
xnn_status status = xnn_deinitialize();
if (status != xnn_status_success) {
dawn::ErrorLog() << "xnn_deinitialize failed: " << status;
return;
}
if (mThreadpool != NULL) {
pthreadpool_destroy(mThreadpool);
}
}

MaybeError Backend::Initialize() {
xnn_status status = xnn_initialize(NULL);
if (status != xnn_status_success) {
dawn::ErrorLog() << "xnn_initialize failed: " << status;
return DAWN_INTERNAL_ERROR("Failed to intialize XNNPACK.");
}
// Create a thread pool with as half of the logical processors in the system.
mThreadpool = pthreadpool_create(std::thread::hardware_concurrency() / 2);
if (mThreadpool == NULL) {
dawn::ErrorLog() << "pthreadpool_create failed";
return DAWN_INTERNAL_ERROR("Failed to create thread pool.");
}
dawn::InfoLog() << "backend XNNPACK backend thread numbers: "
<< pthreadpool_get_threads_count(mThreadpool);
return {};
}

Expand All @@ -33,12 +59,7 @@ namespace webnn_native::xnnpack {
dawn::ErrorLog() << "XNNPACK backend only supports CPU device.";
return nullptr;
}
Ref<ContextBase> context = AcquireRef(new Context(options));
xnn_status status = reinterpret_cast<Context*>(context.Get())->Init();
if (status != xnn_status_success) {
dawn::ErrorLog() << "Failed to init XNNPACK:" << status;
return nullptr;
}
Ref<ContextBase> context = AcquireRef(new Context(mThreadpool));
return context.Detach();
}

Expand Down
6 changes: 6 additions & 0 deletions src/webnn_native/xnnpack/BackendXNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,22 @@
#include "webnn_native/Context.h"
#include "webnn_native/Error.h"

#include <xnnpack.h>

#include <memory>

namespace webnn_native::xnnpack {

class Backend : public BackendConnection {
public:
Backend(InstanceBase* instance);
virtual ~Backend() override;

MaybeError Initialize();
ContextBase* CreateContext(ContextOptions const* options = nullptr) override;

private:
pthreadpool_t mThreadpool;
};

} // namespace webnn_native::xnnpack
Expand Down
32 changes: 1 addition & 31 deletions src/webnn_native/xnnpack/ContextXNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,13 @@

#include "webnn_native/xnnpack/ContextXNN.h"

#include <thread>

#include "common/Log.h"
#include "common/RefCounted.h"
#include "webnn_native/xnnpack/GraphXNN.h"

namespace webnn_native::xnnpack {

Context::Context(ContextOptions const* options) {
}

Context::~Context() {
xnn_status status = xnn_deinitialize();
if (status != xnn_status_success) {
dawn::ErrorLog() << "xnn_deinitialize failed: " << status;
return;
}
if (mThreadpool != NULL) {
pthreadpool_destroy(mThreadpool);
}
}

xnn_status Context::Init() {
xnn_status status = xnn_initialize(NULL);
if (status != xnn_status_success) {
dawn::ErrorLog() << "xnn_initialize failed: " << status;
return status;
}
// Create a thread pool with as half of the logical processors in the system.
mThreadpool = pthreadpool_create(std::thread::hardware_concurrency() / 2);
if (mThreadpool == NULL) {
dawn::ErrorLog() << "pthreadpool_create failed";
return xnn_status_out_of_memory;
}
dawn::InfoLog() << "XNNPACK backend thread numbers: "
<< pthreadpool_get_threads_count(mThreadpool);
return xnn_status_success;
Context::Context(pthreadpool_t threadpool) : mThreadpool(threadpool) {
}

pthreadpool_t Context::GetThreadpool() {
Expand Down
6 changes: 2 additions & 4 deletions src/webnn_native/xnnpack/ContextXNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ namespace webnn_native::xnnpack {

class Context : public ContextBase {
public:
explicit Context(ContextOptions const* options);
~Context() override;

xnn_status Init();
explicit Context(pthreadpool_t threadpool);
~Context() override = default;

pthreadpool_t GetThreadpool();

Expand Down

0 comments on commit 93db15e

Please sign in to comment.