diff --git a/core/conversion/converters/impl/interpolate.cpp b/core/conversion/converters/impl/interpolate.cpp index 3d843992e0..a81ad8832c 100644 --- a/core/conversion/converters/impl/interpolate.cpp +++ b/core/conversion/converters/impl/interpolate.cpp @@ -102,6 +102,58 @@ void resize_layer_size( LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions()); } +bool upsample_triilinear3d(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + bool align_corners = args[2].unwrapToBool(); + + if (args[1].IValue()->isNone() && (args[3].IValue()->isNone() || args[4].IValue()->isNone() || args[5].IValue()->isNone())) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of size or scale_factor should be defined"); + } else if (!args[3].IValue()->isNone() && !args[4].IValue()->isNone() && !args[5].IValue()->isNone()) { + // Case 1: user uses scales + float scale_d = args[3].IValue()->toDouble(); + float scale_h = args[4].IValue()->toDouble(); + float scale_w = args[5].IValue()->toDouble(); + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 3] = scale_d; + padded_scales[padded_scales.size() - 2] = scale_h; + padded_scales[padded_scales.size() - 1] = scale_w; +#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nupsample_linear3d only supports align_corner with TensorRT <= 7.0."); + } else { + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); + } +#else + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); +#endif + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 3, + "aten::upsample_trilinear3d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); +#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + // align_corners not supported in TensorRT, create plugin and + // run layer through PyTorch + create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear")); + } else { + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); + } +#else + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); +#endif + } + + return true; +} + /* * Interpolate Converter */ @@ -109,73 +161,135 @@ void resize_layer_size( auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() .pattern( - {"aten::upsample_nearest1d(Tensor self, int[] output_size, float? scales=None) -> (Tensor)", + {"aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + + if (args[1].IValue()->isNone() && args[2].IValue()->isNone()) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scales should be defined"); + } else if (!args[2].IValue()->isNone()) { + // Case 1: user uses scales + float scale = args[2].IValue()->toDouble(); + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 1] = scale; + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 1, "aten::upsample_nearest1d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + } + + return true; + }}) + .pattern( + {"aten::upsample_nearest1d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + + if (args[1].IValue()->isNone() && args[2].IValue()->isNone()) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scale_factors should be defined"); + } else if (!args[2].IValue()->isNone()) { + // Case 1: user uses scales + auto scale_factors = args[2].unwrapToDoubleList(); + TRTORCH_ASSERT(scale_factors.size() == 1, "Number of scale factors should match the input size"); + float scale = scale_factors[0]; + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 1] = scale; + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 1, "aten::upsample_nearest1d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + } + + return true; + }}) + .pattern( + {"aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto in_shape = util::toVec(in->getDimensions()); - - if (args[1].IValue()->isNone() && args[2].IValue()->isNone()) { - TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) - << "\nOne of size or scale_factor should be defined"); - } else if (!args[2].IValue()->isNone()) { - // Case 1: user uses scales - float scale = args[2].IValue()->toDouble(); - std::vector padded_scales(in_shape.size(), 1); - padded_scales[padded_scales.size() - 1] = scale; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); - } else { - // Case 2: user uses output size - auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); - TRTORCH_ASSERT( - out_size.size() == 1, "aten::upsample_nearest1d input Tensor and output size dimension mismatch"); - - auto out_shape = in_shape; - std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); - } - - return true; + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + + if (args[1].IValue()->isNone() && (args[2].IValue()->isNone() || args[3].IValue()->isNone())) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scales should be defined"); + } else if (!args[2].IValue()->isNone() && !args[3].IValue()->isNone()) { + // Case 1: user uses scales + float scale_h = args[2].IValue()->toDouble(); + float scale_w = args[3].IValue()->toDouble(); + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 2] = scale_h; + padded_scales[padded_scales.size() - 1] = scale_w; + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 2, "aten::upsample_nearest2d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + } + + return true; }}) .pattern( - {"aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)", + {"aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto in_shape = util::toVec(in->getDimensions()); - - if (args[1].IValue()->isNone() && (args[2].IValue()->isNone() || args[3].IValue()->isNone())) { - TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) - << "\nOne of size or scale_factor should be defined"); - } else if (!args[2].IValue()->isNone() && !args[3].IValue()->isNone()) { - // Case 1: user uses scales - float scale_h = args[2].IValue()->toDouble(); - float scale_w = args[3].IValue()->toDouble(); - std::vector padded_scales(in_shape.size(), 1); - padded_scales[padded_scales.size() - 2] = scale_h; - padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); - } else { - // Case 2: user uses output size - auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); - TRTORCH_ASSERT( - out_size.size() == 2, "aten::upsample_nearest2d input Tensor and output size dimension mismatch"); - - auto out_shape = in_shape; - std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); - } - - return true; + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + + if (args[1].IValue()->isNone() && args[2].IValue()->isNone()) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scale_factors should be defined"); + } else if (!args[2].IValue()->isNone()) { + // Case 1: user uses scales + auto scale_factors = args[2].unwrapToDoubleList(); + TRTORCH_ASSERT(scale_factors.size() == 2, "Number of scale factors should match the input size"); + float scale_h = scale_factors[0]; + float scale_w = scale_factors[1]; + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 2] = scale_h; + padded_scales[padded_scales.size() - 1] = scale_w; + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 2, "aten::upsample_nearest2d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + } + + return true; }}) .pattern( - {"aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)", + {"aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto in_shape = util::toVec(in->getDimensions()); + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); - if (args[1].IValue()->isNone() && (args[2].IValue()->isNone() || args[3].IValue()->isNone() || - args[4].IValue()->isNone())) { - TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) - << "\nOne of size or scale_factor should be defined"); + if (args[1].IValue()->isNone() && (args[2].IValue()->isNone() || args[3].IValue()->isNone() || + args[4].IValue()->isNone())) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scales should be defined"); } else if (!args[2].IValue()->isNone() && !args[3].IValue()->isNone() && !args[4].IValue()->isNone()) { // Case 1: user uses scales float scale_d = args[2].IValue()->toDouble(); @@ -187,170 +301,363 @@ auto interpolate_registrations TRTORCH_UNUSED = padded_scales[padded_scales.size() - 1] = scale_w; resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); } else { - // Case 2: user uses output size - auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); - TRTORCH_ASSERT( - out_size.size() == 3, "aten::upsample_nearest3d input Tensor and output size dimension mismatch"); - - auto out_shape = in_shape; - std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 3, "aten::upsample_nearest3d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); } - return true; + return true; }}) .pattern( - {"aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners, float? scales) -> (Tensor)", + {"aten::upsample_nearest3d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> Tensor", // FIX THIS [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto in_shape = util::toVec(in->getDimensions()); - bool align_corners = args[2].unwrapToBool(); - - if (args[1].IValue()->isNone() && args[3].IValue()->isNone()) { - TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) - << "\nOne of size or scale_factor should be defined"); - } else if (!args[3].IValue()->isNone()) { - // Case 1: user uses scales - float scale = args[3].IValue()->toDouble(); - std::vector padded_scales(in_shape.size(), 1); - padded_scales[padded_scales.size() - 1] = scale; -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - if (!align_corners) { - TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) - << "\nupsample_linear1d only supports align_corner with TensorRT <= 7.0."); - } else { - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); - } -#else - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); -#endif - } else { - // Case 2: user uses output size - auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); - TRTORCH_ASSERT( - out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch"); - - auto out_shape = in_shape; - std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - if (!align_corners) { - // align_corners not supported in TensorRT, create plugin and - // run layer through PyTorch - create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear")); - } else { - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); - } -#else - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); -#endif - } + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); - return true; + if (args[1].IValue()->isNone() && args[2].IValue()->isNone()) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scale_factors should be defined"); + } else if (!args[2].IValue()->isNone()) { + // Case 1: user uses scales + auto scale_factors = args[2].unwrapToDoubleList(); + TRTORCH_ASSERT(scale_factors.size() == 3, "Number of scale factors should match the input size"); + float scale_d = scale_factors[0]; + float scale_h = scale_factors[1]; + float scale_w = scale_factors[2]; + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 3] = scale_d; + padded_scales[padded_scales.size() - 2] = scale_h; + padded_scales[padded_scales.size() - 1] = scale_w; + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 3, "aten::upsample_nearest3d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + } + + return true; }}) .pattern( - {"aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> (Tensor)", + {"aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto in_shape = util::toVec(in->getDimensions()); - bool align_corners = args[2].unwrapToBool(); - - if (args[1].IValue()->isNone() && (args[3].IValue()->isNone() || args[4].IValue()->isNone())) { - TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) - << "\nOne of size or scale_factor should be defined"); - } else if (!args[3].IValue()->isNone() && !args[4].IValue()->isNone()) { - // Case 1: user uses scales - float scale_h = args[3].IValue()->toDouble(); - float scale_w = args[4].IValue()->toDouble(); - std::vector padded_scales(in_shape.size(), 1); - padded_scales[padded_scales.size() - 2] = scale_h; - padded_scales[padded_scales.size() - 1] = scale_w; -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - if (!align_corners) { - TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) - << "\nupsample_linear2d only supports align_corner with TensorRT <= 7.0."); - } else { - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); - } -#else - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); -#endif - } else { - // Case 2: user uses output size - auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + bool align_corners = args[2].unwrapToBool(); + + if (args[1].IValue()->isNone() && args[3].IValue()->isNone()) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scales should be defined"); + } else if (!args[3].IValue()->isNone()) { + // Case 1: user uses scales + float scale = args[3].IValue()->toDouble(); + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 1] = scale; + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nupsample_linear1d only supports align_corner with TensorRT <= 7.0."); + } else { + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + // align_corners not supported in TensorRT, create plugin and + // run layer through PyTorch + create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear")); + } else { + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } - TRTORCH_ASSERT( - out_size.size() == 2, "aten::upsample_bilinear2d input Tensor and output size dimension mismatch"); + return true; + }}) + .pattern( + {"aten::upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + bool align_corners = args[2].unwrapToBool(); + + if (args[1].IValue()->isNone() && args[3].IValue()->isNone()) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scale_factors should be defined"); + } else if (!args[3].IValue()->isNone()) { + // Case 1: user uses scales + auto scale_factors = args[3].unwrapToDoubleList(); + TRTORCH_ASSERT(scale_factors.size() == 1, "Number of scale factors should match the input size"); + float scale = scale_factors[0]; + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 1] = scale; + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nupsample_linear1d only supports align_corner with TensorRT <= 7.0."); + } else { + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + // align_corners not supported in TensorRT, create plugin and + // run layer through PyTorch + create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear")); + } else { + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } - auto out_shape = in_shape; - std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + return true; + }}) + .pattern( + {"aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + bool align_corners = args[2].unwrapToBool(); + + if (args[1].IValue()->isNone() && (args[3].IValue()->isNone() || args[4].IValue()->isNone())) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scales should be defined"); + } else if (!args[3].IValue()->isNone() && !args[4].IValue()->isNone()) { + // Case 1: user uses scales + float scale_h = args[3].IValue()->toDouble(); + float scale_w = args[4].IValue()->toDouble(); + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 2] = scale_h; + padded_scales[padded_scales.size() - 1] = scale_w; + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nupsample_linear2d only supports align_corner with TensorRT <= 7.0."); + } else { + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + + TRTORCH_ASSERT( + out_size.size() == 2, "aten::upsample_bilinear2d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + // align_corners not supported in TensorRT, create plugin and + // run layer through PyTorch + create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear")); + } else { + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - if (!align_corners) { - // align_corners not supported in TensorRT, create plugin and - // run layer through PyTorch - create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear")); - } else { - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); - } -#else - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); -#endif - } + return true; + }}) + .pattern( + {"aten::upsample_bilinear2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + bool align_corners = args[2].unwrapToBool(); + + if (args[1].IValue()->isNone() && args[3].IValue()->isNone()) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of output_size or scale_factors should be defined"); + } else if (!args[3].IValue()->isNone()) { + // Case 1: user uses scales + auto scale_factors = args[3].unwrapToDoubleList(); + TRTORCH_ASSERT(scale_factors.size() == 2, "Number of scale factors should match the input size"); + float scale_h = scale_factors[0]; + float scale_w = scale_factors[1]; + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 2] = scale_h; + padded_scales[padded_scales.size() - 1] = scale_w; + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nupsample_linear2d only supports align_corner with TensorRT <= 7.0."); + } else { + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + + TRTORCH_ASSERT( + out_size.size() == 2, "aten::upsample_bilinear2d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + // align_corners not supported in TensorRT, create plugin and + // run layer through PyTorch + create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear")); + } else { + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } - return true; + return true; }}) .pattern( {"aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto in_shape = util::toVec(in->getDimensions()); - bool align_corners = args[2].unwrapToBool(); + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + bool align_corners = args[2].unwrapToBool(); if (args[1].IValue()->isNone() && (args[3].IValue()->isNone() || args[4].IValue()->isNone() || args[5].IValue()->isNone())) { - TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) - << "\nOne of size or scale_factor should be defined"); + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of size or scales should be defined"); } else if (!args[3].IValue()->isNone() && !args[4].IValue()->isNone() && !args[5].IValue()->isNone()) { - // Case 1: user uses scales - float scale_d = args[3].IValue()->toDouble(); - float scale_h = args[4].IValue()->toDouble(); - float scale_w = args[5].IValue()->toDouble(); - std::vector padded_scales(in_shape.size(), 1); - padded_scales[padded_scales.size() - 3] = scale_d; - padded_scales[padded_scales.size() - 2] = scale_h; - padded_scales[padded_scales.size() - 1] = scale_w; -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - if (!align_corners) { - TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) - << "\nupsample_linear3d only supports align_corner with TensorRT <= 7.0."); - } else { - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); - } -#else - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); -#endif + // Case 1: user uses scales + float scale_d = args[3].IValue()->toDouble(); + float scale_h = args[4].IValue()->toDouble(); + float scale_w = args[5].IValue()->toDouble(); + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 3] = scale_d; + padded_scales[padded_scales.size() - 2] = scale_h; + padded_scales[padded_scales.size() - 1] = scale_w; + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nupsample_linear3d only supports align_corner with TensorRT <= 7.0."); + } else { + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif } else { // Case 2: user uses output size - auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); - TRTORCH_ASSERT( - out_size.size() == 3, - "aten::upsample_trilinear3d input Tensor and output size dimension mismatch"); + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 3, + "aten::upsample_trilinear3d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + // align_corners not supported in TensorRT, create plugin and + // run layer through PyTorch + create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear")); + } else { + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } - auto out_shape = in_shape; - std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); -#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) - if (!align_corners) { - // align_corners not supported in TensorRT, create plugin and - // run layer through PyTorch - create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear")); - } else { - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); - } -#else - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); -#endif + return true; + }}) + .pattern( + {"aten::upsample_trilinear3d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto in_shape = util::toVec(in->getDimensions()); + bool align_corners = args[2].unwrapToBool(); + + if (args[1].IValue()->isNone() && args[3].IValue()->isNone()) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nOne of size or scale_factors should be defined"); + } else if (!args[3].IValue()->isNone()) { + // Case 1: user uses scales + auto scale_factors = args[3].unwrapToDoubleList(); + TRTORCH_ASSERT(scale_factors.size() == 3, "Number of scale factors should match the input size"); + float scale_d = scale_factors[0]; + float scale_h = scale_factors[1]; + float scale_w = scale_factors[2]; + std::vector padded_scales(in_shape.size(), 1); + padded_scales[padded_scales.size() - 3] = scale_d; + padded_scales[padded_scales.size() - 2] = scale_h; + padded_scales[padded_scales.size() - 1] = scale_w; + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) + << "\nupsample_linear3d only supports align_corner with TensorRT <= 7.0."); + } else { + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true); + } + #else + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } else { + // Case 2: user uses output size + auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); + TRTORCH_ASSERT( + out_size.size() == 3, + "aten::upsample_trilinear3d input Tensor and output size dimension mismatch"); + + auto out_shape = in_shape; + std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); + #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) + if (!align_corners) { + // align_corners not supported in TensorRT, create plugin and + // run layer through PyTorch + create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear")); + } else { + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true); } + #else + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + #endif + } - return true; + return true; }}); } // namespace diff --git a/tests/core/conversion/converters/test_interpolate.cpp b/tests/core/conversion/converters/test_interpolate.cpp index 17115c3fa0..a3d3aee5e5 100644 --- a/tests/core/conversion/converters/test_interpolate.cpp +++ b/tests/core/conversion/converters/test_interpolate.cpp @@ -4,491 +4,330 @@ #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" -TEST(Converters, ATenUpsampleNearest1dOutputSizeConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=10]() - %3 : int[] = prim::ListConstruct(%2) - %4 : None = prim::Constant() - %5 : Tensor = aten::upsample_nearest1d(%0, %3, %4) - return (%5))IR"; - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 3D for TensorRT upsample_nearest1d - auto in = at::randint(1, 10, {10, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleNearest1dScaleFactorConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %1 : int = prim::Constant[value=8]() - %2 : int[] = prim::ListConstruct(%1) - %3 : float = prim::Constant[value=4.0]() - %5 : Tensor = aten::upsample_nearest1d(%0, %2, %3) - return (%5))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 3D for TensorRT upsample_nearest1d - auto in = at::randint(1, 10, {10, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleNearest2dOutputSizeConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=10]() - %3 : int[] = prim::ListConstruct(%2, %2) - %4 : None = prim::Constant() - %5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4) - return (%5))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 4D for TensorRT upsample_nearest2d - auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleNearest2dScaleFactorConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=8]() - %3 : int[] = prim::ListConstruct(%2, %2) - %4 : float = prim::Constant[value=4.0]() - %5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4) - return (%5))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 4D for TensorRT upsample_nearest2d - auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleNearest3dOutputSizeConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=10]() - %3 : int[] = prim::ListConstruct(%2, %2, %2) - %4 : None = prim::Constant() - %5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4) - return (%5))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 5D for TensorRT upsample_nearest3d - auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleNearest3dScaleFactorConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=8]() - %3 : int[] = prim::ListConstruct(%2, %2, %2) - %4 : float = prim::Constant[value=4.0]() - %5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4) - return (%5))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 5D for TensorRT upsample_nearest3d - auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleLinear1dOutputSizeWithAlignCornersConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=10]() - %3 : int[] = prim::ListConstruct(%2) - %4 : bool = prim::Constant[value=1]() - %5 : None = prim::Constant() - %6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %5) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 3D for TensorRT upsample_linear1d - auto in = at::randint(1, 10, {10, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleLinear1dOutputSizeWithoutAlignCornersConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=10]() - %3 : int[] = prim::ListConstruct(%2) - %4 : bool = prim::Constant[value=0]() - %5 : None = prim::Constant() - %6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %5) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 3D for TensorRT upsample_linear1d - auto in = at::randint(1, 10, {10, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleLinear1dScaleFactorWithoutAlignCornersConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=8]() - %3 : int[] = prim::ListConstruct(%2) - %4 : bool = prim::Constant[value=0]() - %5 : float = prim::Constant[value=4.0]() - %6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %5) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 3D for TensorRT upsample_linear1d - auto in = at::randint(1, 10, {10, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleBilinear2dOutputSizeWithAlignCornersConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=10]() - %3 : int[] = prim::ListConstruct(%2, %2) - %4 : bool = prim::Constant[value=1]() - %5 : None = prim::Constant() - %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 4D for TensorRT upsample_bilinear2d - auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleBilinear2dOutputSizeWithoutAlignCornersConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=10]() - %3 : int[] = prim::ListConstruct(%2, %2) - %4 : bool = prim::Constant[value=0]() - %5 : None = prim::Constant() - %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 4D for TensorRT upsample_bilinear2d - auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleBilinear2dScaleFactorWithoutAlignCornersConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=8]() - %3 : int[] = prim::ListConstruct(%2, %2) - %4 : bool = prim::Constant[value=0]() - %5 : float = prim::Constant[value=4.0]() - %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 4D for TensorRT upsample_bilinear2d - auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleTrilinear3dOutputSizeWithAlignCornersConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=10]() - %3 : int[] = prim::ListConstruct(%2, %2, %2) - %4 : bool = prim::Constant[value=1]() - %5 : None = prim::Constant() - %6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %5, %5, %5) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 5D for TensorRT upsample_trilinear3d - auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleTrilinear3dOutputSizeWithoutAlignCornersConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=10]() - %3 : int[] = prim::ListConstruct(%2, %2, %2) - %4 : bool = prim::Constant[value=0]() - %5 : None = prim::Constant() - %6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %5, %5, %5) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 5D for TensorRT upsample_trilinear3d - auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} - -TEST(Converters, ATenUpsampleTrilinear3dScaleFactorWithoutAlignCornersConvertsCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor): - %2 : int = prim::Constant[value=8]() - %3 : int[] = prim::ListConstruct(%2, %2, %2) - %4 : bool = prim::Constant[value=0]() - %5 : float = prim::Constant[value=4.0]() - %6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %5, %5, %5) - return (%6))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, &*g); - - // Input Tensor needs to be 5D for TensorRT upsample_trilinear3d - auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); - - trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); - trt = trt_results[0].reshape(jit_results[0].sizes()); - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} \ No newline at end of file +#define ATEN_UPSAMPLE_TESTS(name, graph_src, input_shape) \ + TEST(Converters, name##StaticConvertsCorrectly) { \ + const auto graph = graph_src; \ + \ + auto g = std::make_shared(); \ + torch::jit::parseIR(graph, &*g); \ + \ + auto in = at::randint(1, 10, input_shape, {at::kCUDA}); \ + auto jit_in = at::clone(in); \ + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); \ + \ + auto trt_in = at::clone(in); \ + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + \ + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); \ + auto trt = trt_results[0].reshape(jit_results[0].sizes()); \ + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); \ + } \ + \ + TEST(Converters, name##DynamicConvertsCorrectly) { \ + const auto graph = graph_src; \ + \ + auto g = std::make_shared(); \ + torch::jit::parseIR(graph, &*g); \ + \ + auto in = at::randint(1, 10, input_shape, {at::kCUDA}); \ + auto jit_in = at::clone(in); \ + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); \ + \ + auto trt_in = at::clone(in); \ + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ + \ + auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); \ + auto trt = trt_results[0].reshape(jit_results[0].sizes()); \ + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); \ + } + +ATEN_UPSAMPLE_TESTS(ATenUpsampleNearest1dOutputSize, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=10]() + %3 : int[] = prim::ListConstruct(%2) + %4 : None = prim::Constant() + %5 : Tensor = aten::upsample_nearest1d(%0, %3, %4) + return (%5))IR", + std::vector({10, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleNearest1dScales, + R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=8]() + %2 : int[] = prim::ListConstruct(%1) + %3 : float = prim::Constant[value=4.0]() + %5 : Tensor = aten::upsample_nearest1d(%0, %2, %3) + return (%5))IR", + std::vector({10, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleNearest1dVecScaleFactors, + R"IR( + graph(%0 : Tensor): + %2 : None = prim::Constant() + %3 : float = prim::Constant[value=4.0]() + %4 : float[] = prim::ListConstruct(%3) + %5 : Tensor = aten::upsample_nearest1d(%0, %2, %4) + return (%5))IR", + std::vector({10, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleNearest2dOutputSize, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=10]() + %3 : int[] = prim::ListConstruct(%2, %2) + %4 : None = prim::Constant() + %5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4) + return (%5))IR", + std::vector({10, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleNearest2dScales, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=8]() + %3 : int[] = prim::ListConstruct(%2, %2) + %4 : float = prim::Constant[value=4.0]() + %5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4) + return (%5))IR", + std::vector({10, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleNearest2dVecScaleFactors, + R"IR( + graph(%0 : Tensor): + %2 : None = prim::Constant() + %3 : float = prim::Constant[value=4.0]() + %4 : float[] = prim::ListConstruct(%3, %3) + %5 : Tensor = aten::upsample_nearest2d(%0, %2, %4) + return (%5))IR", + std::vector({10, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleNearest3dOutputSize, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=10]() + %3 : int[] = prim::ListConstruct(%2, %2, %2) + %4 : None = prim::Constant() + %5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4) + return (%5))IR", + std::vector({10, 2, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleNearest3dScales, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=8]() + %3 : int[] = prim::ListConstruct(%2, %2, %2) + %4 : float = prim::Constant[value=4.0]() + %5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4) + return (%5))IR", + std::vector({10, 2, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleNearest3dVecScaleFactors, + R"IR( + graph(%0 : Tensor): + %2 : None = prim::Constant() + %3 : float = prim::Constant[value=4.0]() + %4 : float[] = prim::ListConstruct(%3, %3, %3) + %5 : Tensor = aten::upsample_nearest3d(%0, %2, %4) + return (%5))IR", + std::vector({10, 2, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleLinear1dOutputSizeWithAlignCorners, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=10]() + %3 : int[] = prim::ListConstruct(%2) + %4 : bool = prim::Constant[value=1]() + %5 : None = prim::Constant() + %6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %5) + return (%6))IR", + std::vector({10, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleLinear1dOutputSizeWithoutAlignCorners, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=10]() + %3 : int[] = prim::ListConstruct(%2) + %4 : bool = prim::Constant[value=0]() + %5 : None = prim::Constant() + %6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %5) + return (%6))IR", + std::vector({10, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleLinear1dScalesWithoutAlignCorners, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=8]() + %3 : int[] = prim::ListConstruct(%2) + %4 : bool = prim::Constant[value=0]() + %5 : float = prim::Constant[value=4.0]() + %6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %5) + return (%6))IR", + std::vector({10, 2, 2})); + +// ATEN_UPSAMPLE_TESTS(ATenUpsampleLinear1dScalesWithAlignCorners, +// R"IR( +// graph(%0 : Tensor): +// %2 : int = prim::Constant[value=8]() +// %3 : int[] = prim::ListConstruct(%2) +// %4 : bool = prim::Constant[value=1]() +// %5 : float = prim::Constant[value=4.0]() +// %6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %5) +// return (%6))IR", +// std::vector({10, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleLinear1dVecScaleFactorsWithoutAlignCorners, + R"IR( + graph(%0 : Tensor): + %3 : None = prim::Constant() + %4 : bool = prim::Constant[value=0]() + %5 : float = prim::Constant[value=4.0]() + %6 : float[] = prim::ListConstruct(%5) + %6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %6) + return (%6))IR", + std::vector({10, 2, 2})); + +// ATEN_UPSAMPLE_TESTS(ATenUpsampleLinear1dVecScaleFactorsWithAlignCorners, + // R"IR( + // graph(%0 : Tensor): + // %3 : None = prim::Constant() + // %4 : bool = prim::Constant[value=1]() + // %5 : float = prim::Constant[value=4.0]() + // %6 : float[] = prim::ListConstruct(%5) + // %6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %6) + // return (%6))IR", + // std::vector({10, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleBilinear2dOutputSizeWithAlignCorners, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=10]() + %3 : int[] = prim::ListConstruct(%2, %2) + %4 : bool = prim::Constant[value=1]() + %5 : None = prim::Constant() + %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5) + return (%6))IR", + std::vector({10, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleBilinear2dOutputSizeWithoutAlignCorners, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=10]() + %3 : int[] = prim::ListConstruct(%2, %2) + %4 : bool = prim::Constant[value=0]() + %5 : None = prim::Constant() + %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5) + return (%6))IR", + std::vector({10, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleBilinear2dScalesWithoutAlignCorners, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=8]() + %3 : int[] = prim::ListConstruct(%2, %2) + %4 : bool = prim::Constant[value=0]() + %5 : float = prim::Constant[value=4.0]() + %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5) + return (%6))IR", + std::vector({10, 2, 2, 2})); + +// ATEN_UPSAMPLE_TESTS(ATenUpsampleBilinear2dScalesWithAlignCorners, + // R"IR( + // graph(%0 : Tensor): + // %2 : int = prim::Constant[value=8]() + // %3 : int[] = prim::ListConstruct(%2, %2) + // %4 : bool = prim::Constant[value=1]() + // %5 : float = prim::Constant[value=4.0]() + // %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5) + // return (%6))IR", +// std::vector({10, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleBilinear2dVecScaleFactorsWithoutAlignCorners, + R"IR( + graph(%0 : Tensor): + %3 : None = prim::Constant() + %4 : bool = prim::Constant[value=0]() + %5 : float = prim::Constant[value=4.0]() + %6 : float[] = prim::ListConstruct(%5, %5) + %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %6) + return (%6))IR", + std::vector({10, 2, 2, 2})); + +// ATEN_UPSAMPLE_TESTS(ATenUpsampleBilinear2dVecScaleFactorsWithAlignCorners, +// R"IR( +// graph(%0 : Tensor): +// %3 : None = prim::Constant() +// %4 : bool = prim::Constant[value=1]() +// %5 : float = prim::Constant[value=4.0]() +// %6 : float[] = prim::ListConstruct(%5, %5) +// %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %6) +// return (%6))IR", +// std::vector({10, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleTrilinear3dOutputSizeWithAlignCorners, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=10]() + %3 : int[] = prim::ListConstruct(%2, %2, %2) + %4 : bool = prim::Constant[value=1]() + %5 : None = prim::Constant() + %6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %5, %5, %5) + return (%6))IR", + std::vector({10, 2, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleTrilinear3dOutputSizeWithoutAlignCorners, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=10]() + %3 : int[] = prim::ListConstruct(%2, %2, %2) + %4 : bool = prim::Constant[value=0]() + %5 : None = prim::Constant() + %6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %5, %5, %5) + return (%6))IR", + std::vector({10, 2, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleTrilinear3dScalesWithoutAlignCorners, + R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=8]() + %3 : int[] = prim::ListConstruct(%2, %2, %2) + %4 : bool = prim::Constant[value=0]() + %5 : float = prim::Constant[value=4.0]() + %6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %5, %5, %5) + return (%6))IR", + std::vector({10, 2, 2, 2, 2})); + +// ATEN_UPSAMPLE_TESTS(ATenUpsampleTrilinear3dScalesWithAlignCorners, + // R"IR( + // graph(%0 : Tensor): + // %2 : int = prim::Constant[value=8]() + // %3 : int[] = prim::ListConstruct(%2, %2) + // %4 : bool = prim::Constant[value=1]() + // %5 : float = prim::Constant[value=4.0]() + // %6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5) + // return (%6))IR", +// std::vector({10, 2, 2, 2})); + +ATEN_UPSAMPLE_TESTS(ATenUpsampleTrilinear3dVecScaleFactorsWithoutAlignCorners, + R"IR( + graph(%0 : Tensor): + %3 : None = prim::Constant() + %4 : bool = prim::Constant[value=0]() + %5 : float = prim::Constant[value=4.0]() + %6 : float[] = prim::ListConstruct(%5, %5, %5) + %6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %6) + return (%6))IR", + std::vector({10, 2, 2, 2, 2})); + +// ATEN_UPSAMPLE_TESTS(ATenUpsampleTrilinear3dVecScaleFactorsWithAlignCorners, +// R"IR( +// graph(%0 : Tensor): +// %3 : None = prim::Constant() +// %4 : bool = prim::Constant[value=1]() +// %5 : float = prim::Constant[value=4.0]() +// %6 : float[] = prim::ListConstruct(%5, %5, %5) +// %6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %6) +// return (%6))IR", +// std::vector({10, 2, 2, 2})); \ No newline at end of file