diff --git a/core/partitioning/PartitionInfo.cpp b/core/partitioning/PartitionInfo.cpp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/partitioning/PartitionInfo.h b/core/partitioning/PartitionInfo.h index d1a8aca321..7747c00503 100644 --- a/core/partitioning/PartitionInfo.h +++ b/core/partitioning/PartitionInfo.h @@ -14,6 +14,8 @@ struct PartitionInfo { std::vector forced_fallback_operators; }; +std::ostream& operator<<(std::ostream& os, const PartitionInfo& s); + } // namespace partitioning } // namespace core } // namespace trtorch \ No newline at end of file diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index dde3c674ef..924a9d1ed7 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -204,6 +204,8 @@ std::vector Partition( std::shared_ptr g, std::vector& input_ranges, const PartitionInfo& partition_info) { + + LOG_DEBUG(partition_info); // segment lowering global graph into blocks std::vector segmented_blocks = segment_graph(g, partition_info); diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index 2686d3c650..a0b22f8367 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -108,10 +108,10 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.convert_info.engine_settings.device.gpu_id = device.gpu_id; info.convert_info.engine_settings.device.dla_core = device.dla_core; info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback; - info.convert_info.engine_settings.torch_fallback.enabled = torch_fallback.enabled; info.partition_info.enabled = torch_fallback.enabled; info.partition_info.min_block_size = torch_fallback.min_block_size; info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators; + info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double; info.convert_info.engine_settings.capability = toTRTEngineCapability(capability); TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater"); @@ -148,6 +148,15 @@ std::string CompileSpec::stringify() { ss << " \"Workspace Size\": " << workspace_size << std::endl; ss << " \"Max Batch Size\": " << max_batch_size << std::endl; ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl; + ss << " \"Torch Fallback: {" << std::endl; + ss << " \"enabled\": " << torch_fallback.enabled ? "True" : "False" << std::endl; + ss << " \"min_block_size\": " << torch_fallback.min_block_size << std::endl; + ss << " \"forced_fallback_operators\": [" << std::endl; + for (auto i : torch_fallback.forced_fallback_operators) { + ss << " " << i << ',' << std::endl; + } + ss << " ]" << std::endl; + ss << " }" << std::endl; ss << "}"; return ss.str(); }