Skip to content

Commit

Permalink
feat: support truncate long/double to int/float with option
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <[email protected]>
  • Loading branch information
inocsin committed Mar 19, 2021
1 parent 5b6bd4c commit 740eb54
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
1 change: 1 addition & 0 deletions core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct BuilderSettings {
bool refit = false;
bool debug = false;
bool strict_types = false;
bool truncate_long_and_double = false;
Device device;
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
nvinfer1::IInt8Calibrator* calibrator = nullptr;
Expand Down
14 changes: 11 additions & 3 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,18 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());

nvinfer1::ITensor* out;

auto weights = converters::Weights();
if (isIValue()) {
auto weights = converters::Weights(ctx, ptr_.ivalue->toTensor());

auto tensor = ptr_.ivalue->toTensor();
if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, tensor.toType(at::kInt));
LOG_WARNING("Truncate kLong to kInt for IValue");
} else if (tensor.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, tensor.toType(at::kFloat));
LOG_WARNING("Truncate kDouble to kFloat for IValue");
} else {
weights = converters::Weights(ctx, tensor);
}
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");

Expand Down
5 changes: 5 additions & 0 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ struct TRTORCH_API CompileSpec {
*/
bool debug = false;

/**
* Truncate long/double type to int/float type
*/
bool truncate_long_and_double = false;

/**
* Restrict operating type to only set default operation precision
* (op_precision)
Expand Down
1 change: 1 addition & 0 deletions cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
internal.convert_info.engine_settings.refit = external.refit;
internal.convert_info.engine_settings.debug = external.debug;
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
internal.convert_info.engine_settings.strict_types = external.strict_types;
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;
Expand Down

0 comments on commit 740eb54

Please sign in to comment.