Skip to content

Commit

Permalink
feat: support Int/Bool and other constants' inputs/outputs for Tensor…
Browse files Browse the repository at this point in the history
…RT segments

Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Mar 25, 2021
1 parent 6147d4f commit 54e407e
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
2 changes: 1 addition & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ auto aten_registrations TRTORCH_UNUSED =
if (args.at(n->input(0)).IValue()->isInt()) {
auto a = args.at(n->input(0)).unwrapToInt();
auto b = args.at(n->input(1)).unwrapToInt();
return std::floor(a / b);
return static_cast<int>(std::floor(a / b));
} else if (args.at(n->input(0)).IValue()->isDouble()) {
auto a = args.at(n->input(0)).unwrapToDouble();
auto b = args.at(n->input(1)).unwrapToDouble();
Expand Down
41 changes: 39 additions & 2 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "core/lowering/passes/passes.h"
#include "core/util/prelude.h"
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/ir/constants.h"

namespace trtorch {
namespace core {
Expand Down Expand Up @@ -67,6 +68,7 @@ void registerSegmentInOutIValues(
// create a module to run the graph
auto g = seg_block.g();
auto copy_g = g->copy();
// LOG_INFO(*copy_g << "(copy graph)\n");

// create tuple for multiple outputs
if (seg_block.raw_outputs().size() > 1) {
Expand Down Expand Up @@ -163,19 +165,53 @@ void registerSegmentsInputsOutputs(
input_values.insert(graph_output);
}

for (auto& mini_graph_input : input_values) {
for (auto& seg_block : segmented_blocks) {
// should be careful here because some in-place operations don't return any values
for (auto& seg_block : segmented_blocks) {
for (auto& mini_graph_input : input_values) {
if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) ==
seg_block.raw_inputs().end() &&
seg_block.contain_raw_input(mini_graph_input)) {
seg_block.registerOutput(mini_graph_input);
}
}
if (seg_block.raw_outputs().empty()) {
seg_block.registerOutput(seg_block.raw_inputs()[0]);
}
}

return;
}

void eraseNonTensorInputsOutputs(
SegmentedBlock& seg_block,
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
if (seg_block.target() == SegmentedBlock::kTorch)
return;
auto mini_graph = seg_block.g();

for (int i = seg_block.raw_inputs().size() - 1; i >= 0; --i) {
// erase this input and prepend a prim::Constant if it's not Tensor
if (!seg_block.raw_inputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) &&
!seg_block.raw_inputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) {
auto new_val = torch::jit::insertConstant(*mini_graph, ivalues_maps[seg_block.raw_inputs()[i]]);
seg_block.inputs()[i]->replaceAllUsesWith(new_val);
seg_block.eraseInput(i);
}
}

for (int i = seg_block.raw_outputs().size() - 1; i >= 0; --i) {
if (!seg_block.raw_outputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) &&
!seg_block.raw_outputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) {
seg_block.eraseOutput(i);
}
}

// not sure to delete this block or just fallback to pytorch
if (seg_block.raw_outputs().empty()) {
seg_block.update_target(SegmentedBlock::kTorch);
}
}

void construct_segments(
std::vector<torch::jit::Node*>& pytorch_nodes,
std::vector<torch::jit::Node*>& tensorrt_nodes,
Expand Down Expand Up @@ -240,6 +276,7 @@ std::vector<SegmentedBlock> segment_graph(
// register every segment's input shape, and it's running output Ivalues
for (auto& seg_block : segmented_blocks) {
registerSegmentInOutIValues(seg_block, ivalues_maps);
eraseNonTensorInputsOutputs(seg_block, ivalues_maps);
}

return segmented_blocks;
Expand Down
14 changes: 14 additions & 0 deletions core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,20 @@ struct SegmentedBlock {
return g_->inputs();
}

void eraseInput(size_t i) {
inputs_.erase(inputs_.begin() + i);
g_->eraseInput(i);
}

c10::ArrayRef<torch::jit::Value*> outputs() {
return g_->outputs();
}

void eraseOutput(size_t i) {
outputs_.erase(outputs_.begin() + i);
g_->eraseOutput(i);
}

const std::vector<torch::jit::Value*>& raw_inputs() const {
return inputs_;
}
Expand Down Expand Up @@ -102,6 +112,10 @@ struct SegmentedBlock {
g_ = new_g;
}

void update_target(SegmentedBlockTarget new_target) {
target_ = new_target;
}

private:
SegmentedBlockTarget target_;
std::vector<nvinfer1::Dims> in_shape_;
Expand Down

0 comments on commit 54e407e

Please sign in to comment.