Skip to content

Commit

Permalink
feat: support truncate_long_and_double in fallback subgraph input type
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <[email protected]>
  • Loading branch information
inocsin committed Oct 20, 2021
1 parent 4778b2b commit 0bc3c05
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 8 deletions.
1 change: 1 addition & 0 deletions core/partitioning/PartitionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct PartitionInfo {
bool enabled = false;
uint64_t min_block_size = 1;
std::vector<std::string> forced_fallback_operators;
bool truncate_long_and_double;
};

std::ostream& operator<<(std::ostream& os, const PartitionInfo& s);
Expand Down
2 changes: 1 addition & 1 deletion core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ std::vector<SegmentedBlock> Partition(
registerSegmentsOutputs(segmented_blocks, block);

// run shape analysis on each segmented block
runShapeAnalysis(segmented_blocks, input_ivalues_map);
runShapeAnalysis(segmented_blocks, input_ivalues_map, partition_info);

return segmented_blocks;
}
Expand Down
20 changes: 16 additions & 4 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ std::unordered_map<torch::jit::Value*, torch::jit::IValue> generateRandomInputs(

void getSegmentsOutputByRunning(
SegmentedBlock& seg_block,
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
const PartitionInfo& partition_info) {
// create a module to run the graph
auto g = seg_block.g();
auto copy_g = g->copy();
Expand Down Expand Up @@ -99,10 +100,20 @@ void getSegmentsOutputByRunning(
for (auto& i : seg_block.raw_inputs()) {
if (ivalues_maps[i].isTensor()) {
// set the input_shape and data_type
at::ScalarType t = c10::optTypeMetaToScalarType(ivalues_maps[i].toTensor().dtype()).value();
if (!partition_info.truncate_long_and_double &&
(t == at::kLong || t == at::kDouble)) {
TRTORCH_THROW_ERROR(
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
} else if(partition_info.truncate_long_and_double && t == at::kLong) {
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kInt);
} else if(partition_info.truncate_long_and_double && t == at::kDouble) {
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kFloat);
}
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(ivalues_maps[i].toTensor().dtype());
nvinfer1::DataType nv_dtype;
if (dtype == c10::nullopt) {
nv_dtype = nvinfer1::DataType::kFLOAT;
TRTORCH_THROW_ERROR("Unsupported input data type " << ivalues_maps[i].toTensor().dtype());
} else {
nv_dtype = dtype.value();
}
Expand All @@ -116,11 +127,12 @@ void getSegmentsOutputByRunning(

void runShapeAnalysis(
std::vector<SegmentedBlock>& segmented_blocks,
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
const PartitionInfo& partition_info) {
// register every segment's input shape, and it's running output IValues
for (auto& seg_block : segmented_blocks) {
torch::jit::ConstantPooling(seg_block.g());
getSegmentsOutputByRunning(seg_block, ivalues_maps);
getSegmentsOutputByRunning(seg_block, ivalues_maps, partition_info);
}
return;
}
Expand Down
3 changes: 2 additions & 1 deletion core/partitioning/shape_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ std::unordered_map<torch::jit::Value*, torch::jit::IValue> generateRandomInputs(

void runShapeAnalysis(
std::vector<SegmentedBlock>& segmented_blocks,
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps);
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
const PartitionInfo& partition_info);

} // namespace partitioning
} // namespace core
Expand Down
3 changes: 1 addition & 2 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kHalf, nvinfer1::DataType::kHALF},
{at::kInt, nvinfer1::DataType::kINT32},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL},
{at::kLong, nvinfer1::DataType::kINT32},
{at::kBool, nvinfer1::DataType::kBOOL}
};
return at_trt_type_map;
}
Expand Down
1 change: 1 addition & 0 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.partition_info.enabled = external.torch_fallback.enabled;
internal.partition_info.min_block_size = external.torch_fallback.min_block_size;
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops;
internal.partition_info.truncate_long_and_double = external.truncate_long_and_double;
internal.lower_info.forced_fallback_modules = external.torch_fallback.forced_fallback_modules;

switch (external.device.device_type) {
Expand Down

0 comments on commit 0bc3c05

Please sign in to comment.