From 569bcdee704667d5ef5f9781897455a05cd28db4 Mon Sep 17 00:00:00 2001 From: hongwei03 Date: Fri, 18 Mar 2022 13:05:19 +0800 Subject: [PATCH] feat: Add converter files for torch::max Signed-off-by: hongwei03 --- core/conversion/converters/impl/max.cpp | 47 +++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 core/conversion/converters/impl/max.cpp diff --git a/core/conversion/converters/impl/max.cpp b/core/conversion/converters/impl/max.cpp new file mode 100644 index 0000000000..83f5aefd80 --- /dev/null +++ b/core/conversion/converters/impl/max.cpp @@ -0,0 +1,47 @@ +#include "NvInfer.h" +#include "core/conversion/converters/converters.h" +#include "core/conversion/tensorcontainer/TensorContainer.h" +#include "core/util/prelude.h" +#include "torch/torch.h" + +#include +#include + +namespace torch_tensorrt { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { +auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + auto k = 1; + auto dim = args[1].unwrapToInt(); + auto largest = true; + auto selfDim = util::toVec(self->getDimensions()); + if (dim < 0) { + dim = selfDim.size() + dim; + } + uint32_t shiftDim = 1 << dim; + + auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN); + + auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim); + TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n); + + auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1)); + + LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions()); + LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions()); + + return true; + }}); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace torch_tensorrt