Skip to content

Commit

Permalink
Merge pull request #329 from inocsin/clamp_max_min
Browse files Browse the repository at this point in the history
add clamp_min/clamp_max converter
  • Loading branch information
narendasan authored Mar 4, 2021
2 parents 62b077e + 684a318 commit 0a278f2
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 26 deletions.
79 changes: 57 additions & 22 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ nvinfer1::ILayer* add_elementwise(
return ele;
}

nvinfer1::ITensor* clamp_util(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* self,
float limit,
nvinfer1::ElementWiseOperation op_type,
std::string str) {
nvinfer1::ITensor* clamp_layer_out = self;
auto limitTensor = tensor_to_const(ctx, torch::tensor({limit}));
auto limit_layer = add_elementwise(ctx, op_type, clamp_layer_out, limitTensor, util::node_info(n) + str);
TRTORCH_CHECK(limit_layer, "Unable to create elementwise " << str << " layer for node: " << *n);
clamp_layer_out = limit_layer->getOutput(0);
return clamp_layer_out;
}

auto element_wise_registrations TRTORCH_UNUSED =
RegisterNodeConversionPatterns()
.pattern({"aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
Expand Down Expand Up @@ -145,38 +160,58 @@ auto element_wise_registrations TRTORCH_UNUSED =
return true;
}})
.pattern({"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Compute min(max(min_threshold, input), max_threshold)
auto self = args[0].ITensorOrFreeze(ctx);
auto clamp_layer_out = self;

if (args[1].isIValue() && args[1].IValue()->isScalar() && args[2].isIValue() &&
args[2].IValue()->isScalar()) {
auto alpha = args[1].unwrapToScalar().to<float>();
auto beta = args[2].unwrapToScalar().to<float>();
auto clip_layer = ctx->net->addActivation(*self, nvinfer1::ActivationType::kCLIP);
TRTORCH_CHECK(clip_layer, "Unable to create clip layer for node: " << *n);
clip_layer->setAlpha(alpha);
clip_layer->setBeta(beta);
clamp_layer_out = clip_layer->getOutput(0);
} else if (args[1].isIValue() && args[1].IValue()->isScalar()) {
auto limit = args[1].unwrapToScalar().to<float>();
clamp_layer_out = clamp_util(ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMAX, "_max");
} else if (args[2].isIValue() && args[2].IValue()->isScalar()) {
auto limit = args[2].unwrapToScalar().to<float>();
clamp_layer_out = clamp_util(ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMIN, "_min");
}

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out);
LOG_DEBUG("Clamp layer output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::clamp_min(Tensor self, Scalar min) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Compute min(max(min_threshold, input), max_threshold)
auto self = args[0].ITensorOrFreeze(ctx);
auto clamp_layer_out = self;
if (args[1].isIValue() && args[1].IValue()->isScalar()) {
auto minScalar = args[1].unwrapToScalar().to<float>();
auto minTensor = tensor_to_const(ctx, torch::tensor({minScalar}));
auto max_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kMAX,
clamp_layer_out,
minTensor,
util::node_info(n) + std::string("_max"));
TRTORCH_CHECK(max_layer, "Unable to create elementwise max layer for node: " << *n);
clamp_layer_out = max_layer->getOutput(0);
auto limit = args[1].unwrapToScalar().to<float>();
clamp_layer_out = clamp_util(ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMAX, "_max");
}

if (args[2].isIValue() && args[2].IValue()->isScalar()) {
auto maxScalar = args[2].unwrapToScalar().to<float>();
auto maxTensor = tensor_to_const(ctx, torch::tensor({maxScalar}));
auto min_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kMIN,
clamp_layer_out,
maxTensor,
util::node_info(n) + std::string("_min"));
TRTORCH_CHECK(min_layer, "Unable to create elementwise min layer for node: " << *n);
clamp_layer_out = min_layer->getOutput(0);
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out);
LOG_DEBUG("clamp_min layer output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::clamp_max(Tensor self, Scalar max) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Compute min(max(min_threshold, input), max_threshold)
auto self = args[0].ITensorOrFreeze(ctx);
auto clamp_layer_out = self;
if (args[1].isIValue() && args[1].IValue()->isScalar()) {
auto limit = args[1].unwrapToScalar().to<float>();
clamp_layer_out = clamp_util(ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMIN, "_min");
}

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out);
LOG_DEBUG("Clamp layer output tensor shape: " << out->getDimensions());
LOG_DEBUG("clamp_max layer output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
Expand Down
26 changes: 22 additions & 4 deletions tests/core/conversion/converters/test_element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
TEST(Converters, ATenClampMinConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=-2]()
%2 : float = prim::Constant[value=1.5]()
%3 : None = prim::Constant()
%4 : Tensor = aten::clamp(%x.1, %2, %3)
return (%4))IR";
Expand All @@ -260,7 +260,7 @@ TEST(Converters, ATenClampMinConvertsCorrectly) {
TEST(Converters, ATenClampMaxConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%2 : float = prim::Constant[value=3.5]()
%3 : None = prim::Constant()
%4 : Tensor = aten::clamp(%x.1, %3, %2)
return (%4))IR";
Expand All @@ -270,13 +270,31 @@ TEST(Converters, ATenClampMaxConvertsCorrectly) {
TEST(Converters, ATenClampMinMaxConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=-2]()
%2 : float = prim::Constant[value=3.5]()
%3 : float = prim::Constant[value=1.5]()
%4 : Tensor = aten::clamp(%x.1, %3, %2)
return (%4))IR";
pointwise_test_helper(graph, true);
}

TEST(Converters, ATenClampMinimumConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : float = prim::Constant[value=2.5]()
%4 : Tensor = aten::clamp_min(%x.1, %2)
return (%4))IR";
pointwise_test_helper(graph, true);
}

TEST(Converters, ATenClampMaximumConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : float = prim::Constant[value=2.5]()
%4 : Tensor = aten::clamp_max(%x.1, %2)
return (%4))IR";
pointwise_test_helper(graph, true);
}

TEST(Converters, ATenGreaterThanConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
Expand Down

0 comments on commit 0a278f2

Please sign in to comment.