From 1424ca4f8db1fee2f48cda1a22688fc195b3e831 Mon Sep 17 00:00:00 2001 From: ynimmaga Date: Wed, 27 Mar 2024 22:16:13 -0700 Subject: [PATCH] Added support for quantized fx ops for export path --- .../pytorch/torchdynamo/op_support.py | 6 + src/frontends/pytorch/src/op/quantize.cpp | 21 ++++ src/frontends/pytorch/src/op_table.cpp | 7 ++ src/frontends/pytorch/src/utils_quantize.cpp | 110 ++++++++++++++---- src/frontends/pytorch/src/utils_quantize.hpp | 43 +++++++ 5 files changed, 164 insertions(+), 23 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index a1071c1af0e3b8..f3a5cbb2aec63a 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -69,6 +69,7 @@ def __init__(self, options): "torch.ops.aten.argmax.default": None, "torch.ops.aten.argmin.default": None, "torch.ops.aten.as_strided.default": None, + "torch.ops.aten.as_strided_.default": None, "torch.ops.aten.asin.default": None, "torch.ops.aten.asinh.default": None, "torch.ops.aten.asinh.default": None, @@ -240,6 +241,11 @@ def __init__(self, options): "torch.ops.aten.zeros_like.default": None, "torch.ops.torchvision.deform_conv2d.default": None, "torch.ops.torchvision.roi_align.default": None, + "torch.ops.quantized_decomposed.quantize_per_tensor.default": None, + "torch.ops.quantized_decomposed.quantize_per_channel.default": None, + "torch.ops.quantized_decomposed.dequantize_per_tensor.default": None, + "torch.ops.quantized_decomposed.dequantize_per_channel.default": None + } for op in _get_disabled_ops(options): diff --git a/src/frontends/pytorch/src/op/quantize.cpp b/src/frontends/pytorch/src/op/quantize.cpp index 9048a4c0a1295b..f9a3950a1628ea 100644 --- a/src/frontends/pytorch/src/op/quantize.cpp +++ b/src/frontends/pytorch/src/op/quantize.cpp @@ -32,6 +32,27 @@ OutputVector translate_quantize_per_channel(const NodeContext& context) { return {quantize(context, input, scales, zero_points, axis, dtype, QuantizedPtNodeType::QUANTIZE_PER_CHANNEL)}; } +OutputVector translate_quantize_per_tensor_fx(const NodeContext& context) { + num_inputs_check(context, 4, 8); + const auto input = context.get_input(0); + const auto scale = context.get_input(1); + const auto zero_point = context.get_input(2); + auto low = context.const_input(3); + auto high = context.const_input(4); + return {quantize_fx(context, input, scale, zero_point, low, high, element::i8, QuantizedPtNodeType::QUANTIZE_PER_TENSOR)}; +} + +OutputVector translate_quantize_per_channel_fx(const NodeContext& context) { + num_inputs_check(context, 4, 8); + const auto input = context.get_input(0); + const auto scales = context.get_input(1); + const auto zero_points = context.get_input(2); + const auto axis = context.get_input(3); + auto low = context.const_input(4); + auto high = context.const_input(5); + return {quantize_fx(context, input, scales, zero_points, axis, low, high, element::i8, QuantizedPtNodeType::QUANTIZE_PER_CHANNEL)}; +} + OutputVector translate_fake_quantize_per_tensor_affine_fx(const NodeContext& context) { num_inputs_check(context, 6, 6); auto out = translate_quantize_per_tensor(context); diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index ea2ff9cf6c5a59..e075364f5741af 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -295,6 +295,8 @@ OP_CONVERTER(translate_sum_fx); OP_CONVERTER(translate_topk_fx); OP_CONVERTER(translate_to_fx); OP_CONVERTER(translate_transpose_fx); +OP_CONVERTER(translate_quantize_per_channel_fx); +OP_CONVERTER(translate_quantize_per_tensor_fx); OP_CONVERTER(translate_var_fx); OP_CONVERTER(translate_var_mean_fx); OP_CONVERTER(translate_unbind_int_fx); @@ -756,6 +758,7 @@ const std::map get_supported_ops_fx() { {"aten.argmax.default", op::translate_argmax}, {"aten.argmin.default", op::translate_argmin}, {"aten.as_strided.default", op::translate_as_strided}, + {"aten.as_strided_.default", op::translate_as_strided}, {"aten.asin.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.asinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.atan.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, @@ -939,6 +942,10 @@ const std::map get_supported_ops_fx() { {"get_attr", op::translate_constant}, {"torchvision.deform_conv2d.default", op::translate_deform_conv}, {"torchvision.roi_align.default", op::translate_roi_align}, + {"quantized_decomposed.quantize_per_tensor.default", op::translate_quantize_per_tensor_fx}, + {"quantized_decomposed.quantize_per_channel.default", op::translate_quantize_per_channel_fx}, + {"quantized_decomposed.dequantize_per_tensor.default", op::skip_node}, + {"quantized_decomposed.dequantize_per_channel.default", op::skip_node}, }; }; diff --git a/src/frontends/pytorch/src/utils_quantize.cpp b/src/frontends/pytorch/src/utils_quantize.cpp index b7d3979020a115..8a03ffdc7607f6 100644 --- a/src/frontends/pytorch/src/utils_quantize.cpp +++ b/src/frontends/pytorch/src/utils_quantize.cpp @@ -23,11 +23,13 @@ namespace pytorch { using namespace ov::op; -Output quantize(const NodeContext& context, +Output quantize_common(const NodeContext& context, const Output& input, const Output& scale, const Output& zero_point, const Output& axis, + int64_t out_low_i64, + int64_t out_high_i64, element::Type dtype, QuantizedPtNodeType quantization_type) { if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { @@ -35,17 +37,6 @@ Output quantize(const NodeContext& context, const auto scale_convert = context.mark_node(std::make_shared(scale, element::f32)); const auto zero_point_convert = context.mark_node(std::make_shared(zero_point, element::f32)); - int64_t out_low_i64, out_high_i64; - if (dtype == element::u8) { - out_low_i64 = (int64_t)std::numeric_limits::lowest(); - out_high_i64 = (int64_t)std::numeric_limits::max(); - } else if (dtype == element::i8) { - out_low_i64 = (int64_t)std::numeric_limits::lowest(); - out_high_i64 = (int64_t)std::numeric_limits::max(); - } else { // i32 - out_low_i64 = (int64_t)std::numeric_limits::lowest(); - out_high_i64 = (int64_t)std::numeric_limits::max(); - } int64_t levels = out_high_i64 - out_low_i64 + 1; const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64})); const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64})); @@ -75,17 +66,6 @@ Output quantize(const NodeContext& context, const auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); const auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); - int64_t out_low_i64, out_high_i64; - if (dtype == element::u8) { - out_low_i64 = (int64_t)std::numeric_limits::lowest(); - out_high_i64 = (int64_t)std::numeric_limits::max(); - } else if (dtype == element::i8) { - out_low_i64 = (int64_t)std::numeric_limits::lowest(); - out_high_i64 = (int64_t)std::numeric_limits::max(); - } else { // i32 - out_low_i64 = (int64_t)std::numeric_limits::lowest(); - out_high_i64 = (int64_t)std::numeric_limits::max(); - } int64_t levels = out_high_i64 - out_low_i64 + 1; const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64})); const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64})); @@ -120,6 +100,27 @@ Output quantize(const NodeContext& context, FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize."); } +Output quantize(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + const Output& axis, + element::Type dtype, + QuantizedPtNodeType quantization_type) { + int64_t out_low_i64, out_high_i64; + if (dtype == element::u8) { + out_low_i64 = (int64_t)std::numeric_limits::lowest(); + out_high_i64 = (int64_t)std::numeric_limits::max(); + } else if (dtype == element::i8) { + out_low_i64 = (int64_t)std::numeric_limits::lowest(); + out_high_i64 = (int64_t)std::numeric_limits::max(); + } else { // i32 + out_low_i64 = (int64_t)std::numeric_limits::lowest(); + out_high_i64 = (int64_t)std::numeric_limits::max(); + } + return quantize_common(context, input, scale, zero_point, axis, out_low_i64, out_high_i64, dtype, quantization_type); +} + Output quantize(const NodeContext& context, const Output& input, const Output& scale, @@ -159,6 +160,69 @@ Output quantize(const NodeContext& context, FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); } +Output quantize_fx(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + const Output& axis, + int64_t out_low_i64, + int64_t out_high_i64, + element::Type dtype, + QuantizedPtNodeType quantization_type) { + return quantize_common(context, input, scale, zero_point, axis, out_low_i64, out_high_i64, dtype, quantization_type); +} + +Output quantize_fx(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + int64_t out_low_i64, + int64_t out_high_i64, + element::Type dtype, + QuantizedPtNodeType quantization_type) { + return quantize_fx(context, input, scale, zero_point, Output(), out_low_i64, out_high_i64, dtype, quantization_type); +} + +Output quantize_fx(const NodeContext& context, + const Output& input, + int64_t out_low_i64, + int64_t out_high_i64, + const Output& quantized_node) { + if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) { + return quantize_fx(context, + input, + quantized_pt_node->get_scale(), + quantized_pt_node->get_zero_point(), + quantized_pt_node->get_axis(), + out_low_i64, + out_high_i64, + quantized_pt_node->get_dtype(), + quantized_pt_node->get_type()); + } + FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); +} + +Output quantize_fx(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + int64_t out_low_i64, + int64_t out_high_i64, + const Output& quantized_node) { + if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) { + return quantize_fx(context, + input, + scale, + zero_point, + quantized_pt_node->get_axis(), + out_low_i64, + out_high_i64, + quantized_pt_node->get_dtype(), + quantized_pt_node->get_type()); + } + FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); +} + std::shared_ptr cast_quantized_fw_node(std::shared_ptr node) { auto quant_node = std::dynamic_pointer_cast(node); if (!quant_node) { diff --git a/src/frontends/pytorch/src/utils_quantize.hpp b/src/frontends/pytorch/src/utils_quantize.hpp index 2379507d57bbb5..f7c3b40338a61b 100644 --- a/src/frontends/pytorch/src/utils_quantize.hpp +++ b/src/frontends/pytorch/src/utils_quantize.hpp @@ -144,6 +144,49 @@ Output quantize(const NodeContext& context, const Output& zero_point, const Output& quantized_node); +/** + * Quantizes input node with the given parameters. Returns a shared pointer to the new QuantizedPtNode. + */ +Output quantize_fx(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + int64_t out_low_i64, + int64_t out_high_i64, + element::Type dtype, + QuantizedPtNodeType quantization_type); +Output quantize_fx(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + const Output& axis, + int64_t out_low_i64, + int64_t out_high_i64, + element::Type dtype, + QuantizedPtNodeType quantization_type); + +/** + * Quantizes input node like the quantized node. Returns a shared pointer to the new QuantizedPtNode. + */ +Output quantize_fx(const NodeContext& context, + Output input, + int64_t out_low_i64, + int64_t out_high_i64, + Output quantized_node); + +/** + * Quantizes input node like the quantized node, with new scale and zero_point parameters. Returns a shared pointer to + * the new QuantizedPtNode. + */ +Output quantize_fx(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + int64_t out_low_i64, + int64_t out_high_i64, + const Output& quantized_node); + + std::shared_ptr cast_quantized_fw_node(std::shared_ptr node); namespace op {