-
Notifications
You must be signed in to change notification settings - Fork 352
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: inocsin <[email protected]>
- Loading branch information
1 parent
b228bf2
commit 1557f6e
Showing
3 changed files
with
90 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#include "NvInfer.h" | ||
#include "core/conversion/converters/converters.h" | ||
#include "core/conversion/tensorcontainer/TensorContainer.h" | ||
#include "core/util/prelude.h" | ||
#include "torch/torch.h" | ||
|
||
#include <ATen/ATen.h> | ||
#include <vector> | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace conversion { | ||
namespace converters { | ||
namespace impl { | ||
namespace { | ||
|
||
auto topk_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern( | ||
{"aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", | ||
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { | ||
auto self = args[0].ITensorOrFreeze(ctx); | ||
auto k = args[1].unwrapToInt(); | ||
auto dim = args[2].unwrapToInt(); | ||
auto largest = args[3].unwrapToBool(); | ||
auto sorted = args[4].unwrapToBool(); | ||
|
||
auto selfDim = util::toVec(self->getDimensions()); | ||
|
||
//reduceAxes The reduction dimensions. The bit in position i of bitmask reduceAxes corresponds to explicit dimension i of the result. | ||
//E.g., the least significant bit corresponds to the first explicit dimension and the next to least significant bit corresponds to the second explicit dimension. | ||
|
||
if (dim < 0) { | ||
dim = selfDim.size() + dim; | ||
} | ||
|
||
uint32_t shiftDim = 1 << dim; | ||
|
||
LOG_DEBUG("Output topk reduce dim: " << dim); | ||
|
||
auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN); | ||
|
||
auto new_layer = ctx->net->addTopK(*self, TopKOperation, k, shiftDim); | ||
|
||
TRTORCH_CHECK(new_layer, "Unable to create topk layer from node: " << *n); | ||
|
||
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); | ||
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1)); | ||
|
||
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions()); | ||
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions()); | ||
|
||
return true; | ||
}}); | ||
|
||
} // namespace | ||
} // namespace impl | ||
} // namespace converters | ||
} // namespace conversion | ||
} // namespace core | ||
} // namespace trtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#include <string> | ||
#include "core/compiler.h" | ||
#include "gtest/gtest.h" | ||
#include "tests/util/util.h" | ||
#include "torch/csrc/jit/ir/irparser.h" | ||
|
||
TEST(Converters, ATenTopKConvertsCorrectly) { | ||
const auto graph = R"IR( | ||
graph(%0 : Tensor): | ||
%1 : int = prim::Constant[value=20]() | ||
%2 : int = prim::Constant[value=-1]() | ||
%3 : bool = prim::Constant[value=1]() | ||
%4 : bool = prim::Constant[value=1]() | ||
%5 : Tensor, %6 : Tensor = aten::topk(%0, %1, %2, %3, %4) | ||
return (%5, %6))IR"; | ||
|
||
auto g = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(graph, &*g); | ||
|
||
auto in = at::rand({10, 10, 100}, {at::kCUDA}); | ||
|
||
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); | ||
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); | ||
|
||
params = trtorch::core::conversion::get_named_params(g->inputs(), {}); | ||
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); | ||
|
||
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); | ||
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6)); | ||
} |