Skip to content

Commit

Permalink
feat!: Turning on partial compilation by default
Browse files Browse the repository at this point in the history
BREAKING CHANGE: This commit turns on partial compilation
by default. Unsupported modules will attempt to be
run partially in PyTorch and partially in TensorRT

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Oct 19, 2021
1 parent a234335 commit 52e2f05
Show file tree
Hide file tree
Showing 17 changed files with 244 additions and 82 deletions.
45 changes: 40 additions & 5 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ GraphAndMapping ConstructFallbackGraph(
}
// update the input ranges for each segments
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);

auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
Expand Down Expand Up @@ -288,7 +289,7 @@ GraphAndMapping ConstructFallbackGraph(
}


void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, const util::InputTypeMap& first_use_type_map) {
void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, ir::TypeMap& first_use_type_map) {
// Associate input specs with inputs
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));

Expand All @@ -303,9 +304,31 @@ void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::G
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
// If we cannot calculate the type and the user did not define the type, then default to FP32
LOG_WARNING(
"Cannot deterime input type from calcuations in graph for input "
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec.dtype = nvinfer1::DataType::kFLOAT;
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
if (!est_type_opt) {
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
} else {
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
std::stringstream ss;
ss <<"For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.inputs.find(in)->second.dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt.value() << std::endl;
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
}
}
} else {
// The user defined the type so no changes are necessary
}
Expand All @@ -317,10 +340,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);

auto g = graph_and_parameters.first;
TRTORCH_CHECK(conversion::VerifyConverterSupportForBlock(g->block()), "Not all operations in graph are supported by the compiler");
auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);
// Infer the type of an input from the weights of the calculation
auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block());
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

Expand Down Expand Up @@ -357,11 +381,21 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);
// Infer the type of an input from the weights of the calculation
auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block());
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

if (cfg.partition_info.enabled) {
if (cfg.partition_info.enabled
&& (cfg.lower_info.forced_fallback_modules.size() == 0
&& cfg.partition_info.forced_fallback_operators.size() == 0
&& conversion::VerifyConverterSupportForBlock(g->block(), true))) {
LOG_INFO("Skipping partitioning since model is fully supported");
}

if (cfg.partition_info.enabled
&& !(cfg.lower_info.forced_fallback_modules.size() == 0
&& cfg.partition_info.forced_fallback_operators.size() == 0
&& conversion::VerifyConverterSupportForBlock(g->block(), false))) {
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
new_g = graph_and_mapping.first;
Expand All @@ -374,6 +408,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
return mod;
}
} else {
TRTORCH_CHECK(conversion::VerifyConverterSupportForBlock(g->block()), "Not all operations in graph are supported by the compiler");
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
auto device_spec = cfg.convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
Expand Down
18 changes: 12 additions & 6 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
return convertable_ops;
}

bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) {
auto unsupported_ops = GetUnsupportedOpsInBlock(b);

if (unsupported_ops.size() != 0) {
Expand All @@ -506,16 +506,20 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
unsupported_msg << std::endl << "In Module:" << std::endl;

LOG_ERROR(unsupported_msg.str());
if (suppress_errors) {
LOG_ERROR(unsupported_msg.str());
}

for (const auto n : b->nodes()) {
auto schema = n->maybeSchema();
if (schema) {
for (const auto& x : unsupported_ops) {
if (x.first == schema->operator_name()) {
LOG_ERROR(
"Unsupported operator: " << *schema << std::endl
<< trtorch::core::util::GetPyTorchSourceCode(n) << std::endl);
if (suppress_errors) {
LOG_ERROR(
"Unsupported operator: " << *schema << std::endl
<< trtorch::core::util::GetPyTorchSourceCode(n) << std::endl);
}
}
}
}
Expand All @@ -531,7 +535,9 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
unsupported_msg
<< "This may be because there are no operators that can be added to the TensorRT graph or all operators have a resolved compile time value."
<< std::endl;
LOG_ERROR(unsupported_msg.str());
if (suppress_errors) {
LOG_ERROR(unsupported_msg.str());
}
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ std::string ConvertBlockToEngine(

bool OpSupported(const torch::jit::Node* n);

bool VerifyConverterSupportForBlock(const torch::jit::Block* b);
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors=false);

c10::optional<torch::jit::IValue> EvaluateNode(
ConversionCtx* ctx,
Expand Down
103 changes: 103 additions & 0 deletions core/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,109 @@ std::vector<const torch::jit::Value*> get_tensor_inputs(
return input_tensors;
}

c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in) {
TRTORCH_ASSERT(in->owningGraph() == b->owningGraph(), "Provided input is not part of the provided graph");
c10::optional<at::ScalarType> dtype = {};

auto b_ins = b->inputs();
std::unordered_set<torch::jit::Value*> b_in_set(b_ins.begin(), b_ins.end());

TRTORCH_ASSERT(
in->type() == c10::TensorType::get(), "Input is not a tensor, cannot check for dtype based on calculation");

auto consumers = in->uses();
auto search_list = std::vector<torch::jit::Use>(consumers.begin(), consumers.end());

for (auto& u : search_list) {
auto n = u.user;
LOG_GRAPH("Node we are looking at: " << util::node_info(n));
auto ins = n->inputs();
auto outs = n->outputs();

bool outputs_tensor = false;
for (auto o : outs) {
if (o->type() == c10::TensorType::get()) {
outputs_tensor = true;
break;
}
}

if (!outputs_tensor) {
LOG_GRAPH("Node " << util::node_info(n) << " does not output a tensor, skipping");
continue;
}

LOG_GRAPH("Node " << util::node_info(n) << " outputs a tensor");

// If all input tensors are block inputs then this node will not give us useful type info so move to the next one
bool all_n_ins_are_b_ins = true;
for (auto in : ins) {
if (b_in_set.find(in) == b_in_set.end()) {
all_n_ins_are_b_ins = false;
break;
}
}

if (all_n_ins_are_b_ins) {
LOG_GRAPH(
"All inputs to Node " << util::node_info(n) << " are graph inputs, cannot be used to determine input type");
for (auto o : outs) {
if (o->type() == c10::TensorType::get()) {
auto o_uses = o->uses();
search_list.insert(search_list.end(), o_uses.begin(), o_uses.end());
}
}
continue;
}

// If node outputs a Tensor it might be a result of tensor calcuation so check to see
// if any inputs to the calculation can give us hints
c10::optional<torch::jit::Node*> const_tensor_n = {};

// Backtrace to constants which will immediately give us the Tensor type if possible
for (auto in : ins) {
LOG_GRAPH("Input to node: " << util::node_info(in->node()));
if (in->type()->isSubtypeOf(torch::jit::TensorType::get())) {
LOG_GRAPH("Input outputs a Tensor");
if (in->node()->kind() == torch::jit::prim::Constant) {
LOG_GRAPH("Input is a constant");
auto const_val = in->node()->t(c10::attr::value);
LOG_GRAPH("Found that constant tensor has type: " << const_val.scalar_type());
dtype = {const_val.scalar_type()};
goto exit_first_calc_dtype;
}
}
}

// Add all tensor outputs to search list if we still dont know
for (auto o : outs) {
if (o->type() == c10::TensorType::get()) {
auto o_uses = o->uses();
search_list.insert(search_list.end(), o_uses.begin(), o_uses.end());
}
}
}
exit_first_calc_dtype:
if (dtype) {
LOG_GRAPH("Estimated input type is " << dtype.value());
} else {
LOG_GRAPH("Cannot determine input types from graph");
}
return dtype;
}

TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b) {
TypeMap types;

for (auto i : b->inputs()) {
if (i->type() == c10::TensorType::get()) {
torch::jit::Value* in = i;
types.insert({in, get_value_first_calc_dtype_opt(b, i)});
}
}
return types;
}

} // namespace ir
} // namespace core
} // namespace trtorch
5 changes: 5 additions & 0 deletions core/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ std::vector<const torch::jit::Value*> get_tensor_inputs(
std::shared_ptr<torch::jit::Graph>& g,
StaticParams& static_params);

using TypeMap = std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>;

c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in);
ir::TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b);

} // namespace ir
} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion core/lowering/LowerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace lowering {

std::ostream& operator<<(std::ostream& os, const LowerInfo& l) {
os << "Settings requested for Lowering:" << std::endl;
os << " Forced Fallback Modules: [" << std::endl;
os << " torch_executed_modules: [" << std::endl;
for (auto i : l.forced_fallback_modules) {
os << " " << i << std::endl;
}
Expand Down
2 changes: 1 addition & 1 deletion core/partitioning/PartitionInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ std::ostream& operator<<(std::ostream& os, const PartitionInfo& s) {
if (s.enabled) {
os << "True";
os << "\n \"min_block_size\": " << s.min_block_size \
<< "\n \"forced_fallback_operators\": [";
<< "\n \"torch_executed_operators\": [";
for (auto i : s.forced_fallback_operators) {
os <<"\n " << i << ',';
}
Expand Down
3 changes: 0 additions & 3 deletions core/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ cc_library(
hdrs = [
"jit_util.h",
],
srcs = [
"jit_util.cpp"
],
deps = [
":macros"
] + select({
Expand Down
4 changes: 0 additions & 4 deletions core/util/jit_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace trtorch {
namespace core {
namespace util {

using InputTypeMap = std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>;

inline std::string node_info(const torch::jit::Node* n) {
std::stringstream ss;
Expand Down Expand Up @@ -62,9 +61,6 @@ inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) {
return source_code;
}

c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in);
InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b);

} // namespace util
} // namespace core
} // namespace trtorch
4 changes: 2 additions & 2 deletions core/util/logging/TRTorchLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ namespace {

TRTorchLogger& get_global_logger() {
#ifndef NDEBUG
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kGRAPH, true);
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kDEBUG, true);
#else
static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false);
static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kWARNING, false);
#endif
return global_logger;
}
Expand Down
30 changes: 17 additions & 13 deletions cpp/bin/trtorchc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ OPTIONS:
--allow-gpu-fallback (Only used when targeting DLA
(device-type)) Lets engine run layers on
GPU if they are not supported on DLA
--allow-torch-fallback Enable layers to run in torch if they
are not supported in TensorRT
--require-full-compilation Require that the model should be fully
compiled to TensorRT or throw an error
--disable-tf32 Prevent Float32 layers from using the
TF32 data format
--sparse-weights Enable sparsity for weights of conv and
Expand All @@ -63,18 +63,22 @@ OPTIONS:
--calibration-cache-file=[file_path]
Path to calibration cache file to use
for post training quantization
--ffo=[forced_fallback_ops...],
--forced-fallback-op=[forced_fallback_ops...]
--teo=[torch-executed-ops...],
--torch-executed-ops=[torch-executed-ops...]
(Repeatable) Operator in the graph that
should be forced to fallback to Pytorch
for execution (allow torch fallback must
be set)
--ffm=[forced_fallback_mods...],
--forced-fallback-mod=[forced_fallback_mods...]
(Repeatable) Module that should be
forced to fallback to Pytorch for
execution (allow torch fallback must be
set)
should always be run in PyTorch for
execution (partial compilation must be
enabled)
--tem=[torch-executed-mods...],
--torch-executed-mods=[torch-executed-mods...]
(Repeatable) Module that should always
be run in Pytorch for execution (partial
compilation must be enabled)
--mbs=[torch-executed-mods...],
--min-block-size=[torch-executed-mods...]
Minimum number of contiguous TensorRT
supported ops to compile a subgraph to
TensorRT
--embed-engine Whether to treat input file as a
serialized TensorRT engine and embed it
into a TorchScript module (device spec
Expand Down
Loading

0 comments on commit 52e2f05

Please sign in to comment.