Skip to content

Commit

Permalink
feat: allow users to set fallback block size and ops
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Mar 10, 2021
1 parent f4c29b4 commit 6d3064a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 14 deletions.
9 changes: 6 additions & 3 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
old_to_new_g[seg.raw_outputs()[i]] = old_to_new_g[seg.outputs()[i]];
}

LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
// LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
return;
}

Expand All @@ -208,7 +208,10 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo


// segment the graph and convert segmented TensorRT block
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges);
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges, convert_cfg.engine_settings.torch_fallback);
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
return mod;
}

int trt_engine_id = 0;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
Expand All @@ -233,7 +236,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
new_g->registerOutput(old_to_new_g[output]);
}

LOG_INFO(*new_g << "(After CompileGraph)\n");
LOG_INFO(*new_g << "(StitchSegmentedGraph)\n");

auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
Expand Down
8 changes: 7 additions & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {

os << "\n Torch Fallback: " << s.torch_fallback.enabled;
if (s.torch_fallback.enabled) {
os << "\n Fallback min block size: " << s.torch_fallback.min_block_size;
os << "\n Fallback Min Block Size: " << s.torch_fallback.min_block_size;
if (!s.torch_fallback.forced_fallback_operators.empty()) {
os << "\n Forced Fallback Operators:";
for (auto it = s.torch_fallback.forced_fallback_operators.begin(); it != s.torch_fallback.forced_fallback_operators.end(); ++it) {
os << " " << *it;
}
}
}
return os;
}
Expand Down
44 changes: 36 additions & 8 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_map<tor
// create a module to run the graph
auto g = seg_block.g();
auto copy_g = g->copy();
if (seg_block.raw_outputs().size() > 1) {
auto new_output_node = copy_g->appendNode(copy_g->createTuple(copy_g->outputs()));
for (int idx = copy_g->outputs().size() - 1; idx >= 0; --idx) {
copy_g->eraseOutput(idx);
}
copy_g->registerOutput(new_output_node->outputs()[0]);
}

torch::jit::script::Module cur_mod(c10::QualifiedName("module"));

auto self = copy_g->insertInput(0, "self_1");
Expand Down Expand Up @@ -140,25 +148,45 @@ void registerSegmentsInputsOutputs(std::vector<SegmentedBlock> &segmented_blocks
return;
}

std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, std::vector<conversion::InputRange>& input_ranges) {
void merge_nodes(std::vector<torch::jit::Node*> &pytorch_nodes, std::vector<torch::jit::Node*> &tensorrt_nodes,
std::vector<SegmentedBlock> &segmented_blocks, size_t min_block_size) {
if (!tensorrt_nodes.empty()) {
if (tensorrt_nodes.size() < min_block_size) {
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
} else {
if (!pytorch_nodes.empty()) segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
pytorch_nodes.clear();
}
tensorrt_nodes.clear();
}
}

std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
std::vector<conversion::InputRange>& input_ranges,
const conversion::TorchFallback &fallback_info) {
auto min_block_size = fallback_info.min_block_size;
std::unordered_set<std::string> forced_fallback_operators(fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end());
std::vector<SegmentedBlock> segmented_blocks;

auto nodes = g->block()->nodes();

// segment the nodes
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;

for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant) continue;
std::string node_string(n->kind().toQualString());

auto block_target = conversion::OpSupported(n) ? SegmentedBlock::kTensorRT : SegmentedBlock::kTorch;

if (segmented_blocks.empty() || block_target != segmented_blocks.back().target()) {
SegmentedBlock cur_block(block_target);
cur_block.appendNode(n);
segmented_blocks.push_back(cur_block);
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {
tensorrt_nodes.push_back(n);
} else {
segmented_blocks.back().appendNode(n);
merge_nodes(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size);
pytorch_nodes.push_back(n);
}
}
merge_nodes(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size);
if (!pytorch_nodes.empty()) segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);

registerSegmentsInputsOutputs(segmented_blocks, g);

Expand Down
15 changes: 13 additions & 2 deletions core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,17 @@ struct SegmentedBlock {
kTensorRT,
};

SegmentedBlock() = default;

SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}

SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*> &nodes) :
target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
for (auto &node : nodes) {
appendNode(node);
}
}

SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}

enum SegmentedBlockTarget target() {
Expand All @@ -40,6 +49,7 @@ struct SegmentedBlock {

void registerOutput(torch::jit::Value* raw_input) {
outputs_.push_back(raw_input);

g_->registerOutput(old_to_new_[raw_input]);
}

Expand Down Expand Up @@ -108,8 +118,9 @@ struct SegmentedBlock {

};

std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, std::vector<conversion::InputRange>& input_ranges);

std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
std::vector<conversion::InputRange>& input_ranges,
const conversion::TorchFallback &fallback_info);
}
}
}

0 comments on commit 6d3064a

Please sign in to comment.