Skip to content

Commit

Permalink
feat: Add converter files for torch::max
Browse files Browse the repository at this point in the history
Signed-off-by: hongwei03 <[email protected]>
  • Loading branch information
p517332051 committed Mar 23, 2022
1 parent 0b5673b commit dd7a44e
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions core/conversion/converters/impl/max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,27 @@ namespace converters {
namespace impl {
namespace {
auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
auto selfDim = util::toVec(self->getDimensions());
if (dim < 0) {
dim = selfDim.size() + dim;
}
uint32_t shiftDim = 1 << dim;
auto TopKOperation = nvinfer1::TopKOperation::kMAX;
auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n);
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
auto selfDim = util::toVec(self->getDimensions());
if (dim < 0) {
dim = selfDim.size() + dim;
}
uint32_t shiftDim = 1 << dim;
auto TopKOperation = nvinfer1::TopKOperation::kMAX;
auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n);

auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));

LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());

return true;
}});
return true;
}});
} // namespace
} // namespace impl
} // namespace converters
Expand Down

0 comments on commit dd7a44e

Please sign in to comment.