Skip to content

Commit

Permalink
support topk converter/test_case
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <[email protected]>
  • Loading branch information
inocsin authored and narendasan committed Feb 2, 2021
1 parent b228bf2 commit 1557f6e
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ cc_library(
"impl/stack.cpp",
"impl/lstm_cell.cpp",
"impl/unsqueeze.cpp",
"impl/topk.cpp",
],
deps = [
"@tensorrt//:nvinfer",
Expand Down
59 changes: 59 additions & 0 deletions core/conversion/converters/impl/topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "NvInfer.h"
#include "core/conversion/converters/converters.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

#include <ATen/ATen.h>
#include <vector>

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace {

auto topk_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto k = args[1].unwrapToInt();
auto dim = args[2].unwrapToInt();
auto largest = args[3].unwrapToBool();
auto sorted = args[4].unwrapToBool();

auto selfDim = util::toVec(self->getDimensions());

//reduceAxes The reduction dimensions. The bit in position i of bitmask reduceAxes corresponds to explicit dimension i of the result.
//E.g., the least significant bit corresponds to the first explicit dimension and the next to least significant bit corresponds to the second explicit dimension.

if (dim < 0) {
dim = selfDim.size() + dim;
}

uint32_t shiftDim = 1 << dim;

LOG_DEBUG("Output topk reduce dim: " << dim);

auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN);

auto new_layer = ctx->net->addTopK(*self, TopKOperation, k, shiftDim);

TRTORCH_CHECK(new_layer, "Unable to create topk layer from node: " << *n);

auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));

LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());

return true;
}});

} // namespace
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
30 changes: 30 additions & 0 deletions tests/core/converters/test_topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <string>
#include "core/compiler.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"

TEST(Converters, ATenTopKConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=20]()
%2 : int = prim::Constant[value=-1]()
%3 : bool = prim::Constant[value=1]()
%4 : bool = prim::Constant[value=1]()
%5 : Tensor, %6 : Tensor = aten::topk(%0, %1, %2, %3, %4)
return (%5, %6))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::rand({10, 10, 100}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
}

0 comments on commit 1557f6e

Please sign in to comment.