Skip to content

Commit

Permalink
feat: update truncate long/double python api
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <[email protected]>
  • Loading branch information
inocsin committed Mar 22, 2021
1 parent 740eb54 commit 69e49e8
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 4 deletions.
8 changes: 5 additions & 3 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
auto weights = converters::Weights();
if (isIValue()) {
auto tensor = ptr_.ivalue->toTensor();
if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) && !ctx->settings.truncate_long_and_double) {
TRTORCH_CHECK(0, "Unable to freeze tensor of type kLong/kDouble into constant layer, try to compile model with truncate_long_and_double ON");
} else if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, tensor.toType(at::kInt));
LOG_WARNING("Truncate kLong to kInt for IValue");
LOG_WARNING("Warning: Truncating weight (constant in the graph) from kLong to kInt to indicate that only constants are affected.");
} else if (tensor.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, tensor.toType(at::kFloat));
LOG_WARNING("Truncate kDouble to kFloat for IValue");
LOG_WARNING("Warning: Truncating weight (constant in the graph) from kDouble to kFloat to indicate that only constants are affected.");
} else {
weights = converters::Weights(ctx, tensor);
}
Expand Down
6 changes: 6 additions & 0 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
if "max_batch_size" in compile_spec:
assert type(compile_spec["max_batch_size"]) is int
info.max_batch_size = compile_spec["max_batch_size"]

if "truncate_long_and_double" in compile_spec:
assert type(compile_spec["truncate_long_and_double"]) is bool
info.truncate_long_and_double = compile_spec["truncate_long_and_double"]

return info

Expand Down Expand Up @@ -217,6 +221,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
"workspace_size": 0, # Maximum size of workspace given to TensorRT
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
"truncate_long_and_double": False, # Truncate long and double into int and float
})
}
Expand Down Expand Up @@ -257,6 +262,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
backend_spec.set_workspace_size(parsed_spec.workspace_size)
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
backend_spec.set_truncate_long_and_double(parsed_spec.truncate_long_and_double)
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())

return backend_spec
1 change: 1 addition & 0 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, truncate_long_and_double);
}

struct TRTTSRegistrations {
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
info.convert_info.engine_settings.device.dla_core = device.dla_core;
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;

info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");
Expand Down Expand Up @@ -143,6 +144,7 @@ std::string CompileSpec::stringify() {
ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl;
ss << " \"Workspace Size\": " << workspace_size << std::endl;
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
ss << "}";
return ss.str();
}
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ struct CompileSpec : torch::CustomClassHolder {
ADD_FIELD_GET_SET(num_min_timing_iters, int64_t);
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
ADD_FIELD_GET_SET(workspace_size, int64_t);
ADD_FIELD_GET_SET(truncate_long_and_double, bool);
ADD_FIELD_GET_SET(max_batch_size, int64_t);
ADD_FIELD_GET_SET(device, Device);
ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*);
Expand All @@ -126,6 +127,7 @@ struct CompileSpec : torch::CustomClassHolder {
bool refit = false;
bool debug = false;
bool strict_types = false;
bool truncate_long_and_double = false;
Device device;
EngineCapability capability = EngineCapability::kDEFAULT;
int64_t num_min_timing_iters = 2;
Expand Down
3 changes: 2 additions & 1 deletion py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("num_min_timing_iters", &CompileSpec::num_min_timing_iters)
.def_readwrite("num_avg_timing_iters", &CompileSpec::num_avg_timing_iters)
.def_readwrite("workspace_size", &CompileSpec::workspace_size)
.def_readwrite("max_batch_size", &CompileSpec::max_batch_size);
.def_readwrite("max_batch_size", &CompileSpec::max_batch_size)
.def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double);

py::class_<Device>(m, "Device")
.def(py::init<>())
Expand Down

0 comments on commit 69e49e8

Please sign in to comment.