diff --git a/core/conversion/converters/impl/pooling.cpp b/core/conversion/converters/impl/pooling.cpp index 2f591ab797..f18ef91c1b 100644 --- a/core/conversion/converters/impl/pooling.cpp +++ b/core/conversion/converters/impl/pooling.cpp @@ -37,7 +37,7 @@ bool AdaptivePoolingConverter( ConversionCtx* ctx, const torch::jit::Node* n, args& args, - nvinfer1::PoolingType pool_type) { + nvinfer1::PoolingType pool_type, const std::string& mode) { auto in = args[0].ITensorOrFreeze(ctx); auto out_size = util::toDims(args[1].unwrapToIntList()); @@ -48,15 +48,7 @@ bool AdaptivePoolingConverter( } auto orig_dims = in->getDimensions(); - bool expandDims = (orig_dims.nbDims < 4); - TORCHTRT_CHECK(orig_dims.nbDims > 2, "Unable to create pooling layer from node: " << *n); - if (expandDims) { - in = addPadding(ctx, n, in, 4, false, false); - } - - if (out_size.nbDims == 1) { - out_size = util::unsqueezeDims(out_size, 0, 1); - } + TORCHTRT_CHECK(orig_dims.nbDims > 1, "Unable to create pooling layer from node: " << *n); auto in_shape = util::toVec(in->getDimensions()); nvinfer1::ILayer* new_layer = nullptr; @@ -90,10 +82,6 @@ bool AdaptivePoolingConverter( int32_t use_scales_casted = 0; f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1)); - std::string mode = "adaptive_avg_pool2d"; - if (pool_type == nvinfer1::PoolingType::kMAX) { - mode = "adaptive_max_pool2d"; - } f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1)); fc.nbFields = f.size(); @@ -110,7 +98,7 @@ bool AdaptivePoolingConverter( TORCHTRT_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n); new_layer->setName(util::node_info(n).c_str()); - auto layer_output = addUnpadding(ctx, n, new_layer->getOutput(0), orig_dims.nbDims, false, false); + auto layer_output = new_layer->getOutput(0); ctx->AssociateValueAndTensor(n->outputs()[0], layer_output); LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions()); @@ -238,15 +226,15 @@ auto pooling_registrations TORCHTRT_UNUSED = }}) .pattern({"aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE); + return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool1d"); }}) .pattern({"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE); + return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool2d"); }}) .pattern({"aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX); + return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX, "adaptive_max_pool2d"); }}); } // namespace } // namespace impl diff --git a/core/plugins/impl/interpolate_plugin.cpp b/core/plugins/impl/interpolate_plugin.cpp index ee1d77df1c..3568a7481e 100644 --- a/core/plugins/impl/interpolate_plugin.cpp +++ b/core/plugins/impl/interpolate_plugin.cpp @@ -289,6 +289,8 @@ int InterpolatePlugin::enqueue( out = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_); } else if (mode_ == "trilinear") { out = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_); + } else if(mode_ == "adaptive_avg_pool1d"){ + out = at::adaptive_avg_pool1d(input, {size_[0]}); } else if (mode_ == "adaptive_avg_pool2d") { out = at::adaptive_avg_pool2d(input, {size_[0], size_[1]}); } else if (mode_ == "adaptive_max_pool2d") { diff --git a/tests/core/conversion/converters/test_pooling.cpp b/tests/core/conversion/converters/test_pooling.cpp index a8c1cad760..55bf88506b 100644 --- a/tests/core/conversion/converters/test_pooling.cpp +++ b/tests/core/conversion/converters/test_pooling.cpp @@ -540,6 +540,32 @@ TEST(Converters, ATenAdaptiveAvgPool1DGlobalPoolingConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } +TEST(Converters, ATenAdaptiveAvgPool1DUsingPluginConvertsCorrectly) { + const auto graph = + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=3]() + %6 : int[] = prim::ListConstruct(%2) + %10 : Tensor = aten::adaptive_avg_pool1d(%0, %6) + return (%10))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // PyTorch adaptive_avg_pool1d needs a 3D input or a 2D input + auto in = at::randint(-5, 5, {1, 3, 16}, at::kCUDA); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenAdaptiveMaxPool2DConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor):