From 6f4aa40d2fceea59be53fbd4919333e06c2ffd90 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 18 May 2021 11:48:14 -0700 Subject: [PATCH] feat(//core/plugins): Add adaptive_max_pool2d plugin, enable the plugins to run on GPU Signed-off-by: Dheeraj Peri Signed-off-by: Dheeraj Peri Signed-off-by: Dheeraj Peri --- core/conversion/converters/impl/pooling.cpp | 98 ++++++++---------- core/plugins/impl/interpolate_plugin.cpp | 99 ++++++------------- core/plugins/impl/interpolate_plugin.h | 1 - core/plugins/impl/normalize_plugin.cpp | 45 ++------- core/plugins/impl/normalize_plugin.h | 1 - .../conversion/converters/test_pooling.cpp | 70 +++++++++++-- 6 files changed, 137 insertions(+), 177 deletions(-) diff --git a/core/conversion/converters/impl/pooling.cpp b/core/conversion/converters/impl/pooling.cpp index 61da6085e5..2813fe6089 100644 --- a/core/conversion/converters/impl/pooling.cpp +++ b/core/conversion/converters/impl/pooling.cpp @@ -60,77 +60,53 @@ bool AdaptivePoolingConverter( auto in_shape = util::toVec(in->getDimensions()); nvinfer1::ILayer* new_layer = nullptr; - if (ctx->input_is_dynamic) { -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - LOG_WARNING( - "Adaptive pooling layer will be run through ATen, via not TensorRT, performace will be lower than expected. Consider switching either to static input shape or moving to non adaptive pooling if this is an issue"); -#else - LOG_WARNING( - "Adaptive pooling layer will be run through ATen (on CPU), via not TensorRT, performace will suffer. Consider switching either to static input shape or moving to non adaptive pooling"); -#endif + /*======CONFIGURE PLUGIN PARAMETERS======*/ + nvinfer1::PluginFieldCollection fc; + std::vector f; - TRTORCH_CHECK( - pool_type == nvinfer1::PoolingType::kAVERAGE, - "Unable to create MAX pooling (interpolation) plugin from node" << *n); - - nvinfer1::PluginFieldCollection fc; - std::vector f; - - auto out_shape = in_shape; - std::copy_n(out_size.d, out_size.nbDims, out_shape.begin() + (in_shape.size() - out_size.nbDims)); + auto out_shape = in_shape; + auto out_size_vec = util::toVec(out_size); - std::vector in_shape_casted(in_shape.begin(), in_shape.end()); - f.emplace_back( - nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size())); + std::copy(out_size_vec.begin(), out_size_vec.end(), out_shape.begin() + (in_shape.size() - out_size_vec.size())); - std::vector out_shape_casted(out_shape.begin(), out_shape.end()); - f.emplace_back(nvinfer1::PluginField( - "out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size())); + std::vector in_shape_casted(in_shape.begin(), in_shape.end()); + f.emplace_back( + nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size())); - auto out_size_vec = util::toVec(out_size); - std::vector out_size_casted(out_size_vec.begin(), out_size_vec.end()); - f.emplace_back(nvinfer1::PluginField( - "out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size_vec.size())); + std::vector out_shape_casted(out_shape.begin(), out_shape.end()); + f.emplace_back( + nvinfer1::PluginField("out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size())); - f.emplace_back(nvinfer1::PluginField("scales", nullptr, nvinfer1::PluginFieldType::kFLOAT64, 0)); + std::vector out_size_casted(out_size_vec.begin(), out_size_vec.end()); + f.emplace_back(nvinfer1::PluginField( + "out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size_vec.size())); - std::string mode = "adaptive_pool2d"; - f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1)); + f.emplace_back(nvinfer1::PluginField("scales", nullptr, nvinfer1::PluginFieldType::kFLOAT64, 0)); - int32_t align_corners_casted = 0; - f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1)); + int32_t align_corners_casted = 0; + f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1)); - int32_t use_scales_casted = 0; - f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1)); + int32_t use_scales_casted = 0; + f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1)); - fc.nbFields = f.size(); - fc.fields = f.data(); - auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch"); - auto interpolate_plugin = creator->createPlugin("adaptive_pool2d", &fc); - - new_layer = ctx->net->addPluginV2(reinterpret_cast(&in), 1, *interpolate_plugin); - TRTORCH_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n); + std::string mode = "adaptive_avg_pool2d"; + if (pool_type == nvinfer1::PoolingType::kMAX) { + mode = "adaptive_max_pool2d"; + } + f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1)); - } else { - std::vector stride(out_size.nbDims); - for (int64_t i = 0; i < out_size.nbDims; i++) { - stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_size.d[(out_size.nbDims - 1) - i]; - } - LOG_DEBUG("Stride: " << util::toDims(stride)); + fc.nbFields = f.size(); + fc.fields = f.data(); + /*====== PLUGIN PARAMETERS CONFIGURATION COMPLETED ======*/ - std::vector window(out_size.nbDims); - for (int64_t i = 0; i < out_size.nbDims; i++) { - window[window.size() - 1 - i] = - in_shape[in_shape.size() - 1 - i] - (out_size.d[out_size.nbDims - 1 - i] - 1) * stride[stride.size() - 1 - i]; - } + LOG_WARNING( + "Adaptive pooling layer will be using Aten library kernels in pytorch for execution. TensorRT does not support adaptive pooling natively. Consider switching to non-adaptive pooling if this is an issue"); - LOG_DEBUG("Window: " << util::toDims(window)); + auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch"); + auto interpolate_plugin = creator->createPlugin(mode.c_str(), &fc); - auto pooling_layer = ctx->net->addPoolingNd(*in, pool_type, util::toDims(window)); - TRTORCH_CHECK(pooling_layer, "Unable to create average pooling layer from node: " << *n); - pooling_layer->setStrideNd(util::toDims(stride)); - new_layer = pooling_layer; - } + new_layer = ctx->net->addPluginV2(reinterpret_cast(&in), 1, *interpolate_plugin); + TRTORCH_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n); new_layer->setName(util::node_info(n).c_str()); auto layer_output = addUnpadding(ctx, n, new_layer->getOutput(0), orig_dims.nbDims, false, false); @@ -156,7 +132,7 @@ bool PoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& args, auto padding = util::toDims(args[3].unwrapToIntList()); auto stride = util::toDims(args[2].unwrapToIntList()); if (stride.nbDims == 0) { - LOG_DEBUG("Stride not providied, using kernel_size as stride"); + LOG_DEBUG("Stride not provided, using kernel_size as stride"); stride = util::toDims(args[1].unwrapToIntList()); } @@ -265,6 +241,10 @@ auto pooling_registrations TRTORCH_UNUSED = .pattern({"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE); + }}) + .pattern({"aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX); }}); } // namespace } // namespace impl diff --git a/core/plugins/impl/interpolate_plugin.cpp b/core/plugins/impl/interpolate_plugin.cpp index bd4feaae64..89e32dd55d 100644 --- a/core/plugins/impl/interpolate_plugin.cpp +++ b/core/plugins/impl/interpolate_plugin.cpp @@ -27,7 +27,7 @@ InterpolatePlugin::InterpolatePlugin( align_corners_(align_corners), use_scales_(use_scales) { if (use_scales) { - TRTORCH_ASSERT(mode_ != "adaptive_pool2d", "use_scales is not valid for adaptive_pool2d"); + TRTORCH_ASSERT(mode_ != "adaptive_avg_pool2d", "use_scales is not valid for adaptive_avg_pool2d"); TRTORCH_ASSERT( scales_.size() != 0, "Attempted to use interpolate plugin without providing scales while use_scales=true"); at::Tensor input = at::randint(1, 10, in_shape, {at::kCUDA}); @@ -106,7 +106,11 @@ std::vector InterpolatePlugin::getOutputSize() { } int InterpolatePlugin::getNbOutputs() const { - return 1; + if (mode_ == "adaptive_max_pool2d") { + return 2; + } else { + return 1; + } } const char* InterpolatePlugin::getPluginType() const { @@ -166,15 +170,6 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer } int InterpolatePlugin::initialize() { -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - tensor_options_ = tensor_options_.device(c10::kCUDA); -#else - tensor_options_ = tensor_options_.device(c10::kCPU); -#endif - - // c10::kFloat = FLOAT32 - tensor_options_ = tensor_options_.dtype(c10::kFloat); - return 0; } @@ -211,9 +206,15 @@ bool InterpolatePlugin::supportsFormatCombination( const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) { - TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output"); TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to interpolate plugin"); - TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to interpolate plugin"); + + if (mode_ == "adaptive_max_pool2d") { + TRTORCH_ASSERT(nbOutputs == 2, "Expected 2 tensors as output to interpolate plugin"); + TRTORCH_ASSERT(0 <= pos && pos <= 2, "There should be exactly 3 connections to the plugin - 1 input, 2 output"); + } else { + TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to interpolate plugin"); + TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output"); + } const nvinfer1::PluginTensorDesc& in = inOut[0]; @@ -250,10 +251,10 @@ int InterpolatePlugin::enqueue( void* const* outputs, void* workspace, cudaStream_t stream) { -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - at::Tensor input = at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, tensor_options_); - at::Tensor output = at::from_blob( - outputs[0], util::volume(outputDesc->dims), [](void*) {}, tensor_options_); + at::Tensor input = + at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat); + at::Tensor output = + at::from_blob(outputs[0], util::toVec(outputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat); at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool(); at::cuda::CUDAStreamGuard torch_guard(torch_stream); @@ -263,27 +264,30 @@ int InterpolatePlugin::enqueue( cudaEventRecord(event, stream); cudaStreamWaitEvent(torch_stream.stream(), event, 0); - + at::Tensor out; if (use_scales_) { if (mode_ == "linear") { - at::upsample_linear1d_out(output, input, {}, align_corners_, scales_[0]); + out = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]}); } else if (mode_ == "bilinear") { - at::upsample_bilinear2d_out(output, input, {}, align_corners_, scales_[0], scales_[1]); + out = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_); } else if (mode_ == "trilinear") { - at::upsample_trilinear3d_out(output, input, {}, align_corners_, scales_[0], scales_[1], scales_[2]); + out = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_); } } else { if (mode_ == "linear") { - at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_); + out = at::upsample_linear1d(input, {size_[0]}, align_corners_); } else if (mode_ == "bilinear") { - at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_); + out = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_); } else if (mode_ == "trilinear") { - at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_); - } else if (mode_ == "adaptive_pool2d") { - at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]}); + out = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_); + } else if (mode_ == "adaptive_avg_pool2d") { + out = at::adaptive_avg_pool2d(input, {size_[0], size_[1]}); + } else if (mode_ == "adaptive_max_pool2d") { + out = std::get<0>(at::adaptive_max_pool2d(input, {size_[0], size_[1]})); } } + output.copy_(out); cudaEvent_t torch_event; cudaEventCreate(&torch_event); cudaEventRecord(torch_event, torch_stream.stream()); @@ -294,49 +298,6 @@ int InterpolatePlugin::enqueue( cudaEventDestroy(torch_event); return 0; -#else - // TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen - // kernels HACK: WAR because there is a segfault if you try to create a CUDA - // Tensor in the context of TensorRT execution - float* input_blob = (float*)malloc(util::volume(inputDesc->dims) * sizeof(float)); - cudaMemcpyAsync( - input_blob, - static_cast(inputs[0]), - util::volume(inputDesc->dims) * sizeof(float), - cudaMemcpyDeviceToHost, - stream); - cudaStreamSynchronize(stream); - - at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_); - at::Tensor output; - if (use_scales_) { - if (mode_ == "linear") { - output = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]}); - } else if (mode_ == "bilinear") { - output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_); - } else if (mode_ == "trilinear") { - output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_); - } - } else { - if (mode_ == "linear") { - output = at::upsample_linear1d(input, {size_[0]}, align_corners_); - } else if (mode_ == "bilinear") { - output = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_); - } else if (mode_ == "trilinear") { - output = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_); - } else if (mode_ == "adaptive_pool2d") { - output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]}); - } - } - - cudaMemcpyAsync( - outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream); - cudaStreamSynchronize(stream); - - free(input_blob); - - return 0; -#endif } /* diff --git a/core/plugins/impl/interpolate_plugin.h b/core/plugins/impl/interpolate_plugin.h index 93d279a180..eef91869a9 100644 --- a/core/plugins/impl/interpolate_plugin.h +++ b/core/plugins/impl/interpolate_plugin.h @@ -20,7 +20,6 @@ namespace impl { class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt { private: - at::TensorOptions tensor_options_; nvinfer1::DataType dtype_; std::vector in_shape_; diff --git a/core/plugins/impl/normalize_plugin.cpp b/core/plugins/impl/normalize_plugin.cpp index 307f3bb97a..796964152c 100644 --- a/core/plugins/impl/normalize_plugin.cpp +++ b/core/plugins/impl/normalize_plugin.cpp @@ -103,15 +103,6 @@ nvinfer1::DataType NormalizePlugin::getOutputDataType(int index, const nvinfer1: } int NormalizePlugin::initialize() { -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - tensor_options_ = tensor_options_.device(c10::kCUDA); -#else - tensor_options_ = tensor_options_.device(c10::kCPU); -#endif - - // c10::kFloat = FLOAT32 - tensor_options_ = tensor_options_.dtype(c10::kFloat); - return 0; } @@ -181,11 +172,10 @@ int NormalizePlugin::enqueue( void* const* outputs, void* workspace, cudaStream_t stream) { - // TRT <= 7.0 -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - at::Tensor input = at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, tensor_options_); - at::Tensor output = at::from_blob( - outputs[0], util::volume(outputDesc->dims), [](void*) {}, tensor_options_); + at::Tensor input = + at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat); + at::Tensor output = + at::from_blob(outputs[0], util::toVec(outputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat); at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool(); at::cuda::CUDAStreamGuard torch_guard(torch_stream); @@ -195,7 +185,9 @@ int NormalizePlugin::enqueue( cudaEventRecord(event, stream); cudaStreamWaitEvent(torch_stream.stream(), event, 0); - at::Tensor result = at::norm(input, order_, axes_, keep_dims_); + + std::vector axes_double(axes_.begin(), axes_.end()); + at::Tensor result = at::norm(input, (int64_t)order_, axes_double, (bool)keep_dims_); output.copy_(result); cudaEvent_t torch_event; cudaEventCreate(&torch_event); @@ -206,29 +198,6 @@ int NormalizePlugin::enqueue( cudaEventDestroy(event); cudaEventDestroy(torch_event); return 0; -#else - // TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen - // kernels HACK: WAR because there is a segfault if you try to create a CUDA - // Tensor in the context of TensorRT execution - float* input_blob = (float*)malloc(util::volume(inputDesc->dims) * sizeof(float)); - cudaMemcpyAsync( - input_blob, - static_cast(inputs[0]), - util::volume(inputDesc->dims) * sizeof(float), - cudaMemcpyDeviceToHost, - stream); - cudaStreamSynchronize(stream); - - at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_); - std::vector axes_new(axes_.begin(), axes_.end()); - at::Tensor output = at::norm(input, (int64_t)order_, axes_new, (bool)keep_dims_); - cudaMemcpyAsync( - outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream); - cudaStreamSynchronize(stream); - - free(input_blob); - return 0; -#endif } /* diff --git a/core/plugins/impl/normalize_plugin.h b/core/plugins/impl/normalize_plugin.h index bb9d156ff7..a2eaa54832 100644 --- a/core/plugins/impl/normalize_plugin.h +++ b/core/plugins/impl/normalize_plugin.h @@ -22,7 +22,6 @@ namespace impl { class NormalizePlugin : public nvinfer1::IPluginV2DynamicExt { private: - at::TensorOptions tensor_options_; nvinfer1::DataType dtype_; int32_t order_; std::vector axes_; diff --git a/tests/core/conversion/converters/test_pooling.cpp b/tests/core/conversion/converters/test_pooling.cpp index e5c52fa2cb..b804bdc0d2 100644 --- a/tests/core/conversion/converters/test_pooling.cpp +++ b/tests/core/conversion/converters/test_pooling.cpp @@ -413,8 +413,8 @@ TEST(Converters, ATenAvgPool3DNoCountPadConvertsCorrectly) { TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor): - %2 : int = prim::Constant[value=3]() - %3 : int = prim::Constant[value=4]() + %2 : int = prim::Constant[value=7]() + %3 : int = prim::Constant[value=7]() %6 : int[] = prim::ListConstruct(%2, %3) %10 : Tensor = aten::adaptive_avg_pool2d(%0, %6) return (%10))IR"; @@ -423,7 +423,7 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) { torch::jit::parseIR(graph, g.get()); // PyTorch MaxPool needs a 3D input - auto in = at::randint(-5, 5, {1, 12, 16}, at::kCUDA); + auto in = at::randint(-5, 5, {512, 32, 32}, at::kCUDA); auto jit_in = at::clone(in); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); @@ -449,7 +449,7 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) { torch::jit::parseIR(graph, g.get()); // PyTorch MaxPool needs a 3D input - auto in = at::randint(-5, 5, {10, 18, 36}, at::kCUDA); + auto in = at::randint(-5, 5, {512, 32, 32}, at::kCUDA); auto jit_in = at::clone(in); auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); @@ -464,11 +464,11 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) { TEST(Converters, ATenAdaptiveAvgPool1DConvertsCorrectly) { const auto graph = - R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=1]() - %6 : int[] = prim::ListConstruct(%2) - %10 : Tensor = aten::adaptive_avg_pool1d(%0, %6) + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=1]() + %6 : int[] = prim::ListConstruct(%2) + %10 : Tensor = aten::adaptive_avg_pool1d(%0, %6) return (%10))IR"; auto g = std::make_shared(); @@ -487,3 +487,55 @@ TEST(Converters, ATenAdaptiveAvgPool1DConvertsCorrectly) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 1.0)); } + +TEST(Converters, ATenAdaptiveMaxPool2DConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=7]() + %3 : int = prim::Constant[value=7]() + %6 : int[] = prim::ListConstruct(%2, %3) + %10 : Tensor, %11 : Tensor = aten::adaptive_max_pool2d(%0, %6) + return (%10, %11))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + // PyTorch MaxPool needs a 3D input + auto in = at::randint(-5, 5, {512, 32, 32}, at::kCUDA); + + auto jit_in = at::clone(in); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenAdaptiveMaxPool2DConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=7]() + %3 : int = prim::Constant[value=7]() + %6 : int[] = prim::ListConstruct(%2, %3) + %10 : Tensor, %11 : Tensor = aten::adaptive_max_pool2d(%0, %6) + return (%10, %11))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + // PyTorch MaxPool needs a 3D input + auto in = at::rand({512, 32, 32}, at::kCUDA); + + auto jit_in = at::clone(in); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +}