Skip to content

Commit

Permalink
feat!: Changing the default behavior for selecting the input type
Browse files Browse the repository at this point in the history
BREAKING CHANGE: This commit changes the default behavior of
the compiler where if the user does not specify an input data
type explicity instead of using the enabled precision, now
the compiler will inspect the model provided to infer the
data type for the input that will not cause an error if
the model was run in torch. In practice this means

- If the weights are in FP32 for the first tensor calculation
  then default input type is FP32
- If the weights are in FP16 for the first tensor calculation
  then default input type is FP16
- etc.

If the data type cannot be determined the compiler will
default to FP32.

This calculation is done per input tensor so if one input
is inferred to use FP32 and another INT32 then the expected
types will be the same (FP32, INT32)

As was the same before if the user defines the data type
explicitly or provides an example tensor the data type
specified there will be respected

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Oct 19, 2021
1 parent 19ecc64 commit a234335
Show file tree
Hide file tree
Showing 14 changed files with 310 additions and 71 deletions.
56 changes: 32 additions & 24 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,22 +287,45 @@ GraphAndMapping ConstructFallbackGraph(
return {new_g, old_to_new_g};
}


void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, const util::InputTypeMap& first_use_type_map) {
// Associate input specs with inputs
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));

for (auto& in : g->inputs()) {
auto est_type_opt = first_use_type_map.find(in)->second;
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
if (est_type_opt && !spec.dtype_is_user_defined) {
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated type
LOG_INFO("Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
<< in->debugName() << " has type " << est_type_opt.value() << ". If this is incorrect explicitly set dtype for input and file a bug");
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
// If we cannot calculate the type and the user did not define the type, then default to FP32
LOG_WARNING(
"Cannot deterime input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec.dtype = nvinfer1::DataType::kFLOAT;
} else {
// The user defined the type so no changes are necessary
}
}
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);

auto convert_cfg = std::move(cfg.convert_info);
auto g = graph_and_parameters.first;

auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);
// Infer the type of an input from the weights of the calculation
auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block());

LOG_INFO(*g << "(CompileGraph)\n");
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

// Move the user defined inputs to the convert_cfg since some might be static;
convert_cfg.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);

auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, static_params);
return std::move(engine);
}

Expand Down Expand Up @@ -331,27 +354,12 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);

auto g = graph_and_parameters.first;
LOG_INFO("Lowered Graph: " << *g);
auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);

cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));

// If the user did not explicitly set the input type, then use the first
// tensor calculation to infer type.
// Infer the type of an input from the weights of the calculation
auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block());
for (auto& in : g->inputs()) {
auto est_type_opt = first_use_types[in];
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
if (est_type_opt && !spec.dtype_is_user_defined) {
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
LOG_WARNING(
"Cannot deterime input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec.dtype = nvinfer1::DataType::kFLOAT;
}
}

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

if (cfg.partition_info.enabled) {
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
// Is this necessary?
// lowering::LowerBlock(g->block());

LOG_INFO("Lowered Graph: " << *(graph_and_ivalues.first));
return graph_and_ivalues;
}

Expand Down
5 changes: 2 additions & 3 deletions core/util/jit_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block*
return dtype;
}

std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> get_block_first_calc_dtypes_opt(
torch::jit::Block* b) {
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> types;
InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b) {
InputTypeMap types;

for (auto i : b->inputs()) {
if (i->type() == c10::TensorType::get()) {
Expand Down
5 changes: 3 additions & 2 deletions core/util/jit_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace trtorch {
namespace core {
namespace util {

using InputTypeMap = std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>>;

inline std::string node_info(const torch::jit::Node* n) {
std::stringstream ss;
ss << *n;
Expand Down Expand Up @@ -61,8 +63,7 @@ inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) {
}

c10::optional<at::ScalarType> get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in);
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> get_block_first_calc_dtypes_opt(
torch::jit::Block* b);
InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b);

} // namespace util
} // namespace core
Expand Down
2 changes: 1 addition & 1 deletion core/util/logging/TRTorchLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ namespace {

TRTorchLogger& get_global_logger() {
#ifndef NDEBUG
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kDEBUG, true);
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kGRAPH, true);
#else
static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false);
#endif
Expand Down
15 changes: 5 additions & 10 deletions cpp/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ struct TRTORCH_API CompileSpec {
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
*
* @param shape Input tensor shape
* @param dtype Expected data type for the input (Defaults to Float32)
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
Expand All @@ -398,7 +398,7 @@ struct TRTORCH_API CompileSpec {
* tensor format
*
* @param shape Input tensor shape
* @param dtype Expected data type for the input (Defaults to Float32)
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
Expand All @@ -421,7 +421,7 @@ struct TRTORCH_API CompileSpec {
* allow the user to configure expected input shape tensor format
*
* @param shape Input tensor shape
* @param dtype Expected data type for the input (Defaults to Float32)
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
Expand Down Expand Up @@ -451,7 +451,7 @@ struct TRTORCH_API CompileSpec {
* @param min_shape Minimum shape for input tensor
* @param opt_shape Target optimization shape for input tensor
* @param max_shape Maximum acceptible shape for input tensor
* @param dtype Expected data type for the input (Defaults to Float32)
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
Input(
Expand Down Expand Up @@ -486,7 +486,7 @@ struct TRTORCH_API CompileSpec {
* @param min_shape Minimum shape for input tensor
* @param opt_shape Target optimization shape for input tensor
* @param max_shape Maximum acceptible shape for input tensor
* @param dtype Expected data type for the input (Defaults to Float32)
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
Input(
Expand All @@ -506,14 +506,9 @@ struct TRTORCH_API CompileSpec {
*/
Input(at::Tensor tensor);

bool get_explicit_set_dtype() {
return explicit_set_dtype;
}

private:
friend std::ostream& operator<<(std::ostream& os, const Input& input);
bool input_is_dynamic;
bool explicit_set_dtype;
};

/**
Expand Down
20 changes: 5 additions & 15 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ std::ostream& operator<<(std::ostream& os, const CompileSpec::Input& input) {
}

nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) {
TRTORCH_CHECK(!(value == CompileSpec::DataType::kUnknown), "Data type is unknown");
switch (value) {
case CompileSpec::DataType::kChar:
return nvinfer1::DataType::kINT8;
Expand Down Expand Up @@ -162,8 +161,7 @@ CompileSpec::Input::Input(std::vector<int64_t> shape, TensorFormat format) {
this->min_shape = shape;
this->max_shape = shape;
this->shape = shape;
this->dtype = dtype;
this->explicit_set_dtype = false;
this->dtype = CompileSpec::DataType::kUnknown;
this->format = format;
this->input_is_dynamic = false;
}
Expand All @@ -174,7 +172,6 @@ CompileSpec::Input::Input(std::vector<int64_t> shape, DataType dtype, TensorForm
this->max_shape = shape;
this->shape = shape;
this->dtype = dtype;
this->explicit_set_dtype = true;
this->format = format;
this->input_is_dynamic = false;
}
Expand All @@ -184,8 +181,7 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, TensorFormat format) {
this->min_shape = core::util::toVec(shape);
this->max_shape = core::util::toVec(shape);
this->shape = core::util::toVec(shape);
this->dtype = DataType::kFloat;
this->explicit_set_dtype = false;
this->dtype = CompileSpec::DataType::kUnknown;
this->format = format;
this->input_is_dynamic = false;
}
Expand All @@ -196,7 +192,6 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat f
this->max_shape = core::util::toVec(shape);
this->shape = core::util::toVec(shape);
this->dtype = dtype;
this->explicit_set_dtype = true;
this->format = format;
this->input_is_dynamic = false;
}
Expand All @@ -210,8 +205,7 @@ CompileSpec::Input::Input(
this->min_shape = min_shape;
this->max_shape = max_shape;
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
this->dtype = dtype;
this->explicit_set_dtype = false;
this->dtype = CompileSpec::DataType::kUnknown;
this->format = format;
this->input_is_dynamic = true;
}
Expand All @@ -227,7 +221,6 @@ CompileSpec::Input::Input(
this->max_shape = max_shape;
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
this->dtype = dtype;
this->explicit_set_dtype = true;
this->format = format;
this->input_is_dynamic = true;
}
Expand All @@ -241,8 +234,7 @@ CompileSpec::Input::Input(
this->min_shape = core::util::toVec(min_shape);
this->max_shape = core::util::toVec(max_shape);
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
this->dtype = dtype;
this->explicit_set_dtype = false;
this->dtype = CompileSpec::DataType::kUnknown;
this->format = format;
this->input_is_dynamic = true;
}
Expand All @@ -258,7 +250,6 @@ CompileSpec::Input::Input(
this->max_shape = core::util::toVec(max_shape);
this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape);
this->dtype = dtype;
this->explicit_set_dtype = true;
this->format = format;
this->input_is_dynamic = true;
}
Expand All @@ -269,7 +260,6 @@ CompileSpec::Input::Input(at::Tensor tensor) {
this->max_shape = tensor.sizes().vec();
this->shape = tensor.sizes().vec();
this->dtype = tensor.scalar_type();
this->explicit_set_dtype = true;
TRTORCH_ASSERT(
tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous),
"Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last");
Expand All @@ -292,7 +282,7 @@ core::ir::Input to_internal_input(CompileSpec::Input& i) {
i.max_shape,
toTRTDataType(i.dtype),
toTRTTensorFormat(i.format),
i.get_explicit_set_dtype());
!(i.dtype == CompileSpec::DataType::kUnknown));
}

std::vector<core::ir::Input> to_vec_internal_inputs(std::vector<CompileSpec::Input>& external) {
Expand Down
44 changes: 36 additions & 8 deletions py/trtorch/Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class _ShapeMode(Enum):

shape_mode = None #: (trtorch.Input._ShapeMode): Is input statically or dynamically shaped
shape = None #: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
dtype = _types.dtype.float32 #: The expected data type of the input tensor (default: trtorch.dtype.float32)
dtype = _types.dtype.unknown #: The expected data type of the input tensor (default: trtorch.dtype.float32)
_explicit_set_dtype = False
format = _types.TensorFormat.contiguous #: The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)

Expand Down Expand Up @@ -133,16 +133,44 @@ def __str__(self) -> str:
def _to_internal(self) -> trtorch._C.Input:
internal_in = trtorch._C.Input()
if self.shape_mode == Input._ShapeMode.DYNAMIC:
internal_in.min = self.shape["min_shape"]
internal_in.opt = self.shape["opt_shape"]
internal_in.max = self.shape["max_shape"]
if not Input._supported_input_size_type(self.shape["min_shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(self.shape["min_shape"])) + " for min_shape")
else:
internal_in.min = self.shape["min_shape"]

if not Input._supported_input_size_type(self.shape["opt_shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(self.shape["opt_shape"])) + " for opt_shape")
else:
internal_in.min = self.shape["op_shape"]

if not Input._supported_input_size_type(self.shape["max_shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(self.shape["max_shape"])) + " for max_shape")
else:
internal_in.min = self.shape["opt_shape"]
internal_in.input_is_dynamic = True
else:
internal_in.opt = self.shape
if not Input._supported_input_size_type(self.shape):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(self.shape)) + " for shape")
else:
internal_in.opt = self.shape
internal_in.input_is_dynamic = False
internal_in.dtype = self.dtype

if self.dtype != _types.dtype.unknown:
self._explicit_set_dtype = True
else:
self._explicit_set_dtype = False

internal_in.dtype = Input._parse_dtype(self.dtype)
internal_in._explicit_set_dtype = self._explicit_set_dtype
internal_in.format = self.format
internal_in.format = Input._parse_format(self.format)
return internal_in

@staticmethod
Expand Down Expand Up @@ -172,7 +200,7 @@ def _parse_dtype(dtype: Any) -> _types.dtype:
"Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: "
+ str(dtype))

elif isinstance(dtype, _types.DataTypes):
elif isinstance(dtype, _types.dtype):
return dtype

else:
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ nvinfer1::DataType toTRTDataType(DataType value) {
return nvinfer1::DataType::kBOOL;
case DataType::kFloat:
return nvinfer1::DataType::kFLOAT;
case DataType::kUnknown:
return nvinfer1::DataType::kFLOAT;
default:
TRTORCH_THROW_ERROR("Unknown data type: " << to_str(value));
}
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace pyapi {
return static_cast<int64_t>(field_name); \
}

enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool };
enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
std::string to_str(DataType value);
nvinfer1::DataType toTRTDataType(DataType value);

Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ PYBIND11_MODULE(_C, m) {
.value("int8", DataType::kChar, "8 bit integer number")
.value("int32", DataType::kInt32, "32 bit integer number")
.value("bool", DataType::kChar, "Boolean value")
.value("unknown", DataType::kUnknown, "Unknown data type")
.export_values();

py::enum_<DeviceType>(m, "DeviceType", "Enum to specify device kinds to build TensorRT engines for")
Expand Down
Loading

0 comments on commit a234335

Please sign in to comment.