Skip to content

Commit

Permalink
Clean up of the multithreaded benchmark (pytorch#12905)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#12905

This diff does some clean up of the multithread benchmark code:
1. Split implementation to `.cc` file to separate implementation and improve build
2. Make `MutatingNetSupplier` more generic by providing the mutating function as an argument instead of virtual method.
3. Fix AI benchmark by sticking to the original option names

Reviewed By: highker

Differential Revision: D10479238

fbshipit-source-id: afa201fc287e3fdbb232db24513ecf8024501f66
  • Loading branch information
Yinghai Lu authored and facebook-github-bot committed Oct 22, 2018
1 parent 1b530fd commit 56bf485
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 39 deletions.
16 changes: 0 additions & 16 deletions caffe2/predictor/emulator/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,6 @@ C10_DEFINE_string(
"Each element of the array is a mapping from "
"operator index to its input types.");

// aysnc net related params
C10_DEFINE_string(
mutating_net_type,
"",
"If used, we will use async_scheduling instead simple net for predict net");

C10_DEFINE_bool(
mutating_net_async_deferrable_mode,
true,
"If used, use deferrable_mode for DFS scheduling in async_scheduling net");

C10_DEFINE_int(
mutating_net_async_workers,
-1,
"Set number of worker threads for the thread pool in asyn_scheduling net");

namespace caffe2 {
namespace emulator {

Expand Down
31 changes: 8 additions & 23 deletions caffe2/predictor/emulator/net_supplier.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#pragma once
#include <functional>

#include "caffe2/predictor/emulator/data_filler.h"
#include "caffe2/predictor/emulator/utils.h"

C10_DECLARE_string(mutating_net_type);
C10_DECLARE_bool(mutating_net_async_deferrable_mode);
C10_DECLARE_int(mutating_net_async_workers);

namespace caffe2 {
namespace emulator {

Expand Down Expand Up @@ -49,8 +47,10 @@ class SingleNetSupplier : public NetSupplier {

class MutatingNetSupplier : public NetSupplier {
public:
MutatingNetSupplier(std::unique_ptr<NetSupplier>&& core)
: core_(std::move(core)) {}
explicit MutatingNetSupplier(
std::unique_ptr<NetSupplier>&& core,
std::function<void(NetDef*)> m)
: core_(std::move(core)), mutator_(m) {}

RunnableNet next() override {
RunnableNet orig = core_->next();
Expand All @@ -60,30 +60,15 @@ class MutatingNetSupplier : public NetSupplier {
nets_.push_back(orig.netdef);
new_net = &nets_.back();
}
mutate(new_net);
mutator_(new_net);
return RunnableNet(*new_net, orig.filler);
}

protected:
virtual void mutate(NetDef* net) {
// Using async scheduling net if specified
if (!FLAGS_mutating_net_type.empty()) {
net->set_type(FLAGS_mutating_net_type);
if (FLAGS_mutating_net_async_workers > 0) {
net->set_num_workers(FLAGS_mutating_net_async_workers);
}
if (FLAGS_mutating_net_async_deferrable_mode) {
auto* arg = net->add_arg();
arg->set_name("deferrable_mode");
arg->set_i(1);
}
}
}

private:
std::mutex lock_;
std::unique_ptr<NetSupplier> core_;
std::vector<NetDef> nets_;
std::function<void(NetDef*)> mutator_;
};

} // namespace emulator
Expand Down

0 comments on commit 56bf485

Please sign in to comment.