Skip to content

Commit

Permalink
feat(//core/partitioing): Adding ostream for Partition Info
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 7, 2021
1 parent fb1a299 commit b3589c5
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 1 deletion.
Empty file.
2 changes: 2 additions & 0 deletions core/partitioning/PartitionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ struct PartitionInfo {
std::vector<std::string> forced_fallback_operators;
};

std::ostream& operator<<(std::ostream& os, const PartitionInfo& s);

} // namespace partitioning
} // namespace core
} // namespace trtorch
2 changes: 2 additions & 0 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ std::vector<SegmentedBlock> Partition(
std::shared_ptr<torch::jit::Graph> g,
std::vector<ir::InputRange>& input_ranges,
const PartitionInfo& partition_info) {

LOG_DEBUG(partition_info);
// segment lowering global graph into blocks
std::vector<SegmentedBlock> segmented_blocks = segment_graph(g, partition_info);

Expand Down
11 changes: 10 additions & 1 deletion py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
info.convert_info.engine_settings.device.dla_core = device.dla_core;
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
info.convert_info.engine_settings.torch_fallback.enabled = torch_fallback.enabled;
info.partition_info.enabled = torch_fallback.enabled;
info.partition_info.min_block_size = torch_fallback.min_block_size;
info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;

info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");
Expand Down Expand Up @@ -148,6 +148,15 @@ std::string CompileSpec::stringify() {
ss << " \"Workspace Size\": " << workspace_size << std::endl;
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
ss << " \"Torch Fallback: {" << std::endl;
ss << " \"enabled\": " << torch_fallback.enabled ? "True" : "False" << std::endl;
ss << " \"min_block_size\": " << torch_fallback.min_block_size << std::endl;
ss << " \"forced_fallback_operators\": [" << std::endl;
for (auto i : torch_fallback.forced_fallback_operators) {
ss << " " << i << ',' << std::endl;
}
ss << " ]" << std::endl;
ss << " }" << std::endl;
ss << "}";
return ss.str();
}
Expand Down

0 comments on commit b3589c5

Please sign in to comment.