Skip to content

Commit

Permalink
feat: support true_divide, floor_divide, max, min, rsub
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <[email protected]>
  • Loading branch information
inocsin authored and narendasan committed Jan 30, 2021
1 parent 4d3ac4f commit a35fbf1
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 3 deletions.
112 changes: 112 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,61 @@ auto element_wise_registrations TRTORCH_UNUSED =
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Should implement other - alpha * self
auto self = args[0].ITensorOrFreeze(ctx);
auto otherScalar = args[1].unwrapToScalar().to<float>();
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
auto scalar = args[2].unwrapToScalar().to<float>();

if (1 != scalar) {
auto scaleW = Weights(ctx, scalar);
auto unuse = Weights();
// IScaleLayer assert shift, scale and power to have
// the same dtype
auto scaleLayer = ctx->net->addScale(
*self, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
self = scaleLayer->getOutput(0);
}

auto rsub =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, other, self, util::node_info(n));
TRTORCH_CHECK(rsub, "Unable to create rsub layer from node: " << *n);

rsub->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], rsub->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Should implement other - alpha * self
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto scalar = args[2].unwrapToScalar().to<float>();

if (1 != scalar) {
auto scaleW = Weights(ctx, scalar);
auto unuse = Weights();
// IScaleLayer assert shift, scale and power to have
// the same dtype
auto scaleLayer = ctx->net->addScale(
*self, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
self = scaleLayer->getOutput(0);
}

auto rsub =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, other, self, util::node_info(n));
TRTORCH_CHECK(rsub, "Unable to create rsub layer from node: " << *n);

rsub->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], rsub->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Should implement self / other
Expand Down Expand Up @@ -352,6 +407,63 @@ auto element_wise_registrations TRTORCH_UNUSED =
pow->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::floor_divide(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto floor_divide =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
TRTORCH_CHECK(floor_divide, "Unable to create floor_divide layer from node: " << *n);

floor_divide->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], floor_divide->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto otherScalar = args[1].unwrapToScalar().to<float>();
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
auto floor_divide =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
TRTORCH_CHECK(floor_divide, "Unable to create floor_divide layer from node: " << *n);

floor_divide->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], floor_divide->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::max.other(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto max =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMAX, self, other, util::node_info(n));
TRTORCH_CHECK(max, "Unable to create max layer from node: " << *n);

max->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], max->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::min.other(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto min =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMIN, self, other, util::node_info(n));
TRTORCH_CHECK(min, "Unable to create min layer from node: " << *n);

min->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], min->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}});
Expand Down
3 changes: 1 addition & 2 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,4 @@ pkg_tar(
name = "include",
package_dir = "core/lowering/passes/",
srcs = ["passes.h"],
)

)
68 changes: 67 additions & 1 deletion tests/core/conversion/converters/test_element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,71 @@ TEST(Converters, ATenNeScalarConvertsCorrectly) {
%3 : Tensor = aten::ne(%x.1, %2)
return (%3))IR";
pointwise_test_helper(graph, true, false, {3, 4, 2});
;
pointwise_test_helper(graph, true);
}


TEST(Converters, ATenFloorDivideConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = aten::floor_divide(%0, %1)
return (%2))IR";
pointwise_test_helper(graph, false);
pointwise_test_helper(graph, false, false, {3, 4}, {4});
pointwise_test_helper(graph, false, false, {4}, {3, 4});
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
}


TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%scalar : float = prim::Constant[value=2.4]()
%1 : Tensor = aten::floor_divide(%0, %scalar)
return (%1))IR";
pointwise_test_helper(graph, true);
}

TEST(Converters, ATenMaxConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = aten::max(%0, %1)
return (%2))IR";
pointwise_test_helper(graph, false);
pointwise_test_helper(graph, false, false, {3, 4}, {4});
pointwise_test_helper(graph, false, false, {4}, {3, 4});
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
}

TEST(Converters, ATenMinConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = aten::min(%0, %1)
return (%2))IR";
pointwise_test_helper(graph, false);
pointwise_test_helper(graph, false, false, {3, 4}, {4});
pointwise_test_helper(graph, false, false, {4}, {3, 4});
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
}

TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : int = prim::Constant[value=2]()
%3 : Tensor = aten::rsub(%0, %1, %2)
return (%3))IR";
pointwise_test_helper(graph, false, true, {4, 3, 3, 3}, {4, 3, 3, 3});
}

TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%2 : int = prim::Constant[value=2]()
%scalar : float = prim::Constant[value=2.4]()
%3 : Tensor = aten::rsub(%0, %scalar, %2)
return (%3))IR";
pointwise_test_helper(graph, true, false, {4, 3, 3, 3});
}

0 comments on commit a35fbf1

Please sign in to comment.