Skip to content

Commit

Permalink
refactor: Apply linting
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 21, 2021
1 parent c9aa99a commit f70bed6
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
if (isIValue()) {
LOG_DEBUG(ctx->logger, "Found IValue containing object of type " << *(ptr_.ivalue->type()));
}

TRTORCH_CHECK(
isITensor() || (isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isCustomClass())),
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());
Expand All @@ -100,8 +100,10 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
if (ptr_.ivalue->isTensor()) {
auto weights = converters::Weights();
auto tensor = ptr_.ivalue->toTensor();
if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) && !ctx->settings.truncate_long_and_double) {
TRTORCH_THROW_ERROR("Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled");
if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) &&
!ctx->settings.truncate_long_and_double) {
TRTORCH_THROW_ERROR(
"Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled");
} else if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, tensor.toType(at::kInt));
LOG_WARNING("Truncating weight (constant in the graph) from Int64 to Int32");
Expand All @@ -111,7 +113,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
} else {
weights = converters::Weights(ctx, tensor);
}

auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");
out = const_layer->getOutput(0);
Expand Down

0 comments on commit f70bed6

Please sign in to comment.