Skip to content

Commit

Permalink
feat(aten::prelu): Basic prelu support
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jul 23, 2020
1 parent fe06d09 commit 8bc4369
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
21 changes: 21 additions & 0 deletions core/conversion/converters/impl/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns()
new_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));

LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}
}).pattern({
"aten::prelu(Tensor self, Tensor weight) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto slopes = args[1].unwrapToTensor();

//if (slopes.numel() != 1) {
// auto in_dims = util::toVec(in.getDimensions());
// auto per_channel_shape = std::vector<int64_t>(in_dims.begin() + 2, in_dims.end());
// for ()
//}

auto slope_tensor = tensor_to_const(ctx, slopes);

auto new_layer = ctx->net->addParametricReLU(*in, *slope_tensor);
new_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));

LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}
Expand Down
47 changes: 47 additions & 0 deletions tests/core/converters/test_activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,50 @@ TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenPReLUConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(1)):
%3 : Tensor = aten::prelu(%0, %1)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(-5, 5, {5}, {at::kCUDA});
auto slope = at::randint(-5, 5, {1}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenPReLUMultiChannelConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(10)):
%3 : Tensor = aten::prelu(%0, %1)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(-5, 5, {1,10, 1, 1}, {at::kCUDA});
auto slope = at::randint(-5, 5, {10}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}


0 comments on commit 8bc4369

Please sign in to comment.